You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_rt.cpp 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. /**
  2. * \file imperative/python/src/graph_rt.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./graph_rt.h"
  12. #include "megbrain/graph/cg.h"
  13. #include "megbrain/serialization/serializer.h"
  14. #include "megbrain/imperative/opr_utility.h"
  15. #include "megbrain/opr/io.h"
  16. #include "megbrain/opr/utility.h"
  17. #include "megbrain/opr/basic_arith.h"
  18. #include "megbrain/imperative.h"
  19. #include "./helper.h"
  20. #include "megbrain/plugin/profiler.h"
  21. #include "./common.h"
  22. #include "./ops.h"
  23. #include "megbrain/gopt/inference.h"
  24. #include "megbrain/imperative/ops/utility.h"
  25. namespace py = pybind11;
  26. using namespace mgb;
  27. using namespace imperative;
  28. namespace ser = mgb::serialization;
  29. using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
  30. using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
  31. using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
  32. namespace {
  33. class _CompGraphProfilerImpl {
  34. std::shared_ptr<ComputingGraph> m_comp_graph;
  35. GraphProfiler m_profiler;
  36. public:
  37. _CompGraphProfilerImpl(std::shared_ptr<ComputingGraph> cg):
  38. m_comp_graph{cg},
  39. m_profiler{m_comp_graph.get()}
  40. {
  41. }
  42. std::string _get_result() {
  43. auto json = m_profiler.to_json_full(
  44. m_comp_graph->current_comp_seq());
  45. return json->to_string();
  46. }
  47. };
  48. struct WeakRendezvousArray:
  49. public std::vector<std::weak_ptr<RendezvousBase>>,
  50. public UserDataContainer::UserData {
  51. MGB_TYPEINFO_OBJ_DECL;
  52. };
  53. MGB_TYPEINFO_OBJ_IMPL(WeakRendezvousArray);
  54. }
  55. #define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name)
  56. template<typename T>
  57. auto def_rendezvous(py::object m, const char* name) {
  58. return py::class_<Rendezvous<T>, std::shared_ptr<Rendezvous<T>>>(m, name)
  59. .def(py::init([](){return Rendezvous<T>::make();}))
  60. .def("set", [](Rendezvous<T>& r, T v) {r.set(std::move(v));})
  61. .def("get", [](Rendezvous<T>& r) {return r.get();}, py::call_guard<py::gil_scoped_release>())
  62. .def("drop", &Rendezvous<T>::drop)
  63. .def("reset", &Rendezvous<T>::reset)
  64. .def("set_exception", [](Rendezvous<T>& r, std::string&& message) {
  65. r.set_exception(std::make_exception_ptr(
  66. std::runtime_error(std::move(message))));
  67. });
  68. }
  69. using TensorAttr = LogicalTensorDesc;
  70. using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>;
  71. std::vector<mgb::cg::VarNode*> _replace_vars(const std::vector<mgb::cg::VarNode*>& repl_src,
  72. const std::vector<mgb::cg::VarNode*>& repl_dst,
  73. const std::vector<mgb::cg::VarNode*>& vars) {
  74. mgb::ThinHashMap<SymbolVar, SymbolVar> varmap;
  75. for (size_t i = 0; i < repl_src.size(); ++i) {
  76. varmap[SymbolVar(repl_src[i])] = SymbolVar(repl_dst[i]);
  77. }
  78. SymbolVarArray symvars(vars.begin(), vars.end());
  79. auto sym_result = mgb::cg::replace_vars(symvars, varmap);
  80. std::vector<mgb::cg::VarNode*> result;
  81. for (auto symvar : sym_result){
  82. result.push_back(symvar.node());
  83. }
  84. return result;
  85. }
  86. typedef std::vector<mgb::cg::OperatorNodeBase*> OperatorArray;
  87. std::vector<mgb::cg::VarNode*> _replace_oprs(const OperatorArray& repl_src,
  88. const OperatorArray& repl_dst,
  89. const std::vector<mgb::cg::VarNode*>& vars) {
  90. mgb::ThinHashMap<mgb::cg::OperatorNodeBase*, mgb::cg::OperatorNodeBase*>
  91. oprmap;
  92. for (size_t i = 0; i < repl_src.size(); ++i) {
  93. oprmap[repl_src[i]] = repl_dst[i];
  94. }
  95. const SymbolVarArray symvars(vars.begin(), vars.end());
  96. auto sym_result = mgb::cg::replace_oprs(symvars, oprmap);
  97. std::vector<mgb::cg::VarNode*> result;
  98. for (auto symvar : sym_result){
  99. result.push_back(symvar.node());
  100. }
  101. return result;
  102. }
  103. void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
  104. auto on_opr = [](mgb::cg::OperatorNodeBase* opr) {
  105. if (opr->node_prop().attribute().priority == 0) {
  106. opr->node_prop().attribute().priority = opr->id();
  107. }
  108. };
  109. mgb::cg::DepOprIter dep_iter{on_opr};
  110. for (const auto& var : dest_vars) {
  111. dep_iter.add(SymbolVar(var));
  112. }
  113. }
  114. void init_graph_rt(py::module m) {
  115. static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{std::make_unique<mgb::OprFootprint>()};
  116. def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous");
  117. def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous");
  118. def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous");
  119. py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode")
  120. .def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();})
  121. .def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();})
  122. .def_property("name", py::overload_cast<>(&VarNode::name, py::const_),
  123. py::overload_cast<std::string>(&VarNode::name))
  124. .def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();})
  125. .def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();})
  126. .def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* {
  127. auto&& mgr = v->owner_graph()->static_infer_manager();
  128. return mgr.infer_shape_fallible(v);
  129. })
  130. .def_property_readonly("value", [](cg::VarNode* v) -> py::object {
  131. auto&& mgr = v->owner_graph()->static_infer_manager();
  132. auto&& type = mgr.get_infer_type(v);
  133. using InferType = cg::static_infer::InferType;
  134. if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
  135. return py::none();
  136. }
  137. auto* val = mgr.infer_value_fallible(v);
  138. if (!val) {
  139. return py::none();
  140. }
  141. return py::cast(*val).attr("numpy")();
  142. })
  143. .def_property_readonly("id",[](cg::VarNode* v){
  144. return (v->id());
  145. })
  146. .def("__repr__", [](cg::VarNode* v) {
  147. return "Var:" + v->name();
  148. });
  149. py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode")
  150. .def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();})
  151. .def_property("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_),
  152. py::overload_cast<std::string>(&cg::OperatorNodeBase::name))
  153. .def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) {
  154. return to_tuple(opr->input());
  155. })
  156. .def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) {
  157. return to_tuple(opr->usable_output());
  158. })
  159. .def_property_readonly("id",[](cg::OperatorNodeBase* opr){
  160. return opr->id();
  161. })
  162. .def_property_readonly("params",[](cg::OperatorNodeBase* opr){
  163. return _imperative_sm_opr_footprint_ptr->calc_footprint(opr).param->to_string();
  164. })
  165. .def_property_readonly("type",[](cg::OperatorNodeBase* opr){
  166. return opr->dyn_typeinfo()->name;
  167. })
  168. .def("__repr__", [](cg::OperatorNodeBase* opr){
  169. return "Opr:" + opr->name();
  170. });
  171. py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
  172. .def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>())
  173. .def("wait", &cg::AsyncExecutable::wait, py::call_guard<py::gil_scoped_release>())
  174. .def("get_prev_exec_time", &cg::AsyncExecutable::get_prev_exec_time, py::call_guard<py::gil_scoped_release>())
  175. .def("_to_json", [](cg::AsyncExecutable* exec) {
  176. py::call_guard<py::gil_scoped_release>();
  177. // dump currently compiled computing graph for debugging
  178. return exec->to_json()->to_string();
  179. })
  180. // only used for exception handle
  181. .def_property_readonly("_all_rendezvous", [](cg::AsyncExecutable* exec) {
  182. auto ud = exec->owner_graph()->options().user_data
  183. .get_user_data<WeakRendezvousArray>();
  184. std::vector<std::shared_ptr<RendezvousBase>> ret;
  185. if (ud.second) {
  186. for (auto&& r: *ud.first[0]) {
  187. if (auto p = r.lock()) {
  188. ret.emplace_back(std::move(p));
  189. }
  190. }
  191. }
  192. return ret;
  193. });
  194. auto PyComputingGraph = py::class_<cg::ComputingGraph, std::shared_ptr<cg::ComputingGraph>>(m, "ComputingGraph")
  195. .def(py::init(py::overload_cast<>(&cg::ComputingGraph::make)))
  196. .def("compile", [](cg::ComputingGraph& graph, const std::vector<cg::VarNode*>& dest_vars) {
  197. mgb_assert(!dest_vars.empty());
  198. cg::ComputingGraph::OutputSpec spec;
  199. for (auto v : dest_vars) {
  200. spec.emplace_back(v, nullptr);
  201. }
  202. return graph.compile(spec);
  203. })
  204. .def_property_readonly("options", py::overload_cast<>(&cg::ComputingGraph::options));
  205. py::class_<_CompGraphProfilerImpl, std::shared_ptr<_CompGraphProfilerImpl>>(m, "GraphProfiler")
  206. .def(py::init([](std::shared_ptr<ComputingGraph> graph) {
  207. return std::make_shared<_CompGraphProfilerImpl>(graph);
  208. }))
  209. .def("get", [](_CompGraphProfilerImpl& profiler) { return profiler._get_result(); });
  210. auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions")
  211. .def(py::init())
  212. .def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp)
  213. .def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp)
  214. .def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity)
  215. .def_readwrite("fuse_conv_bias_with_z", &_OptimizeForInferenceOptions::fuse_conv_bias_with_z)
  216. .def_readwrite("layout_transform", &_OptimizeForInferenceOptions::layout_transform)
  217. ;
  218. py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform")
  219. .value("DEFAULT", _LayoutTransform::DEFAULT)
  220. .value("NCHW4", _LayoutTransform::NCHW4)
  221. .value("NHWCD4", _LayoutTransform::NHWCD4)
  222. .value("NCHW88", _LayoutTransform::NCHW88)
  223. .value("NCHW44", _LayoutTransform::NCHW44)
  224. .value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT)
  225. .value("NCHW32", _LayoutTransform::NCHW32)
  226. .value("CHWN4", _LayoutTransform::CHWN4)
  227. .export_values()
  228. ;
  229. m.def("optimize_for_inference", [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) {
  230. SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
  231. auto res_symvars = mgb::gopt::optimize_for_inference(symvars, opt);
  232. VarNodeArray vars;
  233. for (auto& si: res_symvars)
  234. vars.push_back(si.node());
  235. return vars;
  236. });
  237. m.def("modify_opr_algo_strategy_inplace", [](const VarNodeArray& dest_vars,
  238. const _AlgoStrategy& strategy) {
  239. mgb::gopt::modify_opr_algo_strategy_inplace(dest_vars, strategy);
  240. });
  241. m.def("get_info_for_strip", [](const std::vector<VarNode*>& dest_vars) {
  242. std::unordered_set<const char*> opr_types, dtype_names, elemwise_modes;
  243. auto on_opr = [&](cg::OperatorNodeBase *opr) {
  244. if (ser::GraphDumper::should_remove_in_dump(opr))
  245. return;
  246. opr_types.insert(opr->dyn_typeinfo()->name);
  247. for (auto i : opr->output())
  248. dtype_names.insert(i->dtype().name());
  249. if (opr->same_type<opr::Elemwise>()) {
  250. auto mode = opr->cast_final<opr::Elemwise>().param().mode;
  251. elemwise_modes.insert(
  252. megdnn::Elemwise::ModeTrait::from_mode(mode).name);
  253. }
  254. };
  255. cg::DepOprIter opr_iter{on_opr};
  256. for (auto i : dest_vars)
  257. opr_iter.add(i->owner_opr());
  258. auto to_json = [](const std::unordered_set<const char*> &v) {
  259. std::vector<std::string> vs(v.begin(), v.end());
  260. std::sort(vs.begin(), vs.end());
  261. auto ret = json::Array::make();
  262. for (auto &&i : vs)
  263. ret->add(json::String::make(i));
  264. return ret;
  265. };
  266. return json::Object::make({
  267. {"opr_types", to_json(opr_types)},
  268. {"dtypes", to_json(dtype_names)},
  269. {"elemwise_modes", to_json(elemwise_modes)},
  270. })->to_string();
  271. });
  272. m.def("dump_graph", [](
  273. const std::vector<VarNode*>& dest_vars,
  274. int keep_var_name,
  275. bool keep_opr_name,
  276. bool keep_param_name,
  277. bool keep_opr_priority,
  278. py::list& stat,
  279. py::list& inputs,
  280. py::list& outputs,
  281. py::list& params
  282. ) {
  283. std::vector<uint8_t> buf;
  284. auto dumper = ser::GraphDumper::make(ser::OutputFile::make_vector_proxy(&buf));
  285. SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
  286. ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name,
  287. keep_opr_priority, keep_opr_name};
  288. auto rst = dumper->dump(symvars, config);
  289. for (auto i : rst.inputs) {
  290. inputs.append(py::cast(i));
  291. }
  292. for (auto i : rst.outputs) {
  293. outputs.append(py::cast(i));
  294. }
  295. for (auto i : rst.params) {
  296. params.append(py::cast(i));
  297. }
  298. auto rst_stat =
  299. std::vector{rst.nr_opr, rst.tot_bytes, rst.tensor_value_bytes,
  300. static_cast<size_t>(rst.content_hash)};
  301. for (auto i : rst_stat) {
  302. stat.append(py::cast(i));
  303. }
  304. return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size());
  305. });
  306. m.def("load_graph", [](
  307. std::string& buf,
  308. py::list& output_var_map,
  309. py::list& output_var_list
  310. ) {
  311. auto file = ser::InputFile::make_mem_proxy(buf.c_str(), buf.length());
  312. auto format = ser::GraphLoader::identify_graph_dump_format(*file);
  313. auto loader = ser::GraphLoader::make(std::move(file), format.val());
  314. ser::GraphLoader::LoadConfig config;
  315. auto rst = loader->load(config);
  316. for (auto i : rst.output_var_map) {
  317. output_var_map.append(py::make_tuple(i.first, i.second.node()));
  318. }
  319. for (auto i : rst.output_var_list) {
  320. output_var_list.append(i.node());
  321. }
  322. std::unordered_map<HostTensorND*, const std::string*> tensor2name;
  323. for (const auto& pair : rst.tensor_map) {
  324. tensor2name[pair.second.get()] = &pair.first;
  325. }
  326. auto cb = [&tensor2name, graph=rst.graph](cg::OperatorNodeBase* opr) {
  327. if (!opr->same_type<opr::Host2DeviceCopy>())
  328. return;
  329. auto& h2d = opr->cast_final_safe<opr::Host2DeviceCopy>();
  330. auto it = tensor2name.find(h2d.host_data().get());
  331. mgb_throw_if(it == tensor2name.end(), GraphError,
  332. "unbound Host2DeviceCopy in loaded graph");
  333. h2d.output(0)->name(*it->second);
  334. };
  335. cg::DepOprIter iter{cb};
  336. for (const auto& var : rst.output_var_list) {
  337. iter.add(var);
  338. }
  339. return rst.graph;
  340. });
  341. #define CURRENT_CLASS cg::ComputingGraph::Options
  342. auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options")
  343. // DEF_READWRITE(opr_attribute)
  344. DEF_READWRITE(seq_opt)
  345. DEF_READWRITE(graph_opt)
  346. DEF_READWRITE(graph_opt_level)
  347. DEF_READWRITE(log_level)
  348. DEF_READWRITE(async_exec_level)
  349. DEF_READWRITE(force_dynamic_alloc)
  350. DEF_READWRITE(var_sanity_check_first_run)
  351. DEF_READWRITE(allocate_static_mem_after_graph_compile)
  352. DEF_READWRITE(fake_next_exec)
  353. DEF_READWRITE(enable_sublinear_memory_opt)
  354. DEF_READWRITE(no_profiling_on_shape_change)
  355. DEF_READWRITE(enable_var_mem_defragment)
  356. DEF_READWRITE(enable_grad_var_static_reshape)
  357. DEF_READWRITE(enable_memory_swap)
  358. DEF_READWRITE(comp_node_seq_record_level)
  359. DEF_READWRITE(no_force_inplace)
  360. DEF_READWRITE(sublinear_mem_config)
  361. // DEF_READWRITE(eager_evaluation)
  362. // DEF_READWRITE(imperative_proxy_graph)
  363. // DEF_READWRITE(extra_vardeps)
  364. // DEF_READWRITE(user_data)
  365. ;
  366. #undef CURRENT_CLASS
  367. #define CURRENT_CLASS cg::ComputingGraph::Options::SeqOpt
  368. py::class_<cg::ComputingGraph::Options::SeqOpt>(PyComputingGraphOptions, "SeqOpt")
  369. DEF_READWRITE(enable_mem_plan_opt)
  370. DEF_READWRITE(enable_mem_reuse_alloc)
  371. DEF_READWRITE(enable_seq_comp_node_opt);
  372. #undef CURRENT_CLASS
  373. #define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt
  374. py::class_<cg::ComputingGraph::Options::GraphOpt>(PyComputingGraphOptions, "GraphOpt")
  375. DEF_READWRITE(jit)
  376. DEF_READWRITE(tensorrt);
  377. #undef CURRENT_CLASS
  378. #define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig
  379. py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(PyComputingGraphOptions, "SublinearMemConfig")
  380. DEF_READWRITE(thresh_nr_try)
  381. DEF_READWRITE(genetic_nr_iter)
  382. DEF_READWRITE(genetic_pool_size)
  383. DEF_READWRITE(lb_memory)
  384. DEF_READWRITE(num_worker);
  385. #undef CURRENT_CLASS
  386. auto common = rel_import("common", m, 1);
  387. common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) {
  388. cg::VarNodeArray vinputs(inputs.begin(), inputs.end());
  389. return to_tuple(OpDef::apply_on_var_node(def, vinputs));
  390. },
  391. py::arg(), py::arg(), py::arg("graph") = py::none());
  392. auto input_callback = [](auto callback,
  393. const CompNode& comp_node,
  394. const DType& dtype,
  395. const TensorShape& shape,
  396. const std::vector<cg::VarNode*>& inputs,
  397. cg::ComputingGraph* graph,
  398. bool use_static_shape) {
  399. if (!graph) {
  400. graph = inputs[0]->owner_graph();
  401. }
  402. SymbolVarArray sinputs;
  403. for (auto i : inputs) {
  404. sinputs.emplace_back(i);
  405. }
  406. static_assert(!std::is_reference<decltype(callback)>::value);
  407. auto soutputs = opr::InputCallback::make(*graph, std::move(callback),
  408. comp_node, dtype, shape,
  409. sinputs, use_static_shape);
  410. std::vector<VarNode*> outputs;
  411. outputs.reserve(soutputs.size());
  412. for (auto i : soutputs) {
  413. outputs.push_back(i.node());
  414. }
  415. return outputs;
  416. };
  417. m.def("make_shared", [](cg::ComputingGraph* graph, const DeviceTensorND& data) {
  418. return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node();
  419. });
  420. m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype, std::optional<std::string> name) {
  421. if (!cn.valid()) {
  422. cn = CompNode::load(get_default_device());
  423. }
  424. OperatorNodeConfig config(cn);
  425. if (name) {
  426. config.name(*name);
  427. }
  428. auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
  429. return opr::ImmutableTensor::make(*graph, hv, config).node();
  430. }, py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::none());
  431. m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) {
  432. if (!cn.valid()) {
  433. throw py::type_error("device must be valid");
  434. }
  435. if (!dtype.valid()) {
  436. throw py::type_error("dtype must be valid");
  437. }
  438. OperatorNodeConfig config;
  439. if (name) {
  440. config.name(*name);
  441. }
  442. return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, shape, dtype), config).node();
  443. }, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none());
  444. m.def("_replace_vars", &_replace_vars,py::arg(),py::arg(),py::arg());
  445. m.def("_replace_oprs", &_replace_oprs,py::arg(),py::arg(),py::arg());
  446. m.def("_set_priority_to_id",&_set_priority_to_id,py::arg());
  447. m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback,
  448. const CompNode& comp_node,
  449. const DType& dtype,
  450. const TensorShape& shape,
  451. const std::vector<cg::VarNode*>& inputs,
  452. cg::ComputingGraph* graph,
  453. bool use_static_shape) {
  454. return input_callback(
  455. [f=std::move(callback)](){py::gil_scoped_acquire _; return f();},
  456. comp_node, dtype, shape, inputs, graph, use_static_shape);
  457. },
  458. py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(),
  459. py::arg("graph") = py::none(), py::arg("use_static_shape") = false);
  460. m.def("input_callback", [input_callback](std::shared_ptr<Rendezvous<DeviceTensorND>> p,
  461. const CompNode& comp_node,
  462. const DType& dtype,
  463. const TensorShape& shape,
  464. const std::vector<cg::VarNode*>& inputs,
  465. cg::ComputingGraph* graph,
  466. bool use_static_shape) {
  467. auto f = [p]() -> DeviceTensorND {
  468. return p->get();
  469. };
  470. return input_callback(std::move(f), comp_node, dtype, shape, inputs, graph, use_static_shape);
  471. },
  472. py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(),
  473. py::arg("graph") = py::none(), py::arg("use_static_shape") = false);
  474. auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs,
  475. std::shared_ptr<RendezvousBase> r = {}, bool borrow = false, bool prefer_host_value = false) {
  476. if (r) {
  477. mgb_assert(inputs.size());
  478. auto cg = inputs[0]->owner_graph();
  479. cg->options().user_data.get_user_data_or_create<WeakRendezvousArray>()
  480. ->emplace_back(r);
  481. }
  482. SymbolVarArray sinputs;
  483. for (auto i : inputs) {
  484. sinputs.emplace_back(i);
  485. }
  486. static_assert(!std::is_reference<decltype(callback)>::value);
  487. opr::OutputCallback::Param param{std::move(callback), borrow, prefer_host_value};
  488. auto output = opr::OutputCallback::make(std::move(param), sinputs);
  489. return output.node();
  490. };
  491. m.def("output_callback", [output_callback](std::function<void(DeviceTensorND)> callback, std::vector<cg::VarNode*> inputs) {
  492. auto f = [f=std::move(callback)](DeviceTensorND dv) {
  493. auto task = [f=std::move(f), dv=std::move(dv)]() {
  494. f(dv);
  495. };
  496. py_task_q.add_task(std::move(task));
  497. };
  498. return output_callback(std::move(f), std::move(inputs));
  499. });
  500. m.def("output_callback", [output_callback](std::shared_ptr<Rendezvous<DeviceTensorND>> p, std::vector<cg::VarNode*> inputs) {
  501. auto f = [p](DeviceTensorND dv) {
  502. p->set(std::move(dv));
  503. };
  504. return output_callback(std::move(f), std::move(inputs), p);
  505. });
  506. m.def("value_output_callback", [output_callback](std::shared_ptr<Rendezvous<HostNDWithEvent>> p, std::vector<cg::VarNode*> inputs) {
  507. auto f = [p](DeviceTensorND dv) {
  508. HostNDWithEvent hv_with_event;
  509. hv_with_event.first.copy_from(dv);
  510. hv_with_event.second = dv.comp_node().create_event();
  511. hv_with_event.second->record();
  512. p->set(std::move(hv_with_event));
  513. };
  514. return output_callback(std::move(f), std::move(inputs), p, true, true);
  515. });
  516. m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) {
  517. auto f = [p](DeviceTensorND dv) {
  518. p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()});
  519. };
  520. return output_callback(std::move(f), std::move(inputs), p, true);
  521. });
  522. m.def("virtual_dep", [](std::vector<cg::VarNode*> inputs, std::string device) {
  523. auto&& graph = inputs[0]->owner_graph();
  524. VarNodeArray inps(inputs.begin(), inputs.end());
  525. cg::OperatorNodeConfig config;
  526. if (device.length() > 0) {
  527. config.comp_node(CompNode::load(device));
  528. }
  529. cg::OperatorNodeBase* opr = graph->insert_opr(
  530. std::make_unique<mgb::opr::VirtualDep>(inps, config));
  531. return opr;
  532. });
  533. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台