You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

specializations.cpp 33 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895
  1. /**
  2. * \file imperative/src/impl/ops/specialzations.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. // FIXME: split this file into separate files for each specialized op
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/opr/basic_arith.h"
  15. #include "megbrain/opr/blas.h"
  16. #include "megbrain/opr/dnn/adaptive_pooling.h"
  17. #include "megbrain/opr/dnn/convolution.h"
  18. #include "megbrain/opr/dnn/correlation.h"
  19. #include "megbrain/opr/dnn/fake_quant.h"
  20. #include "megbrain/opr/dnn/images2neibs.h"
  21. #include "megbrain/opr/dnn/layer_norm.h"
  22. #include "megbrain/opr/dnn/local.h"
  23. #include "megbrain/opr/dnn/lrn.h"
  24. #include "megbrain/opr/dnn/lsq.h"
  25. #include "megbrain/opr/dnn/pooling.h"
  26. #include "megbrain/opr/dnn/roi_align.h"
  27. #include "megbrain/opr/dnn/roi_pooling.h"
  28. #include "megbrain/opr/dnn/sliding_window_transpose.h"
  29. #include "megbrain/opr/dnn/tqt.h"
  30. #include "megbrain/opr/imgproc.h"
  31. #include "megbrain/opr/indexing.h"
  32. #include "megbrain/opr/io.h"
  33. #include "megbrain/opr/misc.h"
  34. #include "megbrain/opr/nn_int.h"
  35. #include "megbrain/opr/rand.h"
  36. #include "megbrain/opr/tensor_gen.h"
  37. #include "megbrain/opr/tensor_manip.h"
  38. #include "megbrain/opr/utility.h"
  39. #include "../blob_manager_impl.h"
  40. #include "../op_trait.h"
  41. namespace mgb::imperative {
  42. namespace {
  43. namespace dimshuffle {
  44. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  45. auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
  46. std::vector<int> pattern(node->param().pattern_len);
  47. for (size_t i = 0; i < node->param().pattern_len; ++i) {
  48. pattern[i] = node->param().pattern[i];
  49. }
  50. return Dimshuffle::make(pattern);
  51. }
  52. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  53. auto&& ds = static_cast<const Dimshuffle&>(def);
  54. OperatorNodeConfig config{ds.make_name()};
  55. return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
  56. }
  57. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  58. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  59. auto&& ds = static_cast<const Dimshuffle&>(def);
  60. mgb_assert(
  61. ds.pattern.size() <= TensorShape::MAX_NDIM,
  62. "Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM);
  63. size_t nr_inp = inputs.size();
  64. mgb_assert(nr_inp == 1, "Dimshuffle expects 1 inputs; got %lu actually", nr_inp);
  65. auto&& src = inputs[0];
  66. TensorShape out_shape;
  67. if (src.layout.ndim == 0) {
  68. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
  69. }
  70. size_t pattern_ndim = *std::max_element(ds.pattern.begin(), ds.pattern.end()) + 1;
  71. mgb_assert(
  72. src.layout.ndim == pattern_ndim,
  73. "input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", pattern_ndim,
  74. src.layout.ndim);
  75. size_t idx = 0;
  76. bool input_used[TensorLayout::MAX_NDIM] = {0};
  77. for (auto i : ds.pattern) {
  78. if (i < 0) {
  79. out_shape[idx] = 1;
  80. } else {
  81. input_used[i] = true;
  82. out_shape[idx] = src.layout.shape[i];
  83. }
  84. ++idx;
  85. }
  86. for (size_t i = 0; i < pattern_ndim; ++i) {
  87. mgb_assert(
  88. input_used[i] || src.layout.shape[i] == 1,
  89. "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
  90. src.layout.megdnn::TensorShape::to_string().c_str(), i);
  91. }
  92. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
  93. }
  94. SmallVector<TensorPtr> apply_on_physical_tensor(
  95. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  96. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  97. auto&& ds = static_cast<const Dimshuffle&>(def);
  98. mgb_assert(
  99. ds.pattern.size() <= TensorShape::MAX_NDIM,
  100. "Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM);
  101. size_t nr_inp = inputs.size();
  102. mgb_assert(nr_inp == 1, "Dimshuffle expects 1 inputs; got %lu actually", nr_inp);
  103. auto&& src = inputs[0];
  104. auto inp_layout = src->layout();
  105. size_t pattern_ndim = *std::max_element(ds.pattern.begin(), ds.pattern.end()) + 1;
  106. mgb_assert(
  107. inp_layout.ndim == pattern_ndim,
  108. "input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", pattern_ndim,
  109. inp_layout.ndim);
  110. TensorLayout out_layout{inp_layout.dtype};
  111. out_layout.ndim = ds.pattern.size();
  112. size_t idx = 0;
  113. bool input_used[TensorLayout::MAX_NDIM] = {0};
  114. for (auto i : ds.pattern) {
  115. if (i < 0) {
  116. out_layout.shape[idx] = 1;
  117. out_layout.stride[idx] = 1;
  118. } else {
  119. input_used[i] = true;
  120. out_layout.shape[idx] = inp_layout.shape[i];
  121. out_layout.stride[idx] = inp_layout.stride[i];
  122. }
  123. ++idx;
  124. }
  125. if (out_layout.is_contiguous()) {
  126. out_layout.init_contiguous_stride();
  127. }
  128. for (size_t i = 0; i < pattern_ndim; ++i) {
  129. mgb_assert(
  130. input_used[i] || inp_layout.shape[i] == 1,
  131. "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
  132. inp_layout.megdnn::TensorShape::to_string().c_str(), i);
  133. }
  134. // memory forward
  135. return {Tensor::make(src->blob(), src->offset(), out_layout)};
  136. }
  137. OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
  138. .make_from_op_node(make_from_op_node)
  139. .apply_on_var_node(apply_on_var_node)
  140. .apply_on_physical_tensor(apply_on_physical_tensor)
  141. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  142. .fallback();
  143. } // namespace dimshuffle
  144. } // namespace
  145. namespace {
  146. namespace add_axis {
  147. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  148. auto&& add_axis = static_cast<const AddAxis&>(def);
  149. using Desc = opr::AxisAddRemove::AxisDesc;
  150. std::vector<Desc> param;
  151. for (auto&& i : add_axis.axis) {
  152. param.push_back(Desc::make_add(i));
  153. }
  154. OperatorNodeConfig config{add_axis.make_name()};
  155. return opr::AxisAddRemove::make(inputs[0], param, config);
  156. }
  157. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  158. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  159. auto&& op_def = def.cast_final_safe<AddAxis>();
  160. size_t nr_inp = inputs.size();
  161. mgb_assert(nr_inp == 1, "AddAxis expects 1 inputs; got %lu actually", nr_inp);
  162. auto&& src = inputs[0];
  163. auto olayout = src.layout;
  164. if (src.layout.ndim == 0) {
  165. return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
  166. }
  167. for (auto&& i : op_def.axis) {
  168. olayout.add_axis_cont_inplace(i);
  169. }
  170. return {{{olayout, src.comp_node}}, true};
  171. }
  172. SmallVector<TensorPtr> apply_on_physical_tensor(
  173. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  174. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  175. auto&& op_def = def.cast_final_safe<AddAxis>();
  176. size_t nr_inp = inputs.size();
  177. mgb_assert(nr_inp == 1, "AddAxis expects 1 inputs; got %lu actually", nr_inp);
  178. auto&& src = inputs[0];
  179. auto tlayout = src->layout();
  180. for (auto&& i : op_def.axis) {
  181. tlayout.add_axis_cont_inplace(i);
  182. }
  183. // memory forward
  184. return {Tensor::make(src->blob(), src->offset(), tlayout)};
  185. }
  186. OP_TRAIT_REG(AddAxis, AddAxis)
  187. .apply_on_var_node(apply_on_var_node)
  188. .apply_on_physical_tensor(apply_on_physical_tensor)
  189. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  190. .fallback();
  191. } // namespace add_axis
  192. } // namespace
  193. namespace {
  194. namespace remove_axis {
  195. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  196. auto&& remove_axis = static_cast<const RemoveAxis&>(def);
  197. using Desc = opr::AxisAddRemove::AxisDesc;
  198. std::vector<Desc> param;
  199. for (auto&& i : remove_axis.axis) {
  200. param.push_back(Desc::make_remove(i));
  201. }
  202. OperatorNodeConfig config{remove_axis.make_name()};
  203. return opr::AxisAddRemove::make(inputs[0], param, config);
  204. }
  205. SmallVector<TensorPtr> apply_on_physical_tensor(
  206. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  207. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  208. auto&& op_def = def.cast_final_safe<RemoveAxis>();
  209. size_t nr_inp = inputs.size();
  210. mgb_assert(nr_inp == 1, "RemoveAxis expects 1 inputs; got %lu actually", nr_inp);
  211. auto&& src = inputs[0];
  212. auto tlayout = src->layout();
  213. for (auto&& i : op_def.axis) {
  214. if (tlayout.ndim == 1) {
  215. mgb_assert(
  216. tlayout.shape[0] == 1 && i == 0,
  217. "can not remove axis %u from tensor of shape=%s", i,
  218. tlayout.megdnn::TensorShape::to_string().c_str());
  219. } else {
  220. mgb_assert(
  221. i < tlayout.ndim && tlayout.shape[i] == 1,
  222. "can not remove axis %u from tensor of shape=%s", i,
  223. tlayout.megdnn::TensorShape::to_string().c_str());
  224. tlayout.remove_axis_inplace(i);
  225. }
  226. }
  227. // memory forward
  228. return {Tensor::make(src->blob(), src->offset(), tlayout)};
  229. }
  230. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  231. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  232. auto&& op_def = def.cast_final_safe<RemoveAxis>();
  233. size_t nr_inp = inputs.size();
  234. mgb_assert(nr_inp == 1, "RemoveAxis expects 1 inputs; got %lu actually", nr_inp);
  235. auto&& src = inputs[0];
  236. auto olayout = src.layout;
  237. if (src.layout.ndim == 0) {
  238. return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
  239. }
  240. for (auto&& i : op_def.axis) {
  241. if (olayout.ndim == 1) {
  242. mgb_assert(
  243. olayout.shape[0] == 1 && i == 0,
  244. "can not remove axis %u from tensor of shape=%s", i,
  245. olayout.megdnn::TensorShape::to_string().c_str());
  246. } else {
  247. mgb_assert(
  248. i < olayout.ndim && olayout.shape[i] == 1,
  249. "can not remove axis %u from tensor of shape=%s", i,
  250. olayout.megdnn::TensorShape::to_string().c_str());
  251. olayout.remove_axis_inplace(i);
  252. }
  253. }
  254. return {{{olayout, src.comp_node}}, true};
  255. }
  256. OP_TRAIT_REG(RemoveAxis, RemoveAxis)
  257. .apply_on_var_node(apply_on_var_node)
  258. .apply_on_physical_tensor(apply_on_physical_tensor)
  259. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  260. .fallback();
  261. } // namespace remove_axis
  262. } // namespace
  263. namespace {
  264. namespace top_k {
  265. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  266. auto&& topk = static_cast<const TopK&>(def);
  267. OperatorNodeConfig config{topk.make_name()};
  268. return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
  269. .node()
  270. ->owner_opr();
  271. }
  272. OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback();
  273. } // namespace top_k
  274. } // namespace
  275. namespace {
  276. namespace adaptive_pooling {
  277. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  278. auto&& pool = static_cast<const AdaptivePooling&>(def);
  279. OperatorNodeConfig config{pool.make_name()};
  280. return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config);
  281. }
  282. OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
  283. .apply_on_var_node(apply_on_var_node)
  284. .fallback();
  285. } // namespace adaptive_pooling
  286. } // namespace
  287. namespace {
  288. namespace batch_conv_bias {
  289. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  290. auto&& conv = static_cast<const BatchConvBias&>(def);
  291. cg::OperatorNodeConfig config{conv.dtype};
  292. config.name(conv.make_name());
  293. if (inputs.size() == 2) {
  294. return opr::BatchConvBias::make(
  295. inputs[0], inputs[1], conv.param(), conv.policy(), config);
  296. } else if (inputs.size() == 3) {
  297. return opr::BatchConvBias::make(
  298. inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
  299. } else if (inputs.size() == 4) {
  300. return opr::BatchConvBias::make(
  301. inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(),
  302. config);
  303. }
  304. mgb_assert(0);
  305. }
  306. OP_TRAIT_REG(BatchConvBias, BatchConvBias)
  307. .apply_on_var_node(apply_on_var_node)
  308. .fallback();
  309. } // namespace batch_conv_bias
  310. } // namespace
  311. namespace {
  312. namespace pooling {
  313. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  314. auto&& pool = static_cast<const Pooling&>(def);
  315. OperatorNodeConfig config{pool.make_name()};
  316. return opr::Pooling::make(inputs[0], pool.param(), pool.policy(), config);
  317. }
  318. OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback();
  319. } // namespace pooling
  320. } // namespace
  321. namespace {
  322. namespace matrix_mul {
  323. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  324. auto&& matmul = static_cast<const MatrixMul&>(def);
  325. mgb_assert(inputs.size() == 2);
  326. OperatorNodeConfig config{matmul.make_name()};
  327. return opr::MatrixMul::make(
  328. inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
  329. }
  330. OP_TRAIT_REG(MatrixMul, MatrixMul).apply_on_var_node(apply_on_var_node).fallback();
  331. } // namespace matrix_mul
  332. } // namespace
  333. namespace {
  334. namespace batched_matrix_mul {
  335. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  336. auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
  337. mgb_assert(inputs.size() == 2);
  338. OperatorNodeConfig config{matmul.make_name()};
  339. return opr::BatchedMatrixMul::make(
  340. inputs[0], inputs[1], matmul.param(), matmul.policy(), config);
  341. }
  342. OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
  343. .apply_on_var_node(apply_on_var_node)
  344. .fallback();
  345. } // namespace batched_matrix_mul
  346. } // namespace
  347. namespace {
  348. namespace dot {
  349. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  350. auto&& op = def.cast_final_safe<Dot>();
  351. mgb_assert(inputs.size() == 2);
  352. OperatorNodeConfig config{op.make_name()};
  353. return opr::Dot::make(inputs[0], inputs[1], config);
  354. }
  355. // std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  356. // auto* node = &node_->cast_final_safe<opr::Dot>();
  357. // return Dot::make(node->param());
  358. // }
  359. SmallVector<TensorPtr> apply_on_physical_tensor(
  360. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  361. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  362. auto a = inputs[0]->layout();
  363. auto comp_node = inputs[0]->comp_node();
  364. using TensorND = megdnn::TensorND;
  365. SmallVector<TensorND> inp_tensornds;
  366. inp_tensornds.reserve(inputs.size());
  367. auto dnn_opr = opr::intl::create_megdnn_opr<megdnn::Dot>(comp_node);
  368. for (unsigned i = 0; i < inputs.size(); ++i) {
  369. auto dnn_ten = inputs[i]->dnn_tensor();
  370. inp_tensornds.push_back(dnn_ten);
  371. }
  372. TensorLayout oup_layout{inputs[0]->dtype()};
  373. auto inp1_tensor = inputs[0]->dnn_tensor();
  374. auto inp2_tensor = inputs[1]->dnn_tensor();
  375. dnn_opr->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout);
  376. if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) {
  377. auto fill_opr = opr::intl::create_megdnn_opr<megdnn::Fill>(comp_node);
  378. DeviceTensorND out =
  379. BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
  380. fill_opr->param() = 0;
  381. fill_opr->exec(out.as_megdnn(), {});
  382. return {Tensor::make(out)};
  383. }
  384. auto wk_size = dnn_opr->get_workspace_in_bytes(
  385. inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout);
  386. DeviceTensorND out_devtensor =
  387. BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout);
  388. TensorLayout wk_layout{TensorShape{wk_size}, inputs[0]->dtype()};
  389. DeviceTensorND workspace =
  390. BlobManager::inst()->alloc_workspace_with_defrag(comp_node, wk_layout);
  391. megdnn::Workspace dnn_wk(workspace.raw_ptr(), wk_size);
  392. dnn_opr->exec(
  393. inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk);
  394. return {Tensor::make(out_devtensor)};
  395. }
  396. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  397. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  398. auto&& op_def = def.cast_final_safe<Dot>();
  399. SmallVector<LogicalTensorDesc> dests(1);
  400. dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype);
  401. dests[0].comp_node = inputs[0].comp_node;
  402. return {dests, true};
  403. }
  404. OP_TRAIT_REG(Dot, Dot, opr::Dot)
  405. .apply_on_var_node(apply_on_var_node)
  406. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  407. .apply_on_physical_tensor(apply_on_physical_tensor)
  408. .fallback();
  409. } // namespace dot
  410. } // namespace
  411. namespace {
  412. namespace argsort {
  413. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  414. auto&& argsort = static_cast<const Argsort&>(def);
  415. OperatorNodeConfig config{argsort.make_name()};
  416. return opr::Argsort::make(inputs[0], argsort.param(), config);
  417. }
  418. OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback();
  419. } // namespace argsort
  420. } // namespace
  421. namespace {
  422. namespace argmax {
  423. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  424. auto&& argmax = static_cast<const Argmax&>(def);
  425. OperatorNodeConfig config{argmax.make_name()};
  426. return opr::Argmax::make(inputs[0], argmax.param(), config);
  427. }
  428. OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback();
  429. } // namespace argmax
  430. } // namespace
  431. namespace {
  432. namespace argmin {
  433. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  434. auto&& argmin = static_cast<const Argmin&>(def);
  435. OperatorNodeConfig config{argmin.make_name()};
  436. return opr::Argmin::make(inputs[0], argmin.param(), config);
  437. }
  438. OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback();
  439. } // namespace argmin
  440. } // namespace
  441. namespace {
  442. namespace warp_perspective {
  443. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  444. auto&& warp = static_cast<const WarpPerspective&>(def);
  445. OperatorNodeConfig config{warp.make_name()};
  446. if (inputs.size() == 3) {
  447. return opr::WarpPerspective::make(
  448. inputs[0], inputs[1], inputs[2], warp.param(), config);
  449. } else {
  450. mgb_assert(inputs.size() == 4);
  451. return opr::WarpPerspective::make(
  452. inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config);
  453. }
  454. }
  455. OP_TRAIT_REG(WarpPerspective, WarpPerspective)
  456. .apply_on_var_node(apply_on_var_node)
  457. .fallback();
  458. } // namespace warp_perspective
  459. } // namespace
  460. namespace {
  461. namespace group_local {
  462. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  463. auto&& local = static_cast<const GroupLocal&>(def);
  464. mgb_assert(inputs.size() == 2);
  465. OperatorNodeConfig config{local.make_name()};
  466. return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config);
  467. }
  468. OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallback();
  469. } // namespace group_local
  470. } // namespace
  471. namespace {
  472. namespace indexing_set_one_hot {
  473. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  474. auto&& op = static_cast<const IndexingSetOneHot&>(def);
  475. mgb_assert(inputs.size() == 3);
  476. OperatorNodeConfig config{op.make_name()};
  477. return opr::IndexingSetOneHot::make(
  478. inputs[0], inputs[1], inputs[2], op.param(), config);
  479. }
  480. OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
  481. .apply_on_var_node(apply_on_var_node)
  482. .fallback();
  483. } // namespace indexing_set_one_hot
  484. } // namespace
  485. namespace {
  486. namespace typecvt {
  487. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  488. auto&& op = static_cast<const TypeCvt&>(def);
  489. mgb_assert(inputs.size() == 1);
  490. OperatorNodeConfig config{op.make_name()};
  491. return opr::TypeCvt::make(inputs[0], op.dtype, config);
  492. }
  493. OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
  494. } // namespace typecvt
  495. } // namespace
  496. namespace {
  497. namespace concat {
  498. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  499. auto&& op = static_cast<const Concat&>(def);
  500. cg::OperatorNodeConfig config{op.comp_node};
  501. config.name(op.make_name());
  502. return opr::Concat::make(inputs, op.axis, config);
  503. }
  504. OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback();
  505. } // namespace concat
  506. } // namespace
  507. namespace {
  508. namespace copy {
  509. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  510. auto&& op = static_cast<const Copy&>(def);
  511. mgb_assert(inputs.size() == 1);
  512. cg::OperatorNodeConfig config{op.comp_node};
  513. config.name(op.make_name());
  514. return opr::Copy::make(inputs[0], config);
  515. }
  516. OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback();
  517. } // namespace copy
  518. } // namespace
  519. namespace {
  520. namespace assert_equal {
  521. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  522. auto&& op = def.cast_final<AssertEqual>();
  523. if (inputs.size() == 2) {
  524. return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
  525. } else {
  526. // workaround for MiniGraph, which only allow one opr in the graph
  527. mgb_assert(inputs.size() == 3);
  528. return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {});
  529. }
  530. }
  531. OP_TRAIT_REG(AssertEqual, AssertEqual).apply_on_var_node(apply_on_var_node).fallback();
  532. } // namespace assert_equal
  533. } // namespace
  534. namespace {
  535. namespace roi_align {
  536. VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  537. auto&& op = static_cast<const ROIAlign&>(def);
  538. mgb_assert(inputs.size() == 2);
  539. OperatorNodeConfig config{op.make_name()};
  540. auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config)
  541. .node()
  542. ->owner_opr();
  543. return {opr->output(0), opr->output(1)};
  544. }
  545. OP_TRAIT_REG(ROIAlign, ROIAlign).apply_on_var_node(apply_on_var_node).fallback();
  546. } // namespace roi_align
  547. } // namespace
  548. namespace {
  549. namespace correlation {
  550. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  551. auto&& op = static_cast<const Correlation&>(def);
  552. mgb_assert(inputs.size() == 2);
  553. OperatorNodeConfig config{op.make_name()};
  554. return opr::Correlation::make(inputs[0], inputs[1], op.param(), config);
  555. }
  556. OP_TRAIT_REG(Correlation, Correlation).apply_on_var_node(apply_on_var_node).fallback();
  557. } // namespace correlation
  558. } // namespace
  559. #if MGB_CUDA
  560. namespace {
  561. namespace nvof {
  562. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  563. auto&& op = static_cast<const NvOf&>(def);
  564. mgb_assert(inputs.size() == 1);
  565. OperatorNodeConfig config{op.make_name()};
  566. return opr::NvOf::make(inputs[0], op.param(), config);
  567. }
  568. OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback();
  569. } // namespace nvof
  570. } // namespace
  571. #endif
  572. namespace {
  573. namespace linspace {
  574. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  575. auto&& op = static_cast<const Linspace&>(def);
  576. mgb_assert(inputs.size() == 3);
  577. cg::OperatorNodeConfig config{op.comp_node};
  578. config.name(op.make_name());
  579. return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
  580. }
  581. OP_TRAIT_REG(Linspace, Linspace).apply_on_var_node(apply_on_var_node).fallback();
  582. } // namespace linspace
  583. } // namespace
  584. namespace {
  585. namespace eye {
  586. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  587. auto&& op = static_cast<const Eye&>(def);
  588. mgb_assert(inputs.size() == 1);
  589. cg::OperatorNodeConfig config{op.comp_node};
  590. config.name(op.make_name());
  591. opr::Eye::Param param{op.k, op.dtype.enumv()};
  592. return opr::Eye::make(inputs[0], param, config);
  593. }
  594. OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback();
  595. } // namespace eye
  596. } // namespace
  597. namespace {
  598. namespace diag {
  599. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  600. auto&& op = static_cast<const Diag&>(def);
  601. mgb_assert(inputs.size() == 1);
  602. cg::OperatorNodeConfig config{op.make_name()};
  603. opr::Diag::Param param{op.k};
  604. return opr::Diag::make(inputs[0], param, config);
  605. }
  606. OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback();
  607. } // namespace diag
  608. } // namespace
  609. namespace {
  610. namespace roi_pooling {
  611. VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  612. auto&& op = static_cast<const ROIPooling&>(def);
  613. mgb_assert(inputs.size() == 3);
  614. OperatorNodeConfig config{op.make_name()};
  615. auto* opr =
  616. opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param(), config)
  617. .node()
  618. ->owner_opr();
  619. return {opr->output(0), opr->output(1)};
  620. }
  621. OP_TRAIT_REG(ROIPooling, ROIPooling).apply_on_var_node(apply_on_var_node).fallback();
  622. } // namespace roi_pooling
  623. } // namespace
  624. namespace {
  625. namespace remap {
  626. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  627. auto&& op = static_cast<const Remap&>(def);
  628. mgb_assert(inputs.size() == 2);
  629. OperatorNodeConfig config{op.make_name()};
  630. return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
  631. }
  632. OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback();
  633. } // namespace remap
  634. } // namespace
  635. namespace {
  636. auto get_index(
  637. const VarNodeArray& inputs, size_t vidx,
  638. const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
  639. size_t length = mask.size();
  640. opr::Subtensor::IndexDesc ret(length);
  641. for (size_t i = 0; i < length; ++i) {
  642. auto&& [axis, begin, end, step, idx] = mask[i];
  643. ret[i].axis = axis;
  644. if (idx) {
  645. ret[i].idx = inputs[vidx++];
  646. } else {
  647. mgb_assert(begin || end || step);
  648. if (begin)
  649. ret[i].begin = inputs[vidx++];
  650. if (end)
  651. ret[i].end = inputs[vidx++];
  652. if (step)
  653. ret[i].step = inputs[vidx++];
  654. }
  655. }
  656. mgb_assert(vidx == inputs.size());
  657. return ret;
  658. }
  659. #define IN1 inputs[0]
  660. #define IN2 inputs[0], inputs[1]
  661. #define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \
  662. namespace NAME##_impl { \
  663. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { \
  664. auto&& op = static_cast<const NAME&>(def); \
  665. OperatorNodeConfig config{op.make_name()}; \
  666. return opr::NAME::make( \
  667. IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \
  668. } \
  669. OP_TRAIT_REG(NAME, NAME).apply_on_var_node(apply_on_var_node).fallback(); \
  670. }
  671. FANCY_INDEXING_IMPL(Subtensor, 1)
  672. FANCY_INDEXING_IMPL(SetSubtensor, 2)
  673. FANCY_INDEXING_IMPL(IncrSubtensor, 2)
  674. FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
  675. FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
  676. FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
  677. FANCY_INDEXING_IMPL(MeshIndexing, 1)
  678. FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
  679. FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
  680. FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
  681. FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
  682. FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)
  683. #undef FANCY_INDEXING_IMPL
  684. #undef IN1
  685. #undef IN2
  686. } // anonymous namespace
  687. namespace {
  688. namespace fake_quant {
  689. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  690. auto&& op = static_cast<const FakeQuant&>(def);
  691. mgb_assert(inputs.size() == 3);
  692. OperatorNodeConfig config{op.make_name()};
  693. return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config);
  694. }
  695. OP_TRAIT_REG(FakeQuant, FakeQuant).apply_on_var_node(apply_on_var_node).fallback();
  696. } // namespace fake_quant
  697. } // namespace
  698. namespace {
  699. namespace tqt {
  700. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  701. auto&& op = static_cast<const TQT&>(def);
  702. mgb_assert(inputs.size() == 2);
  703. OperatorNodeConfig config{op.make_name()};
  704. return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
  705. }
  706. OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback();
  707. } // namespace tqt
  708. } // namespace
  709. namespace {
  710. namespace elemwise_multi_type {
  711. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  712. auto&& op = static_cast<const ElemwiseMultiType&>(def);
  713. OperatorNodeConfig config{op.dtype};
  714. config.name(op.make_name());
  715. return opr::ElemwiseMultiType::make(inputs, op.param(), config);
  716. }
  717. OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
  718. .apply_on_var_node(apply_on_var_node)
  719. .fallback();
  720. } // namespace elemwise_multi_type
  721. } // namespace
  722. namespace {
  723. namespace svd {
  724. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  725. auto&& op = static_cast<const SVD&>(def);
  726. mgb_assert(inputs.size() == 1);
  727. OperatorNodeConfig config{op.make_name()};
  728. return opr::SVD::make(inputs[0], op.param(), config)[0]
  729. .node()
  730. ->owner_opr()
  731. ->usable_output();
  732. }
  733. OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback();
  734. } // namespace svd
  735. } // namespace
  736. namespace {
  737. namespace images2neibs {
  738. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  739. auto&& op = static_cast<const Images2Neibs&>(def);
  740. OperatorNodeConfig config{op.make_name()};
  741. return opr::Images2Neibs::make(inputs[0], op.param(), config);
  742. }
  743. OP_TRAIT_REG(Images2Neibs, Images2Neibs)
  744. .apply_on_var_node(apply_on_var_node)
  745. .fallback();
  746. } // namespace images2neibs
  747. } // namespace
  748. namespace {
  749. namespace lsq {
  750. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  751. auto&& op = static_cast<const LSQ&>(def);
  752. mgb_assert(inputs.size() == 4);
  753. OperatorNodeConfig config{op.make_name()};
  754. return opr::LSQ::make(
  755. inputs[0], inputs[1], inputs[2], inputs[3], op.param(), config);
  756. }
  757. OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback();
  758. } // namespace lsq
  759. } // namespace
  760. namespace {
  761. namespace sliding_window_transpose {
  762. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  763. auto&& op = static_cast<const SlidingWindowTranspose&>(def);
  764. OperatorNodeConfig config{op.make_name()};
  765. return opr::SlidingWindowTranspose::make(inputs[0], op.param(), config);
  766. }
  767. OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose)
  768. .apply_on_var_node(apply_on_var_node)
  769. .fallback();
  770. } // namespace sliding_window_transpose
  771. } // namespace
  772. namespace {
  773. namespace cumsum {
  774. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  775. auto&& op = static_cast<const Cumsum&>(def);
  776. OperatorNodeConfig config{op.make_name()};
  777. return opr::Cumsum::make(inputs[0], op.param(), config);
  778. }
  779. OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback();
  780. } // namespace cumsum
  781. } // namespace
  782. namespace padding {
  783. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  784. auto&& op = static_cast<const Padding&>(def);
  785. mgb_assert(inputs.size() == 1);
  786. return opr::Padding::make(inputs[0], op.param());
  787. }
  788. OP_TRAIT_REG(Padding, Padding).apply_on_var_node(apply_on_var_node).fallback();
  789. } // namespace padding
  790. namespace lrn {
  791. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  792. auto&& op = static_cast<const LRN&>(def);
  793. mgb_assert(inputs.size() == 1);
  794. return opr::LRN::make(inputs[0], op.param());
  795. }
  796. OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback();
  797. } // namespace lrn
  798. namespace layer_norm {
  799. cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  800. auto&& op = static_cast<const LayerNorm&>(def);
  801. size_t nr_inp = inputs.size();
  802. auto p = op.param();
  803. mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine));
  804. OperatorNodeConfig config{op.make_name()};
  805. if (nr_inp == 3) {
  806. return opr::LayerNorm::make(
  807. inputs[0], inputs[1], inputs[2], op.param(), config)[0]
  808. .node()
  809. ->owner_opr();
  810. } else {
  811. return opr::LayerNorm::make(inputs[0], op.param(), config)[0]
  812. .node()
  813. ->owner_opr();
  814. }
  815. }
  816. OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback();
  817. } // namespace layer_norm
  818. } // namespace mgb::imperative