You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

specializations.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. /**
  2. * \file imperative/src/impl/ops/autogen.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. // FIXME: split this file into separate files for each specialized op
  12. #include "megbrain/imperative/ops/autogen.h"
  13. #include "megbrain/opr/dnn/convolution.h"
  14. #include "megbrain/opr/dnn/adaptive_pooling.h"
  15. #include "megbrain/opr/dnn/fake_quant.h"
  16. #include "megbrain/opr/dnn/tqt.h"
  17. #include "megbrain/opr/dnn/pooling.h"
  18. #include "megbrain/opr/dnn/local.h"
  19. #include "megbrain/opr/dnn/roi_align.h"
  20. #include "megbrain/opr/dnn/roi_pooling.h"
  21. #include "megbrain/opr/basic_arith.h"
  22. #include "megbrain/opr/blas.h"
  23. #include "megbrain/opr/imgproc.h"
  24. #include "megbrain/opr/indexing.h"
  25. #include "megbrain/opr/io.h"
  26. #include "megbrain/opr/misc.h"
  27. #include "megbrain/opr/nn_int.h"
  28. #include "megbrain/opr/rand.h"
  29. #include "megbrain/opr/tensor_gen.h"
  30. #include "megbrain/opr/tensor_manip.h"
  31. #include "megbrain/opr/utility.h"
  32. #include "../op_trait.h"
  33. namespace mgb::imperative {
  34. namespace { namespace convolution {
  35. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  36. auto* node = &node_->cast_final_safe<opr::Convolution>();
  37. return Convolution::make(node->param(), node->execution_policy());
  38. }
  39. auto apply_on_var_node(
  40. const OpDef& def,
  41. const VarNodeArray& inputs) {
  42. auto&& conv = static_cast<const Convolution&>(def);
  43. return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy());
  44. }
  45. OP_TRAIT_REG(Convolution, Convolution, opr::Convolution)
  46. .make_from_op_node(make_from_op_node)
  47. .apply_on_var_node(apply_on_var_node)
  48. .fallback();
  49. }} // convolution
  50. namespace { namespace convolution_backward_data {
  51. auto apply_on_var_node(
  52. const OpDef& def,
  53. const VarNodeArray& inputs) {
  54. auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
  55. cg::OperatorNodeConfig config;
  56. if (inputs.size() == 2) {
  57. return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
  58. } else {
  59. mgb_assert(inputs.size() == 3);
  60. return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
  61. }
  62. }
  63. OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData)
  64. .apply_on_var_node(apply_on_var_node)
  65. .fallback();
  66. }} // convolution_backward_data
  67. namespace { namespace dimshuffle {
  68. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  69. auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
  70. std::vector<int> pattern(node->param().pattern_len);
  71. for (size_t i = 0; i < node->param().pattern_len; ++ i) {
  72. pattern[i] = node->param().pattern[i];
  73. }
  74. return Dimshuffle::make(pattern);
  75. }
  76. auto apply_on_var_node(
  77. const OpDef& def,
  78. const VarNodeArray& inputs) {
  79. auto&& ds = static_cast<const Dimshuffle&>(def);
  80. return opr::Dimshuffle::make(inputs[0], ds.pattern);
  81. }
  82. OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
  83. .make_from_op_node(make_from_op_node)
  84. .apply_on_var_node(apply_on_var_node)
  85. .fallback();
  86. }} // dimshuffle
  87. namespace { namespace add_axis {
  88. auto apply_on_var_node(
  89. const OpDef& def,
  90. const VarNodeArray& inputs) {
  91. auto&& add_axis = static_cast<const AddAxis&>(def);
  92. using Desc = opr::AxisAddRemove::AxisDesc;
  93. std::vector<Desc> param;
  94. for (auto&& i : add_axis.axis) {
  95. param.push_back(Desc::make_add(i));
  96. }
  97. return opr::AxisAddRemove::make(inputs[0], param);
  98. }
  99. OP_TRAIT_REG(AddAxis, AddAxis)
  100. .apply_on_var_node(apply_on_var_node)
  101. .fallback();
  102. }} // add_axis
  103. namespace { namespace remove_axis {
  104. auto apply_on_var_node(
  105. const OpDef& def,
  106. const VarNodeArray& inputs) {
  107. auto&& remove_axis = static_cast<const RemoveAxis&>(def);
  108. using Desc = opr::AxisAddRemove::AxisDesc;
  109. std::vector<Desc> param;
  110. for (auto&& i : remove_axis.axis) {
  111. param.push_back(Desc::make_remove(i));
  112. }
  113. return opr::AxisAddRemove::make(inputs[0], param);
  114. }
  115. OP_TRAIT_REG(RemoveAxis, RemoveAxis)
  116. .apply_on_var_node(apply_on_var_node)
  117. .fallback();
  118. }} // remove_axis
  119. namespace { namespace top_k {
  120. auto apply_on_var_node(
  121. const OpDef& def,
  122. const VarNodeArray& inputs) {
  123. auto&& topk = static_cast<const TopK&>(def);
  124. return opr::TopK::make(inputs[0], inputs[1], topk.param())[0]
  125. .node()->owner_opr();
  126. }
  127. OP_TRAIT_REG(TopK, TopK)
  128. .apply_on_var_node(apply_on_var_node)
  129. .fallback();
  130. }} // top_k
  131. namespace { namespace reduce {
  132. auto apply_on_var_node(
  133. const OpDef& def,
  134. const VarNodeArray& inputs) {
  135. auto&& reduce = static_cast<const Reduce&>(def);
  136. if (inputs.size() > 1) {
  137. return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]);
  138. } else {
  139. return opr::Reduce::make(inputs[0], reduce.param());
  140. }
  141. }
  142. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  143. auto* node = &node_->cast_final_safe<opr::Reduce>();
  144. return Reduce::make(node->param());
  145. }
  146. OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
  147. .make_from_op_node(make_from_op_node)
  148. .apply_on_var_node(apply_on_var_node)
  149. .fallback();
  150. }} // reduce
  151. namespace { namespace adaptive_pooling {
  152. auto apply_on_var_node(
  153. const OpDef& def,
  154. const VarNodeArray& inputs) {
  155. auto&& pool = static_cast<const AdaptivePooling&>(def);
  156. return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param());
  157. }
  158. OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
  159. .apply_on_var_node(apply_on_var_node)
  160. .fallback();
  161. }} // adaptive_pooling
  162. namespace { namespace conv_bias {
  163. auto apply_on_var_node(
  164. const OpDef& def,
  165. const VarNodeArray& inputs) {
  166. auto&& conv = static_cast<const ConvBias&>(def);
  167. cg::OperatorNodeConfig config{conv.dtype};
  168. if (inputs.size() == 2) {
  169. return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
  170. } else if (inputs.size() == 3) {
  171. return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
  172. } else if (inputs.size() == 4) {
  173. return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config);
  174. }
  175. mgb_assert(0);
  176. }
  177. OP_TRAIT_REG(ConvBias, ConvBias)
  178. .apply_on_var_node(apply_on_var_node)
  179. .fallback();
  180. }} // conv_bias
  181. namespace { namespace batch_conv_bias {
  182. auto apply_on_var_node(
  183. const OpDef& def,
  184. const VarNodeArray& inputs) {
  185. auto&& conv = static_cast<const BatchConvBias&>(def);
  186. cg::OperatorNodeConfig config{conv.dtype};
  187. if (inputs.size() == 2) {
  188. return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
  189. } else if (inputs.size() == 3) {
  190. return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
  191. } else if (inputs.size() == 4) {
  192. return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config);
  193. }
  194. mgb_assert(0);
  195. }
  196. OP_TRAIT_REG(BatchConvBias, BatchConvBias)
  197. .apply_on_var_node(apply_on_var_node)
  198. .fallback();
  199. }} // batch_conv_bias
  200. namespace { namespace pooling {
  201. auto apply_on_var_node(
  202. const OpDef& def,
  203. const VarNodeArray& inputs) {
  204. auto&& pool = static_cast<const Pooling&>(def);
  205. return opr::Pooling::make(inputs[0], pool.param());
  206. }
  207. OP_TRAIT_REG(Pooling, Pooling)
  208. .apply_on_var_node(apply_on_var_node)
  209. .fallback();
  210. }} // pooling
  211. namespace { namespace matrix_mul {
  212. auto apply_on_var_node(
  213. const OpDef& def,
  214. const VarNodeArray& inputs) {
  215. auto&& matmul = static_cast<const MatrixMul&>(def);
  216. mgb_assert(inputs.size() == 2);
  217. return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(),
  218. matmul.policy());
  219. }
  220. OP_TRAIT_REG(MatrixMul, MatrixMul)
  221. .apply_on_var_node(apply_on_var_node)
  222. .fallback();
  223. }} // matrix_mul
  224. namespace { namespace batched_matrix_mul {
  225. auto apply_on_var_node(
  226. const OpDef& def,
  227. const VarNodeArray& inputs) {
  228. auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
  229. mgb_assert(inputs.size() == 2);
  230. return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(),
  231. matmul.policy());
  232. }
  233. OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
  234. .apply_on_var_node(apply_on_var_node)
  235. .fallback();
  236. }} // batched_matrix_mul
  237. namespace { namespace dot {
  238. auto apply_on_var_node(
  239. const OpDef&,
  240. const VarNodeArray& inputs) {
  241. mgb_assert(inputs.size() == 2);
  242. return opr::Dot::make(inputs[0], inputs[1]);
  243. }
  244. OP_TRAIT_REG(Dot, Dot)
  245. .apply_on_var_node(apply_on_var_node)
  246. .fallback();
  247. }} // dot
  248. namespace { namespace argsort {
  249. auto apply_on_var_node(
  250. const OpDef& def,
  251. const VarNodeArray& inputs) {
  252. auto&& argsort = static_cast<const Argsort&>(def);
  253. return opr::Argsort::make(inputs[0], argsort.param());
  254. }
  255. OP_TRAIT_REG(Argsort, Argsort)
  256. .apply_on_var_node(apply_on_var_node)
  257. .fallback();
  258. }} // argsort
  259. namespace { namespace argmax {
  260. auto apply_on_var_node(
  261. const OpDef& def,
  262. const VarNodeArray& inputs) {
  263. auto&& argmax = static_cast<const Argmax&>(def);
  264. return opr::Argmax::make(inputs[0], argmax.param());
  265. }
  266. OP_TRAIT_REG(Argmax, Argmax)
  267. .apply_on_var_node(apply_on_var_node)
  268. .fallback();
  269. }} // argmax
  270. namespace { namespace argmin {
  271. auto apply_on_var_node(
  272. const OpDef& def,
  273. const VarNodeArray& inputs) {
  274. auto&& argmin = static_cast<const Argmin&>(def);
  275. return opr::Argmin::make(inputs[0], argmin.param());
  276. }
  277. OP_TRAIT_REG(Argmin, Argmin)
  278. .apply_on_var_node(apply_on_var_node)
  279. .fallback();
  280. }} // argmin
  281. namespace { namespace warp_perspective {
  282. auto apply_on_var_node(
  283. const OpDef& def,
  284. const VarNodeArray& inputs) {
  285. auto&& warp = static_cast<const WarpPerspective&>(def);
  286. if (inputs.size() == 3) {
  287. return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param());
  288. } else {
  289. mgb_assert(inputs.size() == 4);
  290. return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param());
  291. }
  292. }
  293. OP_TRAIT_REG(WarpPerspective, WarpPerspective)
  294. .apply_on_var_node(apply_on_var_node)
  295. .fallback();
  296. }} // warp_perspective
  297. namespace { namespace group_local {
  298. auto apply_on_var_node(
  299. const OpDef& def,
  300. const VarNodeArray& inputs) {
  301. auto&& local = static_cast<const GroupLocal&>(def);
  302. mgb_assert(inputs.size() == 2);
  303. return opr::GroupLocal::make(inputs[0], inputs[1], local.param());
  304. }
  305. OP_TRAIT_REG(GroupLocal, GroupLocal)
  306. .apply_on_var_node(apply_on_var_node)
  307. .fallback();
  308. }} // group_local
  309. namespace { namespace indexing_one_hot {
  310. auto apply_on_var_node(
  311. const OpDef& def,
  312. const VarNodeArray& inputs) {
  313. auto&& op = static_cast<const IndexingOneHot&>(def);
  314. mgb_assert(inputs.size() == 2);
  315. return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param());
  316. }
  317. OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
  318. .apply_on_var_node(apply_on_var_node)
  319. .fallback();
  320. }} // indexing_one_hot
  321. namespace { namespace indexing_set_one_hot {
  322. auto apply_on_var_node(
  323. const OpDef& def,
  324. const VarNodeArray& inputs) {
  325. auto&& op = static_cast<const IndexingSetOneHot&>(def);
  326. mgb_assert(inputs.size() == 3);
  327. return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param());
  328. }
  329. OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
  330. .apply_on_var_node(apply_on_var_node)
  331. .fallback();
  332. }} // indexing_set_one_hot
  333. namespace { namespace typecvt {
  334. auto apply_on_var_node(
  335. const OpDef& def,
  336. const VarNodeArray& inputs) {
  337. auto&& op = static_cast<const TypeCvt&>(def);
  338. mgb_assert(inputs.size() == 1);
  339. return opr::TypeCvt::make(inputs[0], op.dtype);
  340. }
  341. OP_TRAIT_REG(TypeCvt, TypeCvt)
  342. .apply_on_var_node(apply_on_var_node)
  343. .fallback();
  344. }} // typecvt
  345. namespace { namespace concat {
  346. auto apply_on_var_node(
  347. const OpDef& def,
  348. const VarNodeArray& inputs) {
  349. auto&& op = static_cast<const Concat&>(def);
  350. cg::OperatorNodeConfig config{op.comp_node};
  351. return opr::Concat::make(inputs, op.axis, config);
  352. }
  353. OP_TRAIT_REG(Concat, Concat)
  354. .apply_on_var_node(apply_on_var_node)
  355. .fallback();
  356. }} // concat
  357. namespace { namespace copy {
  358. auto apply_on_var_node(
  359. const OpDef& def,
  360. const VarNodeArray& inputs) {
  361. auto&& op = static_cast<const Copy&>(def);
  362. mgb_assert(inputs.size() == 1);
  363. cg::OperatorNodeConfig config{op.comp_node};
  364. return opr::Copy::make(inputs[0], config);
  365. }
  366. OP_TRAIT_REG(Copy, Copy)
  367. .apply_on_var_node(apply_on_var_node)
  368. .fallback();
  369. }} // copy
  370. namespace { namespace identity {
  371. auto apply_on_var_node(
  372. const OpDef&,
  373. const VarNodeArray& inputs) {
  374. mgb_assert(inputs.size() == 1);
  375. return opr::Identity::make(inputs[0]);
  376. }
  377. OP_TRAIT_REG(Identity, Identity)
  378. .apply_on_var_node(apply_on_var_node)
  379. .fallback();
  380. }} // identity
  381. namespace { namespace assert_equal {
  382. auto apply_on_var_node(
  383. const OpDef& def,
  384. const VarNodeArray& inputs) {
  385. auto&& op = static_cast<const AssertEqual&>(def);
  386. mgb_assert(inputs.size() == 2);
  387. return opr::AssertEqual::make(inputs[0],inputs[1],op.param());
  388. }
  389. OP_TRAIT_REG(AssertEqual, AssertEqual)
  390. .apply_on_var_node(apply_on_var_node)
  391. .fallback();
  392. }}
  393. namespace { namespace uniform_rng {
  394. auto apply_on_var_node(
  395. const OpDef& def,
  396. const VarNodeArray& inputs) {
  397. auto&& op = static_cast<const UniformRNG&>(def);
  398. mgb_assert(inputs.size() == 1);
  399. return opr::UniformRNG::make(inputs[0], op.param());
  400. }
  401. OP_TRAIT_REG(UniformRNG, UniformRNG)
  402. .apply_on_var_node(apply_on_var_node)
  403. .fallback();
  404. }} // uniform_rng
  405. namespace { namespace gaussian_rng {
  406. auto apply_on_var_node(
  407. const OpDef& def,
  408. const VarNodeArray& inputs) {
  409. auto&& op = static_cast<const GaussianRNG&>(def);
  410. mgb_assert(inputs.size() == 1);
  411. return opr::GaussianRNG::make(inputs[0], op.param());
  412. }
  413. OP_TRAIT_REG(GaussianRNG, GaussianRNG)
  414. .apply_on_var_node(apply_on_var_node)
  415. .fallback();
  416. }} // gaussian_rng
  417. namespace { namespace roi_align {
  418. VarNodeArray apply_on_var_node(
  419. const OpDef& def,
  420. const VarNodeArray& inputs) {
  421. auto&& op = static_cast<const ROIAlign&>(def);
  422. mgb_assert(inputs.size() == 2);
  423. auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param()).node()->owner_opr();
  424. return {opr->output(0), opr->output(1)};
  425. }
  426. OP_TRAIT_REG(ROIAlign, ROIAlign)
  427. .apply_on_var_node(apply_on_var_node)
  428. .fallback();
  429. }} // roi_align
  430. #if MGB_CUDA
  431. namespace { namespace nvof {
  432. auto apply_on_var_node(
  433. const OpDef& def,
  434. const VarNodeArray& inputs) {
  435. auto&& op = static_cast<const NvOf&>(def);
  436. mgb_assert(inputs.size() == 1);
  437. return opr::NvOf::make(inputs[0], op.param());
  438. }
  439. OP_TRAIT_REG(NvOf, NvOf)
  440. .apply_on_var_node(apply_on_var_node)
  441. .fallback();
  442. }} // nvof
  443. #endif
  444. namespace { namespace linspace {
  445. auto apply_on_var_node(
  446. const OpDef& def,
  447. const VarNodeArray& inputs) {
  448. auto&& op = static_cast<const Linspace&>(def);
  449. mgb_assert(inputs.size() == 3);
  450. cg::OperatorNodeConfig config{op.comp_node};
  451. return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
  452. }
  453. OP_TRAIT_REG(Linspace, Linspace)
  454. .apply_on_var_node(apply_on_var_node)
  455. .fallback();
  456. }} // linspace
  457. namespace { namespace eye {
  458. auto apply_on_var_node(
  459. const OpDef& def,
  460. const VarNodeArray& inputs) {
  461. auto&& op = static_cast<const Eye&>(def);
  462. mgb_assert(inputs.size() == 1);
  463. cg::OperatorNodeConfig config{op.comp_node};
  464. opr::Eye::Param param{op.k, op.dtype.enumv()};
  465. return opr::Eye::make(inputs[0], param, config);
  466. }
  467. OP_TRAIT_REG(Eye, Eye)
  468. .apply_on_var_node(apply_on_var_node)
  469. .fallback();
  470. }} // eye
  471. namespace { namespace roi_pooling {
  472. VarNodeArray apply_on_var_node(
  473. const OpDef& def,
  474. const VarNodeArray& inputs) {
  475. auto&& op = static_cast<const ROIPooling&>(def);
  476. mgb_assert(inputs.size() == 3);
  477. auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()).node()->owner_opr();
  478. return {opr->output(0), opr->output(1)};
  479. }
  480. OP_TRAIT_REG(ROIPooling, ROIPooling)
  481. .apply_on_var_node(apply_on_var_node)
  482. .fallback();
  483. }} // roi_pooling
  484. namespace { namespace remap {
  485. auto apply_on_var_node(
  486. const OpDef& def,
  487. const VarNodeArray& inputs) {
  488. auto&& op = static_cast<const Remap&>(def);
  489. mgb_assert(inputs.size() == 2);
  490. return opr::Remap::make(inputs[0], inputs[1], op.param());
  491. }
  492. OP_TRAIT_REG(Remap, Remap)
  493. .apply_on_var_node(apply_on_var_node)
  494. .fallback();
  495. }} // remap
  496. namespace {
  497. auto get_index(
  498. const VarNodeArray& inputs, size_t vidx,
  499. const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
  500. size_t length = mask.size();
  501. opr::Subtensor::IndexDesc ret(length);
  502. for (size_t i = 0; i < length; ++ i) {
  503. auto&& [axis, begin, end, step, idx] = mask[i];
  504. ret[i].axis = axis;
  505. if (idx) {
  506. ret[i].idx = inputs[vidx++];
  507. } else {
  508. mgb_assert(begin || end || step);
  509. if (begin) ret[i].begin = inputs[vidx++];
  510. if (end) ret[i].end = inputs[vidx++];
  511. if (step) ret[i].step = inputs[vidx++];
  512. }
  513. }
  514. mgb_assert(vidx == inputs.size());
  515. return ret;
  516. }
  517. #define IN1 inputs[0]
  518. #define IN2 inputs[0], inputs[1]
  519. #define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \
  520. namespace NAME##_impl { \
  521. auto apply_on_var_node( \
  522. const OpDef& def, \
  523. const VarNodeArray& inputs) { \
  524. auto&& op = static_cast<const NAME&>(def); \
  525. return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \
  526. } \
  527. OP_TRAIT_REG(NAME, NAME) \
  528. .apply_on_var_node(apply_on_var_node) \
  529. .fallback(); \
  530. }
  531. FANCY_INDEXING_IMPL(Subtensor, 1)
  532. FANCY_INDEXING_IMPL(SetSubtensor, 2)
  533. FANCY_INDEXING_IMPL(IncrSubtensor, 2)
  534. FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
  535. FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
  536. FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
  537. FANCY_INDEXING_IMPL(MeshIndexing, 1)
  538. FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
  539. FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
  540. FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
  541. FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
  542. FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)
  543. #undef FANCY_INDEXING_IMPL
  544. #undef IN1
  545. #undef IN2
  546. } // anonymous namespace
  547. namespace { namespace fake_quant {
  548. auto apply_on_var_node(
  549. const OpDef& def,
  550. const VarNodeArray& inputs) {
  551. auto&& op = static_cast<const FakeQuant&>(def);
  552. mgb_assert(inputs.size() == 3);
  553. return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param());
  554. }
  555. OP_TRAIT_REG(FakeQuant, FakeQuant)
  556. .apply_on_var_node(apply_on_var_node)
  557. .fallback();
  558. }} // fake_quant
  559. namespace { namespace tqt {
  560. auto apply_on_var_node(
  561. const OpDef& def,
  562. const VarNodeArray& inputs) {
  563. auto&& op = static_cast<const TQT&>(def);
  564. mgb_assert(inputs.size() == 2);
  565. return opr::TQT::make(inputs[0], inputs[1], op.param());
  566. }
  567. OP_TRAIT_REG(TQT, TQT)
  568. .apply_on_var_node(apply_on_var_node)
  569. .fallback();
  570. }} // tqt
  571. namespace { namespace elemwise_multi_type {
  572. auto apply_on_var_node(
  573. const OpDef& def,
  574. const VarNodeArray& inputs) {
  575. auto&& op = static_cast<const ElemwiseMultiType&>(def);
  576. OperatorNodeConfig config{op.dtype};
  577. return opr::ElemwiseMultiType::make(inputs, op.param(), config);
  578. }
  579. OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
  580. .apply_on_var_node(apply_on_var_node)
  581. .fallback();
  582. }} // elemwise_multi_type
  583. namespace { namespace svd {
  584. auto apply_on_var_node(
  585. const OpDef& def,
  586. const VarNodeArray& inputs) {
  587. auto&& op = static_cast<const SVD&>(def);
  588. mgb_assert(inputs.size() == 1);
  589. return opr::SVD::make(inputs[0], op.param())[0].node()->owner_opr()->usable_output();
  590. }
  591. OP_TRAIT_REG(SVD, SVD)
  592. .apply_on_var_node(apply_on_var_node)
  593. .fallback();
  594. }} // svd
  595. } // namespace mgb::imperative

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台