You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cg_impl.cpp 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781
  1. /**
  2. * \file src/core/impl/graph/cg_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./cg_impl.h"
  12. #include "./cg_impl_partial.h"
  13. #include "./cg_impl_seq.h"
  14. #include "megbrain/gopt/framework.h"
  15. #include "megbrain/gopt/inference.h"
  16. #include "megbrain/gopt/basic_arith.h"
  17. #include "megbrain/gopt/misc.h"
  18. #include "megbrain/graph/cg.h"
  19. #include "megbrain/graph/event.h"
  20. #include "megbrain/graph/exc_extra_info.h"
  21. #include "megbrain/graph/helper.h"
  22. #include "megbrain/opr/utility.h"
  23. #if MGB_ENABLE_TENSOR_RT
  24. #include "megbrain/tensorrt/opr_replace.h"
  25. #endif
  26. #if MGB_JIT
  27. #include "megbrain/jit/fusion_pass.h"
  28. #endif
  29. #include "megbrain/gopt/weights_preprocess.h"
  30. using namespace mgb;
  31. using namespace cg;
  32. namespace {
  33. void check_opr_not_cross_mem(OperatorNodeBase* opr) {
  34. if (opr->node_prop().contain(
  35. OperatorNodeBase::NodeProp::Flag::CROSS_COMP_NODE_MEMORY))
  36. return;
  37. MemNode mem_node_id;
  38. bool first = true;
  39. auto check = [&](VarNode* var) {
  40. auto cur = var->comp_node().mem_node();
  41. mgb_assert(cur);
  42. if (first) {
  43. first = false;
  44. mem_node_id = cur;
  45. } else
  46. mgb_assert(mem_node_id == cur,
  47. "for non cross-memory oprs, "
  48. "all vars should reside on the same memory node");
  49. };
  50. for (auto i : opr->input()) {
  51. check(i);
  52. }
  53. for (auto i : opr->output()) {
  54. check(i);
  55. }
  56. }
  57. void update_output_shapes(static_infer::StaticInferManagerImpl& infer_mgr,
  58. OperatorNodeBase* opr, bool add_freeze_flag) {
  59. for (auto i : opr->output()) {
  60. if (add_freeze_flag) {
  61. i->add_flag(VarNode::Flag::FLAG_FREEZED);
  62. }
  63. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  64. using namespace static_infer;
  65. if (infer_mgr.get_infer_type(i).shape &
  66. (InferType::CONST | InferType::RT_STATIC)) {
  67. auto shp = infer_mgr.infer_shape_fallible(i);
  68. if (shp) {
  69. i->shape(*shp);
  70. } else {
  71. i->shape({});
  72. }
  73. } else {
  74. i->shape({});
  75. }
  76. }
  77. }
  78. }
  79. } // anonymous namespace
  80. /* ========================== global helpers ========================== */
  81. void cg::update_output_var_shapes(OperatorNodeBase* opr) {
  82. update_output_shapes(static_cast<static_infer::StaticInferManagerImpl&>(
  83. opr->owner_graph()->static_infer_manager()),
  84. opr, false);
  85. }
  86. /* ========================= DeviceMemoryAllocator ========================= */
  87. void DeviceMemoryAllocator::alloc_static(ComputingGraph*,
  88. DeviceTensorStorage& dest,
  89. size_t size) {
  90. dest.ensure_size(size);
  91. }
  92. void DeviceMemoryAllocator::alloc_dynamic(VarNode*, DeviceTensorStorage& dest,
  93. size_t size) {
  94. dest.ensure_size(size);
  95. }
  96. void DeviceMemoryAllocator::defrag_prealloc_contig(ComputingGraph* graph,
  97. CompNode comp_node,
  98. size_t size){
  99. MGB_TRY{comp_node.free_device(comp_node.alloc_device(size));
  100. }
  101. MGB_CATCH(MemAllocError&, {})
  102. }
  103. size_t DeviceMemoryAllocator::static_alloc_version(ComputingGraph*) const {
  104. return 0;
  105. }
  106. /* ========================== ComputingGraph ========================== */
  107. ComputingGraph::ComputingGraph() {
  108. static std::atomic_size_t tot_id{0};
  109. m_id = (tot_id++);
  110. }
  111. void ComputingGraph::assert_destroy(std::shared_ptr<ComputingGraph>& ptr) {
  112. mgb_assert(ptr.use_count() == 1, "unexpected use_count: %zu",
  113. size_t(ptr.use_count()));
  114. ptr.reset();
  115. }
  116. #if !MGB_THREAD_SAFE
  117. size_t ComputingGraph::prealloc_static_storage(size_t size) {
  118. // note that in single-threaded mode, all cpus map to the same comp node
  119. static int version = 0;
  120. auto cn = CompNode::load("cpu0");
  121. mgb_assert(cn == CompNode::load("cpu1"));
  122. auto inst = StaticDeviceMemoryManager::make_default_impl();
  123. auto ret = inst->get_size(cn);
  124. inst->alloc(nullptr, cn, size, version).ptr();
  125. version = inst->version(nullptr);
  126. return ret;
  127. }
  128. #endif
  129. /* ========================== CallbackCaller ========================== */
  130. MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
  131. SingleCNOperatorNodeBase) // {
  132. std::vector<ComputingGraph::Callback> m_cb;
  133. void scn_do_execute() override {
  134. auto&& dv = input(0)->dev_tensor();
  135. for (auto&& i : m_cb) {
  136. // const cast for backward API compatibility
  137. i(const_cast<DeviceTensorND&>(dv));
  138. }
  139. }
  140. void init_output_static_infer_desc() override {
  141. using namespace cg::static_infer;
  142. owner_graph()->static_infer_manager().register_shape_infer(
  143. output(0), ShapeInferDesc::make_const({}));
  144. }
  145. void add_input_layout_constraint() override {
  146. if (owner_graph()->options().comp_node_seq_record_level) {
  147. // the user callback usually copies from device to host, which
  148. // involves tmp alloc if input is not contiguous
  149. input(0)->add_layout_constraint_contiguous();
  150. }
  151. }
  152. NodeProp* do_make_node_prop() const override {
  153. auto ret = Super::do_make_node_prop();
  154. ret->add_dep_type_existing_var(input(0),
  155. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  156. return ret;
  157. }
  158. bool update_priority() const override {
  159. node_prop().attribute().priority = std::numeric_limits<int>::min();
  160. return true;
  161. }
  162. public:
  163. CallbackCaller(VarNode* inp)
  164. : Super{inp->owner_graph(), {}, "callback", {inp}} {
  165. add_input({inp});
  166. using F = VarNode::Flag;
  167. add_output(None)
  168. ->add_flag(F::ALLOW_EMPTY_SHAPE)
  169. .add_flag(F::VOLATILE_CONTENT);
  170. }
  171. static SymbolVar make(SymbolVar inp) {
  172. return inp.insert_single_output_opr<CallbackCaller>(inp.node());
  173. }
  174. void add_callback(const ComputingGraph::Callback& cb) {
  175. mgb_assert(cb);
  176. m_cb.push_back(cb);
  177. }
  178. void clear_callback() { m_cb.clear(); }
  179. };
  180. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ComputingGraphImpl::CallbackCaller);
  181. /* ========================== ComputingGraphImpl ========================== */
  182. ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner)
  183. : topo_sorter{owner},
  184. var_node_mem_manager{owner},
  185. seq_comp_node_opt{owner},
  186. static_infer_manager{owner},
  187. static_infer_comp_seq_manager{owner},
  188. grad_manager{owner},
  189. #if MGB_ENABLE_SUBLINEAR
  190. seq_modifier_for_sublinear_memory{owner,
  191. &(owner->options().sublinear_mem_cofig)},
  192. #endif
  193. #if MGB_ENABLE_MEMORY_SWAP
  194. memory_swap_support{owner},
  195. #endif
  196. eager_eval_manager{owner}
  197. {
  198. }
  199. ComputingGraphImpl::ComputingGraphImpl() {
  200. auto ptr = new (&m_components_storage) Components{this};
  201. mgb_assert(ptr == &components());
  202. }
  203. ComputingGraphImpl::~ComputingGraphImpl() {
  204. if (!is_finalized()) {
  205. cleanup();
  206. }
  207. }
  208. std::shared_ptr<void> ComputingGraphImpl::on_comp_node_finalize() {
  209. // hold a reference because the object itself may be deleted by user data or
  210. // oprs
  211. std::shared_ptr<void> ref = shared_from_this();
  212. cleanup();
  213. return ref;
  214. }
  215. void ComputingGraphImpl::cleanup() {
  216. if (m_recorded_seq_level2_dtor_chk) {
  217. m_recorded_seq_level2_dtor_chk->enable();
  218. }
  219. // clear device memory storage and return them to comp node
  220. clear_device_memory();
  221. // so opr dtors would incur no overhead when deleting vars
  222. m_var_node_pool.disable_freelist();
  223. // TODO: call this after each graph exec when we have faster impl
  224. CompNode::try_coalesce_all_free_memory();
  225. options().user_data.clear_all_user_data();
  226. components().~Components();
  227. m_var_receiver.clear();
  228. m_opr_refkeeper.clear();
  229. }
  230. OperatorNodeBase* ComputingGraphImpl::insert_opr(
  231. std::unique_ptr<OperatorNodeBase> opr_uniqp) {
  232. auto opr = opr_uniqp.get();
  233. if (opr->inserted_in_graph()) {
  234. // FIXME: it's just a trick used for re-evaluation in eager evaluation
  235. // mode. Since comp_graph has already taken an ownership of the opr,
  236. // we can release it directly.
  237. mgb_throw_if(
  238. #if MGB_BUILD_SLIM_SERVING
  239. true,
  240. #else
  241. !options().eager_evaluation,
  242. #endif
  243. GraphError, "an inserted opr %s re-insert into graph"
  244. "with eager evaluation mode OFF.", opr->cname());
  245. opr_uniqp.release();
  246. // No need to do the insert_post under eager mode
  247. eager_eval_manager().on_opr_insert(opr);
  248. return opr;
  249. }
  250. auto&& infer_mgr = static_infer_manager_impl();
  251. auto cleanup = [&]() {
  252. infer_mgr.set_register_allowed_opr(nullptr);
  253. for (auto i : opr->output()) {
  254. infer_mgr.clear_tag_handler(i);
  255. var_node_mem_manager().remove_var_node_mem_trait(i);
  256. }
  257. };
  258. if (auto ret = graph_optimizer().insert_pre(opr)) {
  259. bool should_update_shape = true;
  260. #if !MGB_BUILD_SLIM_SERVING
  261. // in normal mode, we update the shape in deduplication in case shape
  262. // changes; in eager evaluation mode, shape is set by EagerEvalManager
  263. // and should not be modified
  264. should_update_shape = !options().eager_evaluation;
  265. #endif
  266. if (should_update_shape) {
  267. update_output_shapes(infer_mgr, ret, false);
  268. }
  269. cleanup();
  270. event().signal_inplace<cg::event::OprInserted>(true, ret, nullptr);
  271. ret = graph_optimizer().insert_post(ret);
  272. eager_eval_manager().on_opr_insert(ret);
  273. return ret;
  274. }
  275. // record opr early, since exceptions may refer to the opr
  276. m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
  277. MGB_TRY {
  278. mgb_assert(!opr->inserted_in_graph());
  279. mgb_assert(!opr->output().empty(),
  280. "operator must have at least one output");
  281. opr->set_inserted_in_graph();
  282. // basic init
  283. opr->init_output_comp_node();
  284. opr->init_output_dtype();
  285. opr->init_output_format();
  286. // check output initialized
  287. for (auto i : opr->output()) {
  288. mgb_assert(i->comp_node().valid() && i->dtype().valid());
  289. }
  290. // register static infer
  291. {
  292. auto old = infer_mgr.set_register_allowed_opr(opr);
  293. opr->init_output_static_infer_desc();
  294. infer_mgr.set_register_allowed_opr(old);
  295. }
  296. // more init
  297. opr->init_rt_force_dynamic_mem_alloc_imply_chain();
  298. // freeze output flag and static infer shape eagerly
  299. update_output_shapes(infer_mgr, opr, true);
  300. check_opr_not_cross_mem(opr);
  301. }
  302. MGB_CATCH(MegBrainError & exc, {
  303. cleanup();
  304. if (!exc.extra_info())
  305. OperatorNodeExcExtraInfo::record(opr, exc);
  306. event().signal_inplace<cg::event::OprInserted>(false, opr, &exc);
  307. throw;
  308. })
  309. // add to receiver list if above succeeds
  310. for (auto&& i : opr->input()) {
  311. auto iter = m_var_receiver.find(i);
  312. mgb_assert(iter != m_var_receiver.end());
  313. auto&& arr = iter->second;
  314. if (arr.empty() || arr.back() != opr) {
  315. // check if added, because opr may have identical inputs
  316. arr.push_back(opr);
  317. }
  318. }
  319. // alloc var receiver for the outputs
  320. for (auto&& i : opr->output()) {
  321. bool em = m_var_receiver[i].empty();
  322. mgb_assert(em);
  323. }
  324. event().signal_inplace<cg::event::OprInserted>(false, opr, nullptr);
  325. opr = graph_optimizer().insert_post(opr);
  326. eager_eval_manager().on_opr_insert(opr);
  327. return opr;
  328. }
  329. std::shared_ptr<ComputingGraph> ComputingGraph::make() {
  330. return std::make_shared<ComputingGraphImpl>();
  331. }
  332. std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile(
  333. const OutputSpec& out_spec) {
  334. return compile_commit(compile_prepare(out_spec));
  335. }
  336. SmallVector<std::unique_ptr<AsyncExecutable>>
  337. ComputingGraphImpl::compile_multi_part(
  338. const SmallVector<OutputSpec>& out_specs) {
  339. #if MGB_ENABLE_PARTIAL_EXECUTION
  340. return MultiPartCompiler{this}.compile(out_specs);
  341. #else
  342. mgb_throw(MegBrainError, "partial execution disabled at compile time");
  343. #endif
  344. }
  345. ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
  346. const OutputSpec& out_spec) {
  347. auto&& cmpnt = components();
  348. mgb_throw_if(m_recorded_seq_level2_dtor_chk, GraphError,
  349. "graphs with comp_node_seq_record_level==2 can only be "
  350. "compiled once");
  351. mgb_throw_if(out_spec.empty(), GraphError,
  352. "empty output spec given to ComputingGraph::compile");
  353. // topo sorter may have modified opr properties; restore them before this
  354. // new compiling
  355. topo_sorter().restore_opr_prop();
  356. cmpnt.seq_comp_node_opt.restore_comp_nodes();
  357. SpecialOprStat sopr_stat;
  358. auto dest_vars = get_dest_vars_from_out_spec(out_spec, sopr_stat);
  359. #if MGB_ENABLE_SUBLINEAR
  360. if (options().enable_sublinear_memory_opt) {
  361. if (!sopr_stat.has_virtual_grad) {
  362. mgb_log_warn(
  363. "no virtual grad var; sublinear memory may produce "
  364. "unsatisfying result");
  365. }
  366. seq_modifier_for_sublinear_memory().set_priority_before_opt(
  367. dest_vars);
  368. }
  369. #else
  370. mgb_assert(!options().enable_sublinear_memory_opt);
  371. #endif // MGB_ENABLE_SUBLINEAR
  372. #if !MGB_BUILD_SLIM_SERVING
  373. mgb_assert(!options().eager_evaluation,
  374. "attempt to compile eager_evaluation graph");
  375. {
  376. bool need_opt = std::abs(options().graph_opt_level) >= 2;
  377. gopt::GraphOptimizer optimizer;
  378. optimizer.verbosity(options().log_level);
  379. optimizer.enable_check_result(options().graph_opt_level < 0);
  380. if (sopr_stat.has_virtual_grad) {
  381. if (need_opt)
  382. optimizer.add_preset_passes(false, nullptr, &options());
  383. optimizer.add_pass<gopt::ExpandVirtualGradPass>();
  384. }
  385. if (need_opt)
  386. optimizer.add_preset_passes(true, nullptr, &options());
  387. optimizer.apply_inplace(dest_vars);
  388. }
  389. #endif
  390. #if MGB_ENABLE_TENSOR_RT
  391. if (options().graph_opt.tensorrt) {
  392. options().graph_opt.tensorrt = false;
  393. tensorrt::transform_dest_vars_inplace(dest_vars);
  394. }
  395. #endif
  396. if (options().graph_opt.winograd_transform) {
  397. options().graph_opt.winograd_transform = false;
  398. gopt::transform_vars_inplace_with_winograd(dest_vars);
  399. }
  400. if (options().graph_opt.transform_chwn4()) {
  401. gopt::GraphOptimizer optimizer;
  402. optimizer.apply_optimize_options(options().graph_opt);
  403. options().graph_opt.layout_transform =
  404. cg::GraphCommonOptimizeOptions::LayoutTransform::DEFAULT;
  405. optimizer.apply_inplace(dest_vars);
  406. }
  407. #if MGB_JIT
  408. if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) {
  409. setenv("MGB_JIT_BACKEND","NVRTC",1);
  410. gopt::GraphOptimizer optimizer;
  411. optimizer.add_pass<gopt::JITFusionPass>(
  412. sopr_stat.has_virtual_grad,
  413. std::max<uint8_t>(options().graph_opt.jit, 1));
  414. optimizer.apply_inplace(dest_vars);
  415. }
  416. #endif
  417. const OprNodeArray* opr_seq = nullptr;
  418. CompSeqExtraInfo extra_info;
  419. cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars);
  420. auto init_opr_seq = [&]() {
  421. ThinHashMap<VarNode*, CallbackCaller*> var2cb_caller;
  422. for (size_t i = 0; i < out_spec.size(); ++i) {
  423. auto&& cb = out_spec[i].second;
  424. if (cb) {
  425. auto var = dest_vars[i];
  426. auto&& cb_caller = var2cb_caller[var];
  427. if (!cb_caller) {
  428. auto dvar = CallbackCaller::make(var);
  429. cb_caller = &dvar.node()
  430. ->owner_opr()
  431. ->cast_final_safe<CallbackCaller>();
  432. ++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req;
  433. cb_caller->clear_callback();
  434. }
  435. cb_caller->add_callback(cb);
  436. dest_vars[i] = cb_caller->output(0);
  437. }
  438. }
  439. opr_seq = topo_sorter().get_comp_seq(extra_info, dest_vars);
  440. };
  441. #if MGB_ENABLE_MEMORY_SWAP
  442. bool enable_swap_memory_after_sublinear =
  443. options().enable_sublinear_memory_opt &&
  444. options().enable_memory_swap;
  445. bool enable_swap_memory_without_sublinear =
  446. !(options().enable_sublinear_memory_opt) &&
  447. options().enable_memory_swap;
  448. if (enable_swap_memory_without_sublinear) {
  449. components().memory_swap_support.modify_dest_var_inplace(dest_vars);
  450. }
  451. #else
  452. mgb_assert(!options().enable_memory_swap);
  453. #endif
  454. #if MGB_ENABLE_SUBLINEAR
  455. if (options().enable_sublinear_memory_opt) {
  456. MGB_TRY {
  457. seq_modifier_for_sublinear_memory().modify_endpoint_vars(
  458. dest_vars);
  459. #if MGB_ENABLE_MEMORY_SWAP
  460. if (enable_swap_memory_after_sublinear) {
  461. cmpnt.memory_swap_support.modify_dest_var_inplace(dest_vars);
  462. }
  463. #endif
  464. init_opr_seq();
  465. }
  466. MGB_FINALLY(
  467. /*
  468. * restore graph option immediately because it may be
  469. * read/modified by user
  470. */
  471. seq_modifier_for_sublinear_memory().restore_graph_option());
  472. seq_modifier_for_sublinear_memory().sanity_check(*opr_seq);
  473. } else {
  474. init_opr_seq();
  475. }
  476. #else
  477. init_opr_seq();
  478. #endif // MGB_ENABLE_SUBLINEAR
  479. return {std::move(extra_info), opr_seq};
  480. }
  481. std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile_commit(
  482. CompileState state) {
  483. auto comp_seq = std::make_unique<ComputingSequence>(shared_from_this());
  484. comp_seq->extra_info = std::move(state.extra_info);
  485. auto opr_seq = state.opr_seq;
  486. auto&& cmpnt = components();
  487. comp_seq->setup_opr_seq(opr_seq);
  488. for (auto&& i : *opr_seq) {
  489. for (auto&& j : i->node_prop().dep_map()) {
  490. if (OperatorNodeBase::NodeProp::is_device_value_dep(j.second)) {
  491. comp_seq->extra_info.var2recvinfo.at(j.first)
  492. .last_dev_value_reader = i;
  493. }
  494. }
  495. }
  496. comp_seq->attach_to_graph();
  497. MGB_TRY {
  498. var_node_mem_manager().reset_opr_seq(comp_seq->extra_info, opr_seq);
  499. static_infer_comp_seq_manager().reset_dest(comp_seq->extra_info);
  500. cmpnt.seq_comp_node_opt.init_ready_event(comp_seq->extra_info, *opr_seq);
  501. if (options().allocate_static_mem_after_graph_compile)
  502. var_node_mem_manager().alloc_var_node_mem_static();
  503. }
  504. MGB_FINALLY({ var_node_mem_manager().on_graph_compile_finished(); });
  505. event().signal_inplace<event::CompSeqOrderDetermined>(this, comp_seq.get());
  506. if (options().comp_node_seq_record_level > 1) {
  507. mgb_assert(options().comp_node_seq_record_level <= 2,
  508. "invalid comp_node_seq_record_level: %u",
  509. options().comp_node_seq_record_level);
  510. mgb_assert(!options().fake_next_exec &&
  511. !options().var_sanity_check_first_run,
  512. "both fake_next_exec and var_sanity_check_first_run "
  513. "must be false when comp_node_seq_record_level is 2");
  514. return comp_seq->as_recorded_seq();
  515. }
  516. return comp_seq;
  517. }
  518. VarNodeArray ComputingGraphImpl::get_dest_vars_from_out_spec(
  519. const OutputSpec& spec, SpecialOprStat& sopr_stat) {
  520. SymbolVarArray sym_vars;
  521. for (auto&& i : spec) {
  522. sym_vars.push_back(i.first);
  523. }
  524. return to_var_node_array(
  525. get_dest_vars_with_extra_deps(sym_vars, &sopr_stat));
  526. }
  527. const ComputingGraph::VarReceiverInfo&
  528. ComputingGraphImpl::var_receiver_in_current_comp_seq(const VarNode* var) const {
  529. static VarReceiverInfo empty;
  530. if (auto ret = components().eager_eval_manager.var_receiver_info(var)) {
  531. return *ret;
  532. }
  533. if (!m_current_comp_seq)
  534. return empty;
  535. auto cseq = static_cast<ComputingSequence*>(m_current_comp_seq);
  536. auto iter = cseq->extra_info.var2recvinfo.find(var);
  537. if (iter == cseq->extra_info.var2recvinfo.end())
  538. return empty;
  539. return iter->second;
  540. }
  541. VarNode* ComputingGraphImpl::find_var_by_id(size_t id) const {
  542. for (auto&& i : m_opr_refkeeper) {
  543. for (auto j : i->output()) {
  544. if (j->id() == id)
  545. return j;
  546. }
  547. }
  548. for (auto&& i : m_subgraphs) {
  549. auto sub = i->find_var_by_id(id);
  550. if (sub)
  551. return sub;
  552. }
  553. return nullptr;
  554. }
  555. #if MGB_ENABLE_SUBLINEAR
  556. SeqModifierForSublinearMemory&
  557. ComputingGraphImpl::seq_modifier_for_sublinear_memory() {
  558. return components().seq_modifier_for_sublinear_memory;
  559. }
  560. #endif
  561. void ComputingGraphImpl::share_device_memory_with(ComputingGraph& other) {
  562. mgb_assert(
  563. !m_current_comp_seq,
  564. "share_device_memory_with must be called before compiling graph");
  565. auto&& oimpl = static_cast<ComputingGraphImpl&>(other);
  566. var_node_mem_manager().static_device_memory_manager(
  567. oimpl.var_node_mem_manager().static_device_memory_manager());
  568. }
  569. void ComputingGraphImpl::set_device_memory_allocator(
  570. std::shared_ptr<DeviceMemoryAllocator> allocator) {
  571. var_node_mem_manager().static_device_memory_manager()->set_allocator(
  572. std::move(allocator));
  573. }
  574. size_t ComputingGraphImpl::get_device_memory_size(CompNode cn) {
  575. return var_node_mem_manager().static_device_memory_manager()->get_size(cn);
  576. }
  577. size_t ComputingGraphImpl::clear_device_memory() {
  578. #if !MGB_BUILD_SLIM_SERVING
  579. if (options().eager_evaluation) {
  580. for (auto& opr : m_opr_refkeeper) {
  581. if (!opr->same_type<mgb::opr::SharedDeviceTensor>() &&
  582. !opr->same_type<mgb::opr::ImmutableTensor>()) {
  583. for (auto& var : opr->output()) {
  584. if (var->mem_plan().valid())
  585. var->mem_plan().release_chunk();
  586. }
  587. }
  588. }
  589. }
  590. #endif
  591. return var_node_mem_manager().clear_static_device_memory();
  592. }
  593. void ComputingGraphImpl::set_as_subgraph(ComputingGraph& par_graph) {
  594. m_parent_graph = static_cast<ComputingGraphImpl*>(&par_graph);
  595. m_parent_graph->m_subgraphs.emplace_back(this);
  596. m_node_id_counter = m_parent_graph->m_node_id_counter;
  597. options().var_sanity_check_first_run =
  598. par_graph.options().var_sanity_check_first_run;
  599. par_graph.event().signal_inplace<event::SubgraphAssociated>(&par_graph,
  600. this);
  601. }
  602. void ComputingGraphImpl::record_async_error(
  603. std::unique_ptr<MegBrainError> async_exc) {
  604. mgb_assert(m_current_comp_seq);
  605. static_cast<ComputingSequence*>(m_current_comp_seq)
  606. ->set_async_error(std::move(async_exc));
  607. }
  608. const CompSeqExtraInfo& ComputingGraphImpl::current_comp_seq_extra_info() {
  609. if (auto ret = eager_eval_manager().comp_seq_extra_info()) {
  610. return *ret;
  611. }
  612. mgb_assert(m_current_comp_seq);
  613. return static_cast<ComputingSequence*>(m_current_comp_seq)->extra_info;
  614. }
  615. GraphExecutable::ExecEnv* ComputingGraphImpl::current_exec_env() {
  616. if (auto ret = eager_eval_manager().exec_env()) {
  617. return ret;
  618. }
  619. if (m_current_comp_seq) {
  620. return &static_cast<ComputingSequence*>(m_current_comp_seq)->exec_env();
  621. }
  622. return nullptr;
  623. }
  624. Maybe<size_t> ComputingGraphImpl::opr_step_num_in_cur_comp_seq(
  625. OperatorNodeBase* opr) {
  626. mgb_assert(m_current_comp_seq && opr->owner_graph() == this);
  627. return static_cast<ComputingSequence*>(m_current_comp_seq)
  628. ->opr2stepnum(opr);
  629. }
  630. std::string ComputingGraphImpl::VarReceiverInfo::to_string() const {
  631. return mgb_ssprintf_log(
  632. "VarReceiverInfo("
  633. "nr_direct_comp_req=%zu dev_value=%zu, host_value=%zu, shape=%zu, "
  634. "allow_empty_value=%zu)",
  635. nr_direct_comp_req, dev_value, host_value, shape,
  636. allow_empty_value);
  637. }
  638. std::string ComputingGraphImpl::get_mem_allocation_info() const {
  639. #if MGB_ENABLE_JSON
  640. auto make_var_json = [](VarNode* single_var) {
  641. auto &&cur_mem_plan = single_var->mem_plan();
  642. if (cur_mem_plan.valid())
  643. return json::Object::make({
  644. {"name", json::String::make(single_var->name())},
  645. {"memory", json::Number::make(cur_mem_plan.chunk().size())},
  646. {"dev_ptr", json::NumberInt::make(
  647. reinterpret_cast<size_t>(single_var->dev_tensor().raw_ptr()))}
  648. });
  649. else
  650. return json::Object::make({
  651. {"name", json::String::make(single_var->name())},
  652. {"memory", json::Null::make()},
  653. {"dev_ptr", json::Null::make()}
  654. });
  655. };
  656. auto objlist = json::Array::make();
  657. for(auto &opri: m_opr_refkeeper){
  658. auto cur_opr = opri.get();
  659. auto objptr = json::Object::make();
  660. auto &&objbody = *objptr;
  661. objbody["name"] = json::String::make(cur_opr->name());
  662. auto jvars = json::Array::make();
  663. for(auto &outputi: cur_opr->output()){
  664. jvars->add(make_var_json(outputi));
  665. }
  666. objbody["output"] = jvars;
  667. auto obj = json::Object::make({{std::to_string(cur_opr->id()), objptr}});
  668. objlist->add(obj);
  669. }
  670. return objlist->to_string();
  671. #endif // MGB_ENABLE_JSON
  672. mgb_log_warn("mgb is not configured with MGB_ENABLE_JSON on,"
  673. "get_mem_allocation_info returns null string");
  674. return std::string();
  675. }
  676. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台