You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_helper.cpp 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. /**
  2. * \file python_module/src/cpp/opr_helper.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #include "./opr_helper.h"
  10. #include "./megbrain_wrap.h"
  11. #include "megbrain/opr/indexing.h"
  12. #include "megbrain/opr/io.h"
  13. #include "megbrain/serialization/opr_load_dump.h"
  14. using namespace mgb;
  15. namespace {
  16. class OprParamsLoadContext final: public serialization::OprLoadContextRawPOD {
  17. PyObject *m_params;
  18. ComputingGraph *m_graph;
  19. size_t m_nr_used_params = 0, m_param_size = 0;
  20. size_t m_item_bytes_consumed = 0;
  21. void read_raw(void *dest, size_t size) override final {
  22. mgb_assert(m_nr_used_params < m_param_size);
  23. auto item = PyList_GetItem(m_params, m_nr_used_params);
  24. mgb_assert(item, "failed to get item %zu", m_nr_used_params);
  25. mgb_assert(PyBytes_Check(item), "list item must be bytes");
  26. auto item_size = PyBytes_Size(item);
  27. mgb_assert(size < (SIZE_MAX >> 3));
  28. mgb_assert(m_item_bytes_consumed + size <= size_t(item_size));
  29. auto item_buf = PyBytes_AsString(item);
  30. mgb_assert(item_size > 0 && item_buf);
  31. memcpy(dest, item_buf + m_item_bytes_consumed, size);
  32. m_item_bytes_consumed += size;
  33. if (m_item_bytes_consumed == size_t(item_size)) {
  34. ++ m_nr_used_params;
  35. m_item_bytes_consumed = 0;
  36. }
  37. }
  38. std::shared_ptr<HostTensorND> load_tensor() override {
  39. mgb_assert(0);
  40. }
  41. std::shared_ptr<DeviceTensorND> load_tensor_shared() override {
  42. mgb_assert(0);
  43. }
  44. const serialization::GraphLoadConfig& config() const override {
  45. mgb_assert(0);
  46. }
  47. public:
  48. OprParamsLoadContext(PyObject *params, ComputingGraph *graph):
  49. m_params{params}, m_graph{graph}
  50. {
  51. mgb_assert(PyList_Check(params), "params must be a list");
  52. m_param_size = PyList_Size(params);
  53. }
  54. ~OprParamsLoadContext() {
  55. mgb_assert(m_nr_used_params == m_param_size,
  56. "number of params mismatch");
  57. }
  58. ComputingGraph& graph() override {
  59. return *m_graph;
  60. }
  61. };
  62. } // anonymous namespace
  63. _SplitPartCallback::callback_t _SplitPartCallback::make_callback() {
  64. mgb_assert(!m_cb_created);
  65. m_cb_created = true;
  66. std::shared_ptr<_SplitPartCallback> cb_ptr(this);
  67. auto cb = [cb_ptr](size_t sz) {
  68. return cb_ptr->call(sz);
  69. };
  70. return cb;
  71. }
  72. _SetGradCallback::callback_t _SetGradCallback::make_callback() {
  73. mgb_assert(!m_cb_created);
  74. m_cb_created = true;
  75. if (empty()) {
  76. return {};
  77. }
  78. std::shared_ptr<_SetGradCallback> cb_ptr(this);
  79. auto cb = [cb_ptr](const opr::SetGrad& opr) {
  80. auto graph = CompGraph::make_from_weak_ptr(
  81. opr.owner_graph()->shared_from_this());
  82. return cb_ptr->call(graph);
  83. };
  84. return cb;
  85. }
  86. _TimeoutCallback::callback_t _TimeoutCallback::make_callback() {
  87. mgb_assert(!m_cb_created);
  88. m_cb_created = true;
  89. std::shared_ptr<_TimeoutCallback> cb_ptr(this);
  90. auto cb = [cb_ptr]() {
  91. return cb_ptr->call();
  92. };
  93. return cb;
  94. }
  95. mgb::SymbolVar _create_subtensor_like_opr(
  96. const std::string &name,
  97. const SymbolVarArray& inputs,
  98. const std::vector<AxisIndexer> &idx,
  99. const mgb::OperatorNodeConfig &config) {
  100. #define CHK1(_name, _opr) \
  101. if (name == _name) { \
  102. mgb_assert(inputs.size() == 1); \
  103. return opr::_opr::make(inputs[0], idx, config); \
  104. }
  105. #define CHK2(_name, _opr) \
  106. if (name == _name) { \
  107. mgb_assert(inputs.size() == 2); \
  108. return opr::_opr::make(inputs[0], inputs[1], idx, config); \
  109. }
  110. CHK1("subtensor", Subtensor);
  111. CHK2("set_subtensor", SetSubtensor);
  112. CHK2("incr_subtensor", IncrSubtensor);
  113. CHK1("mavi", IndexingMultiAxisVec);
  114. CHK2("set_mavi", IndexingSetMultiAxisVec);
  115. CHK2("incr_mavi", IndexingIncrMultiAxisVec);
  116. CHK1("mesh_indexing", MeshIndexing);
  117. CHK1("batched_mesh_indexing", BatchedMeshIndexing);
  118. CHK2("incr_mesh_indexing", IncrMeshIndexing);
  119. CHK2("set_mesh_indexing", SetMeshIndexing);
  120. CHK2("batched_incr_mesh_indexing", BatchedIncrMeshIndexing);
  121. CHK2("batched_set_mesh_indexing", BatchedSetMeshIndexing);
  122. mgb_throw(MegBrainError, "bad subtensor opr name: %s", name.c_str());
  123. #undef CHK1
  124. #undef CHK2
  125. }
  126. SymbolVar _make_immutable(CompGraph &comp_graph, PyObject *npyarr,
  127. PyObject *dtype, const mgb::cg::OperatorNodeConfig &config) {
  128. auto cn = config.get_single_comp_node();
  129. mgb_assert(cn.valid(), "invalid comp node given to make_tensor");
  130. DType dtype_mgb;
  131. if (dtype && dtype != Py_None)
  132. dtype_mgb = npy::dtype_np2mgb(dtype);
  133. auto hv = npy::np2tensor(npyarr, npy::Meth::borrow(cn), dtype_mgb);
  134. return opr::ImmutableTensor::make(comp_graph.get(), hv, config);
  135. }
  136. SymbolVarArray _create_opr(
  137. const char *name, const SymbolVarArray &inputs,
  138. PyObject *params, const OperatorNodeConfig &config) {
  139. mgb_assert(!inputs.empty());
  140. auto registry = serialization::OprRegistry::find_by_name(name);
  141. mgb_assert(registry, "operator %s not found", name);
  142. OprParamsLoadContext ctx{params, inputs[0].node()->owner_graph()};
  143. VarNodeArray vinputs(inputs.size());
  144. for (size_t i = 0; i < inputs.size(); ++ i)
  145. vinputs[i] = inputs[i].node();
  146. auto opr = registry->loader(ctx, vinputs, config);
  147. SymbolVarArray ret;
  148. for (auto i: opr->output()) {
  149. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT))
  150. ret.push_back(i);
  151. }
  152. return ret;
  153. }
  154. #if MGB_ENABLE_OPR_MM
  155. mgb::opr::CollectiveComm::Param load_collective_comm_params(
  156. PyObject* params, mgb::ComputingGraph* graph) {
  157. OprParamsLoadContext ctx{params, graph};
  158. return ctx.read_param<mgb::opr::CollectiveComm::Param>();
  159. }
  160. #endif
  161. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台