You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor_opr.cpp 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. /**
  2. * \file src/jit/impl/executor_opr.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/jit/executor_opr.h"
  12. #include "megbrain/common.h"
  13. #include "megbrain/comp_node_env.h"
  14. #include "megbrain/gopt/framework.h"
  15. #include "megbrain/graph/grad_impl.h"
  16. #include "megbrain/graph/helper.h"
  17. #include "megbrain/jit/compiler.h"
  18. #include "megbrain/jit/param_elem_visitor.h"
  19. #include "megbrain/jit/placeholder_opr.h"
  20. #include "megbrain/opr/basic_arith_wrapper.h"
  21. #include "megbrain/opr/tensor_manip.h"
  22. #include "megbrain/opr/utility.h"
  23. #include "megbrain/utils/hash.h"
  24. #include "megbrain/serialization/opr_shallow_copy.h"
  25. #if MGB_JIT
  26. using namespace mgb;
  27. using namespace jit;
  28. using CPFlag = Compiler::Property::Flag;
  29. /* =================== Fusion ==================== */
  30. MGB_DYN_TYPE_OBJ_FINAL_IMPL(JITExecutor);
  31. JITExecutor::JITExecutor(const InternalGraphPtr& internal_graph,
  32. const VarNodeArray& inputs,
  33. const OperatorNodeConfig& config)
  34. : Super(internal_graph->output()->owner_graph(), config,
  35. ssprintf("JIT-Fusion{%zu}",
  36. internal_graph->placeholders().size()),
  37. inputs),
  38. m_internal_graph{internal_graph},
  39. m_compiler{Compiler::get(*inputs[0]->owner_graph(),
  40. inputs[0]->comp_node())} {
  41. for (auto inp : inputs) {
  42. add_input({inp});
  43. }
  44. m_input_broadcastable.resize(inputs.size());
  45. auto&& placeholders = m_internal_graph->placeholders();
  46. mgb_assert(placeholders.size() == inputs.size());
  47. for (size_t i = 0; i < inputs.size(); ++i) {
  48. mgb_assert(placeholders[i]->output(0) != internal_graph->output());
  49. if (placeholders[i]->is_host_value_shape_input() ||
  50. input()[i]
  51. ->owner_opr()
  52. ->same_type<opr::MarkNoBroadcastElemwise>()) {
  53. m_input_broadcastable[i] = false;
  54. } else {
  55. m_input_broadcastable[i] = true;
  56. }
  57. }
  58. if (inputs.size() == 1) {
  59. m_input_broadcastable[0] = false;
  60. } else {
  61. Maybe<size_t> non_scalar;
  62. for (size_t i = 0; i < input().size(); ++i) {
  63. if (placeholders[i]->is_host_value_shape_input())
  64. continue;
  65. if (!(cg::is_const_var_shape(input(i)) &&
  66. input(i)->shape().is_scalar())) {
  67. if (non_scalar.valid()) {
  68. non_scalar.invalidate();
  69. break;
  70. }
  71. non_scalar = i;
  72. }
  73. }
  74. if (non_scalar.valid()) {
  75. // exactly one input is non-scalar
  76. m_input_broadcastable[non_scalar.val()] = false;
  77. }
  78. }
  79. add_output(None)->dtype(m_internal_graph->output()->dtype());
  80. add_equivalence_component<ScalarHash<void*>>(internal_graph->output());
  81. for (size_t i = 0, it = m_compiler->get_nr_workspace_outputs(this); i < it;
  82. ++i) {
  83. cg::add_workspace_output(this);
  84. }
  85. // check if output of internal_graph is depend on all placeholders
  86. size_t nr_placeholders = internal_graph_ptr()->placeholders().size();
  87. std::vector<bool> used(nr_placeholders, false);
  88. // check if there is reduce or dimshuffle opr
  89. cg::DepOprIter{[this, nr_placeholders, &used](cg::OperatorNodeBase* opr) {
  90. if (opr->same_type<opr::Reduce>()) {
  91. m_feature_bits |= JITFeatureBits::REDUCE;
  92. }
  93. if (opr->same_type<opr::Dimshuffle>()) {
  94. m_feature_bits |= JITFeatureBits::DIMSHUFFLE;
  95. }
  96. if (auto ph = opr->try_cast_final<JITPlaceholder>()) {
  97. mgb_assert(ph->input_id() < nr_placeholders,
  98. "bad placeholders %s in JITExecutor %s",
  99. ph->cname(), cname());
  100. used[ph->input_id()] = true;
  101. }
  102. }}.add(internal_graph->output());
  103. for (size_t i = 0; i < nr_placeholders; ++ i) {
  104. mgb_assert(used[i],
  105. "placeholder %s is not depended on the output of %s",
  106. internal_graph_ptr()->placeholders()[i]->cname(), cname());
  107. }
  108. if (has_dimshuffle()) {
  109. prepare_dimshuffle();
  110. }
  111. }
  112. void JITExecutor::add_input_layout_constraint() {
  113. if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_CONTIG)) {
  114. for (auto i : input()) {
  115. i->add_layout_constraint_contiguous();
  116. }
  117. } else {
  118. for (auto i : input()) {
  119. i->add_layout_constraint_monotone();
  120. }
  121. }
  122. }
  123. void JITExecutor::init_output_mem_plan(bool dynamic) {
  124. Super::init_output_mem_plan(dynamic);
  125. m_args.need_update = true;
  126. }
  127. SymbolVar JITExecutor::make(const InternalGraphPtr& internal_graph,
  128. const VarNodeArray& inputs,
  129. const OperatorNodeConfig& config) {
  130. return internal_graph->output()
  131. ->owner_graph()
  132. ->insert_opr(std::make_unique<JITExecutor>(internal_graph, inputs,
  133. config))
  134. ->output(0);
  135. }
  136. void JITExecutor::init_output_static_infer_desc() {
  137. using namespace cg::static_infer;
  138. auto&& mgr = owner_graph()->static_infer_manager();
  139. mgr.register_shape_infer(
  140. output(0),
  141. ShapeInferDesc::make_identity(m_internal_graph->shape_infer()));
  142. m_compiler->init_workspace_size_infer(this);
  143. if (m_internal_graph->value_infer()) {
  144. mgr.register_value_infer(
  145. output(0),
  146. ValueInferDesc::make_identity(m_internal_graph->value_infer()));
  147. }
  148. }
  149. void JITExecutor::scn_do_execute() {
  150. if (m_executable == nullptr || m_args.need_update) {
  151. m_executable = m_compiler->compile(this);
  152. }
  153. m_executable->execute(this);
  154. }
  155. //! change the inputs which depend on dimshuffle opr, make sure dimshuffles
  156. //! can be ignored
  157. void JITExecutor::do_dimshuffle() {
  158. static auto get_dimshuffled_layout = [](const TensorLayout& ily,
  159. std::vector<int> pattern) {
  160. TensorLayout oly{ily.dtype};
  161. oly.ndim = pattern.size();
  162. bool input_used[TensorLayout::MAX_NDIM] = {0};
  163. for (uint32_t idx = 0; idx < pattern.size(); ++idx) {
  164. auto i = pattern[idx];
  165. if (i < 0) {
  166. oly.shape[idx] = 1;
  167. oly.stride[idx] = 1;
  168. } else {
  169. input_used[i] = true;
  170. oly.shape[idx] = ily.shape[i];
  171. oly.stride[idx] = ily.stride[i];
  172. }
  173. }
  174. for (size_t i = 0; i < ily.ndim; ++i) {
  175. mgb_assert(input_used[i] || ily.shape[i] == 1,
  176. "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
  177. static_cast<const TensorShape&>(ily).to_string().c_str(),
  178. i);
  179. }
  180. return oly;
  181. };
  182. for (auto&& i : m_internal_graph->placeholders()) {
  183. auto&& input = m_args.inputs[i->input_id()];
  184. auto&& iter = m_jitph2dimshuffle.find(i);
  185. if (iter == m_jitph2dimshuffle.end()) continue;
  186. auto&& param = iter->second;
  187. mgb_assert(input.layout.ndim == param.second,
  188. "input ndim mismatch for Dimshuffle: "
  189. "expect=%u "
  190. "actual=%zu",
  191. param.second, input.layout.ndim);
  192. auto dimshuffled_layout = get_dimshuffled_layout(
  193. input.layout, param.first);
  194. input.layout = dimshuffled_layout;
  195. }
  196. }
  197. void JITExecutor::update_args() {
  198. m_args.outputs.clear();
  199. for (auto out : output()) {
  200. m_args.outputs.push_back({out, out->layout(), -1});
  201. }
  202. m_args.inputs.resize(input().size());
  203. auto is_host_value_shape_input = [this](size_t idx) {
  204. return m_internal_graph->placeholders()
  205. .at(idx)
  206. ->is_host_value_shape_input();
  207. };
  208. for (size_t i = 0; i < input().size(); i++) {
  209. auto&& dst_data = m_args.inputs[i];
  210. dst_data.from = input(i);
  211. dst_data.idx = i;
  212. if (is_host_value_shape_input(i)) {
  213. auto&& mgr = owner_graph()->static_infer_manager();
  214. auto&& shpval_inp_val = &mgr.infer_value(input(i));
  215. cg::copy_tensor_value_to_shape(dst_data.layout, *shpval_inp_val);
  216. dst_data.layout.dtype = {};
  217. for (size_t i = 0; i < dst_data.layout.ndim; ++i) {
  218. dst_data.layout.stride[i] = 0;
  219. }
  220. } else {
  221. dst_data.layout = input(i)->layout();
  222. }
  223. }
  224. //! dimshuffle opr need to change the input.
  225. if (has_dimshuffle()) {
  226. do_dimshuffle();
  227. }
  228. if (m_compiler->property().contain_flag(CPFlag::NEED_INPUT_COLLAPSE)) {
  229. // collective collapse datum layout, try to reduce the output ndim
  230. opr::Elemwise::TensorLayoutPtrArray inp_layouts;
  231. inp_layouts.reserve(m_args.inputs.size());
  232. for (size_t i = 0; i < m_args.inputs.size(); i++) {
  233. if (!is_host_value_shape_input(i)) {
  234. inp_layouts.push_back(&m_args.inputs[i].layout);
  235. }
  236. }
  237. opr::Elemwise::broadcast_collective_collapse(inp_layouts,
  238. &m_args.outputs[0].layout);
  239. }
  240. // compute and update hash
  241. XXHash hstate;
  242. // update layout info
  243. auto prop = m_compiler->property();
  244. if (prop.contain_flag(CPFlag::BIND_NDIM | CPFlag::BIND_SHAPE)) {
  245. mgb_assert(prop.contain_flag(CPFlag::BIND_NDIM),
  246. "BIND_NDIM must be set if bind_shape is set");
  247. std::vector<size_t> buf;
  248. buf.reserve(1024);
  249. buf.push_back(m_args.inputs.size());
  250. for (auto&& i : m_args.inputs) {
  251. buf.push_back(i.layout.ndim);
  252. if (prop.contain_flag(CPFlag::BIND_SHAPE)) {
  253. for (size_t j = 0; j < i.layout.ndim; ++j) {
  254. buf.push_back(i.layout[j]);
  255. }
  256. }
  257. }
  258. hstate.update(buf.data(), sizeof(buf[0]) * buf.size());
  259. }
  260. m_args.hash = hstate.digest();
  261. // update version number
  262. static std::atomic_uint_fast64_t global_version;
  263. m_args.version = global_version.fetch_add(1);
  264. m_args.need_update = false;
  265. }
  266. void JITExecutor::prepare_dimshuffle() {
  267. std::unordered_set<OperatorNodeBase*> visited;
  268. std::vector<OperatorNodeBase*> stack(0);
  269. std::vector<uint8_t> idx(0); // input index
  270. using Param = DimshuffleParam;
  271. std::vector<Param> dimshuffle_stack;
  272. auto merge_dimshuffle = [&](const opr::Dimshuffle::Param& p) {
  273. if (dimshuffle_stack.empty()) {
  274. dimshuffle_stack.emplace_back();
  275. auto&& param = dimshuffle_stack.back();
  276. param.first.insert(param.first.end(), p.pattern, p.pattern + p.pattern_len);
  277. param.second = p.ndim;
  278. } else {
  279. // merge(p, src) -> param and it has performing dimshuffle(dimshuffle(x, p), src)
  280. // is equivalent to dimshuffle(x, param)
  281. dimshuffle_stack.emplace_back();
  282. auto&& param = dimshuffle_stack.back();
  283. auto&& src = dimshuffle_stack[dimshuffle_stack.size() - 2];
  284. mgb_assert(p.pattern_len == src.second);
  285. param.first.resize(src.first.size());
  286. for (size_t i = 0; i < src.first.size(); ++ i) {
  287. if (src.first[i] == -1) {
  288. param.first[i] = -1;
  289. } else {
  290. param.first[i] = p.pattern[src.first[i]];
  291. }
  292. }
  293. param.second = p.ndim;
  294. }
  295. };
  296. auto push_back = [&](cg::OperatorNodeBase* op) {
  297. mgb_assert(!op->same_type<jit::JITPlaceholder>());
  298. if (auto o = op->try_cast_final<opr::Dimshuffle>()) {
  299. merge_dimshuffle(o->param());
  300. }
  301. stack.push_back(op);
  302. idx.push_back(0);
  303. };
  304. auto pop_back = [&]() {
  305. auto&& op = stack.back();
  306. if (op->same_type<opr::Dimshuffle>()) {
  307. dimshuffle_stack.pop_back();
  308. }
  309. stack.pop_back();
  310. idx.pop_back();
  311. };
  312. push_back(m_internal_graph->output()->owner_opr());
  313. while (!stack.empty()) {
  314. if (idx.back() < stack.back()->input().size()) {
  315. auto cur_opr = stack.back()->input(idx.back())->owner_opr();
  316. if (visited.insert(cur_opr).second) {
  317. if (auto jitph = cur_opr->try_cast_final<jit::JITPlaceholder>()) {
  318. if (!dimshuffle_stack.empty()) {
  319. mgb_assert(
  320. m_jitph2dimshuffle.emplace(jitph, dimshuffle_stack.back()).second,
  321. "already visited JITPlaceholder %s",
  322. jitph->cname());
  323. }
  324. ++ idx.back();
  325. } else {
  326. push_back(cur_opr);
  327. }
  328. } else {
  329. ++ idx.back();
  330. }
  331. } else {
  332. pop_back();
  333. if (!stack.empty())
  334. ++ idx.back();
  335. }
  336. }
  337. }
  338. const JITExecutor::Args& JITExecutor::args() const {
  339. if (m_args.need_update) {
  340. const_cast<JITExecutor*>(this)->update_args();
  341. }
  342. return m_args;
  343. }
  344. bool JITExecutor::Args::operator==(const Args& rhs) const {
  345. auto&& lhs = *this;
  346. mgb_assert(!lhs.need_update && !rhs.need_update);
  347. if (lhs.hash != rhs.hash) {
  348. return false;
  349. }
  350. if (lhs.version == rhs.version) {
  351. return true;
  352. }
  353. if (lhs.outputs.size() != rhs.outputs.size())
  354. return false;
  355. if (lhs.inputs.size() != rhs.inputs.size())
  356. return false;
  357. auto prop = owner->m_compiler->property();
  358. if (prop.contain_flag(CPFlag::BIND_NDIM | CPFlag::BIND_SHAPE)) {
  359. bool (*chk_layout)(const TensorLayout&, const TensorLayout&);
  360. if (prop.contain_flag(CPFlag::BIND_SHAPE)) {
  361. chk_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) {
  362. return lhs.eq_shape(rhs);
  363. };
  364. } else {
  365. chk_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) {
  366. return lhs.ndim == rhs.ndim;
  367. };
  368. }
  369. for (size_t i = 0; i < lhs.inputs.size(); i++) {
  370. if (!chk_layout(lhs.inputs[i].layout, rhs.inputs[i].layout))
  371. return false;
  372. }
  373. for (size_t i = 0; i < lhs.outputs.size(); i++) {
  374. if (!chk_layout(lhs.outputs[i].layout, rhs.outputs[i].layout))
  375. return false;
  376. }
  377. }
  378. // elect a common version so next check can be fast
  379. lhs.version = rhs.version = std::min(lhs.version, rhs.version);
  380. return true;
  381. }
  382. JITExecutor::NodeProp* JITExecutor::do_make_node_prop() const {
  383. auto ret = Super::do_make_node_prop();
  384. using DepType = NodeProp::DepType;
  385. SmallVector<DepType> dt(input().size());
  386. auto&& placeholders = internal_graph().placeholders();
  387. for (size_t i = 0; i < dt.size(); ++i) {
  388. dt[i] = placeholders[i]->is_host_value_shape_input()
  389. ? DepType::HOST_VALUE
  390. : DepType::DEV_VALUE;
  391. }
  392. ret->reset_dep_type(input(), dt);
  393. return ret;
  394. }
  395. megdnn::TensorShape JITExecutor::broadcasted_input_shape() const {
  396. megdnn::TensorShapeArray inp_shps;
  397. megdnn::TensorShape brdcast_shp;
  398. auto placeholders = m_internal_graph->placeholders();
  399. for (auto ph : placeholders) {
  400. if (!ph->is_host_value_shape_input()) {
  401. inp_shps.push_back(input(ph->input_id())->shape());
  402. }
  403. }
  404. megdnn::Elemwise::deduce_shape(inp_shps, brdcast_shp);
  405. return brdcast_shp;
  406. }
  407. #if MGB_ENABLE_GRAD
  408. namespace {
  409. class InternalGraphRewriter {
  410. ThinHashMap<VarNode*, VarNode*> m_var_map;
  411. VarNode* m_dest_var;
  412. VarNodeArray m_new_inp;
  413. VarNode* get_var(VarNode* var) {
  414. auto&& iter = m_var_map.find(var);
  415. if (iter != m_var_map.end()) {
  416. return iter->second;
  417. }
  418. return var;
  419. }
  420. public:
  421. InternalGraphRewriter(VarNode* dest_var)
  422. :m_dest_var{dest_var}{}
  423. void iter(thin_function<void(cg::OperatorNodeBase*)>&& cb) {
  424. m_var_map.clear();
  425. cg::DepOprIter{std::move(cb)}.add(m_dest_var->owner_opr());
  426. m_dest_var = get_var(m_dest_var);
  427. }
  428. VarNode* dest_var() {
  429. return m_dest_var;
  430. }
  431. void replace_var(VarNode* src, VarNode* dst) {
  432. // Note: do not perform var replacing recursively
  433. // when we extract used placeholders from internal graph, we don't
  434. // consider placeholder replacement pair (a to b), (b to c) as a
  435. // var replacing chain (a to b to c) but as a injective function
  436. // from (a, b) to (b, c)
  437. // in other cases, each var node would be passed as \p src or
  438. // \p dst at most once
  439. m_var_map[src] = dst;
  440. }
  441. void auto_replace_outputs(cg::OperatorNodeBase* opr) {
  442. // in JIT internal graph, output size of opr is always 1
  443. mgb_assert(opr->usable_output().size() == 1);
  444. m_new_inp.clear();
  445. bool need_replace = false;
  446. for (auto&& i : opr->input()) {
  447. auto inp = get_var(i);
  448. m_new_inp.push_back(inp);
  449. need_replace |= (inp != i);
  450. }
  451. if (need_replace) {
  452. auto new_op = serialization::copy_opr_shallow(*opr, m_new_inp);
  453. replace_var(opr->output(0), new_op->output(0));
  454. }
  455. }
  456. };
  457. } // anonymous namespace
  458. MGB_IMPL_OPR_GRAD(JITExecutor) {
  459. VarNodeArray grad_inputs;
  460. for (auto input : opr.input())
  461. grad_inputs.push_back(input);
  462. mgb_assert(out_grad[0]);
  463. grad_inputs.push_back(opr.output(0));
  464. grad_inputs.push_back(out_grad[0]);
  465. auto fwd_igraph_ptr = opr.internal_graph_ptr();
  466. auto output_ph = JITPlaceholder::make(
  467. fwd_igraph_ptr->output(), fwd_igraph_ptr->placeholders().size());
  468. auto og_ph = JITPlaceholder::make(
  469. out_grad[0], fwd_igraph_ptr->placeholders().size() + 1);
  470. auto loss = opr::VirtualLoss::make({fwd_igraph_ptr->output()}, {og_ph});
  471. auto gx = cg::grad(loss, fwd_igraph_ptr->placeholders()[wrt_idx]->output(0),
  472. false, false);
  473. if (!gx.node()) {
  474. return nullptr;
  475. }
  476. if (gx.node()->owner_opr()->same_type<opr::InvalidGrad>()) {
  477. return opr::InvalidGrad::make(opr, wrt_idx);
  478. }
  479. // early return if grad expression is single node
  480. for (size_t i = 0; i < fwd_igraph_ptr->placeholders().size(); ++i) {
  481. if (gx.node() == fwd_igraph_ptr->placeholders()[i]->output(0)) {
  482. return grad_inputs[i];
  483. }
  484. }
  485. if (gx.node() == og_ph.node()) {
  486. return out_grad[0];
  487. }
  488. if (gx.node() == fwd_igraph_ptr->output()) {
  489. return opr.output(0);
  490. }
  491. if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(gx.node()->owner_opr())) {
  492. HostTensorND hval{grad_inputs[0]->comp_node()};
  493. hval.copy_from(imm->value()).sync();
  494. return opr::ImmutableTensor::make(*imm->owner_graph(), hval).node();
  495. }
  496. // replace output var in internal graph with output placeholder, so
  497. // we could forward opr.output(computeed by forward JITExecutor) into
  498. // placeholder to avoid redundant computation
  499. InternalGraphRewriter rewriter{gx.node()};
  500. rewriter.iter([&rewriter, &fwd_igraph_ptr,
  501. &output_ph](cg::OperatorNodeBase* opr) {
  502. if (opr == fwd_igraph_ptr->output()->owner_opr()) {
  503. rewriter.replace_var(opr->output(0), output_ph.node());
  504. return;
  505. }
  506. rewriter.auto_replace_outputs(opr);
  507. });
  508. auto expand_into_origin_graph = [&rewriter](
  509. cg::OperatorNodeBase* opr, const VarNodeArray& grad_inputs) {
  510. if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
  511. rewriter.replace_var(
  512. opr->output(0), grad_inputs.at(ph->input_id()));
  513. return;
  514. }
  515. if (auto imm = gopt::try_cast_as_op<opr::ImmutableTensor>(opr)) {
  516. HostTensorND hval{grad_inputs[0]->comp_node()};
  517. hval.copy_from(imm->value()).sync();
  518. rewriter.replace_var(opr->output(0),
  519. opr::ImmutableTensor::make(*opr->owner_graph(), hval).node());
  520. return;
  521. }
  522. rewriter.auto_replace_outputs(opr);
  523. };
  524. if (opr.compiler()->property().feature_bits & JITFeatureBits::REDUCE) {
  525. // expand the gradient graph into the original graph to handle bcast
  526. // oprs
  527. using namespace std::placeholders;
  528. rewriter.iter(std::bind(expand_into_origin_graph, _1,
  529. std::cref(grad_inputs)));
  530. return rewriter.dest_var();
  531. } else {
  532. VarNodeArray new_grad_inputs;
  533. PlaceholderArray placeholders;
  534. bool all_inp_const = true;
  535. // gx was not depend on all JITPlaceholders so we need to extract used
  536. // placeholders and build a new internal graph
  537. rewriter.iter([&rewriter, &grad_inputs, &new_grad_inputs,
  538. &placeholders, &all_inp_const](cg::OperatorNodeBase* opr) {
  539. if (auto ph = gopt::try_cast_as_op<JITPlaceholder>(opr)) {
  540. new_grad_inputs.push_back(grad_inputs[ph->input_id()]);
  541. auto new_ph = JITPlaceholder::make(
  542. new_grad_inputs.back(), placeholders.size())
  543. .node()->owner_opr();
  544. placeholders.push_back(new_ph->try_cast_final<JITPlaceholder>());
  545. mgb_assert(placeholders.back());
  546. rewriter.replace_var(opr->output(0), new_ph->output(0));
  547. if (!cg::is_const_var_value(new_grad_inputs.back())) {
  548. all_inp_const = false;
  549. }
  550. return;
  551. }
  552. rewriter.auto_replace_outputs(opr);
  553. });
  554. if (all_inp_const) {
  555. // if all_inp_const, expand grad graph into origin graph by replace
  556. // placeholders with const inputs, so it could benefit from static
  557. // infer and const folding mechanism
  558. using namespace std::placeholders;
  559. rewriter.iter(std::bind(expand_into_origin_graph, _1,
  560. std::cref(new_grad_inputs)));
  561. return rewriter.dest_var();
  562. }
  563. gx = rewriter.dest_var();
  564. auto shape_infer = fwd_igraph_ptr->shape_infer();
  565. if (opr.has_dimshuffle()) {
  566. auto&& iter = opr.dimshuffle_params().find(
  567. fwd_igraph_ptr->placeholders()[wrt_idx]);
  568. if (iter != opr.dimshuffle_params().end()) {
  569. auto&& pattern = iter->second.first;
  570. auto&& ndim = iter->second.second;
  571. std::vector<int> back(ndim, -1);
  572. for (size_t i = 0; i < pattern.size(); i ++) {
  573. // outdim[i] is indim[j]
  574. auto j = pattern[i];
  575. if (j >= 0) {
  576. mgb_assert(back[j] == -1,
  577. "taking grad for Dimshuffle with duplicated "
  578. "input axis unsupported");
  579. back[j] = i;
  580. }
  581. }
  582. shape_infer = opr::Dimshuffle::make(shape_infer, back, pattern.size()).node();
  583. }
  584. }
  585. auto grad_ig = std::make_shared<InternalGraph>(
  586. gx.node(), shape_infer, nullptr,
  587. std::move(placeholders));
  588. auto grad_jit = JITExecutor::make(grad_ig, new_grad_inputs);
  589. if (opr.input_broadcastable()[wrt_idx]) {
  590. grad_jit = opr::reduce_sum(
  591. grad_jit, opr::GetVarShape::make(opr.input(wrt_idx)));
  592. }
  593. return grad_jit.node();
  594. }
  595. }
  596. #endif // MGB_ENABLE_GRAD
  597. #endif // MGB_JIT
  598. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台