You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. /**
  2. * \file src/opr/impl/indexing.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/indexing.h"
  12. #include "megbrain/opr/basic_arith.h"
  13. #include "megbrain/opr/utility.h"
  14. #include "megbrain/graph/grad_impl.h"
  15. #include "./internal/megdnn_opr_wrapper.inl"
  16. using namespace mgb;
  17. using namespace opr;
  18. namespace {
  19. void check_index_dtype(std::initializer_list<SymbolVar*> &inputs) {
  20. mgb_assert(inputs.size() >= 2);
  21. auto iter = inputs.begin();
  22. ++ iter;
  23. SymbolVar &index = **iter;
  24. if (index.dtype() != dtype::Int32()) {
  25. mgb_log_warn("dtype of index in IndexingOneHot must be Int32, "
  26. "got %s for variable %s; convert to Int32 implicitly",
  27. index.dtype().name(), index.node()->cname());
  28. index = opr::TypeCvt::make(index, dtype::Int32());
  29. }
  30. }
  31. enum IndexingModifyType {
  32. SET, INCR
  33. };
  34. template<typename Opr>
  35. struct IndexingModifyTypeGetter {};
  36. #define REG(op, type) \
  37. template<> \
  38. struct IndexingModifyTypeGetter<megdnn::op> { \
  39. static constexpr IndexingModifyType value = IndexingModifyType::type; \
  40. };
  41. REG(IndexingIncrMultiAxisVec, INCR)
  42. REG(IncrMeshIndexing, INCR)
  43. REG(BatchedIncrMeshIndexing, INCR)
  44. REG(IndexingSetMultiAxisVec, SET)
  45. REG(SetMeshIndexing, SET)
  46. REG(BatchedSetMeshIndexing, SET)
  47. #undef REG
  48. }
  49. namespace mgb {
  50. namespace opr {
  51. namespace intl {
  52. template<>
  53. struct MegDNNOprInitInputsModifier<IndexingOneHot> {
  54. static void apply(const IndexingOneHot::Param &param,
  55. std::initializer_list<SymbolVar*> inputs) {
  56. MGB_MARK_USED_VAR(param);
  57. check_index_dtype(inputs);
  58. }
  59. };
  60. template<>
  61. struct MegDNNOprInitInputsModifier<IndexingSetOneHot>:
  62. public MegDNNOprInitInputsModifier<IndexingOneHot> {};
  63. }
  64. }
  65. }
  66. /* ==================== IndexingOneHot ==================== */
  67. MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingOneHot);
  68. MEGDNN_OPR_INIT2(IndexingOneHot, "indexing_one_hot")
  69. void IndexingOneHot::init_output_dtype() {
  70. output(0)->dtype(input(0)->dtype());
  71. }
  72. MGB_IMPL_OPR_GRAD(IndexingOneHot) {
  73. if (wrt_idx == 0) {
  74. return IndexingSetOneHot::make(
  75. SymbolVar{opr.input(0)}.fill_retain_dtype(0),
  76. opr.input(1), out_grad[0], opr.param()).node();
  77. }
  78. return InvalidGrad::make(opr, wrt_idx);
  79. }
  80. /* ==================== IndexingSetOneHot ==================== */
  81. MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingSetOneHot);
  82. MEGDNN_OPR_INIT3(IndexingSetOneHot, "indexing_set_one_hot")
  83. void IndexingSetOneHot::init_output_dtype() {
  84. output(0)->dtype(input(0)->dtype());
  85. }
  86. void IndexingSetOneHot::add_input_layout_constraint() {
  87. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  88. }
  89. void IndexingSetOneHot::mem_plan_fwd_in2out_writable() {
  90. cg::request_fwd_in2out_writable_if_no_mem_ovelap(this, 0, 0);
  91. }
  92. void IndexingSetOneHot::init_output_static_infer_desc() {
  93. using namespace cg::static_infer;
  94. auto &&mgr = owner_graph()->static_infer_manager();
  95. mgr.register_shape_infer(output(0),
  96. ShapeInferDesc::make_identity(input(0)));
  97. init_output_static_infer_desc_workspace(false);
  98. }
  99. void IndexingSetOneHot::scn_do_execute() {
  100. auto &&idata = input(0)->dev_tensor(), &&index = input(1)->dev_tensor(),
  101. &&odata = output(0)->dev_tensor();
  102. if (idata.raw_ptr() != odata.raw_ptr()) {
  103. odata.copy_from_fixlayout(idata);
  104. } else {
  105. mgb_assert(odata.layout().eq_layout(idata.layout()));
  106. }
  107. mgb_assert(odata.layout().is_contiguous());
  108. megdnn_opr()->exec(odata.as_megdnn(), index.as_megdnn(),
  109. input(2)->dev_tensor().as_megdnn(),
  110. intl::get_megdnn_workspace_from_var(output(1)));
  111. }
  112. MGB_IMPL_OPR_GRAD(IndexingSetOneHot) {
  113. SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)};
  114. if (wrt_idx == 0) {
  115. return IndexingSetOneHot::make(og, index, sub.fill_retain_dtype(0),
  116. opr.param()).node();
  117. }
  118. if (wrt_idx == 2) {
  119. return IndexingOneHot::make(og, index, opr.param()).node();
  120. }
  121. return InvalidGrad::make(opr, wrt_idx);
  122. }
  123. size_t IndexingSetOneHot::get_workspace_size_bytes(
  124. const TensorShapeArray &input_shapes,
  125. const TensorShapeArray &output_shapes) const {
  126. return megdnn_opr()->get_workspace_in_bytes(
  127. {input_shapes[0], input(0)->dtype()},
  128. {input_shapes[1], input(1)->dtype()},
  129. {input_shapes[2], input(2)->dtype()}
  130. );
  131. }
  132. /* ==================== IndexingRemap ==================== */
  133. MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingRemap);
  134. MEGDNN_OPR_INIT2(IndexingRemap, "indexing_remap")
  135. void IndexingRemap::init_output_dtype() {
  136. mgb_throw_if(input(1)->dtype() != dtype::Int32(), GraphError,
  137. "IndexingRemap requires map input to be int32");
  138. output(0)->dtype(input(0)->dtype());
  139. }
  140. MGB_IMPL_OPR_GRAD(IndexingRemap) {
  141. if (wrt_idx == 1)
  142. return InvalidGrad::make(opr, wrt_idx);
  143. mgb_assert(wrt_idx == 0 && out_grad[0]);
  144. return IndexingRemapBackward::make(
  145. out_grad[0], opr.input(1), opr.input(0), opr.param()).node();
  146. }
  147. MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingRemapBackward);
  148. MEGDNN_OPR_INIT3(IndexingRemapBackward, "indexing_remap_bwd", 2, false);
  149. /* ================= IndexingMultiAxisVecMegDNNOprHolder ================= */
  150. template<class Opr>
  151. Opr& mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::megdnn_opr(
  152. cg::SingleCNOperatorNodeBase& self) {
  153. auto comp_node = self.comp_node();
  154. if (!m_megdnn_opr || m_megdnn_opr.comp_node() != comp_node) {
  155. m_megdnn_opr = intl::create_megdnn_opr<Opr>(comp_node);
  156. m_megdnn_opr->set_error_tracker(
  157. static_cast<cg::OperatorNodeBase*>(&self));
  158. }
  159. return *m_megdnn_opr;
  160. }
  161. template<class Opr>
  162. void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::register_workspace_infer(
  163. const indexing::IndexDesc &index_desc,
  164. cg::SingleCNOperatorNodeBase &opr, VarNode *data, VarNode *value) {
  165. using namespace cg::static_infer;
  166. auto infer_shape = [this, &index_desc, &opr](
  167. TensorShape &dest, const InpVal &inp) {
  168. size_t axes[TensorShape::MAX_NDIM], nr_axes = 0;
  169. auto ndim = inp.val[0].shape().ndim;
  170. for (auto &&i: reverse_adaptor(index_desc)) {
  171. if (i.idx.node()) {
  172. axes[nr_axes ++] = i.axis.get(ndim);
  173. }
  174. }
  175. if (!nr_axes) {
  176. dest = {0};
  177. } else {
  178. dest = {megdnn_opr(opr).get_workspace_in_bytes(
  179. inp.val[1].shape(), axes, nr_axes)};
  180. }
  181. return true;
  182. };
  183. opr.owner_graph()->static_infer_manager().register_shape_infer(
  184. opr.output(1),
  185. {SourceType::DEP,
  186. {{data, DepType::SHAPE}, {value, DepType::SHAPE}},
  187. infer_shape});
  188. }
  189. template <class Opr>
  190. void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::record_megdnn_opr(
  191. mgb::cg::GraphExecutable::ExecDependencyArray& deps) {
  192. deps.emplace_back(
  193. std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr)));
  194. }
  195. /* ==================== MultiAxisVecFancyIndexingHelper ==================== */
  196. std::pair<const megdnn::IndexingMultiAxisVec::IndexDesc&, bool>
  197. intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc(
  198. size_t inp_ndim, bool warn_all_scalar) {
  199. auto &&index = m_megdnn_index_cache;
  200. index.clear();
  201. bool is_empty_shape = false;
  202. for (auto i: reverse_adaptor(m_input2idxonly_axis_indexer)) {
  203. if (i) {
  204. index.push_back({
  205. i->axis.get(inp_ndim),
  206. i->idx.node()->dev_tensor().as_megdnn()});
  207. is_empty_shape |= index.back().vec.layout.is_empty();
  208. }
  209. }
  210. if (!m_scalar_idx_warn_printed && warn_all_scalar) {
  211. bool all_scalar = true;
  212. for (auto &&i: index) {
  213. if (!i.vec.layout.is_scalar()) {
  214. all_scalar = false;
  215. break;
  216. }
  217. }
  218. if (all_scalar) {
  219. mgb_log_warn("%s{%s}: no vector indexer; consider using Subtensor "
  220. "family for better performance; you can set "
  221. "MGB_THROW_ON_SCALAR_IDX to throw an exception to help "
  222. "tracking the related operator",
  223. cname(), dyn_typeinfo()->name);
  224. mgb_throw_if(MGB_GETENV("MGB_THROW_ON_SCALAR_IDX"),
  225. MegBrainError, "vector-indexing operator used with all "
  226. "scalar indices");
  227. }
  228. // always set m_scalar_idx_warn_printed to be true, so we do not print
  229. // this warning in the future
  230. m_scalar_idx_warn_printed = true;
  231. }
  232. return {index, is_empty_shape};
  233. }
  234. /* ==================== IndexingMultiAxisVecBase ==================== */
  235. template<class Opr>
  236. cg::OperatorNodeBase::NodeProp*
  237. IndexingMultiAxisVecBase<Opr>::do_make_node_prop() const {
  238. auto prop = Super::do_make_node_prop();
  239. // TODO: should also allow input shape is empty if any
  240. // indexer's shape is empty
  241. for (auto i: m_input2idxonly_axis_indexer) {
  242. if (i) {
  243. prop->add_dep_type_existing_var(
  244. i->idx.node(), NodeProp::DepType::VALUE_ALLOW_EMPTY);
  245. }
  246. }
  247. return prop;
  248. }
  249. template <class Opr>
  250. void IndexingMultiAxisVecBase<Opr>::init_output_static_infer_desc() {
  251. using namespace cg::static_infer;
  252. DepVal deps;
  253. // shape inference only needs slices
  254. deps.push_back({input(0), DepType::SHAPE});
  255. // loop in reverse order because megdnn opr needs ascending axes
  256. for (size_t i = m_input2idxonly_axis_indexer.size() - 1; i; -- i) {
  257. if (m_input2idxonly_axis_indexer[i]) {
  258. deps.push_back({input(i), DepType::SHAPE});
  259. }
  260. }
  261. size_t inp_interval_start = deps.size();
  262. for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++ i) {
  263. if (!m_input2idxonly_axis_indexer[i]) {
  264. deps.push_back({input(i), DepType::VALUE});
  265. }
  266. }
  267. auto infer_shape = [this, inp_interval_start](
  268. TensorShape &dest, const InpVal &inp) {
  269. auto &&ishp = inp.val[0].shape();
  270. auto subspec = fancy_indexing_make_sub_spec(
  271. {ishp, input(0)->dtype()}, inp, inp_interval_start);
  272. dest = subspec.layout();
  273. typename Opr::IndexDescLayoutOnly index_layout;
  274. size_t indexer_pos = 1;
  275. for (auto i: reverse_adaptor(m_input2idxonly_axis_indexer)) {
  276. if (i) {
  277. index_layout.push_back({i->axis.get(dest.ndim),
  278. {inp.val.at(indexer_pos ++).shape(), dtype::Int32()}});
  279. }
  280. }
  281. mgb_assert(indexer_pos == inp_interval_start);
  282. if (!index_layout.empty()) {
  283. // index_layout is empty if all indices are intervals
  284. TensorLayout tmp;
  285. Opr::deduce_layout(
  286. {dest, input(0)->dtype()}, index_layout, tmp);
  287. dest = tmp;
  288. }
  289. return true;
  290. };
  291. owner_graph()->static_infer_manager().register_shape_infer(
  292. output(0), {SourceType::DEP, deps, infer_shape});
  293. this->register_workspace_infer(index_desc(), *this, input(0), output(0));
  294. }
  295. template <class Opr>
  296. void IndexingMultiAxisVecBase<Opr>::record_execute_deps(
  297. mgb::cg::GraphExecutable::ExecDependencyArray& deps) {
  298. this->record_megdnn_opr(deps);
  299. }
  300. namespace {
  301. template <class Opr>
  302. struct ShouldWarnOnScalarIndexer {
  303. static constexpr bool val = false;
  304. };
  305. #define WARN(opr) \
  306. template <> \
  307. struct ShouldWarnOnScalarIndexer<megdnn::opr> { \
  308. static constexpr bool val = true; \
  309. }
  310. WARN(IndexingMultiAxisVec);
  311. WARN(IndexingSetMultiAxisVec);
  312. WARN(IndexingIncrMultiAxisVec);
  313. #undef WARN
  314. } // anonymous namespace
  315. template <class Opr>
  316. void IndexingMultiAxisVecBase<Opr>::scn_do_execute() {
  317. auto inp = input(0)->dev_tensor();
  318. inp = inp.sub(fancy_indexing_make_sub_spec(inp.layout()));
  319. auto &&index_desc = make_megdnn_index_desc(
  320. inp.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val);
  321. auto &&odev = output(0)->dev_tensor();
  322. if (index_desc.first.empty()) {
  323. odev.copy_from_fixlayout(inp);
  324. } else {
  325. if (!index_desc.second) {
  326. // only call megdnn exec if result is not empty
  327. this->megdnn_opr(*this).exec(
  328. inp.as_megdnn(), index_desc.first, odev.as_megdnn(),
  329. intl::get_megdnn_workspace_from_var(output(1)));
  330. } else {
  331. mgb_assert(odev.empty());
  332. }
  333. }
  334. }
  335. /* ==================== IndexingModifyMultiAxisVecHelper ==================== */
  336. template<class Opr>
  337. void intl::IndexingModifyMultiAxisVecHelper<Opr>::
  338. init_output_static_infer_desc() {
  339. using namespace cg::static_infer;
  340. this->owner_graph()->static_infer_manager().register_shape_infer(
  341. this->output(0), ShapeInferDesc::make_identity(this->input(0)));
  342. this->register_workspace_infer(index_desc(), *this, input(0), input(1));
  343. }
  344. template<class Opr>
  345. void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
  346. auto inp = this->fancy_indexing_get_tensors_for_modify_in_scn_do_execute();
  347. auto index_desc = this->make_megdnn_index_desc(
  348. inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val);
  349. if (index_desc.second){
  350. mgb_assert(inp.second.shape().is_empty());
  351. return;
  352. }
  353. if (index_desc.first.empty()) {
  354. using IMT = IndexingModifyType;
  355. static constexpr auto modify_type =
  356. IndexingModifyTypeGetter<Opr>::value;
  357. switch (modify_type) {
  358. case IMT::SET: {
  359. inp.first.copy_from_fixlayout(inp.second);
  360. break;
  361. } case IMT::INCR: {
  362. megdnn::AddUpdate* add_update = intl::get_megdnn_global_opr<
  363. megdnn::AddUpdate>(comp_node());
  364. add_update->exec(inp.first.as_megdnn(), inp.second.as_megdnn());
  365. break;
  366. } default:
  367. mgb_throw(MegBrainError, "bad modify type");
  368. }
  369. } else {
  370. this->megdnn_opr(*this).exec(
  371. inp.first.as_megdnn(), inp.second.as_megdnn(),
  372. index_desc.first,
  373. intl::get_megdnn_workspace_from_var(output(1)));
  374. }
  375. }
  376. template<class Opr>
  377. cg::OperatorNodeBase::NodeProp*
  378. intl::IndexingModifyMultiAxisVecHelper<Opr>::do_make_node_prop() const {
  379. auto prop = Super::do_make_node_prop();
  380. using DT = NodeProp::DepType;
  381. // TODO: should also allow input shape is empty if any
  382. // indexer's shape is empty
  383. prop->add_dep_type_existing_var(input(1), DT::VALUE_ALLOW_EMPTY);
  384. for (auto i: m_input2idxonly_axis_indexer) {
  385. if (i) {
  386. prop->add_dep_type_existing_var(
  387. i->idx.node(), DT::VALUE_ALLOW_EMPTY);
  388. }
  389. }
  390. return prop;
  391. }
  392. template<class Opr>
  393. void intl::IndexingModifyMultiAxisVecHelper<Opr>::
  394. add_input_layout_constraint() {
  395. auto check_cont1 = [](const TensorLayout &ly) {
  396. return ly.collapse_contiguous().ndim == 1;
  397. };
  398. this->input(1)->add_layout_constraint(check_cont1);
  399. }
  400. /* ==================== MultiAxisVec misc ==================== */
  401. MGB_IMPL_FANCY_INDEXING_OPR_GET(
  402. IndexingMultiAxisVec, "indexing_multi_axis_vec", false,
  403. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  404. );
  405. MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
  406. IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false);
  407. MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
  408. IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false);
  409. MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
  410. if (wrt_idx)
  411. return InvalidGrad::make(opr, wrt_idx);
  412. return IndexingIncrMultiAxisVec::make(
  413. SymbolVar{opr.input(0)}.fill_retain_dtype(0),
  414. out_grad.at(0), opr.index_desc()).node();
  415. }
  416. MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) {
  417. if (wrt_idx >= 2)
  418. return InvalidGrad::make(opr, wrt_idx);
  419. if (wrt_idx == 0) {
  420. return IndexingSetMultiAxisVec::make(out_grad.at(0),
  421. SymbolVar{opr.input(1)}.fill_retain_dtype(0),
  422. opr.index_desc()).node();
  423. }
  424. return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node();
  425. }
  426. MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) {
  427. if (wrt_idx >= 2)
  428. return InvalidGrad::make(opr, wrt_idx);
  429. if (wrt_idx == 0) {
  430. return out_grad.at(0);
  431. }
  432. return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node();
  433. }
  434. /* ============================= Mesh Indexing ============================ */
  435. MGB_IMPL_FANCY_INDEXING_OPR_GET(
  436. MeshIndexing, "mesh_indexing", false,
  437. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
  438. MGB_IMPL_FANCY_INDEXING_OPR_GET(
  439. BatchedMeshIndexing, "batched_mesh_indexing", false,
  440. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
  441. MGB_IMPL_OPR_GRAD(MeshIndexing) {
  442. if (wrt_idx != 0) {
  443. return InvalidGrad::make(opr, wrt_idx);
  444. }
  445. return IncrMeshIndexing::make(
  446. SymbolVar{opr.input(0)}.fill_retain_dtype(0), out_grad.at(0),
  447. opr.index_desc())
  448. .node();
  449. }
  450. MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) {
  451. if (wrt_idx != 0) {
  452. return InvalidGrad::make(opr, wrt_idx);
  453. }
  454. return BatchedIncrMeshIndexing::make(
  455. SymbolVar{opr.input(0)}.fill_retain_dtype(0), out_grad.at(0),
  456. opr.index_desc())
  457. .node();
  458. }
  459. /* ========================= IncrMeshIndexing ========================= */
  460. MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing",
  461. false);
  462. MGB_IMPL_OPR_GRAD(IncrMeshIndexing) {
  463. if (wrt_idx > 2) {
  464. return opr::InvalidGrad::make(opr, wrt_idx);
  465. }
  466. if (wrt_idx == 0) {
  467. return out_grad.at(0);
  468. }
  469. return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
  470. }
  471. MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing,
  472. "batched_incr_mesh_indexing", false);
  473. MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) {
  474. if (wrt_idx > 2) {
  475. return opr::InvalidGrad::make(opr, wrt_idx);
  476. }
  477. if (wrt_idx == 0) {
  478. return out_grad.at(0);
  479. }
  480. return BatchedMeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
  481. }
  482. /* ======================== SetMeshIndexing =========================== */
  483. MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false);
  484. MGB_IMPL_OPR_GRAD(SetMeshIndexing) {
  485. if (wrt_idx >= 2) {
  486. return opr::InvalidGrad::make(opr, wrt_idx);
  487. }
  488. if (wrt_idx == 0) {
  489. return SetMeshIndexing::make(
  490. out_grad.at(0),
  491. SymbolVar{opr.input(1)}.fill_retain_dtype(0),
  492. opr.index_desc())
  493. .node();
  494. } else {
  495. return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
  496. }
  497. }
  498. MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing,
  499. "batched_set_mesh_indexing", false);
  500. MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) {
  501. if (wrt_idx > 2) {
  502. return opr::InvalidGrad::make(opr, wrt_idx);
  503. }
  504. if (wrt_idx == 0) {
  505. return BatchedSetMeshIndexing::make(
  506. out_grad.at(0),
  507. SymbolVar{opr.input(1)}.fill_retain_dtype(0),
  508. opr.index_desc())
  509. .node();
  510. } else {
  511. return BatchedMeshIndexing::make(out_grad.at(0), opr.index_desc())
  512. .node();
  513. }
  514. }
  515. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台