You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

proxy_graph.cpp 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670
  1. /**
  2. * \file imperative/src/impl/proxy_graph.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./proxy_graph.h"
  12. #include "./blob_manager_impl.h"
  13. #include "megbrain/graph/operator_node.h"
  14. #include "megbrain/graph/static_infer.h"
  15. #include "megbrain/imperative/ops/backward_graph.h"
  16. #include "megbrain/imperative/ops/opr_attr.h"
  17. #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
  18. #include "megbrain/opr/io.h"
  19. #include "megbrain/opr/tensor_manip.h"
  20. #include "megbrain/opr/utility.h"
  21. #if __cplusplus >= 201703L
  22. #include <optional>
  23. #endif
  24. namespace mgb {
  25. namespace imperative {
  26. using cg::OperatorNodeBase;
  27. template <bool p, typename T, typename F>
  28. constexpr auto&& select(T&& t, F&& f) {
  29. if constexpr (p) {
  30. return std::forward<T>(t);
  31. } else {
  32. return std::forward<F>(f);
  33. }
  34. }
  35. MGB_DEFINE_OPR_CLASS(ProxyGraph::InputPlaceholder, cg::OperatorNodeBase) // {
  36. void on_output_comp_node_stream_changed() override { mgb_assert(0); }
  37. // TODO: consider implement following initialization method,
  38. // so InputPlaceholder can be initialized correctly during
  39. // operator insertion
  40. void init_output_comp_node() override {}
  41. void init_output_format() override {}
  42. void init_output_dtype() override {}
  43. void init_output_static_infer_desc() override {}
  44. void init_output_mem_plan(bool dynamic) override {
  45. MGB_MARK_USED_VAR(dynamic);
  46. mgb_assert(0);
  47. }
  48. void do_execute(ExecEnv& env) override { mgb_assert(0); }
  49. public:
  50. Tensor* m_tensor;
  51. InputPlaceholder(
  52. ComputingGraph& graph, Tensor* tensor = nullptr,
  53. const DeviceTensorND& static_infer_value = {})
  54. : Super(&graph, {}, "device_value", {}),
  55. m_tensor(tensor),
  56. m_static_infer_value(static_infer_value) {
  57. mgb_assert(
  58. m_static_infer_value.empty() ||
  59. m_static_infer_value.comp_node() == CompNode::default_cpu());
  60. add_output(None)->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC);
  61. // never dedup
  62. add_equivalence_component<ScalarHash<void*>>(this);
  63. }
  64. static SymbolVar make(ComputingGraph& graph, Tensor& tensor) {
  65. auto opr = graph.insert_opr(std::make_unique<InputPlaceholder>(graph, &tensor));
  66. auto var = opr->output(0);
  67. auto&& dev_tensor = tensor.dev_tensor(false);
  68. var->m_comp_node = dev_tensor.comp_node();
  69. var->m_shape = dev_tensor.shape();
  70. if (dev_tensor.empty()) {
  71. auto layout = dev_tensor.layout();
  72. layout.init_contiguous_stride();
  73. dev_tensor.reset(dev_tensor.storage(), layout);
  74. }
  75. var->force_assign_dev_tensor_from_tensor(dev_tensor);
  76. return var;
  77. }
  78. static SymbolVar make(ComputingGraph& graph, const LogicalTensorDesc& desc) {
  79. auto opr = graph.insert_opr(
  80. std::make_unique<InputPlaceholder>(graph, nullptr, desc.value));
  81. auto var = opr->output(0);
  82. var->m_comp_node = desc.comp_node;
  83. var->m_shape = desc.layout;
  84. var->m_dev_tensor.reset({}, TensorLayout(desc.layout.dtype));
  85. return var;
  86. }
  87. const DeviceTensorND* get_static_infer_value(bool may_sync) {
  88. if (!m_static_infer_value.empty()) {
  89. return &m_static_infer_value;
  90. }
  91. if (m_tensor && (may_sync || m_tensor->try_get_value())) {
  92. auto&& hv = m_tensor->get_value();
  93. mgb_assert(!hv.empty());
  94. m_static_infer_value = hv.proxy_to_default_cpu();
  95. // steal ownership from shared_ptr
  96. using SP = std::shared_ptr<dt_byte>;
  97. auto& sp = const_cast<SP&>(m_static_infer_value.storage().raw_storage());
  98. static auto dummy = std::make_shared<dt_byte>();
  99. sp = SP(dummy, sp.get());
  100. return &m_static_infer_value;
  101. }
  102. return nullptr;
  103. }
  104. private:
  105. DeviceTensorND m_static_infer_value;
  106. };
  107. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ProxyGraph::InputPlaceholder);
  108. class ProxyGraph::StaticInferManager : public cg::static_infer::StaticInferManager {
  109. public:
  110. using Tag = cg::static_infer::Tag;
  111. using ShapeInferDesc = cg::static_infer::ShapeInferDesc;
  112. using ValueInferDesc = cg::static_infer::ValueInferDesc;
  113. using InferType = cg::static_infer::InferType;
  114. using DepVal = cg::static_infer::DepVal;
  115. using DepElement = cg::static_infer::DepElement;
  116. using DepType = cg::static_infer::DepType;
  117. using InpElement = cg::static_infer::InpElement;
  118. struct Result {
  119. TensorShape shape;
  120. DeviceTensorND value;
  121. };
  122. ProxyGraph* owner;
  123. cg::OperatorNodeBase* cur_opr = nullptr;
  124. std::vector<std::optional<ShapeInferDesc>> shape_descs;
  125. std::vector<std::optional<ValueInferDesc>> value_descs;
  126. std::vector<Result> inferred_outputs;
  127. StaticInferManager(ProxyGraph* owner_) : owner(owner_) {}
  128. size_t locate_output(VarNode* var) {
  129. mgb_assert(cur_opr);
  130. auto&& output_vars = cur_opr->output();
  131. mgb_assert(shape_descs.size() == output_vars.size());
  132. auto&& it = std::find(output_vars.begin(), output_vars.end(), var);
  133. mgb_assert(it != output_vars.end());
  134. return it - output_vars.begin();
  135. }
  136. void register_shape_infer(Tag dest, const ShapeInferDesc& desc) override {
  137. auto i = locate_output(dest);
  138. mgb_assert(!shape_descs[i]);
  139. shape_descs[i].emplace(desc);
  140. }
  141. void register_value_infer(Tag dest, const ValueInferDesc& desc) override {
  142. auto i = locate_output(dest);
  143. mgb_assert(!value_descs[i]);
  144. value_descs[i].emplace(desc);
  145. }
  146. InferType get_infer_type(Tag var) override {
  147. // don't let opr apply any immediate optimization
  148. return {InferType::MISSING_INP, InferType::MISSING_INP};
  149. }
  150. void update() {
  151. if (cur_opr != owner->m_cur_opr) {
  152. clear();
  153. cur_opr = owner->m_cur_opr;
  154. if (cur_opr) {
  155. auto nout = cur_opr->output().size();
  156. shape_descs.resize(nout);
  157. value_descs.resize(nout);
  158. inferred_outputs.resize(nout);
  159. cur_opr->init_output_static_infer_desc();
  160. }
  161. }
  162. }
  163. void clear() {
  164. cur_opr = nullptr;
  165. shape_descs.clear();
  166. value_descs.clear();
  167. inferred_outputs.clear();
  168. }
  169. template <bool is_shape>
  170. auto do_infer(Tag dest, bool may_sync)
  171. -> const std::conditional_t<is_shape, TensorShape, DeviceTensorND>* {
  172. // Some infer_func does not use InpVal passed to them, but
  173. // call infer_* on their inputs instead, so dest could be an input.
  174. // It is also possible that an opr call infer_* on its inputs before it
  175. // is inserted
  176. if (auto opr = dest->owner_opr()->try_cast_final<InputPlaceholder>()) {
  177. if constexpr (is_shape) {
  178. auto* shp = &dest->shape();
  179. return shp->ndim ? shp : nullptr;
  180. } else {
  181. return opr->get_static_infer_value(may_sync);
  182. }
  183. }
  184. mgb_assert(cur_opr);
  185. mgb_assert(cur_opr->output().size() == shape_descs.size());
  186. // dest must be an output now
  187. auto i = locate_output(dest);
  188. auto& result = inferred_outputs[i];
  189. auto& desc = select<is_shape>(shape_descs[i], value_descs[i]);
  190. // return if no need to call infer_func
  191. if constexpr (is_shape) {
  192. if (result.shape.ndim != 0) {
  193. return &result.shape;
  194. }
  195. } else {
  196. if (!result.value.empty()) {
  197. return &result.value;
  198. }
  199. }
  200. if (!desc) {
  201. return nullptr;
  202. }
  203. // fill args for infer_func
  204. cg::static_infer::InpVal args{1};
  205. auto push_shape = [&args](const TensorShape* shape) {
  206. args.val.emplace_back();
  207. args.val.back().m_shape = shape;
  208. };
  209. auto push_value = [&args](const DeviceTensorND* value) {
  210. args.val.emplace_back();
  211. args.val.back().m_value = value;
  212. };
  213. for (auto&& dep : desc->deps) {
  214. if (auto opr = dep.dest->owner_opr()
  215. ->template try_cast_final<InputPlaceholder>()) {
  216. if (dep.type == DepType::SHAPE) {
  217. if (dep.dest->shape().ndim) {
  218. push_shape(&dep.dest->shape());
  219. } else {
  220. return nullptr;
  221. }
  222. } else {
  223. if (auto* p = opr->get_static_infer_value(may_sync)) {
  224. push_value(p);
  225. } else {
  226. return nullptr;
  227. }
  228. }
  229. continue;
  230. }
  231. // dep must be an output
  232. if (dep.type == DepType::SHAPE) {
  233. if (auto* p = do_infer<true>(dep.dest, may_sync)) {
  234. push_shape(p);
  235. } else {
  236. return nullptr;
  237. }
  238. } else {
  239. if (auto* p = do_infer<false>(dep.dest, may_sync)) {
  240. push_value(p);
  241. } else {
  242. return nullptr;
  243. }
  244. }
  245. }
  246. // call infer_func
  247. if constexpr (is_shape) {
  248. if (!desc->infer_func(result.shape, args)) {
  249. mgb_log_warn(
  250. "something is missing for shape inference of %s",
  251. cur_opr->dyn_typeinfo()->name);
  252. return nullptr;
  253. }
  254. return &result.shape;
  255. } else {
  256. if (!desc->infer_func(result.value, args)) {
  257. mgb_log_warn(
  258. "something is missing for value inference of %s",
  259. cur_opr->dyn_typeinfo()->name);
  260. return nullptr;
  261. }
  262. return &result.value;
  263. }
  264. }
  265. const TensorShape& infer_shape(Tag var) override {
  266. auto* p = do_infer<true>(var, true);
  267. mgb_assert(p, "failed to infer shape for %s", var->name().c_str());
  268. return *p;
  269. }
  270. const TensorShape* infer_shape_fallible(Tag var) override {
  271. return do_infer<true>(var, false);
  272. }
  273. const DeviceTensorND& infer_value(Tag var) override {
  274. auto* p = do_infer<false>(var, true);
  275. mgb_assert(p, "failed to infer value for %s", var->name().c_str());
  276. return *p;
  277. }
  278. const DeviceTensorND* infer_value_fallible(Tag var) override {
  279. return do_infer<false>(var, false);
  280. }
  281. DepVal get_rt_static_source_deps(const DepElement&) override { mgb_assert(0); }
  282. };
  283. class ProxyGraph::SeqCompNodeOptimizer : public cg::SeqCompNodeOptimizer {
  284. void register_stream_var(VarNode*, StreamPropType) override {}
  285. void register_propagate_function(VarNode*, PropFunction) override {}
  286. StreamPropType stream_prop_type(VarNode*) override { mgb_assert(0); }
  287. };
  288. class ProxyGraph::ProxyGraphImpl : public cg::ComputingGraph {
  289. static std::atomic<size_t> m_node_id;
  290. ProxyGraph* m_owner;
  291. MemPool<VarNode> m_var_node_pool;
  292. std::vector<std::unique_ptr<OperatorNodeBase>> m_opr_refkeeper;
  293. std::mutex m_opr_refkeeper_mtx;
  294. CompNode::UnorderedSet m_used_comp_node;
  295. VarReceiverInfo m_var_receiver_info;
  296. public:
  297. ~ProxyGraphImpl() {
  298. mgb_assert(!m_owner->m_cur_opr);
  299. if (is_finalized())
  300. return;
  301. for (auto&& i : m_used_comp_node) {
  302. if (i.device_type() == CompNode::DeviceType::CUDA)
  303. continue;
  304. if (i.device_type() == CompNode::DeviceType::ROCM)
  305. continue;
  306. i.sync();
  307. }
  308. }
  309. ProxyGraphImpl(ProxyGraph* owner) : m_owner(owner) {
  310. options().imperative_proxy_graph = true;
  311. options().no_force_inplace = true;
  312. options().log_level = 0;
  313. m_var_receiver_info.dev_value = 1;
  314. m_var_receiver_info.allow_empty_value = 1;
  315. }
  316. static std::unique_ptr<ProxyGraphImpl> make(ProxyGraph* owner) {
  317. return std::make_unique<ProxyGraphImpl>(owner);
  318. }
  319. void add_used_comp_node(CompNode cn) { m_used_comp_node.insert(cn); }
  320. bool invalid() const {
  321. return is_finalized() || nr_oprs_in_graph() > m_owner->m_max_op_cnt;
  322. }
  323. size_t next_node_id() override { return m_node_id.fetch_add(1); }
  324. void* alloc_varnode_storage() override { return m_var_node_pool.alloc_raw(); }
  325. void free_varnode_storage(void* ptr) override { m_var_node_pool.free_raw(ptr); }
  326. OperatorNodeBase* insert_opr(std::unique_ptr<OperatorNodeBase> opr_uniqp) override {
  327. mgb_assert(!is_finalized());
  328. auto opr = opr_uniqp.get();
  329. if (!opr->inserted_in_graph()) {
  330. m_opr_refkeeper.emplace_back(std::move(opr_uniqp));
  331. opr->set_inserted_in_graph();
  332. opr->init_output_comp_node();
  333. opr->init_output_dtype();
  334. opr->init_output_format();
  335. }
  336. return opr;
  337. }
  338. cg::static_infer::StaticInferManager& static_infer_manager() override {
  339. return *m_owner->m_static_infer_manager;
  340. }
  341. cg::SeqCompNodeOptimizer& seq_comp_node_optimizer() override {
  342. return *m_owner->m_seq_comp_node_optimizer;
  343. }
  344. std::shared_ptr<void> on_comp_node_finalize() override {
  345. MGB_LOCK_GUARD(m_opr_refkeeper_mtx);
  346. mgb_assert(!m_owner->m_cur_opr);
  347. // finalize would do sync first
  348. m_opr_refkeeper.clear();
  349. return {};
  350. }
  351. const VarReceiverInfo& var_receiver_in_current_comp_seq(
  352. const VarNode* var) const override {
  353. return m_var_receiver_info;
  354. }
  355. size_t nr_oprs_in_graph() const override { return m_opr_refkeeper.size(); }
  356. void record_async_error(std::unique_ptr<MegBrainError> async_exc) override {
  357. if (!ProxyGraph::tm_async_error) {
  358. std::swap(async_exc, tm_async_error);
  359. }
  360. }
  361. std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec& out_spec) override {
  362. mgb_assert(0);
  363. }
  364. SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part(
  365. const SmallVector<OutputSpec>& out_specs) override {
  366. mgb_assert(0);
  367. }
  368. cg::AsyncExecutable* current_comp_seq() override { mgb_assert(0); }
  369. std::string get_mem_allocation_info() const override { mgb_assert(0); }
  370. VarNode* find_var_by_id(size_t id) const override { mgb_assert(0); }
  371. void share_device_memory_with(ComputingGraph& other) override { mgb_assert(0); }
  372. void set_device_memory_allocator(
  373. std::shared_ptr<cg::DeviceMemoryAllocator> allocator) override {
  374. mgb_assert(0);
  375. }
  376. size_t get_device_memory_size(CompNode cn) override { mgb_assert(0); }
  377. size_t clear_device_memory() override { mgb_assert(0); }
  378. void set_as_subgraph(ComputingGraph& par_graph) override { mgb_assert(0); }
  379. };
  380. std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0;
  381. ProxyGraph::ProxyGraph()
  382. : m_graph(ProxyGraphImpl::make(this)),
  383. m_static_infer_manager(new StaticInferManager(this)),
  384. m_seq_comp_node_optimizer(new SeqCompNodeOptimizer()) {}
  385. void ProxyGraph::reset() {
  386. mgb_assert(!m_cur_opr);
  387. m_graph = ProxyGraphImpl::make(this);
  388. }
  389. ProxyGraph* ProxyGraph::get_default_graph() {
  390. static thread_local ProxyGraph inst;
  391. if (inst.m_graph->invalid()) {
  392. inst.reset();
  393. }
  394. return &inst;
  395. }
  396. class ProxyGraph::CurOprGuard {
  397. public:
  398. CurOprGuard(ProxyGraph* owner, OperatorNodeBase* opr) : m_owner(owner) {
  399. mgb_assert(!owner->m_cur_opr);
  400. owner->m_cur_opr = opr;
  401. }
  402. CurOprGuard(const CurOprGuard&) = delete;
  403. ~CurOprGuard() { m_owner->cleanup(); }
  404. private:
  405. ProxyGraph* m_owner;
  406. };
  407. #define CUR_OPR_GUARD(opr) \
  408. CurOprGuard MGB_TOKENPASTE2(__cur_opr_guard_, __LINE__)(this, opr)
  409. /*********************** Physical Tensor Impl ***********************/
  410. void ProxyGraph::cleanup() {
  411. if (m_cur_opr) {
  412. for (auto&& i : m_cur_opr->input()) {
  413. i->m_dev_tensor.storage({});
  414. }
  415. for (auto&& i : m_cur_opr->output()) {
  416. i->m_dev_tensor.storage({});
  417. }
  418. m_static_infer_manager->clear();
  419. }
  420. m_cur_opr = nullptr;
  421. }
  422. /*********************** Logical Tensor Impl ***********************/
  423. EncodedSubgraph ProxyGraph::make_backward_graph(
  424. const OpDef& opdef, const SmallVector<LogicalTensorDesc>& input_descs,
  425. const SmallVector<bool>& input_requires_grad,
  426. const SmallVector<bool>& output_has_grad) {
  427. ThinHashMap<VarNode*, size_t> var2idx;
  428. auto push = [&var2idx,
  429. cnt = 1](VarNode* var) mutable { // cnt is always greater non zero
  430. auto&& ret = var2idx.emplace(var, cnt++);
  431. mgb_assert(ret.second, "var %s has been already inserted", var->cname());
  432. return ret.first->second;
  433. };
  434. auto inputs = make_input_place_holders(input_descs);
  435. auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr();
  436. auto&& outputs = fwd->usable_output();
  437. SmallVector<LogicalTensorDesc> output_descs;
  438. for (auto&& i : outputs) {
  439. output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()});
  440. }
  441. auto output_grads = make_input_place_holders(output_descs);
  442. mgb_assert(
  443. output_grads.size() == output_has_grad.size(), "%d vs %d",
  444. output_grads.size(), output_has_grad.size());
  445. bool any_input_has_grad = false;
  446. for (size_t i = 0; i < output_grads.size(); ++i) {
  447. if (!output_has_grad[i]) {
  448. output_grads[i] = nullptr;
  449. } else {
  450. any_input_has_grad = true;
  451. }
  452. }
  453. if (!any_input_has_grad) {
  454. return {};
  455. }
  456. auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo());
  457. EncodedSubgraph result;
  458. auto&& igraph = result.graph;
  459. size_t nr_backward_graph_inputs = 0;
  460. auto gen_expr = [this, &var2idx, &igraph, &push, &fwd,
  461. &nr_backward_graph_inputs](cg::OperatorNodeBase* op) {
  462. if (auto t = as_tensor(op)) {
  463. mgb_assert(op->output().size() == 1);
  464. igraph.constants.emplace_back(push(op->output(0)), std::move(t));
  465. } else if (op->same_type<InputPlaceholder>()) {
  466. ++nr_backward_graph_inputs;
  467. push(op->output(0));
  468. } else {
  469. SmallVector<size_t> inputs, outputs;
  470. for (auto&& i : op->input()) {
  471. if (i->owner_opr() == fwd) {
  472. if (var2idx.find(i) == var2idx.end()) {
  473. ++nr_backward_graph_inputs;
  474. push(i);
  475. }
  476. }
  477. inputs.push_back(var2idx.at(i));
  478. }
  479. for (auto&& i : op->usable_output()) {
  480. outputs.push_back(push(i));
  481. }
  482. igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs});
  483. }
  484. };
  485. // set backward graph outputs
  486. cg::DepOprIter iter{gen_expr};
  487. iter.set_visited(fwd);
  488. result.output_mask.resize(inputs.size());
  489. VarNodeArray output_grads_with_unused_var;
  490. {
  491. auto iter = output_grads.begin();
  492. for (auto&& i : fwd->output()) {
  493. if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  494. // the var node with VOLATILE_CONTENT(e.g. workspace
  495. // or an empty var) would not be considered as a normal
  496. // output, so its grad is always NULL
  497. output_grads_with_unused_var.push_back(nullptr);
  498. } else {
  499. output_grads_with_unused_var.push_back(*iter);
  500. ++iter;
  501. }
  502. }
  503. mgb_assert(iter == output_grads.end());
  504. }
  505. Maybe<VarNodeArray> grad_results;
  506. for (size_t i = 0; i < inputs.size(); ++i) {
  507. VarNode* grad;
  508. if (grad_results.valid()) {
  509. grad = grad_results.val()[i];
  510. } else {
  511. mgb_assert(gfunc, "could not find grad function");
  512. auto res = (*gfunc)(fwd, i, output_grads_with_unused_var);
  513. if (res.from_single()) {
  514. grad = res.single();
  515. } else {
  516. grad_results.emplace(res.all(fwd));
  517. grad = grad_results.val()[i];
  518. }
  519. }
  520. if (grad && !grad->owner_opr()->same_type<opr::InvalidGrad>() &&
  521. input_requires_grad[i]) {
  522. mgb_assert(
  523. !grad->owner_opr()->same_type<opr::InvalidGrad>(),
  524. "gradient of operator %s w.r.t. input #%lu is "
  525. "either not well defined or not implemented",
  526. fwd->dyn_typeinfo()->name, i);
  527. iter.add(grad);
  528. igraph.outputs.push_back(var2idx.at(grad));
  529. result.output_mask[i] = true;
  530. } else {
  531. result.output_mask[i] = false;
  532. }
  533. }
  534. if (igraph.outputs.empty()) {
  535. return {};
  536. }
  537. // set backward graph inputs
  538. auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) {
  539. for (auto&& i : vars) {
  540. auto&& iter = var2idx.find(i);
  541. if (iter != var2idx.end()) {
  542. igraph.inputs.push_back(iter->second);
  543. result.input_mask.push_back(true);
  544. } else {
  545. result.input_mask.push_back(false);
  546. }
  547. }
  548. };
  549. write_inputs(inputs);
  550. write_inputs(outputs);
  551. write_inputs(output_grads);
  552. mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs);
  553. return result;
  554. }
  555. VarNodeArray ProxyGraph::make_input_place_holders(
  556. const SmallVector<LogicalTensorDesc>& inputs) {
  557. VarNodeArray vinputs(inputs.size());
  558. for (size_t i = 0; i < inputs.size(); ++i) {
  559. vinputs[i] = InputPlaceholder::make(*m_graph, inputs[i]).node();
  560. }
  561. return vinputs;
  562. }
  563. /*********************** Common Impl ***********************/
  564. TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) {
  565. // TODO : maybe some tensor should copy value from origin opr rather than
  566. // share the RawStorage
  567. mgb_assert(share, "can't share memory with opr %s", opr->cname());
  568. if (opr->same_type<opr::ImmutableTensor>()) {
  569. auto&& dv = opr->cast_final_safe<opr::ImmutableTensor>().value();
  570. HostTensorND hv(dv.comp_node(), dv.shape(), dv.dtype());
  571. const DeviceTensorND* cpu_value;
  572. // get host value
  573. if (opr->owner_graph() == m_graph.get()) {
  574. CUR_OPR_GUARD(opr);
  575. m_static_infer_manager->update();
  576. cpu_value = m_static_infer_manager->infer_value_fallible(opr->output(0));
  577. } else {
  578. cpu_value = opr->owner_graph()->static_infer_manager().infer_value_fallible(
  579. opr->output(0));
  580. }
  581. mgb_assert(cpu_value);
  582. mgb_assert(cpu_value->comp_node() == CompNode::default_cpu());
  583. // default_cpu is synchronous with respect to caller
  584. hv.proxy_to_default_cpu().copy_from_fixlayout(*cpu_value);
  585. return Tensor::make(dv, hv);
  586. } else if (opr->same_type<opr::SharedDeviceTensor>()) {
  587. return Tensor::make(
  588. opr->cast_final_safe<opr::SharedDeviceTensor>().get_dev_tensor());
  589. } else {
  590. return {};
  591. }
  592. }
  593. thread_local std::unique_ptr<MegBrainError> ProxyGraph::tm_async_error;
  594. } // namespace imperative
  595. } // namespace mgb
  596. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}