You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cg_impl.cpp 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888
  1. /**
  2. * \file src/core/impl/graph/cg_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./cg_impl.h"
  12. #include "./cg_impl_partial.h"
  13. #include "./cg_impl_seq.h"
  14. #include "megbrain/gopt/framework.h"
  15. #include "megbrain/gopt/inference.h"
  16. #include "megbrain/gopt/basic_arith.h"
  17. #include "megbrain/gopt/misc.h"
  18. #include "megbrain/graph/cg.h"
  19. #include "megbrain/graph/event.h"
  20. #include "megbrain/graph/exc_extra_info.h"
  21. #include "megbrain/graph/helper.h"
  22. #include "megbrain/opr/utility.h"
  23. #if MGB_ENABLE_TENSOR_RT
  24. #include "megbrain/tensorrt/opr_replace.h"
  25. #endif
  26. #if MGB_JIT
  27. #include "megbrain/jit/fusion_pass.h"
  28. #endif
  29. using namespace mgb;
  30. using namespace cg;
  31. namespace {
  32. void check_opr_not_cross_mem(OperatorNodeBase* opr) {
  33. if (opr->node_prop().contain(
  34. OperatorNodeBase::NodeProp::Flag::CROSS_COMP_NODE_MEMORY))
  35. return;
  36. MemNode mem_node_id;
  37. bool first = true;
  38. auto check = [&](VarNode* var) {
  39. auto cur = var->comp_node().mem_node();
  40. mgb_assert(cur);
  41. if (first) {
  42. first = false;
  43. mem_node_id = cur;
  44. } else
  45. mgb_assert(mem_node_id == cur,
  46. "for non cross-memory oprs, "
  47. "all vars should reside on the same memory node");
  48. };
  49. for (auto i : opr->input()) {
  50. check(i);
  51. }
  52. for (auto i : opr->output()) {
  53. check(i);
  54. }
  55. }
  56. void update_output_shapes(static_infer::StaticInferManagerImpl& infer_mgr,
  57. OperatorNodeBase* opr, bool add_freeze_flag) {
  58. for (auto i : opr->output()) {
  59. if (add_freeze_flag) {
  60. i->add_flag(VarNode::Flag::FLAG_FREEZED);
  61. }
  62. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  63. using namespace static_infer;
  64. if (infer_mgr.get_infer_type(i).shape &
  65. (InferType::CONST | InferType::RT_STATIC)) {
  66. auto shp = infer_mgr.infer_shape_fallible(i);
  67. if (shp) {
  68. i->shape(*shp);
  69. } else {
  70. i->shape({});
  71. }
  72. } else {
  73. i->shape({});
  74. }
  75. }
  76. }
  77. }
  78. } // anonymous namespace
  79. /* ========================== global helpers ========================== */
  80. void cg::update_output_var_shapes(OperatorNodeBase* opr) {
  81. update_output_shapes(static_cast<static_infer::StaticInferManagerImpl&>(
  82. opr->owner_graph()->static_infer_manager()),
  83. opr, false);
  84. }
  85. /* ========================= DeviceMemoryAllocator ========================= */
  86. void DeviceMemoryAllocator::alloc_static(ComputingGraph*,
  87. DeviceTensorStorage& dest,
  88. size_t size) {
  89. dest.ensure_size(size);
  90. }
  91. void DeviceMemoryAllocator::alloc_dynamic(VarNode*, DeviceTensorStorage& dest,
  92. size_t size) {
  93. dest.ensure_size(size);
  94. }
  95. void DeviceMemoryAllocator::defrag_prealloc_contig(ComputingGraph* graph,
  96. CompNode comp_node,
  97. size_t size){
  98. MGB_TRY{comp_node.free_device(comp_node.alloc_device(size));
  99. }
  100. MGB_CATCH(MemAllocError&, {})
  101. }
  102. size_t DeviceMemoryAllocator::static_alloc_version(ComputingGraph*) const {
  103. return 0;
  104. }
  105. /* ========================== ComputingGraph ========================== */
  106. ComputingGraph::ComputingGraph() {
  107. static std::atomic_size_t tot_id{0};
  108. m_id = (tot_id++);
  109. }
  110. void ComputingGraph::assert_destroy(std::shared_ptr<ComputingGraph>& ptr) {
  111. mgb_assert(ptr.use_count() == 1, "unexpected use_count: %zu",
  112. size_t(ptr.use_count()));
  113. ptr.reset();
  114. }
  115. #if !MGB_THREAD_SAFE
  116. size_t ComputingGraph::prealloc_static_storage(size_t size) {
  117. // note that in single-threaded mode, all cpus map to the same comp node
  118. static int version = 0;
  119. auto cn = CompNode::load("cpu0");
  120. mgb_assert(cn == CompNode::load("cpu1"));
  121. auto inst = StaticDeviceMemoryManager::make_default_impl();
  122. auto ret = inst->get_size(cn);
  123. inst->alloc(nullptr, cn, size, version).ptr();
  124. version = inst->version(nullptr);
  125. return ret;
  126. }
  127. #endif
  128. /* ========================== CallbackCaller ========================== */
  129. MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
  130. SingleCNOperatorNodeBase) // {
  131. std::vector<std::vector<ComputingGraph::Callback>> m_cb;
  132. void scn_do_execute() override {
  133. for (size_t i = 0; i < input().size(); ++i) {
  134. auto&& in = input(i)->dev_tensor();
  135. for (auto&& callback : m_cb[i]) {
  136. // const cast for backward API compatibility
  137. callback(const_cast<DeviceTensorND&>(in));
  138. }
  139. }
  140. }
  141. void init_output_static_infer_desc() override {
  142. using namespace cg::static_infer;
  143. owner_graph()->static_infer_manager().register_shape_infer(
  144. output(0), ShapeInferDesc::make_const({}));
  145. }
  146. void add_input_layout_constraint() override {
  147. if (owner_graph()->options().comp_node_seq_record_level) {
  148. // the user callback usually copies from device to host, which
  149. // involves tmp alloc if input is not contiguous
  150. for (auto&& inp : input()) {
  151. inp->add_layout_constraint_contiguous();
  152. }
  153. }
  154. }
  155. void init_output_dtype() override {
  156. if (output(0)->dtype().valid()) {
  157. return;
  158. }
  159. mgb_assert(!input().empty());
  160. DType dtype = input(0)->dtype();
  161. mgb_assert(dtype.valid() && dtype != dtype::Byte());
  162. output(0)->dtype(dtype);
  163. }
  164. NodeProp* do_make_node_prop() const override {
  165. auto ret = Super::do_make_node_prop();
  166. for (auto&& inp : input()) {
  167. ret->add_dep_type_existing_var(
  168. inp, NodeProp::DepType::VALUE_ALLOW_EMPTY);
  169. }
  170. return ret;
  171. }
  172. bool update_priority() const override {
  173. node_prop().attribute().priority = std::numeric_limits<int>::min();
  174. return true;
  175. }
  176. public:
  177. CallbackCaller(const VarNodeArrayView& inp)
  178. : Super{inp[0]->owner_graph(), {}, "callback", inp} {
  179. mgb_assert(!inp.empty());
  180. m_cb.resize(inp.size());
  181. for (auto&& i : inp) {
  182. add_input({i});
  183. }
  184. using F = VarNode::Flag;
  185. add_output(None)
  186. ->add_flag(F::ALLOW_EMPTY_SHAPE)
  187. .add_flag(F::VOLATILE_CONTENT);
  188. }
  189. static SymbolVar make(const VarNodeArrayView& inp) {
  190. mgb_assert(!inp.empty());
  191. return SymbolVar{inp[0]}
  192. .node()
  193. ->owner_graph()
  194. ->insert_opr(std::make_unique<CallbackCaller>(inp))
  195. ->output(0);
  196. }
  197. void add_callback(const ComputingGraph::Callback& cb, size_t i = 0) {
  198. mgb_assert(cb && i < m_cb.size());
  199. m_cb[i].push_back(cb);
  200. }
  201. void clear_callback() {
  202. for (size_t i = 0; i < m_cb.size(); ++i) {
  203. m_cb[i].clear();
  204. }
  205. }
  206. };
  207. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ComputingGraphImpl::CallbackCaller);
  208. /* ========================== ComputingGraphImpl ========================== */
  209. ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner)
  210. : topo_sorter{owner},
  211. var_node_mem_manager{owner},
  212. seq_comp_node_opt{owner},
  213. static_infer_manager{owner},
  214. static_infer_comp_seq_manager{owner},
  215. grad_manager{owner},
  216. #if MGB_ENABLE_SUBLINEAR
  217. seq_modifier_for_sublinear_memory{owner,
  218. &(owner->options().sublinear_mem_config)},
  219. #endif
  220. #if MGB_ENABLE_MEMORY_SWAP
  221. memory_swap_support{owner},
  222. #endif
  223. eager_eval_manager{owner}
  224. {
  225. }
  226. ComputingGraphImpl::ComputingGraphImpl() {
  227. auto ptr = new (&m_components_storage) Components{this};
  228. mgb_assert(ptr == &components());
  229. }
  230. ComputingGraphImpl::~ComputingGraphImpl() {
  231. if (!is_finalized()) {
  232. cleanup();
  233. }
  234. }
  235. std::shared_ptr<void> ComputingGraphImpl::on_comp_node_finalize() {
  236. // hold a reference because the object itself may be deleted by user data or
  237. // oprs
  238. std::shared_ptr<void> ref = shared_from_this();
  239. cleanup();
  240. return ref;
  241. }
  242. void ComputingGraphImpl::cleanup() {
  243. if (m_recorded_seq_level2_dtor_chk) {
  244. m_recorded_seq_level2_dtor_chk->enable();
  245. }
  246. // clear device memory storage and return them to comp node
  247. clear_device_memory();
  248. // so opr dtors would incur no overhead when deleting vars
  249. m_var_node_pool.disable_freelist();
  250. // TODO: call this after each graph exec when we have faster impl
  251. CompNode::try_coalesce_all_free_memory();
  252. options().user_data.clear_all_user_data();
  253. components().~Components();
  254. m_var_receiver.clear();
  255. m_opr_refkeeper.clear();
  256. }
  257. void* ComputingGraphImpl::alloc_varnode_storage() {
  258. return m_var_node_pool.alloc_raw();
  259. };
  260. void ComputingGraphImpl::free_varnode_storage(void *ptr) {
  261. m_var_node_pool.free_raw(ptr);
  262. };
  263. OperatorNodeBase* ComputingGraphImpl::insert_opr(
  264. std::unique_ptr<OperatorNodeBase> opr_uniqp) {
  265. auto opr = opr_uniqp.get();
  266. if (options().imperative_proxy_graph) {
  267. if (!opr->inserted_in_graph()) {
  268. m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
  269. opr->set_inserted_in_graph();
  270. opr->init_output_comp_node();
  271. opr->init_output_dtype();
  272. opr->init_output_format();
  273. // register static infer
  274. {
  275. auto&& mgr = static_infer_manager_impl();
  276. auto old = mgr.set_register_allowed_opr(opr);
  277. opr->init_output_static_infer_desc();
  278. mgr.set_register_allowed_opr(old);
  279. }
  280. }
  281. return opr;
  282. }
  283. if (opr->inserted_in_graph()) {
  284. // FIXME: it's just a trick used for re-evaluation in eager evaluation
  285. // mode. Since comp_graph has already taken an ownership of the opr,
  286. // we can release it directly.
  287. mgb_throw_if(
  288. #if MGB_BUILD_SLIM_SERVING
  289. true,
  290. #else
  291. !options().eager_evaluation,
  292. #endif
  293. GraphError, "an inserted opr %s re-insert into graph"
  294. "with eager evaluation mode OFF.", opr->cname());
  295. opr_uniqp.release();
  296. // No need to do the insert_post under eager mode
  297. eager_eval_manager().on_opr_insert(opr);
  298. return opr;
  299. }
  300. auto&& infer_mgr = static_infer_manager_impl();
  301. auto cleanup = [&]() {
  302. infer_mgr.set_register_allowed_opr(nullptr);
  303. for (auto i : opr->output()) {
  304. infer_mgr.clear_tag_handler(i);
  305. var_node_mem_manager().remove_var_node_mem_trait(i);
  306. }
  307. };
  308. if (auto ret = graph_optimizer().insert_pre(opr)) {
  309. bool should_update_shape = true;
  310. #if !MGB_BUILD_SLIM_SERVING
  311. // in normal mode, we update the shape in deduplication in case shape
  312. // changes; in eager evaluation mode, shape is set by EagerEvalManager
  313. // and should not be modified
  314. should_update_shape = !options().eager_evaluation;
  315. #endif
  316. if (should_update_shape) {
  317. update_output_shapes(infer_mgr, ret, false);
  318. }
  319. cleanup();
  320. event().signal_inplace<cg::event::OprInserted>(true, ret, nullptr);
  321. ret = graph_optimizer().insert_post(ret);
  322. eager_eval_manager().on_opr_insert(ret);
  323. return ret;
  324. }
  325. // record opr early, since exceptions may refer to the opr
  326. m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
  327. MGB_TRY {
  328. mgb_assert(!opr->inserted_in_graph());
  329. mgb_assert(!opr->output().empty(),
  330. "operator must have at least one output");
  331. opr->set_inserted_in_graph();
  332. // basic init
  333. opr->init_output_comp_node();
  334. opr->init_output_dtype();
  335. opr->init_output_format();
  336. // check output initialized
  337. for (auto i : opr->output()) {
  338. mgb_assert(i->comp_node().valid() && i->dtype().valid());
  339. }
  340. // register static infer
  341. {
  342. auto old = infer_mgr.set_register_allowed_opr(opr);
  343. opr->init_output_static_infer_desc();
  344. infer_mgr.set_register_allowed_opr(old);
  345. }
  346. // more init
  347. opr->init_rt_force_dynamic_mem_alloc_imply_chain();
  348. // freeze output flag and static infer shape eagerly
  349. update_output_shapes(infer_mgr, opr, true);
  350. check_opr_not_cross_mem(opr);
  351. }
  352. MGB_CATCH(MegBrainError & exc, {
  353. cleanup();
  354. if (!exc.extra_info())
  355. OperatorNodeExcExtraInfo::record(opr, exc);
  356. event().signal_inplace<cg::event::OprInserted>(false, opr, &exc);
  357. throw;
  358. })
  359. // add to receiver list if above succeeds
  360. for (auto&& i : opr->input()) {
  361. auto iter = m_var_receiver.find(i);
  362. mgb_assert(iter != m_var_receiver.end());
  363. auto&& arr = iter->second;
  364. if (arr.empty() || arr.back() != opr) {
  365. // check if added, because opr may have identical inputs
  366. arr.push_back(opr);
  367. }
  368. }
  369. // alloc var receiver for the outputs
  370. for (auto&& i : opr->output()) {
  371. bool em = m_var_receiver[i].empty();
  372. mgb_assert(em);
  373. }
  374. event().signal_inplace<cg::event::OprInserted>(false, opr, nullptr);
  375. opr = graph_optimizer().insert_post(opr);
  376. eager_eval_manager().on_opr_insert(opr);
  377. return opr;
  378. }
  379. std::shared_ptr<ComputingGraph> ComputingGraph::make() {
  380. return std::make_shared<ComputingGraphImpl>();
  381. }
  382. std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile(
  383. const OutputSpec& out_spec) {
  384. return compile_commit(compile_prepare(out_spec));
  385. }
  386. SmallVector<std::unique_ptr<AsyncExecutable>>
  387. ComputingGraphImpl::compile_multi_part(
  388. const SmallVector<OutputSpec>& out_specs) {
  389. #if MGB_ENABLE_PARTIAL_EXECUTION
  390. return MultiPartCompiler{this}.compile(out_specs);
  391. #else
  392. mgb_throw(MegBrainError, "partial execution disabled at compile time");
  393. #endif
  394. }
  395. ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
  396. const OutputSpec& out_spec) {
  397. auto&& cmpnt = components();
  398. mgb_throw_if(m_recorded_seq_level2_dtor_chk, GraphError,
  399. "graphs with comp_node_seq_record_level==2 can only be "
  400. "compiled once");
  401. mgb_throw_if(out_spec.empty(), GraphError,
  402. "empty output spec given to ComputingGraph::compile");
  403. // topo sorter may have modified opr properties; restore them before this
  404. // new compiling
  405. topo_sorter().restore_opr_prop();
  406. cmpnt.seq_comp_node_opt.restore_comp_nodes();
  407. SpecialOprStat sopr_stat;
  408. auto dest_vars = get_dest_vars_from_out_spec(out_spec, sopr_stat);
  409. #if MGB_ENABLE_SUBLINEAR
  410. if (options().enable_sublinear_memory_opt) {
  411. if (!sopr_stat.has_virtual_grad) {
  412. mgb_log_debug(
  413. "no virtual grad var; sublinear memory may produce "
  414. "unsatisfying result");
  415. }
  416. seq_modifier_for_sublinear_memory().set_priority_before_opt(
  417. dest_vars);
  418. }
  419. #else
  420. mgb_assert(!options().enable_sublinear_memory_opt);
  421. #endif // MGB_ENABLE_SUBLINEAR
  422. #if !MGB_BUILD_SLIM_SERVING
  423. mgb_assert(!options().eager_evaluation,
  424. "attempt to compile eager_evaluation graph");
  425. {
  426. bool need_opt = std::abs(options().graph_opt_level) >= 2;
  427. gopt::GraphOptimizer optimizer;
  428. optimizer.verbosity(options().log_level);
  429. optimizer.enable_check_result(options().graph_opt_level < 0);
  430. if (sopr_stat.has_virtual_grad) {
  431. if (need_opt) {
  432. #if MGB_ENABLE_OPR_MM
  433. optimizer.add_pass<gopt::PackAllReduceScanPass>();
  434. #endif
  435. optimizer.add_preset_passes(false, nullptr, &options());
  436. }
  437. optimizer.add_pass<gopt::ExpandVirtualGradPass>();
  438. }
  439. if (need_opt) {
  440. optimizer.add_preset_passes(true, nullptr, &options());
  441. #if MGB_ENABLE_OPR_MM
  442. if (sopr_stat.has_virtual_grad) {
  443. optimizer.add_pass<gopt::PackAllReduceReplacePass>();
  444. }
  445. #endif
  446. }
  447. optimizer.apply_inplace(dest_vars);
  448. }
  449. #endif
  450. #if MGB_ENABLE_TENSOR_RT
  451. if (options().graph_opt.tensorrt) {
  452. options().graph_opt.tensorrt = false;
  453. tensorrt::transform_dest_vars_inplace(dest_vars, options().graph_opt);
  454. }
  455. #endif
  456. #if MGB_JIT
  457. if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) {
  458. setenv("MGB_JIT_BACKEND","NVRTC",1);
  459. gopt::GraphOptimizer optimizer;
  460. optimizer.add_pass<gopt::JITFusionPass>(
  461. sopr_stat.has_virtual_grad,
  462. std::max<uint8_t>(options().graph_opt.jit, 1));
  463. optimizer.apply_inplace(dest_vars);
  464. }
  465. #endif
  466. gopt::GraphOptimizer optimizer;
  467. /**
  468. * \note We should reset options when we add passes indicated by optimize
  469. * options, As there exists `ParamFuse pass` will compile subgraph which may
  470. * cause ring invoking, \see
  471. * https://git-core.megvii-inc.com/brain-sdk/MegBrain/merge_requests/1717
  472. * for detail
  473. */
  474. optimizer.add_passes_for_optimize_options(options().graph_opt, true);
  475. optimizer.apply_inplace(dest_vars);
  476. if (sopr_stat.has_shape_hint) {
  477. // FIXME(zhangxuanrun): strictly speaking, it could and has to remove
  478. // ShapeHints even they were occured in subgraph
  479. mgb_assert(!m_parent_graph, "can not use ShapeHint in subgraph");
  480. // always need remove shape hint
  481. gopt::GraphOptimizer opt;
  482. opt.add_pass<gopt::RemoveShapeHintPass>();
  483. opt.apply_inplace(dest_vars);
  484. }
  485. const OprNodeArray* opr_seq = nullptr;
  486. CompSeqExtraInfo extra_info;
  487. cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars);
  488. auto init_opr_seq = [&]() {
  489. ThinHashMap<VarNode*, size_t> var2idx;
  490. std::unordered_map<CallbackCallerKey, CallbackCallerVal,
  491. CallbackCallerKey::Hash>
  492. opr2vars;
  493. using F = VarNode::Flag;
  494. if (dest_vars[0]->owner_graph()->options().force_output_dynamic_alloc) {
  495. for (auto&& i : dest_vars) {
  496. if (!i->contain_flag(F::NO_SYS_MEM_ALLOC |
  497. F::NO_SYS_STATIC_MEM_ALLOC)) {
  498. mgb_assert(
  499. !i->contain_flag(
  500. F::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC),
  501. "Can not force graph output dynamic alloc with "
  502. "DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC flag, var: %s",
  503. i->cname());
  504. i->add_flag(F::NO_SYS_STATIC_MEM_ALLOC);
  505. }
  506. i->add_flag(F::NO_MEM_RECLAIM);
  507. }
  508. }
  509. for (size_t i = 0; i < out_spec.size(); ++i) {
  510. auto&& cb = out_spec[i].second;
  511. if (cb) {
  512. auto var = dest_vars[i];
  513. CallbackCallerKey key{var->owner_opr(), var->comp_node()};
  514. auto&& vals = opr2vars[key];
  515. auto&& var2idx_iter = var2idx.find(var);
  516. if ( var2idx_iter == var2idx.end()) {
  517. vals.vars.push_back(var);
  518. vals.indexs.push_back({i});
  519. var2idx[var] = vals.vars.size() - 1;
  520. } else {
  521. vals.indexs[var2idx_iter->second].push_back(i);
  522. }
  523. }
  524. }
  525. for (auto& item : opr2vars) {
  526. auto&& val = item.second;
  527. auto dvar = CallbackCaller::make(val.vars);
  528. CallbackCaller* cb_caller = &dvar.node()
  529. ->owner_opr()
  530. ->cast_final_safe<CallbackCaller>();
  531. ++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req;
  532. cb_caller->clear_callback();
  533. for (size_t i=0;i<val.vars.size(); ++i) {
  534. for (auto&& idx : val.indexs[i]) {
  535. cb_caller->add_callback(out_spec[idx].second, i);
  536. dest_vars[idx] = cb_caller->output(0);
  537. }
  538. }
  539. }
  540. opr_seq = topo_sorter().get_comp_seq(extra_info, dest_vars);
  541. };
  542. #if MGB_ENABLE_MEMORY_SWAP
  543. bool enable_swap_memory_after_sublinear =
  544. options().enable_sublinear_memory_opt &&
  545. options().enable_memory_swap;
  546. bool enable_swap_memory_without_sublinear =
  547. !(options().enable_sublinear_memory_opt) &&
  548. options().enable_memory_swap;
  549. if (enable_swap_memory_without_sublinear) {
  550. components().memory_swap_support.modify_dest_var_inplace(dest_vars);
  551. }
  552. #else
  553. mgb_assert(!options().enable_memory_swap);
  554. #endif
  555. #if MGB_ENABLE_SUBLINEAR
  556. if (options().enable_sublinear_memory_opt) {
  557. MGB_TRY {
  558. seq_modifier_for_sublinear_memory().modify_endpoint_vars(
  559. dest_vars);
  560. #if MGB_ENABLE_MEMORY_SWAP
  561. if (enable_swap_memory_after_sublinear) {
  562. cmpnt.memory_swap_support.modify_dest_var_inplace(dest_vars);
  563. }
  564. #endif
  565. init_opr_seq();
  566. }
  567. MGB_FINALLY(
  568. /*
  569. * restore graph option immediately because it may be
  570. * read/modified by user
  571. */
  572. seq_modifier_for_sublinear_memory().restore_graph_option());
  573. seq_modifier_for_sublinear_memory().sanity_check(*opr_seq);
  574. } else {
  575. init_opr_seq();
  576. }
  577. #else
  578. init_opr_seq();
  579. #endif // MGB_ENABLE_SUBLINEAR
  580. return {std::move(extra_info), opr_seq, std::move(dest_vars)};
  581. }
  582. std::unique_ptr<AsyncExecutable> ComputingGraphImpl::compile_commit(
  583. CompileState state) {
  584. auto comp_seq = std::make_unique<ComputingSequence>(shared_from_this());
  585. comp_seq->extra_info = std::move(state.extra_info);
  586. comp_seq->set_output_vars(state.dest_vars);
  587. auto opr_seq = state.opr_seq;
  588. auto&& cmpnt = components();
  589. comp_seq->setup_opr_seq(opr_seq);
  590. for (auto&& i : *opr_seq) {
  591. for (auto&& j : i->node_prop().dep_map()) {
  592. if (OperatorNodeBase::NodeProp::is_device_value_dep(j.second)) {
  593. comp_seq->extra_info.var2recvinfo.at(j.first)
  594. .last_dev_value_reader = i;
  595. }
  596. }
  597. }
  598. comp_seq->attach_to_graph();
  599. MGB_TRY {
  600. var_node_mem_manager().reset_opr_seq(comp_seq->extra_info, opr_seq);
  601. static_infer_comp_seq_manager().reset_dest(comp_seq->extra_info);
  602. cmpnt.seq_comp_node_opt.init_ready_event(comp_seq->extra_info, *opr_seq);
  603. if (options().allocate_static_mem_after_graph_compile)
  604. var_node_mem_manager().alloc_var_node_mem_static();
  605. }
  606. MGB_FINALLY({ var_node_mem_manager().on_graph_compile_finished(); });
  607. event().signal_inplace<event::CompSeqOrderDetermined>(this, comp_seq.get());
  608. if (options().comp_node_seq_record_level > 1) {
  609. mgb_assert(options().comp_node_seq_record_level <= 2,
  610. "invalid comp_node_seq_record_level: %u",
  611. options().comp_node_seq_record_level);
  612. mgb_assert(!options().fake_next_exec &&
  613. !options().var_sanity_check_first_run,
  614. "both fake_next_exec and var_sanity_check_first_run "
  615. "must be false when comp_node_seq_record_level is 2");
  616. return comp_seq->as_recorded_seq();
  617. }
  618. return comp_seq;
  619. }
  620. VarNodeArray ComputingGraphImpl::get_dest_vars_from_out_spec(
  621. const OutputSpec& spec, SpecialOprStat& sopr_stat) {
  622. SymbolVarArray sym_vars;
  623. for (auto&& i : spec) {
  624. sym_vars.push_back(i.first);
  625. }
  626. return to_var_node_array(
  627. get_dest_vars_with_extra_deps(sym_vars, &sopr_stat));
  628. }
  629. const ComputingGraph::VarReceiverInfo&
  630. ComputingGraphImpl::var_receiver_in_current_comp_seq(const VarNode* var) const {
  631. static VarReceiverInfo empty;
  632. if (auto ret = components().eager_eval_manager.var_receiver_info(var)) {
  633. return *ret;
  634. }
  635. if (!m_current_comp_seq)
  636. return empty;
  637. auto cseq = static_cast<ComputingSequence*>(m_current_comp_seq);
  638. auto iter = cseq->extra_info.var2recvinfo.find(var);
  639. if (iter == cseq->extra_info.var2recvinfo.end())
  640. return empty;
  641. return iter->second;
  642. }
  643. VarNode* ComputingGraphImpl::find_var_by_id(size_t id) const {
  644. for (auto&& i : m_opr_refkeeper) {
  645. for (auto j : i->output()) {
  646. if (j->id() == id)
  647. return j;
  648. }
  649. }
  650. for (auto&& i : m_subgraphs) {
  651. auto sub = i->find_var_by_id(id);
  652. if (sub)
  653. return sub;
  654. }
  655. return nullptr;
  656. }
  657. #if MGB_ENABLE_SUBLINEAR
  658. SeqModifierForSublinearMemory&
  659. ComputingGraphImpl::seq_modifier_for_sublinear_memory() {
  660. return components().seq_modifier_for_sublinear_memory;
  661. }
  662. #endif
  663. void ComputingGraphImpl::share_device_memory_with(ComputingGraph& other) {
  664. mgb_assert(
  665. !m_current_comp_seq,
  666. "share_device_memory_with must be called before compiling graph");
  667. auto&& oimpl = *ComputingGraphImpl::downcast(&other);
  668. var_node_mem_manager().static_device_memory_manager(
  669. oimpl.var_node_mem_manager().static_device_memory_manager());
  670. }
  671. void ComputingGraphImpl::set_device_memory_allocator(
  672. std::shared_ptr<DeviceMemoryAllocator> allocator) {
  673. var_node_mem_manager().static_device_memory_manager()->set_allocator(
  674. std::move(allocator));
  675. }
  676. size_t ComputingGraphImpl::get_device_memory_size(CompNode cn) {
  677. return var_node_mem_manager().static_device_memory_manager()->get_size(cn);
  678. }
  679. size_t ComputingGraphImpl::clear_device_memory() {
  680. #if !MGB_BUILD_SLIM_SERVING
  681. if (options().eager_evaluation) {
  682. for (auto& opr : m_opr_refkeeper) {
  683. if (!opr->same_type<mgb::opr::SharedDeviceTensor>() &&
  684. !opr->same_type<mgb::opr::ImmutableTensor>()) {
  685. for (auto& var : opr->output()) {
  686. if (var->mem_plan().valid())
  687. var->mem_plan().release_chunk();
  688. }
  689. }
  690. }
  691. }
  692. #endif
  693. return var_node_mem_manager().clear_static_device_memory();
  694. }
  695. void ComputingGraphImpl::set_as_subgraph(ComputingGraph& par_graph) {
  696. m_parent_graph = ComputingGraphImpl::downcast(&par_graph);
  697. m_parent_graph->m_subgraphs.emplace_back(this);
  698. m_node_id_counter = m_parent_graph->m_node_id_counter;
  699. options().var_sanity_check_first_run =
  700. par_graph.options().var_sanity_check_first_run;
  701. par_graph.event().signal_inplace<event::SubgraphAssociated>(&par_graph,
  702. this);
  703. }
  704. void ComputingGraphImpl::record_async_error(
  705. std::unique_ptr<MegBrainError> async_exc) {
  706. mgb_assert(m_current_comp_seq);
  707. static_cast<ComputingSequence*>(m_current_comp_seq)
  708. ->set_async_error(std::move(async_exc));
  709. }
  710. const CompSeqExtraInfo& ComputingGraphImpl::current_comp_seq_extra_info() {
  711. if (auto ret = eager_eval_manager().comp_seq_extra_info()) {
  712. return *ret;
  713. }
  714. mgb_assert(m_current_comp_seq);
  715. return static_cast<ComputingSequence*>(m_current_comp_seq)->extra_info;
  716. }
  717. GraphExecutable::ExecEnv* ComputingGraphImpl::current_exec_env() {
  718. if (auto ret = eager_eval_manager().exec_env()) {
  719. return ret;
  720. }
  721. if (m_current_comp_seq) {
  722. return &static_cast<ComputingSequence*>(m_current_comp_seq)->exec_env();
  723. }
  724. return nullptr;
  725. }
  726. Maybe<size_t> ComputingGraphImpl::opr_step_num_in_cur_comp_seq(
  727. OperatorNodeBase* opr) {
  728. mgb_assert(m_current_comp_seq && opr->owner_graph() == this);
  729. return static_cast<ComputingSequence*>(m_current_comp_seq)
  730. ->opr2stepnum(opr);
  731. }
  732. std::string ComputingGraphImpl::VarReceiverInfo::to_string() const {
  733. return mgb_ssprintf_log(
  734. "VarReceiverInfo("
  735. "nr_direct_comp_req=%zu dev_value=%zu, host_value=%zu, shape=%zu, "
  736. "allow_empty_value=%zu)",
  737. nr_direct_comp_req, dev_value, host_value, shape,
  738. allow_empty_value);
  739. }
  740. std::string ComputingGraphImpl::get_mem_allocation_info() const {
  741. #if MGB_ENABLE_JSON
  742. auto make_var_json = [](VarNode* single_var) {
  743. auto &&cur_mem_plan = single_var->mem_plan();
  744. if (cur_mem_plan.valid())
  745. return json::Object::make({
  746. {"name", json::String::make(single_var->name())},
  747. {"memory", json::Number::make(cur_mem_plan.chunk().size())},
  748. {"dev_ptr", json::NumberInt::make(
  749. reinterpret_cast<size_t>(single_var->dev_tensor().raw_ptr()))}
  750. });
  751. else
  752. return json::Object::make({
  753. {"name", json::String::make(single_var->name())},
  754. {"memory", json::Null::make()},
  755. {"dev_ptr", json::Null::make()}
  756. });
  757. };
  758. auto objlist = json::Array::make();
  759. for(auto &opri: m_opr_refkeeper){
  760. auto cur_opr = opri.get();
  761. auto objptr = json::Object::make();
  762. auto &&objbody = *objptr;
  763. objbody["name"] = json::String::make(cur_opr->name());
  764. auto jvars = json::Array::make();
  765. for(auto &outputi: cur_opr->output()){
  766. jvars->add(make_var_json(outputi));
  767. }
  768. objbody["output"] = jvars;
  769. auto obj = json::Object::make({{std::to_string(cur_opr->id()), objptr}});
  770. objlist->add(obj);
  771. }
  772. return objlist->to_string();
  773. #endif // MGB_ENABLE_JSON
  774. mgb_log_warn("target is not configured with JSON BUILD on,"
  775. "get_mem_allocation_info returns null string");
  776. return std::string();
  777. }
  778. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台