You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

proxy_graph.cpp 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870
  1. /**
  2. * \file imperative/src/impl/proxy_graph.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./blob_manager_impl.h"
  12. #include "./proxy_graph.h"
  13. #include "megbrain/graph/static_infer.h"
  14. #include "megbrain/graph/operator_node.h"
  15. #include "megbrain/opr/io.h"
  16. #include "megbrain/opr/tensor_manip.h"
  17. #include "megbrain/opr/utility.h"
  18. #include "megbrain/imperative/ops/opr_attr.h"
  19. #include "megbrain/imperative/ops/backward_graph.h"
  20. #if __cplusplus >= 201703L
  21. #include <optional>
  22. #endif
  23. namespace mgb {
  24. namespace imperative {
  25. using cg::OperatorNodeBase;
  26. template<bool p, typename T, typename F>
  27. constexpr auto&& select(T&& t, F&& f) {
  28. if constexpr (p) {
  29. return std::forward<T>(t);
  30. } else {
  31. return std::forward<F>(f);
  32. }
  33. }
  34. MGB_DEFINE_OPR_CLASS(
  35. ProxyGraph::InputPlaceholder,
  36. cg::OperatorNodeBase) // {
  37. void on_output_comp_node_stream_changed() override {
  38. mgb_assert(0);
  39. }
  40. // TODO: consider implement following initialization method,
  41. // so InputPlaceholder can be initialized correctly during
  42. // operator insertion
  43. void init_output_comp_node() override {
  44. }
  45. void init_output_format() override {
  46. }
  47. void init_output_dtype() override {
  48. }
  49. void init_output_static_infer_desc() override {
  50. }
  51. void init_output_mem_plan(bool dynamic) override {
  52. MGB_MARK_USED_VAR(dynamic);
  53. mgb_assert(0);
  54. }
  55. void do_execute(ExecEnv &env) override {
  56. mgb_assert(0);
  57. }
  58. public:
  59. Tensor* m_tensor;
  60. InputPlaceholder(ComputingGraph& graph, Tensor* tensor = nullptr,
  61. const DeviceTensorND& static_infer_value = {})
  62. : Super(&graph, {}, "device_value", {}), m_tensor(tensor),
  63. m_static_infer_value(static_infer_value) {
  64. mgb_assert(m_static_infer_value.empty() ||
  65. m_static_infer_value.comp_node() == CompNode::default_cpu());
  66. add_output(None)->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC);
  67. // never dedup
  68. add_equivalence_component<ScalarHash<void*>>(this);
  69. }
  70. static SymbolVar make(ComputingGraph& graph, Tensor& tensor) {
  71. auto opr = graph.insert_opr(
  72. std::make_unique<InputPlaceholder>(graph, &tensor));
  73. auto var = opr->output(0);
  74. auto&& dev_tensor = tensor.dev_tensor();
  75. var->m_comp_node = dev_tensor.comp_node();
  76. var->m_shape = dev_tensor.shape();
  77. var->m_dev_tensor = dev_tensor;
  78. var->m_mem_plan.reset_from_owner_var().chunk()
  79. .mem_alloc_status.set_from_owner_var();
  80. return var;
  81. }
  82. static SymbolVar make(ComputingGraph& graph, const LogicalTensorDesc& desc) {
  83. auto opr = graph.insert_opr(
  84. std::make_unique<InputPlaceholder>(graph, nullptr, desc.value));
  85. auto var = opr->output(0);
  86. var->m_comp_node = desc.comp_node;
  87. var->m_shape = desc.layout;
  88. var->m_dev_tensor.reset({}, TensorLayout(desc.layout.dtype));
  89. return var;
  90. }
  91. const DeviceTensorND* get_static_infer_value(bool may_sync) {
  92. if (!m_static_infer_value.empty()) {
  93. return &m_static_infer_value;
  94. }
  95. if (m_tensor && (may_sync || m_tensor->try_get_value())) {
  96. auto&& hv = m_tensor->get_value();
  97. mgb_assert(!hv.empty());
  98. m_static_infer_value = hv.proxy_to_default_cpu();
  99. // steal ownership from shared_ptr
  100. using SP = std::shared_ptr<dt_byte>;
  101. auto& sp = const_cast<SP&>(m_static_infer_value.storage().raw_storage());
  102. static auto dummy = std::make_shared<dt_byte>();
  103. sp = SP(dummy, sp.get());
  104. return &m_static_infer_value;
  105. }
  106. return nullptr;
  107. }
  108. private:
  109. DeviceTensorND m_static_infer_value;
  110. };
  111. MGB_DYN_TYPE_OBJ_FINAL_IMPL(
  112. ProxyGraph::InputPlaceholder);
  113. class ProxyGraph::ExecEnv final : public cg::GraphExecutable::ExecEnv {
  114. public:
  115. void dispatch_on_comp_node(CompNode, Task&& task) override {
  116. task();
  117. }
  118. void dispatch_on_comp_node_with_mask(CompNode, Task&& task,
  119. cg::ExecutionMask* mask) override {
  120. mgb_throw_if(mask, GraphError,
  121. "ExecutionMask not supported in imperative mode");
  122. task();
  123. }
  124. void pause_exec() override {}
  125. void resume_exec() override {}
  126. };
  127. class ProxyGraph::StaticInferManager : public cg::static_infer::StaticInferManager {
  128. public:
  129. using Tag = cg::static_infer::Tag;
  130. using ShapeInferDesc = cg::static_infer::ShapeInferDesc;
  131. using ValueInferDesc = cg::static_infer::ValueInferDesc;
  132. using InferType = cg::static_infer::InferType;
  133. using DepVal = cg::static_infer::DepVal;
  134. using DepElement = cg::static_infer::DepElement;
  135. using DepType = cg::static_infer::DepType;
  136. using InpElement = cg::static_infer::InpElement;
  137. struct Result {
  138. TensorShape shape;
  139. DeviceTensorND value;
  140. };
  141. ProxyGraph* owner;
  142. cg::OperatorNodeBase* cur_opr = nullptr;
  143. std::vector<std::optional<ShapeInferDesc>> shape_descs;
  144. std::vector<std::optional<ValueInferDesc>> value_descs;
  145. std::vector<Result> inferred_outputs;
  146. StaticInferManager(ProxyGraph* owner_) : owner(owner_) {}
  147. size_t locate_output(VarNode* var) {
  148. mgb_assert(cur_opr);
  149. auto&& output_vars = cur_opr->output();
  150. mgb_assert(shape_descs.size() == output_vars.size());
  151. auto&& it = std::find(output_vars.begin(), output_vars.end(), var);
  152. mgb_assert(it != output_vars.end());
  153. return it - output_vars.begin();
  154. }
  155. void register_shape_infer(Tag dest, const ShapeInferDesc &desc) override {
  156. auto i = locate_output(dest);
  157. mgb_assert(!shape_descs[i]);
  158. shape_descs[i].emplace(desc);
  159. }
  160. void register_value_infer(Tag dest, const ValueInferDesc &desc) override {
  161. auto i = locate_output(dest);
  162. mgb_assert(!value_descs[i]);
  163. value_descs[i].emplace(desc);
  164. }
  165. InferType get_infer_type(Tag var) override {
  166. // may be called during get_proxy_opr or make_backward_graph
  167. // don't let opr apply any immediate optimization
  168. return {InferType::MISSING_INP, InferType::MISSING_INP};
  169. if (auto opr = var->owner_opr()->try_cast_final<InputPlaceholder>()) {
  170. return {var->shape().ndim ? InferType::CONST : InferType::MISSING_INP,
  171. opr->m_tensor ? InferType::CONST : InferType::MISSING_INP};
  172. }
  173. if (cur_opr) {
  174. auto&& outputs = cur_opr->output();
  175. auto&& it = std::find(outputs.begin(), outputs.end(), var);
  176. if (it != outputs.end()) {
  177. return {infer_shape_fallible(var) ? InferType::CONST : InferType::MISSING_INP,
  178. // value inference could be expensive
  179. InferType::MISSING_INP};
  180. }
  181. }
  182. return {InferType::MISSING_INP, InferType::MISSING_INP};
  183. }
  184. void update() {
  185. if (cur_opr != owner->m_cur_opr) {
  186. clear();
  187. cur_opr = owner->m_cur_opr;
  188. if (cur_opr) {
  189. auto nout = cur_opr->output().size();
  190. shape_descs.resize(nout);
  191. value_descs.resize(nout);
  192. inferred_outputs.resize(nout);
  193. cur_opr->init_output_static_infer_desc();
  194. }
  195. }
  196. }
  197. void clear() {
  198. cur_opr = nullptr;
  199. shape_descs.clear();
  200. value_descs.clear();
  201. inferred_outputs.clear();
  202. }
  203. template<bool is_shape>
  204. auto do_infer(Tag dest, bool may_sync)
  205. -> const std::conditional_t<is_shape, TensorShape, DeviceTensorND>* {
  206. // Some infer_func does not use InpVal passed to them, but
  207. // call infer_* on their inputs instead, so dest could be an input.
  208. // It is also possible that an opr call infer_* on its inputs before it
  209. // is inserted
  210. if (auto opr = dest->owner_opr()->try_cast_final<InputPlaceholder>()) {
  211. if constexpr (is_shape) {
  212. auto* shp = &dest->shape();
  213. return shp->ndim ? shp : nullptr;
  214. } else {
  215. return opr->get_static_infer_value(may_sync);
  216. }
  217. }
  218. mgb_assert(cur_opr);
  219. mgb_assert(cur_opr->output().size() == shape_descs.size());
  220. // dest must be an output now
  221. auto i = locate_output(dest);
  222. auto& result = inferred_outputs[i];
  223. auto& desc = select<is_shape>(shape_descs[i], value_descs[i]);
  224. // return if no need to call infer_func
  225. if constexpr (is_shape) {
  226. if (result.shape.ndim != 0) {
  227. return &result.shape;
  228. }
  229. } else {
  230. if (!result.value.empty()) {
  231. return &result.value;
  232. }
  233. }
  234. if (!desc) {
  235. return nullptr;
  236. }
  237. // fill args for infer_func
  238. cg::static_infer::InpVal args{1};
  239. args.val.reserve(desc->deps.size());
  240. auto push_shape = [&args](const TensorShape* shape) {
  241. args.val.emplace_back();
  242. args.val.back().m_shape = shape;
  243. };
  244. auto push_value = [&args](const DeviceTensorND* value) {
  245. args.val.emplace_back();
  246. args.val.back().m_value = value;
  247. };
  248. for (auto&& dep : desc->deps) {
  249. if (auto opr = dep.dest->owner_opr()->template try_cast_final<InputPlaceholder>()) {
  250. if (dep.type == DepType::SHAPE) {
  251. if (dep.dest->shape().ndim) {
  252. push_shape(&dep.dest->shape());
  253. } else {
  254. return nullptr;
  255. }
  256. } else {
  257. if (auto* p = opr->get_static_infer_value(may_sync)) {
  258. push_value(p);
  259. } else {
  260. return nullptr;
  261. }
  262. }
  263. continue;
  264. }
  265. // dep must be an output
  266. if (dep.type == DepType::SHAPE) {
  267. if (auto* p = do_infer<true>(dep.dest, may_sync)) {
  268. push_shape(p);
  269. } else {
  270. return nullptr;
  271. }
  272. } else {
  273. if (auto* p = do_infer<false>(dep.dest, may_sync)) {
  274. push_value(p);
  275. } else {
  276. return nullptr;
  277. }
  278. }
  279. }
  280. // call infer_func
  281. if constexpr (is_shape) {
  282. if (!desc->infer_func(result.shape, args)) {
  283. mgb_log_warn("something is missing for shape inference of %s",
  284. cur_opr->dyn_typeinfo()->name);
  285. return nullptr;
  286. }
  287. return &result.shape;
  288. } else {
  289. if (!desc->infer_func(result.value, args)) {
  290. mgb_log_warn("something is missing for value inference of %s",
  291. cur_opr->dyn_typeinfo()->name);
  292. return nullptr;
  293. }
  294. return &result.value;
  295. }
  296. }
  297. const TensorShape& infer_shape(Tag var) override {
  298. auto* p = do_infer<true>(var, true);
  299. mgb_assert(p, "failed to infer shape for %s", var->name().c_str());
  300. return *p;
  301. }
  302. const TensorShape* infer_shape_fallible(Tag var) override {
  303. return do_infer<true>(var, false);
  304. }
  305. const DeviceTensorND& infer_value(Tag var) override {
  306. auto* p = do_infer<false>(var, true);
  307. mgb_assert(p, "failed to infer value for %s", var->name().c_str());
  308. return *p;
  309. }
  310. const DeviceTensorND* infer_value_fallible(Tag var) override {
  311. return do_infer<false>(var, false);
  312. }
  313. DepVal get_rt_static_source_deps(const DepElement&) override {mgb_assert(0);}
  314. };
  315. class ProxyGraph::SeqCompNodeOptimizer : public cg::SeqCompNodeOptimizer {
  316. void register_stream_var(VarNode*, StreamPropType) override {}
  317. void register_propagate_function(VarNode*, PropFunction) override {}
  318. StreamPropType stream_prop_type(VarNode*) override {mgb_assert(0);}
  319. };
  320. class ProxyGraph::ProxyGraphImpl : public cg::ComputingGraph {
  321. static std::atomic<size_t> m_node_id;
  322. ProxyGraph* m_owner;
  323. MemPool<VarNode> m_var_node_pool;
  324. std::vector<std::unique_ptr<OperatorNodeBase>> m_opr_refkeeper;
  325. std::mutex m_opr_refkeeper_mtx;
  326. CompNode::UnorderedSet m_used_comp_node;
  327. VarReceiverInfo m_var_receiver_info;
  328. public:
  329. ~ProxyGraphImpl() {
  330. mgb_assert(!m_owner->m_cur_opr);
  331. if (is_finalized()) return;
  332. for (auto&& i : m_used_comp_node) {
  333. if (i.device_type() == CompNode::DeviceType::CUDA) continue;
  334. i.sync();
  335. }
  336. }
  337. ProxyGraphImpl(ProxyGraph* owner) : m_owner(owner) {
  338. options().imperative_proxy_graph = true;
  339. options().no_force_inplace = true;
  340. options().log_level = 0;
  341. m_var_receiver_info.dev_value = 1;
  342. m_var_receiver_info.allow_empty_value = 1;
  343. }
  344. static std::unique_ptr<ProxyGraphImpl> make(ProxyGraph* owner) {
  345. return std::make_unique<ProxyGraphImpl>(owner);
  346. }
  347. void add_used_comp_node(CompNode cn) {
  348. m_used_comp_node.insert(cn);
  349. }
  350. bool invalid() const {
  351. return is_finalized() || nr_oprs_in_graph() > m_owner->m_max_op_cnt;
  352. }
  353. size_t next_node_id() override {
  354. return m_node_id.fetch_add(1);
  355. }
  356. void* alloc_varnode_storage() override {
  357. return m_var_node_pool.alloc_raw();
  358. }
  359. void free_varnode_storage(void* ptr) override {
  360. m_var_node_pool.free_raw(ptr);
  361. }
  362. OperatorNodeBase* insert_opr(std::unique_ptr<OperatorNodeBase> opr_uniqp) override {
  363. mgb_assert(!is_finalized());
  364. auto opr = opr_uniqp.get();
  365. if (!opr->inserted_in_graph()) {
  366. m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
  367. opr->set_inserted_in_graph();
  368. opr->init_output_comp_node();
  369. opr->init_output_dtype();
  370. opr->init_output_format();
  371. }
  372. return opr;
  373. }
  374. cg::static_infer::StaticInferManager& static_infer_manager() override {
  375. return *m_owner->m_static_infer_manager;
  376. }
  377. cg::SeqCompNodeOptimizer& seq_comp_node_optimizer() override {
  378. return *m_owner->m_seq_comp_node_optimizer;
  379. }
  380. std::shared_ptr<void> on_comp_node_finalize() override {
  381. MGB_LOCK_GUARD(m_opr_refkeeper_mtx);
  382. mgb_assert(!m_owner->m_cur_opr);
  383. // finalize would do sync first
  384. m_opr_refkeeper.clear();
  385. return {};
  386. }
  387. const VarReceiverInfo& var_receiver_in_current_comp_seq(
  388. const VarNode *var) const override {
  389. return m_var_receiver_info;
  390. }
  391. size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();}
  392. void record_async_error(std::unique_ptr<MegBrainError> async_exc) override {
  393. if (!ProxyGraph::tm_async_error) {
  394. std::swap(async_exc, tm_async_error);
  395. }
  396. }
  397. std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec &out_spec) override {mgb_assert(0);}
  398. SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part(
  399. const SmallVector<OutputSpec>& out_specs) override {mgb_assert(0);}
  400. cg::AsyncExecutable* current_comp_seq() override {mgb_assert(0);}
  401. std::string get_mem_allocation_info() const override {mgb_assert(0);}
  402. VarNode* find_var_by_id(size_t id) const override {mgb_assert(0);}
  403. void share_device_memory_with(ComputingGraph &other) override {mgb_assert(0);}
  404. void set_device_memory_allocator(
  405. std::shared_ptr<cg::DeviceMemoryAllocator> allocator) override {mgb_assert(0);}
  406. size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);}
  407. size_t clear_device_memory() override {mgb_assert(0);}
  408. void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);}
  409. };
  410. std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0;
  411. ProxyGraph::ProxyGraph() :
  412. m_graph(ProxyGraphImpl::make(this)),
  413. m_env{new ExecEnv},
  414. m_static_infer_manager(new StaticInferManager(this)),
  415. m_seq_comp_node_optimizer(new SeqCompNodeOptimizer()) {
  416. }
  417. void ProxyGraph::reset() {
  418. mgb_assert(!m_cur_opr);
  419. m_graph = ProxyGraphImpl::make(this);
  420. }
  421. ProxyGraph* ProxyGraph::get_default_graph() {
  422. static thread_local ProxyGraph inst;
  423. if (inst.m_graph->invalid()) {
  424. inst.reset();
  425. }
  426. return &inst;
  427. }
  428. class ProxyGraph::CurOprGuard {
  429. public:
  430. CurOprGuard(ProxyGraph* owner, OperatorNodeBase* opr) : m_owner(owner) {
  431. mgb_assert(!owner->m_cur_opr);
  432. owner->m_cur_opr = opr;
  433. }
  434. CurOprGuard(const CurOprGuard&) = delete;
  435. ~CurOprGuard() {
  436. m_owner->cleanup();
  437. }
  438. private:
  439. ProxyGraph* m_owner;
  440. };
  441. #define CUR_OPR_GUARD(opr) CurOprGuard MGB_TOKENPASTE2(__cur_opr_guard_, __LINE__)(this, opr)
  442. /*********************** Physical Tensor Impl ***********************/
  443. SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs(
  444. const OpDef& opdef,
  445. const SmallVector<Tensor*>& inputs) {
  446. SmallVector<LogicalTensorDesc> ret;
  447. CUR_OPR_GUARD(get_proxy_opr(opdef, inputs));
  448. do_shape_infer(true);
  449. for (auto&& i: m_cur_opr->usable_output()) {
  450. mgb_assert(i->dtype().valid() && i->comp_node().valid());
  451. mgb_assert(i->shape().ndim || i->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC));
  452. ret.push_back({{i->shape(), i->dtype()}, i->comp_node()});
  453. }
  454. return ret;
  455. }
  456. void ProxyGraph::invoke_op(const OpDef& opdef,
  457. const SmallVector<Tensor*>& inputs,
  458. const SmallVector<Tensor*>& outputs,
  459. const SmallVector<Tensor*>& workspaces) {
  460. CUR_OPR_GUARD(get_proxy_opr(opdef, inputs));
  461. init_output_tensor(outputs, workspaces);
  462. for (auto oup : m_cur_opr->output()) {
  463. m_graph->add_used_comp_node(oup->comp_node());
  464. }
  465. m_cur_opr->execute(*m_env);
  466. }
  467. void ProxyGraph::cleanup() {
  468. if (m_cur_opr) {
  469. for (auto&& i : m_cur_opr->input()) {
  470. i->m_dev_tensor.storage({});
  471. }
  472. for (auto&& i : m_cur_opr->output()) {
  473. i->m_dev_tensor.storage({});
  474. }
  475. m_static_infer_manager->clear();
  476. }
  477. m_cur_opr = nullptr;
  478. }
  479. void ProxyGraph::init_output_tensor(const SmallVector<Tensor*>& outputs, const SmallVector<Tensor*>& workspaces) {
  480. // get proxy opr
  481. auto proxy = m_cur_opr;
  482. do_shape_infer(true);
  483. size_t j = 0;
  484. size_t k = 0;
  485. for (auto&& var : proxy->output()) {
  486. auto &&chk = var->m_mem_plan.reset_from_owner_var().chunk();
  487. if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  488. // workspace
  489. if (workspaces.size()) {
  490. mgb_assert(k < workspaces.size());
  491. auto && layout = workspaces[k]->layout();
  492. mgb_assert(var->comp_node() == workspaces[k]->comp_node() &&
  493. var->shape().eq_shape(layout) &&
  494. var->dtype() == layout.dtype);
  495. var->m_dev_tensor = workspaces[k]->dev_tensor();
  496. ++ k;
  497. } else {
  498. TensorLayout layout{var->shape(), var->dtype(), var->format()};
  499. var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(var->comp_node(), layout);
  500. }
  501. } else {
  502. mgb_assert(j < outputs.size());
  503. auto &&tensor = outputs[j];
  504. auto &&layout = tensor->layout();
  505. mgb_assert(var->comp_node() == tensor->comp_node() &&
  506. var->shape().eq_shape(layout) &&
  507. var->dtype() == layout.dtype);
  508. var->assign_dev_tensor_from_tensor(tensor->dev_tensor());
  509. ++ j;
  510. }
  511. chk.mem_alloc_status.set_from_owner_var();
  512. }
  513. mgb_assert(j == outputs.size());
  514. mgb_assert(k == workspaces.size());
  515. // Memory forwarding was bypassed in megbrain with graph option
  516. // imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly
  517. // to initialize some opr(e.g. Subtensor)'s internal state
  518. // TODO: implement memory forwarding
  519. proxy->mem_plan_fwd_in2out_readonly();
  520. {
  521. // some opr (e.g. Reduce) rely on on_mem_status_changed to set
  522. // input/output tensor corretly, since we bypass var_node_mem_mgr
  523. // on_mem_status_changed should be called here
  524. auto&& cb = proxy->get_opr_event_callback().on_mem_status_changed;
  525. if (cb.valid()) {
  526. cb.val()();
  527. }
  528. }
  529. }
  530. cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
  531. const OpDef& opdef,
  532. const SmallVector<Tensor*>& inputs) {
  533. VarNodeArray vinputs(inputs.size());
  534. for (size_t i = 0; i < inputs.size(); ++ i) {
  535. vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node();
  536. }
  537. auto opr = OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr();
  538. mgb_assert(!opr->same_type<InputPlaceholder>());
  539. for (auto &&i : opr->input()) {
  540. mgb_assert(i->owner_opr()->same_type<InputPlaceholder>());
  541. }
  542. return opr;
  543. }
  544. /*********************** Logical Tensor Impl ***********************/
  545. size_t ProxyGraph::get_opr_output_size(const OpDef& opdef,
  546. const SmallVector<LogicalTensorDesc>& inputs) {
  547. return get_proxy_opr(opdef, inputs)->usable_output().size();
  548. }
  549. std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::infer_output_attrs_fallible(
  550. const OpDef& opdef,
  551. const SmallVector<LogicalTensorDesc>& inputs) {
  552. auto opr = get_proxy_opr(opdef, inputs);
  553. CUR_OPR_GUARD(opr);
  554. SmallVector<LogicalTensorDesc> outputs;
  555. bool validated = do_shape_infer(false);
  556. for (auto&& i : opr->usable_output()) {
  557. outputs.push_back({{i->shape(), i->dtype()}, i->comp_node()});
  558. }
  559. bool need_check = opr->same_type<opr::Reshape>();
  560. return {outputs, validated && !need_check};
  561. }
  562. std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::infer_output_mem_desc(
  563. const OpDef& def,
  564. const SmallVector<Tensor*>& inputs_tensors,
  565. const SmallVector<MemoryDesc>& inputs_mems) {
  566. auto opr = get_proxy_opr(def, inputs_tensors);
  567. CUR_OPR_GUARD(opr);
  568. do_shape_infer(true);
  569. SmallVector<MemoryDesc> outputs;
  570. SmallVector<MemoryDesc> workspaces;
  571. size_t cur_id = 0;
  572. for (auto&& i : opr->output()) {
  573. if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  574. workspaces.push_back({{i->shape(), i->dtype(), i->format()}, 0, i->comp_node(), StorageIdentifier::make(++ cur_id)});
  575. } else {
  576. outputs.push_back({{i->shape(), i->dtype()}, 0, i->comp_node(), StorageIdentifier::make(++ cur_id)});
  577. }
  578. }
  579. return {outputs, workspaces};
  580. }
  581. struct ProxyGraph::GradGraph {
  582. cg::VarNodeArray inputs;
  583. cg::VarNodeArray outputs;
  584. cg::VarNodeArray output_grads;
  585. cg::VarNode* grad;
  586. };
  587. BackwardGraphResult
  588. ProxyGraph::make_backward_graph(
  589. const OpDef& opdef,
  590. const SmallVector<LogicalTensorDesc>& input_descs,
  591. const SmallVector<bool>& input_requires_grad,
  592. const SmallVector<bool>& output_has_grad) {
  593. ThinHashMap<VarNode*, size_t> var2idx;
  594. auto push = [&var2idx, cnt=0](VarNode* var) mutable {
  595. auto&& ret = var2idx.emplace(var, cnt ++);
  596. mgb_assert(ret.second, "var %s has been already inserted", var->cname());
  597. return ret.first->second;
  598. };
  599. auto inputs = make_input_place_holders(input_descs);
  600. auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr();
  601. auto&& outputs = fwd->usable_output();
  602. SmallVector<LogicalTensorDesc> output_descs;
  603. for (auto&& i : outputs) {
  604. output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()});
  605. }
  606. auto output_grads = make_input_place_holders(output_descs);
  607. mgb_assert(output_grads.size() == output_has_grad.size());
  608. bool any_input_has_grad = false;
  609. for (size_t i = 0; i < output_grads.size(); ++ i) {
  610. if (!output_has_grad[i]) {
  611. output_grads[i] = nullptr;
  612. } else {
  613. any_input_has_grad = true;
  614. }
  615. }
  616. if (!any_input_has_grad) {
  617. return {};
  618. }
  619. auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo());
  620. BackwardGraphResult result;
  621. auto&& igraph = result.backward;
  622. size_t nr_backward_graph_inputs = 0;
  623. auto gen_expr = [this, &var2idx, &igraph, &push, &fwd,
  624. &nr_backward_graph_inputs](cg::OperatorNodeBase* op) {
  625. if (auto t = as_tensor(op)) {
  626. mgb_assert(op->output().size() == 1);
  627. igraph.constants.emplace_back(push(op->output(0)), std::move(t));
  628. } else if (op->same_type<InputPlaceholder>()) {
  629. ++ nr_backward_graph_inputs;
  630. push(op->output(0));
  631. } else {
  632. SmallVector<size_t> inputs, outputs;
  633. for (auto &&i : op->input()) {
  634. if (i->owner_opr() == fwd) {
  635. if (var2idx.find(i) == var2idx.end()) {
  636. ++ nr_backward_graph_inputs;
  637. push(i);
  638. }
  639. }
  640. inputs.push_back(var2idx.at(i));
  641. }
  642. for (auto &&i : op->usable_output()) {
  643. outputs.push_back(push(i));
  644. }
  645. igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs});
  646. }
  647. };
  648. // set backward graph outputs
  649. cg::DepOprIter iter{gen_expr};
  650. iter.set_visited(fwd);
  651. result.input_has_grad.resize(inputs.size());
  652. VarNodeArray output_grads_with_unused_var;
  653. {
  654. auto iter = output_grads.begin();
  655. for (auto&& i : fwd->output()) {
  656. if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  657. // the var node with VOLATILE_CONTENT(e.g. workspace
  658. // or an empty var) would not be considered as a normal
  659. // output, so its grad is always NULL
  660. output_grads_with_unused_var.push_back(nullptr);
  661. } else {
  662. output_grads_with_unused_var.push_back(*iter);
  663. ++ iter;
  664. }
  665. }
  666. mgb_assert(iter == output_grads.end());
  667. }
  668. Maybe<VarNodeArray> grad_results;
  669. for (size_t i = 0; i < inputs.size(); ++ i) {
  670. VarNode* grad;
  671. if (grad_results.valid()) {
  672. grad = grad_results.val()[i];
  673. } else {
  674. auto res = (*gfunc)(fwd, i, output_grads_with_unused_var);
  675. if (res.from_single()) {
  676. grad = res.single();
  677. } else {
  678. grad_results.emplace(res.all(fwd));
  679. grad = grad_results.val()[i];
  680. }
  681. }
  682. if (grad && !grad->owner_opr()->same_type<opr::InvalidGrad>()
  683. && input_requires_grad[i]) {
  684. mgb_assert(!grad->owner_opr()->same_type<opr::InvalidGrad>(),
  685. "gradient of operator %s w.r.t. input #%lu is "
  686. "either not well defined or not implemented",
  687. fwd->dyn_typeinfo()->name, i);
  688. iter.add(grad);
  689. igraph.outputs.push_back(var2idx.at(grad));
  690. result.input_has_grad[i] = true;
  691. } else {
  692. result.input_has_grad[i] = false;
  693. }
  694. }
  695. if (igraph.outputs.empty()) {
  696. return {};
  697. }
  698. // set backward graph inputs
  699. igraph.inputs.reserve(nr_backward_graph_inputs);
  700. result.save_for_backward.reserve(nr_backward_graph_inputs);
  701. auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) {
  702. for (auto&& i: vars) {
  703. auto&& iter = var2idx.find(i);
  704. if (iter != var2idx.end()) {
  705. igraph.inputs.push_back(iter->second);
  706. result.save_for_backward.push_back(true);
  707. } else {
  708. result.save_for_backward.push_back(false);
  709. }
  710. }
  711. };
  712. write_inputs(inputs);
  713. write_inputs(outputs);
  714. write_inputs(output_grads);
  715. mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs);
  716. return result;
  717. }
  718. cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(const OpDef& opdef,
  719. const SmallVector<LogicalTensorDesc>& inputs) {
  720. mgb_assert(!m_cur_opr);
  721. auto vinputs = make_input_place_holders(inputs);
  722. return OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr();
  723. }
  724. VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTensorDesc>& inputs) {
  725. VarNodeArray vinputs(inputs.size());
  726. for (size_t i = 0; i < inputs.size(); ++ i) {
  727. vinputs[i] = InputPlaceholder::make(*m_graph, inputs[i]).node();
  728. }
  729. return vinputs;
  730. }
  731. /*********************** Common Impl ***********************/
  732. bool ProxyGraph::do_shape_infer(bool sync_value) {
  733. m_static_infer_manager->update();
  734. bool validated = true;
  735. for (auto* var : m_cur_opr->output()) {
  736. if (sync_value) {
  737. var->shape(m_static_infer_manager->infer_shape(var));
  738. } else if (auto* shape = m_static_infer_manager->infer_shape_fallible(var)) {
  739. var->shape(*shape);
  740. } else {
  741. validated = false;
  742. }
  743. }
  744. return validated;
  745. }
  746. TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) {
  747. // TODO : maybe some tensor should copy value from origin opr rather than
  748. // share the RawStorage
  749. mgb_assert(share, "can't share memory with opr %s", opr->cname());
  750. if (opr->same_type<opr::ImmutableTensor>()) {
  751. auto&& dv = opr->cast_final_safe<opr::ImmutableTensor>().value();
  752. HostTensorND hv(dv.comp_node(), dv.shape(), dv.dtype());
  753. const DeviceTensorND* cpu_value;
  754. // get host value
  755. if (opr->owner_graph() == m_graph.get()) {
  756. CUR_OPR_GUARD(opr);
  757. m_static_infer_manager->update();
  758. cpu_value = m_static_infer_manager->infer_value_fallible(opr->output(0));
  759. } else {
  760. cpu_value = opr->owner_graph()->static_infer_manager().infer_value_fallible(opr->output(0));
  761. }
  762. mgb_assert(cpu_value);
  763. mgb_assert(cpu_value->comp_node() == CompNode::default_cpu());
  764. // default_cpu is synchronous with respect to caller
  765. hv.proxy_to_default_cpu().copy_from_fixlayout(*cpu_value);
  766. return Tensor::make(dv, hv);
  767. } else if (opr->same_type<opr::SharedDeviceTensor>()) {
  768. return Tensor::make(opr->cast_final_safe<opr::SharedDeviceTensor>().get_dev_tensor());
  769. } else {
  770. return {};
  771. }
  772. }
  773. thread_local std::unique_ptr<MegBrainError> ProxyGraph::tm_async_error;
  774. } // namespace imperative
  775. } // namespace mgb
  776. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台