You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cg_impl.cpp 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. /**
  2. * \file src/core/impl/graph/cg_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./cg_impl.h"
  12. #include "./cg_impl_partial.h"
  13. #include "./cg_impl_seq.h"
  14. #include "megbrain/gopt/framework.h"
  15. #include "megbrain/gopt/inference.h"
  16. #include "megbrain/gopt/basic_arith.h"
  17. #include "megbrain/gopt/misc.h"
  18. #include "megbrain/graph/cg.h"
  19. #include "megbrain/graph/event.h"
  20. #include "megbrain/graph/exc_extra_info.h"
  21. #include "megbrain/graph/helper.h"
  22. #include "megbrain/opr/utility.h"
  23. #if MGB_ENABLE_TENSOR_RT
  24. #include "megbrain/tensorrt/opr_replace.h"
  25. #endif
  26. #if MGB_JIT
  27. #include "megbrain/jit/fusion_pass.h"
  28. #endif
  29. using namespace mgb;
  30. using namespace cg;
  31. namespace {
  32. void check_opr_not_cross_mem(OperatorNodeBase* opr) {
  33. if (opr->node_prop().contain(
  34. OperatorNodeBase::NodeProp::Flag::CROSS_COMP_NODE_MEMORY))
  35. return;
  36. MemNode mem_node_id;
  37. bool first = true;
  38. auto check = [&](VarNode* var) {
  39. auto cur = var->comp_node().mem_node();
  40. mgb_assert(cur);
  41. if (first) {
  42. first = false;
  43. mem_node_id = cur;
  44. } else
  45. mgb_assert(mem_node_id == cur,
  46. "for non cross-memory oprs, "
  47. "all vars should reside on the same memory node");
  48. };
  49. for (auto i : opr->input()) {
  50. check(i);
  51. }
  52. for (auto i : opr->output()) {
  53. check(i);
  54. }
  55. }
  56. void update_output_shapes(static_infer::StaticInferManagerImpl& infer_mgr,
  57. OperatorNodeBase* opr, bool add_freeze_flag) {
  58. for (auto i : opr->output()) {
  59. if (add_freeze_flag) {
  60. i->add_flag(VarNode::Flag::FLAG_FREEZED);
  61. }
  62. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  63. using namespace static_infer;
  64. if (infer_mgr.get_infer_type(i).shape &
  65. (InferType::CONST | InferType::RT_STATIC)) {
  66. auto shp = infer_mgr.infer_shape_fallible(i);
  67. if (shp) {
  68. i->shape(*shp);
  69. } else {
  70. i->shape({});
  71. }
  72. } else {
  73. i->shape({});
  74. }
  75. }
  76. }
  77. }
  78. } // anonymous namespace
  79. /* ========================== global helpers ========================== */
  80. void cg::update_output_var_shapes(OperatorNodeBase* opr) {
  81. update_output_shapes(static_cast<static_infer::StaticInferManagerImpl&>(
  82. opr->owner_graph()->static_infer_manager()),
  83. opr, false);
  84. }
  85. /* ========================= DeviceMemoryAllocator ========================= */
  86. void DeviceMemoryAllocator::alloc_static(ComputingGraph*,
  87. DeviceTensorStorage& dest,
  88. size_t size) {
  89. dest.ensure_size(size);
  90. }
  91. void DeviceMemoryAllocator::alloc_dynamic(VarNode*, DeviceTensorStorage& dest,
  92. size_t size) {
  93. dest.ensure_size(size);
  94. }
  95. void DeviceMemoryAllocator::defrag_prealloc_contig(ComputingGraph* graph,
  96. CompNode comp_node,
  97. size_t size){
  98. MGB_TRY{comp_node.free_device(comp_node.alloc_device(size));
  99. }
  100. MGB_CATCH(MemAllocError&, {})
  101. }
  102. size_t DeviceMemoryAllocator::static_alloc_version(ComputingGraph*) const {
  103. return 0;
  104. }
  105. /* ========================== ComputingGraph ========================== */
  106. ComputingGraph::ComputingGraph() {
  107. static std::atomic_size_t tot_id{0};
  108. m_id = (tot_id++);
  109. }
  110. void ComputingGraph::assert_destroy(std::shared_ptr<ComputingGraph>& ptr) {
  111. mgb_assert(ptr.use_count() == 1, "unexpected use_count: %zu",
  112. size_t(ptr.use_count()));
  113. ptr.reset();
  114. }
  115. #if !MGB_THREAD_SAFE
  116. size_t ComputingGraph::prealloc_static_storage(size_t size) {
  117. // note that in single-threaded mode, all cpus map to the same comp node
  118. static int version = 0;
  119. auto cn = CompNode::load("cpu0");
  120. mgb_assert(cn == CompNode::load("cpu1"));
  121. auto inst = StaticDeviceMemoryManager::make_default_impl();
  122. auto ret = inst->get_size(cn);
  123. inst->alloc(nullptr, cn, size, version).ptr();
  124. version = inst->version(nullptr);
  125. return ret;
  126. }
  127. #endif
  128. /* ========================== CallbackCaller ========================== */
  129. MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
  130. SingleCNOperatorNodeBase) // {
  131. std::vector<std::vector<ComputingGraph::Callback>> m_cb;
  132. void scn_do_execute() override {
  133. for (size_t i = 0; i < input().size(); ++i) {
  134. auto&& in = input(i)->dev_tensor();
  135. for (auto&& callback : m_cb[i]) {
  136. // const cast for backward API compatibility
  137. callback(const_cast<DeviceTensorND&>(in));
  138. }
  139. }
  140. }
  141. void init_output_static_infer_desc() override {
  142. using namespace cg::static_infer;
  143. owner_graph()->static_infer_manager().register_shape_infer(
  144. output(0), ShapeInferDesc::make_const({}));
  145. }
  146. void add_input_layout_constraint() override {
  147. if (owner_graph()->options().comp_node_seq_record_level) {
  148. // the user callback usually copies from device to host, which
  149. // involves tmp alloc if input is not contiguous
  150. for (auto&& inp : input()) {
  151. inp->add_layout_constraint_contiguous();
  152. }
  153. }
  154. }
  155. void init_output_dtype() override {
  156. if (output(0)->dtype().valid()) {
  157. return;
  158. }
  159. mgb_assert(!input().empty());
  160. DType dtype = input(0)->dtype();
  161. mgb_assert(dtype.valid() && dtype != dtype::Byte());
  162. output(0)->dtype(dtype);
  163. }
  164. NodeProp* do_make_node_prop() const override {
  165. auto ret = Super::do_make_node_prop();
  166. for (auto&& inp : input()) {
  167. ret->add_dep_type_existing_var(
  168. inp, NodeProp::DepType::VALUE_ALLOW_EMPTY);
  169. }
  170. return ret;
  171. }
  172. bool update_priority() const override {
  173. node_prop().attribute().priority = std::numeric_limits<int>::min();
  174. return true;
  175. }
  176. public:
  177. CallbackCaller(const VarNodeArrayView& inp)
  178. : Super{inp[0]->owner_graph(), {}, "callback", inp} {
  179. mgb_assert(!inp.empty());
  180. m_cb.resize(inp.size());
  181. for (auto&& i : inp) {
  182. add_input({i});
  183. }
  184. using F = VarNode::Flag;
  185. add_output(None)
  186. ->add_flag(F::ALLOW_EMPTY_SHAPE)
  187. .add_flag(F::VOLATILE_CONTENT);
  188. }
  189. static SymbolVar make(const VarNodeArrayView& inp) {
  190. mgb_assert(!inp.empty());
  191. return SymbolVar{inp[0]}
  192. .node()
  193. ->owner_graph()
  194. ->insert_opr(std::make_unique<CallbackCaller>(inp))
  195. ->output(0);
  196. }
  197. void add_callback(const ComputingGraph::Callback& cb, size_t i = 0) {
  198. mgb_assert(cb && i < m_cb.size());
  199. m_cb[i].push_back(cb);
  200. }
  201. void clear_callback() {
  202. for (size_t i = 0; i < m_cb.size(); ++i) {
  203. m_cb[i].clear();
  204. }
  205. }
  206. };
  207. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ComputingGraphImpl::CallbackCaller);
  208. /* ========================== ComputingGraphImpl ========================== */
  209. ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner)
  210. : topo_sorter{owner},
  211. var_node_mem_manager{owner},
  212. seq_comp_node_opt{owner},
  213. static_infer_manager{owner},
  214. static_infer_comp_seq_manager{owner},
  215. grad_manager{owner},
  216. #if MGB_ENABLE_SUBLINEAR
  217. seq_modifier_for_sublinear_memory{owner,
  218. &(owner->options().sublinear_mem_config)},
  219. #endif
  220. #if MGB_ENABLE_DTR
  221. seq_modifier_for_dtr{owner,
  222. &(owner->options().dtr_config)},
  223. #endif
  224. #if MGB_ENABLE_MEMORY_SWAP
  225. memory_swap_support{owner},
  226. #endif
  227. eager_eval_manager{owner}
  228. {
  229. }
  230. ComputingGraphImpl::ComputingGraphImpl() {
  231. auto ptr = new (&m_components_storage) Components{this};
  232. mgb_assert(ptr == &components());
  233. }
  234. ComputingGraphImpl::~ComputingGraphImpl() {
  235. if (!is_finalized()) {
  236. cleanup();
  237. }
  238. }
  239. std::shared_ptr<void> ComputingGraphImpl::on_comp_node_finalize() {
  240. // hold a reference because the object itself may be deleted by user data or
  241. // oprs
  242. std::shared_ptr<void> ref = shared_from_this();
  243. cleanup();
  244. return ref;
  245. }
  246. void ComputingGraphImpl::cleanup() {
  247. if (m_recorded_seq_level2_dtor_chk) {
  248. m_recorded_seq_level2_dtor_chk->enable();
  249. }
  250. // clear device memory storage and return them to comp node
  251. clear_device_memory();
  252. // so opr dtors would incur no overhead when deleting vars
  253. m_var_node_pool.disable_freelist();
  254. // TODO: call this after each graph exec when we have faster impl
  255. CompNode::try_coalesce_all_free_memory();
  256. options().user_data.clear_all_user_data();
  257. components().~Components();
  258. m_var_receiver.clear();
  259. m_opr_refkeeper.clear();
  260. }
  261. void* ComputingGraphImpl::alloc_varnode_storage() {
  262. return m_var_node_pool.alloc_raw();
  263. };
  264. void ComputingGraphImpl::free_varnode_storage(void *ptr) {
  265. m_var_node_pool.free_raw(ptr);
  266. };
  267. OperatorNodeBase* ComputingGraphImpl::insert_opr(
  268. std::unique_ptr<OperatorNodeBase> opr_uniqp) {
  269. auto opr = opr_uniqp.get();
  270. if (options().imperative_proxy_graph) {
  271. if (!opr->inserted_in_graph()) {
  272. m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
  273. opr->set_inserted_in_graph();
  274. opr->init_output_comp_node();
  275. opr->init_output_dtype();
  276. opr->init_output_format();
  277. // register static infer
  278. {
  279. auto&& mgr = static_infer_manager_impl();
  280. auto old = mgr.set_register_allowed_opr(opr);
  281. opr->init_output_static_infer_desc();
  282. mgr.set_register_allowed_opr(old);
  283. }
  284. }
  285. return opr;
  286. }
  287. if (opr->inserted_in_graph()) {
  288. // FIXME: it's just a trick used for re-evaluation in eager evaluation
  289. // mode. Since comp_graph has already taken an ownership of the opr,
  290. // we can release it directly.
  291. mgb_throw_if(
  292. #if MGB_BUILD_SLIM_SERVING
  293. true,
  294. #else
  295. !options().eager_evaluation,
  296. #endif
  297. GraphError, "an inserted opr %s re-insert into graph"
  298. "with eager evaluation mode OFF.", opr->cname());
  299. opr_uniqp.release();
  300. // No need to do the insert_post under eager mode
  301. eager_eval_manager().on_opr_insert(opr);
  302. return opr;
  303. }
  304. auto&& infer_mgr = static_infer_manager_impl();
  305. auto cleanup = [&]() {
  306. infer_mgr.set_register_allowed_opr(nullptr);
  307. for (auto i : opr->output()) {
  308. infer_mgr.clear_tag_handler(i);
  309. var_node_mem_manager().remove_var_node_mem_trait(i);
  310. }
  311. };
  312. if (auto ret = graph_optimizer().insert_pre(opr)) {
  313. bool should_update_shape = true;
  314. #if !MGB_BUILD_SLIM_SERVING
  315. // in normal mode, we update the shape in deduplication in case shape
  316. // changes; in eager evaluation mode, shape is set by EagerEvalManager
  317. // and should not be modified
  318. should_update_shape = !options().eager_evaluation;
  319. #endif
  320. if (should_update_shape) {
  321. update_output_shapes(infer_mgr, ret, false);
  322. }
  323. cleanup();
  324. event().signal_inplace<cg::event::OprInserted>(true, ret, nullptr);
  325. ret = graph_optimizer().insert_post(ret);
  326. eager_eval_manager().on_opr_insert(ret);
  327. return ret;
  328. }
  329. // record opr early, since exceptions may refer to the opr
  330. m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
  331. MGB_TRY {
  332. mgb_assert(!opr->inserted_in_graph());
  333. mgb_assert(!opr->output().empty(),
  334. "operator must have at least one output");
  335. opr->set_inserted_in_graph();
  336. // basic init
  337. opr->init_output_comp_node();
  338. opr->init_output_dtype();
  339. opr->init_output_format();
  340. // check output initialized
  341. for (auto i : opr->output()) {
  342. mgb_assert(i->comp_node().valid() && i->dtype().valid());
  343. }
  344. // register static infer
  345. {
  346. auto old = infer_mgr.set_register_allowed_opr(opr);
  347. opr->init_output_static_infer_desc();
  348. infer_mgr.set_register_allowed_opr(old);
  349. }
  350. // more init
  351. opr->init_rt_force_dynamic_mem_alloc_imply_chain();
  352. // freeze output flag and static infer shape eagerly
  353. update_output_shapes(infer_mgr, opr, true);
  354. check_opr_not_cross_mem(opr);
  355. }
  356. MGB_CATCH(MegBrainError & exc, {
  357. cleanup();
  358. if (!exc.extra_info())
  359. OperatorNodeExcExtraInfo::record(opr, exc);
  360. event().signal_inplace<cg::event::OprInserted>(false, opr, &exc);
  361. throw;
  362. })
  363. // add to receiver list if above succeeds
  364. for (auto&& i : opr->input()) {
  365. auto iter = m_var_receiver.find(i);
  366. mgb_assert(iter != m_var_receiver.end());
  367. auto&& arr = iter->second;
  368. if (arr.empty() || arr.back() != opr) {
  369. // check if added, because opr may have identical inputs
  370. arr.push_back(opr);
  371. }
  372. }
  373. // alloc var receiver for the outputs
  374. for (auto&& i : opr->output()) {
  375. bool em = m_var_receiver[i].empty();
  376. mgb_assert(em);
  377. }
  378. event().signal_inplace<cg::event::OprInserted>(false, opr, nullptr);
  379. opr = graph_optimizer().insert_post(opr);
  380. eager_eval_manager().on_opr_insert(opr);
  381. return opr;
  382. }
  383. std::shared_ptr<ComputingGraph> ComputingGraph::make() {
  384. return std::make_shared<ComputingGraphImpl>();
  385. }
  386. std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile(
  387. const OutputSpec& out_spec) {
  388. return compile_commit(compile_prepare(out_spec));
  389. }
  390. SmallVector<std::unique_ptr<AsyncExecutable>>
  391. ComputingGraphImpl::compile_multi_part(
  392. const SmallVector<OutputSpec>& out_specs) {
  393. #if MGB_ENABLE_PARTIAL_EXECUTION
  394. return MultiPartCompiler{this}.compile(out_specs);
  395. #else
  396. mgb_throw(MegBrainError, "partial execution disabled at compile time");
  397. #endif
  398. }
  399. ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
  400. const OutputSpec& out_spec) {
  401. auto&& cmpnt = components();
  402. mgb_throw_if(m_recorded_seq_level2_dtor_chk, GraphError,
  403. "graphs with comp_node_seq_record_level==2 can only be "
  404. "compiled once");
  405. mgb_throw_if(out_spec.empty(), GraphError,
  406. "empty output spec given to ComputingGraph::compile");
  407. // topo sorter may have modified opr properties; restore them before this
  408. // new compiling
  409. topo_sorter().restore_opr_prop();
  410. cmpnt.seq_comp_node_opt.restore_comp_nodes();
  411. SpecialOprStat sopr_stat;
  412. auto dest_vars = get_dest_vars_from_out_spec(out_spec, sopr_stat);
  413. #if MGB_ENABLE_SUBLINEAR
  414. if (options().enable_sublinear_memory_opt) {
  415. mgb_assert(!options().enable_dtr_memory_opt);
  416. if (!sopr_stat.has_virtual_grad) {
  417. mgb_log_debug(
  418. "no virtual grad var; sublinear memory may produce "
  419. "unsatisfying result");
  420. }
  421. seq_modifier_for_sublinear_memory().set_priority_before_opt(
  422. dest_vars);
  423. }
  424. #else
  425. mgb_assert(!options().enable_sublinear_memory_opt);
  426. #endif // MGB_ENABLE_SUBLINEAR
  427. #if MGB_ENABLE_DTR
  428. if (options().enable_dtr_memory_opt) {
  429. mgb_assert(!options().enable_sublinear_memory_opt);
  430. seq_modifier_for_dtr().set_priority_before_opt(dest_vars);
  431. }
  432. #else
  433. mgb_assert(!options().enable_dtr_memory_opt);
  434. #endif // MGB_ENABLE_DTR
  435. #if !MGB_BUILD_SLIM_SERVING
  436. mgb_assert(!options().eager_evaluation,
  437. "attempt to compile eager_evaluation graph");
  438. {
  439. bool need_opt = std::abs(options().graph_opt_level) >= 2;
  440. gopt::GraphOptimizer optimizer;
  441. optimizer.verbosity(options().log_level);
  442. optimizer.enable_check_result(options().graph_opt_level < 0);
  443. if (sopr_stat.has_virtual_grad) {
  444. if (need_opt) {
  445. #if MGB_ENABLE_OPR_MM
  446. optimizer.add_pass<gopt::PackAllReduceScanPass>();
  447. #endif
  448. optimizer.add_preset_passes(false, nullptr, &options());
  449. }
  450. optimizer.add_pass<gopt::ExpandVirtualGradPass>();
  451. }
  452. if (need_opt) {
  453. optimizer.add_preset_passes(true, nullptr, &options());
  454. #if MGB_ENABLE_OPR_MM
  455. if (sopr_stat.has_virtual_grad) {
  456. optimizer.add_pass<gopt::PackAllReduceReplacePass>();
  457. }
  458. #endif
  459. }
  460. optimizer.apply_inplace(dest_vars);
  461. }
  462. #endif
  463. #if MGB_ENABLE_TENSOR_RT
  464. if (options().graph_opt.tensorrt) {
  465. options().graph_opt.tensorrt = false;
  466. tensorrt::transform_dest_vars_inplace(dest_vars, options().graph_opt);
  467. }
  468. #endif
  469. #if MGB_JIT
  470. if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) {
  471. setenv("MGB_JIT_BACKEND","NVRTC",1);
  472. gopt::GraphOptimizer optimizer;
  473. optimizer.add_pass<gopt::JITFusionPass>(
  474. sopr_stat.has_virtual_grad,
  475. std::max<uint8_t>(options().graph_opt.jit, 1));
  476. optimizer.apply_inplace(dest_vars);
  477. }
  478. #endif
  479. gopt::GraphOptimizer optimizer;
  480. /**
  481. * \note We should reset options when we add passes indicated by optimize
  482. * options, As there exists `ParamFuse pass` will compile subgraph which may
  483. * cause ring invoking, \see
  484. * https://git-core.megvii-inc.com/brain-sdk/MegBrain/merge_requests/1717
  485. * for detail
  486. */
  487. optimizer.add_passes_for_optimize_options(options().graph_opt, true);
  488. optimizer.apply_inplace(dest_vars);
  489. if (sopr_stat.has_shape_hint) {
  490. // FIXME(zhangxuanrun): strictly speaking, it could and has to remove
  491. // ShapeHints even they were occured in subgraph
  492. mgb_assert(!m_parent_graph, "can not use ShapeHint in subgraph");
  493. // always need remove shape hint
  494. gopt::GraphOptimizer opt;
  495. opt.add_pass<gopt::RemoveShapeHintPass>();
  496. opt.apply_inplace(dest_vars);
  497. }
  498. const OprNodeArray* opr_seq = nullptr;
  499. CompSeqExtraInfo extra_info;
  500. cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars);
  501. bool init_flag = false;
  502. auto init_opr_seq = [&]() {
  503. mgb_assert(!init_flag);
  504. init_flag = true;
  505. ThinHashMap<VarNode*, size_t> var2idx;
  506. std::unordered_map<CallbackCallerKey, CallbackCallerVal,
  507. CallbackCallerKey::Hash>
  508. opr2vars;
  509. using F = VarNode::Flag;
  510. if (dest_vars[0]->owner_graph()->options().force_output_dynamic_alloc) {
  511. for (auto&& i : dest_vars) {
  512. if (!i->contain_flag(F::NO_SYS_MEM_ALLOC |
  513. F::NO_SYS_STATIC_MEM_ALLOC)) {
  514. mgb_assert(
  515. !i->contain_flag(
  516. F::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC),
  517. "Can not force graph output dynamic alloc with "
  518. "DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC flag, var: %s",
  519. i->cname());
  520. i->add_flag(F::NO_SYS_STATIC_MEM_ALLOC);
  521. }
  522. i->add_flag(F::NO_MEM_RECLAIM);
  523. }
  524. }
  525. for (size_t i = 0; i < out_spec.size(); ++i) {
  526. auto&& cb = out_spec[i].second;
  527. if (cb) {
  528. auto var = dest_vars[i];
  529. CallbackCallerKey key{var->owner_opr(), var->comp_node()};
  530. auto&& vals = opr2vars[key];
  531. auto&& var2idx_iter = var2idx.find(var);
  532. if ( var2idx_iter == var2idx.end()) {
  533. vals.vars.push_back(var);
  534. vals.indexs.push_back({i});
  535. var2idx[var] = vals.vars.size() - 1;
  536. } else {
  537. vals.indexs[var2idx_iter->second].push_back(i);
  538. }
  539. }
  540. }
  541. for (auto& item : opr2vars) {
  542. auto&& val = item.second;
  543. auto dvar = CallbackCaller::make(val.vars);
  544. CallbackCaller* cb_caller = &dvar.node()
  545. ->owner_opr()
  546. ->cast_final_safe<CallbackCaller>();
  547. ++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req;
  548. cb_caller->clear_callback();
  549. for (size_t i=0;i<val.vars.size(); ++i) {
  550. for (auto&& idx : val.indexs[i]) {
  551. cb_caller->add_callback(out_spec[idx].second, i);
  552. dest_vars[idx] = cb_caller->output(0);
  553. }
  554. }
  555. }
  556. opr_seq = topo_sorter().get_comp_seq(extra_info, dest_vars);
  557. };
  558. #if MGB_ENABLE_MEMORY_SWAP
  559. bool enable_swap_memory_after_sublinear =
  560. options().enable_sublinear_memory_opt &&
  561. options().enable_memory_swap;
  562. bool enable_swap_memory_without_sublinear =
  563. !(options().enable_sublinear_memory_opt) &&
  564. options().enable_memory_swap;
  565. if (enable_swap_memory_without_sublinear) {
  566. components().memory_swap_support.modify_dest_var_inplace(dest_vars);
  567. }
  568. #else
  569. mgb_assert(!options().enable_memory_swap);
  570. #endif
  571. #if MGB_ENABLE_DTR
  572. if (options().enable_dtr_memory_opt) {
  573. MGB_TRY {
  574. seq_modifier_for_dtr().modify_endpoint_vars(dest_vars);
  575. init_opr_seq();
  576. }
  577. MGB_FINALLY(seq_modifier_for_dtr().restore_graph_option());
  578. }
  579. #endif
  580. #if MGB_ENABLE_SUBLINEAR
  581. if (options().enable_sublinear_memory_opt) {
  582. MGB_TRY {
  583. seq_modifier_for_sublinear_memory().modify_endpoint_vars(
  584. dest_vars);
  585. #if MGB_ENABLE_MEMORY_SWAP
  586. if (enable_swap_memory_after_sublinear) {
  587. cmpnt.memory_swap_support.modify_dest_var_inplace(dest_vars);
  588. }
  589. #endif
  590. init_opr_seq();
  591. }
  592. MGB_FINALLY(
  593. /*
  594. * restore graph option immediately because it may be
  595. * read/modified by user
  596. */
  597. seq_modifier_for_sublinear_memory().restore_graph_option());
  598. seq_modifier_for_sublinear_memory().sanity_check(*opr_seq);
  599. }
  600. #endif // MGB_ENABLE_SUBLINEAR
  601. if (!init_flag) {
  602. init_opr_seq();
  603. }
  604. return {std::move(extra_info), opr_seq, std::move(dest_vars)};
  605. }
  606. std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile_commit(
  607. CompileState state) {
  608. auto comp_seq = std::make_unique<ComputingSequence>(shared_from_this());
  609. comp_seq->extra_info = std::move(state.extra_info);
  610. comp_seq->set_output_vars(state.dest_vars);
  611. auto opr_seq = state.opr_seq;
  612. auto&& cmpnt = components();
  613. comp_seq->setup_opr_seq(opr_seq);
  614. for (auto&& i : *opr_seq) {
  615. for (auto&& j : i->node_prop().dep_map()) {
  616. if (OperatorNodeBase::NodeProp::is_device_value_dep(j.second)) {
  617. comp_seq->extra_info.var2recvinfo.at(j.first)
  618. .last_dev_value_reader = i;
  619. }
  620. }
  621. }
  622. comp_seq->attach_to_graph();
  623. MGB_TRY {
  624. var_node_mem_manager().reset_opr_seq(comp_seq->extra_info, opr_seq);
  625. static_infer_comp_seq_manager().reset_dest(comp_seq->extra_info);
  626. cmpnt.seq_comp_node_opt.init_ready_event(comp_seq->extra_info, *opr_seq);
  627. if (options().allocate_static_mem_after_graph_compile)
  628. var_node_mem_manager().alloc_var_node_mem_static();
  629. }
  630. MGB_FINALLY({ var_node_mem_manager().on_graph_compile_finished(); });
  631. event().signal_inplace<event::CompSeqOrderDetermined>(this, comp_seq.get());
  632. if (options().comp_node_seq_record_level > 1) {
  633. mgb_assert(options().comp_node_seq_record_level <= 2,
  634. "invalid comp_node_seq_record_level: %u",
  635. options().comp_node_seq_record_level);
  636. mgb_assert(!options().fake_next_exec &&
  637. !options().var_sanity_check_first_run,
  638. "both fake_next_exec and var_sanity_check_first_run "
  639. "must be false when comp_node_seq_record_level is 2");
  640. return comp_seq->as_recorded_seq();
  641. }
  642. return comp_seq;
  643. }
  644. VarNodeArray ComputingGraphImpl::get_dest_vars_from_out_spec(
  645. const OutputSpec& spec, SpecialOprStat& sopr_stat) {
  646. SymbolVarArray sym_vars;
  647. for (auto&& i : spec) {
  648. sym_vars.push_back(i.first);
  649. }
  650. return to_var_node_array(
  651. get_dest_vars_with_extra_deps(sym_vars, &sopr_stat));
  652. }
  653. const ComputingGraph::VarReceiverInfo&
  654. ComputingGraphImpl::var_receiver_in_current_comp_seq(const VarNode* var) const {
  655. static VarReceiverInfo empty;
  656. if (auto ret = components().eager_eval_manager.var_receiver_info(var)) {
  657. return *ret;
  658. }
  659. if (!m_current_comp_seq)
  660. return empty;
  661. auto cseq = static_cast<ComputingSequence*>(m_current_comp_seq);
  662. auto iter = cseq->extra_info.var2recvinfo.find(var);
  663. if (iter == cseq->extra_info.var2recvinfo.end())
  664. return empty;
  665. return iter->second;
  666. }
  667. VarNode* ComputingGraphImpl::find_var_by_id(size_t id) const {
  668. for (auto&& i : m_opr_refkeeper) {
  669. for (auto j : i->output()) {
  670. if (j->id() == id)
  671. return j;
  672. }
  673. }
  674. for (auto&& i : m_subgraphs) {
  675. auto sub = i->find_var_by_id(id);
  676. if (sub)
  677. return sub;
  678. }
  679. return nullptr;
  680. }
  681. #if MGB_ENABLE_SUBLINEAR
  682. SeqModifierForSublinearMemory&
  683. ComputingGraphImpl::seq_modifier_for_sublinear_memory() {
  684. return components().seq_modifier_for_sublinear_memory;
  685. }
  686. #endif
  687. #if MGB_ENABLE_DTR
  688. SeqModifierForDTR&
  689. ComputingGraphImpl::seq_modifier_for_dtr() {
  690. return components().seq_modifier_for_dtr;
  691. }
  692. #endif
  693. void ComputingGraphImpl::share_device_memory_with(ComputingGraph& other) {
  694. mgb_assert(
  695. !m_current_comp_seq,
  696. "share_device_memory_with must be called before compiling graph");
  697. auto&& oimpl = *ComputingGraphImpl::downcast(&other);
  698. var_node_mem_manager().static_device_memory_manager(
  699. oimpl.var_node_mem_manager().static_device_memory_manager());
  700. }
  701. void ComputingGraphImpl::set_device_memory_allocator(
  702. std::shared_ptr<DeviceMemoryAllocator> allocator) {
  703. var_node_mem_manager().static_device_memory_manager()->set_allocator(
  704. std::move(allocator));
  705. }
  706. size_t ComputingGraphImpl::get_device_memory_size(CompNode cn) {
  707. return var_node_mem_manager().static_device_memory_manager()->get_size(cn);
  708. }
  709. size_t ComputingGraphImpl::clear_device_memory() {
  710. #if !MGB_BUILD_SLIM_SERVING
  711. if (options().eager_evaluation) {
  712. for (auto& opr : m_opr_refkeeper) {
  713. if (!opr->same_type<mgb::opr::SharedDeviceTensor>() &&
  714. !opr->same_type<mgb::opr::ImmutableTensor>()) {
  715. for (auto& var : opr->output()) {
  716. if (var->mem_plan().valid())
  717. var->mem_plan().release_chunk();
  718. }
  719. }
  720. }
  721. }
  722. #endif
  723. return var_node_mem_manager().clear_static_device_memory();
  724. }
  725. void ComputingGraphImpl::set_as_subgraph(ComputingGraph& par_graph) {
  726. m_parent_graph = ComputingGraphImpl::downcast(&par_graph);
  727. m_parent_graph->m_subgraphs.emplace_back(this);
  728. m_node_id_counter = m_parent_graph->m_node_id_counter;
  729. options().var_sanity_check_first_run =
  730. par_graph.options().var_sanity_check_first_run;
  731. par_graph.event().signal_inplace<event::SubgraphAssociated>(&par_graph,
  732. this);
  733. }
  734. void ComputingGraphImpl::record_async_error(
  735. std::unique_ptr<MegBrainError> async_exc) {
  736. mgb_assert(m_current_comp_seq);
  737. static_cast<ComputingSequence*>(m_current_comp_seq)
  738. ->set_async_error(std::move(async_exc));
  739. }
  740. const CompSeqExtraInfo& ComputingGraphImpl::current_comp_seq_extra_info() {
  741. if (auto ret = eager_eval_manager().comp_seq_extra_info()) {
  742. return *ret;
  743. }
  744. mgb_assert(m_current_comp_seq);
  745. return static_cast<ComputingSequence*>(m_current_comp_seq)->extra_info;
  746. }
  747. GraphExecutable::ExecEnv* ComputingGraphImpl::current_exec_env() {
  748. if (auto ret = eager_eval_manager().exec_env()) {
  749. return ret;
  750. }
  751. if (m_current_comp_seq) {
  752. return &static_cast<ComputingSequence*>(m_current_comp_seq)->exec_env();
  753. }
  754. return nullptr;
  755. }
  756. Maybe<size_t> ComputingGraphImpl::opr_step_num_in_cur_comp_seq(
  757. OperatorNodeBase* opr) {
  758. mgb_assert(m_current_comp_seq && opr->owner_graph() == this);
  759. return static_cast<ComputingSequence*>(m_current_comp_seq)
  760. ->opr2stepnum(opr);
  761. }
  762. std::string ComputingGraphImpl::VarReceiverInfo::to_string() const {
  763. return mgb_ssprintf_log(
  764. "VarReceiverInfo("
  765. "nr_direct_comp_req=%zu dev_value=%zu, host_value=%zu, shape=%zu, "
  766. "allow_empty_value=%zu)",
  767. nr_direct_comp_req, dev_value, host_value, shape,
  768. allow_empty_value);
  769. }
  770. std::string ComputingGraphImpl::get_mem_allocation_info() const {
  771. #if MGB_ENABLE_JSON
  772. auto make_var_json = [](VarNode* single_var) {
  773. auto &&cur_mem_plan = single_var->mem_plan();
  774. if (cur_mem_plan.valid())
  775. return json::Object::make({
  776. {"name", json::String::make(single_var->name())},
  777. {"memory", json::Number::make(cur_mem_plan.chunk().size())},
  778. {"dev_ptr", json::NumberInt::make(
  779. reinterpret_cast<size_t>(single_var->dev_tensor().raw_ptr()))}
  780. });
  781. else
  782. return json::Object::make({
  783. {"name", json::String::make(single_var->name())},
  784. {"memory", json::Null::make()},
  785. {"dev_ptr", json::Null::make()}
  786. });
  787. };
  788. auto objlist = json::Array::make();
  789. for(auto &opri: m_opr_refkeeper){
  790. auto cur_opr = opri.get();
  791. auto objptr = json::Object::make();
  792. auto &&objbody = *objptr;
  793. objbody["name"] = json::String::make(cur_opr->name());
  794. auto jvars = json::Array::make();
  795. for(auto &outputi: cur_opr->output()){
  796. jvars->add(make_var_json(outputi));
  797. }
  798. objbody["output"] = jvars;
  799. auto obj = json::Object::make({{std::to_string(cur_opr->id()), objptr}});
  800. objlist->add(obj);
  801. }
  802. return objlist->to_string();
  803. #endif // MGB_ENABLE_JSON
  804. mgb_log_warn("target is not configured with JSON BUILD on,"
  805. "get_mem_allocation_info returns null string");
  806. return std::string();
  807. }
  808. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台