You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

specializations.cpp 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664
  1. // FIXME: split this file into separate files for each specialized op
  2. #include "megbrain/imperative/ops/autogen.h"
  3. #include "megbrain/opr/basic_arith.h"
  4. #include "megbrain/opr/blas.h"
  5. #include "megbrain/opr/dnn/adaptive_pooling.h"
  6. #include "megbrain/opr/dnn/convolution.h"
  7. #include "megbrain/opr/dnn/correlation.h"
  8. #include "megbrain/opr/dnn/fake_quant.h"
  9. #include "megbrain/opr/dnn/images2neibs.h"
  10. #include "megbrain/opr/dnn/local.h"
  11. #include "megbrain/opr/dnn/lrn.h"
  12. #include "megbrain/opr/dnn/lsq.h"
  13. #include "megbrain/opr/dnn/pooling.h"
  14. #include "megbrain/opr/dnn/roi_align.h"
  15. #include "megbrain/opr/dnn/roi_pooling.h"
  16. #include "megbrain/opr/dnn/sliding_window_transpose.h"
  17. #include "megbrain/opr/dnn/tqt.h"
  18. #include "megbrain/opr/imgproc.h"
  19. #include "megbrain/opr/indexing.h"
  20. #include "megbrain/opr/io.h"
  21. #include "megbrain/opr/misc.h"
  22. #include "megbrain/opr/nn_int.h"
  23. #include "megbrain/opr/rand.h"
  24. #include "megbrain/opr/tensor_gen.h"
  25. #include "megbrain/opr/tensor_manip.h"
  26. #include "megbrain/opr/utility.h"
  27. #include "../blob_manager_impl.h"
  28. #include "../op_trait.h"
  29. namespace mgb::imperative {
  30. namespace {
  31. namespace dimshuffle {
  32. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  33. auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
  34. std::vector<int> pattern(node->param().pattern_len);
  35. for (size_t i = 0; i < node->param().pattern_len; ++i) {
  36. pattern[i] = node->param().pattern[i];
  37. }
  38. return Dimshuffle::make(pattern);
  39. }
  40. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  41. auto&& ds = static_cast<const Dimshuffle&>(def);
  42. OperatorNodeConfig config{ds.make_name()};
  43. return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
  44. }
  45. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  46. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  47. auto&& ds = static_cast<const Dimshuffle&>(def);
  48. mgb_assert(
  49. ds.pattern.size() <= TensorShape::MAX_NDIM,
  50. "Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM);
  51. size_t nr_inp = inputs.size();
  52. mgb_assert(nr_inp == 1, "Dimshuffle expects 1 inputs; got %lu actually", nr_inp);
  53. auto&& src = inputs[0];
  54. TensorShape out_shape;
  55. if (src.layout.ndim == 0) {
  56. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
  57. }
  58. size_t pattern_ndim = *std::max_element(ds.pattern.begin(), ds.pattern.end()) + 1;
  59. mgb_assert(
  60. src.layout.ndim == pattern_ndim,
  61. "input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", pattern_ndim,
  62. src.layout.ndim);
  63. size_t idx = 0;
  64. bool input_used[TensorLayout::MAX_NDIM] = {0};
  65. out_shape.ndim = ds.pattern.size();
  66. for (auto i : ds.pattern) {
  67. if (i < 0) {
  68. out_shape[idx] = 1;
  69. } else {
  70. input_used[i] = true;
  71. out_shape[idx] = src.layout.shape[i];
  72. }
  73. ++idx;
  74. }
  75. for (size_t i = 0; i < pattern_ndim; ++i) {
  76. mgb_assert(
  77. input_used[i] || src.layout.shape[i] == 1,
  78. "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
  79. src.layout.megdnn::TensorShape::to_string().c_str(), i);
  80. }
  81. return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
  82. }
  83. SmallVector<TensorPtr> apply_on_physical_tensor(
  84. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  85. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  86. auto&& ds = static_cast<const Dimshuffle&>(def);
  87. mgb_assert(
  88. ds.pattern.size() <= TensorShape::MAX_NDIM,
  89. "Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM);
  90. size_t nr_inp = inputs.size();
  91. mgb_assert(nr_inp == 1, "Dimshuffle expects 1 inputs; got %lu actually", nr_inp);
  92. auto&& src = inputs[0];
  93. auto inp_layout = src->layout();
  94. size_t pattern_ndim = *std::max_element(ds.pattern.begin(), ds.pattern.end()) + 1;
  95. mgb_assert(
  96. inp_layout.ndim == pattern_ndim,
  97. "input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", pattern_ndim,
  98. inp_layout.ndim);
  99. TensorLayout out_layout{inp_layout.dtype};
  100. out_layout.ndim = ds.pattern.size();
  101. size_t idx = 0;
  102. bool input_used[TensorLayout::MAX_NDIM] = {0};
  103. for (auto i : ds.pattern) {
  104. if (i < 0) {
  105. out_layout.shape[idx] = 1;
  106. out_layout.stride[idx] = 1;
  107. } else {
  108. input_used[i] = true;
  109. out_layout.shape[idx] = inp_layout.shape[i];
  110. out_layout.stride[idx] = inp_layout.stride[i];
  111. }
  112. ++idx;
  113. }
  114. if (out_layout.is_contiguous()) {
  115. out_layout.init_contiguous_stride();
  116. }
  117. for (size_t i = 0; i < pattern_ndim; ++i) {
  118. mgb_assert(
  119. input_used[i] || inp_layout.shape[i] == 1,
  120. "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
  121. inp_layout.megdnn::TensorShape::to_string().c_str(), i);
  122. }
  123. // memory forward
  124. return {Tensor::make(src->blob(), src->offset(), out_layout)};
  125. }
  126. OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
  127. .make_from_op_node(make_from_op_node)
  128. .apply_on_var_node(apply_on_var_node)
  129. .apply_on_physical_tensor(apply_on_physical_tensor)
  130. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  131. .fallback();
  132. } // namespace dimshuffle
  133. } // namespace
  134. namespace {
  135. namespace add_axis {
  136. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  137. auto&& add_axis = static_cast<const AddAxis&>(def);
  138. using Desc = opr::AxisAddRemove::AxisDesc;
  139. std::vector<Desc> param;
  140. for (auto&& i : add_axis.axis) {
  141. param.push_back(Desc::make_add(i));
  142. }
  143. OperatorNodeConfig config{add_axis.make_name()};
  144. return opr::AxisAddRemove::make(inputs[0], param, config);
  145. }
  146. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  147. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  148. auto&& op_def = def.cast_final_safe<AddAxis>();
  149. size_t nr_inp = inputs.size();
  150. mgb_assert(nr_inp == 1, "AddAxis expects 1 inputs; got %lu actually", nr_inp);
  151. auto&& src = inputs[0];
  152. auto olayout = src.layout;
  153. if (src.layout.ndim == 0) {
  154. return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
  155. }
  156. for (auto&& i : op_def.axis) {
  157. olayout.add_axis_cont_inplace(i);
  158. }
  159. return {{{olayout, src.comp_node}}, true};
  160. }
  161. SmallVector<TensorPtr> apply_on_physical_tensor(
  162. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  163. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  164. auto&& op_def = def.cast_final_safe<AddAxis>();
  165. size_t nr_inp = inputs.size();
  166. mgb_assert(nr_inp == 1, "AddAxis expects 1 inputs; got %lu actually", nr_inp);
  167. auto&& src = inputs[0];
  168. auto tlayout = src->layout();
  169. for (auto&& i : op_def.axis) {
  170. tlayout.add_axis_cont_inplace(i);
  171. }
  172. // memory forward
  173. return {Tensor::make(src->blob(), src->offset(), tlayout)};
  174. }
  175. OP_TRAIT_REG(AddAxis, AddAxis)
  176. .apply_on_var_node(apply_on_var_node)
  177. .apply_on_physical_tensor(apply_on_physical_tensor)
  178. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  179. .fallback();
  180. } // namespace add_axis
  181. } // namespace
  182. namespace {
  183. namespace remove_axis {
  184. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  185. auto&& remove_axis = static_cast<const RemoveAxis&>(def);
  186. using Desc = opr::AxisAddRemove::AxisDesc;
  187. std::vector<Desc> param;
  188. for (auto&& i : remove_axis.axis) {
  189. param.push_back(Desc::make_remove(i));
  190. }
  191. OperatorNodeConfig config{remove_axis.make_name()};
  192. return opr::AxisAddRemove::make(inputs[0], param, config);
  193. }
  194. SmallVector<TensorPtr> apply_on_physical_tensor(
  195. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  196. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  197. auto&& op_def = def.cast_final_safe<RemoveAxis>();
  198. size_t nr_inp = inputs.size();
  199. mgb_assert(nr_inp == 1, "RemoveAxis expects 1 inputs; got %lu actually", nr_inp);
  200. auto&& src = inputs[0];
  201. auto tlayout = src->layout();
  202. for (auto&& i : op_def.axis) {
  203. if (tlayout.ndim == 1) {
  204. mgb_assert(
  205. tlayout.shape[0] == 1 && i == 0,
  206. "can not remove axis %u from tensor of shape=%s", i,
  207. tlayout.megdnn::TensorShape::to_string().c_str());
  208. } else {
  209. mgb_assert(
  210. i < tlayout.ndim && tlayout.shape[i] == 1,
  211. "can not remove axis %u from tensor of shape=%s", i,
  212. tlayout.megdnn::TensorShape::to_string().c_str());
  213. tlayout.remove_axis_inplace(i);
  214. }
  215. }
  216. // memory forward
  217. return {Tensor::make(src->blob(), src->offset(), tlayout)};
  218. }
  219. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  220. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  221. auto&& op_def = def.cast_final_safe<RemoveAxis>();
  222. size_t nr_inp = inputs.size();
  223. mgb_assert(nr_inp == 1, "RemoveAxis expects 1 inputs; got %lu actually", nr_inp);
  224. auto&& src = inputs[0];
  225. auto olayout = src.layout;
  226. if (src.layout.ndim == 0) {
  227. return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false};
  228. }
  229. for (auto&& i : op_def.axis) {
  230. if (olayout.ndim == 1) {
  231. mgb_assert(
  232. olayout.shape[0] == 1 && i == 0,
  233. "can not remove axis %u from tensor of shape=%s", i,
  234. olayout.megdnn::TensorShape::to_string().c_str());
  235. } else {
  236. mgb_assert(
  237. i < olayout.ndim && olayout.shape[i] == 1,
  238. "can not remove axis %u from tensor of shape=%s", i,
  239. olayout.megdnn::TensorShape::to_string().c_str());
  240. olayout.remove_axis_inplace(i);
  241. }
  242. }
  243. return {{{olayout, src.comp_node}}, true};
  244. }
  245. OP_TRAIT_REG(RemoveAxis, RemoveAxis)
  246. .apply_on_var_node(apply_on_var_node)
  247. .apply_on_physical_tensor(apply_on_physical_tensor)
  248. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  249. .fallback();
  250. } // namespace remove_axis
  251. } // namespace
  252. namespace {
  253. namespace top_k {
  254. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  255. auto&& topk = static_cast<const TopK&>(def);
  256. OperatorNodeConfig config{topk.make_name()};
  257. return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
  258. .node()
  259. ->owner_opr();
  260. }
  261. OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback();
  262. } // namespace top_k
  263. } // namespace
  264. namespace {
  265. namespace batch_conv_bias {
  266. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  267. auto&& conv = static_cast<const BatchConvBias&>(def);
  268. cg::OperatorNodeConfig config{conv.dtype};
  269. config.name(conv.make_name());
  270. if (inputs.size() == 2) {
  271. return opr::BatchConvBias::make(
  272. inputs[0], inputs[1], conv.param(), conv.policy(), config);
  273. } else if (inputs.size() == 3) {
  274. return opr::BatchConvBias::make(
  275. inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
  276. } else if (inputs.size() == 4) {
  277. return opr::BatchConvBias::make(
  278. inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(),
  279. config);
  280. }
  281. mgb_assert(0);
  282. }
  283. OP_TRAIT_REG(BatchConvBias, BatchConvBias)
  284. .apply_on_var_node(apply_on_var_node)
  285. .fallback();
  286. } // namespace batch_conv_bias
  287. } // namespace
  288. namespace {
  289. namespace argsort {
  290. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  291. auto&& argsort = static_cast<const Argsort&>(def);
  292. OperatorNodeConfig config{argsort.make_name()};
  293. return opr::Argsort::make(inputs[0], argsort.param(), config);
  294. }
  295. OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback();
  296. } // namespace argsort
  297. } // namespace
  298. namespace {
  299. namespace argmax {
  300. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  301. auto&& argmax = static_cast<const Argmax&>(def);
  302. OperatorNodeConfig config{argmax.make_name()};
  303. return opr::Argmax::make(inputs[0], argmax.param(), config);
  304. }
  305. OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback();
  306. } // namespace argmax
  307. } // namespace
  308. namespace {
  309. namespace argmin {
  310. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  311. auto&& argmin = static_cast<const Argmin&>(def);
  312. OperatorNodeConfig config{argmin.make_name()};
  313. return opr::Argmin::make(inputs[0], argmin.param(), config);
  314. }
  315. OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback();
  316. } // namespace argmin
  317. } // namespace
  318. namespace {
  319. namespace warp_perspective {
  320. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  321. auto&& warp = static_cast<const WarpPerspective&>(def);
  322. OperatorNodeConfig config{warp.make_name()};
  323. if (inputs.size() == 3) {
  324. return opr::WarpPerspective::make(
  325. inputs[0], inputs[1], inputs[2], warp.param(), config);
  326. } else {
  327. mgb_assert(inputs.size() == 4);
  328. return opr::WarpPerspective::make(
  329. inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config);
  330. }
  331. }
  332. OP_TRAIT_REG(WarpPerspective, WarpPerspective)
  333. .apply_on_var_node(apply_on_var_node)
  334. .fallback();
  335. } // namespace warp_perspective
  336. } // namespace
  337. namespace {
  338. namespace group_local {
  339. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  340. auto&& local = static_cast<const GroupLocal&>(def);
  341. mgb_assert(inputs.size() == 2);
  342. OperatorNodeConfig config{local.make_name()};
  343. return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config);
  344. }
  345. OP_TRAIT_REG(GroupLocal, GroupLocal).apply_on_var_node(apply_on_var_node).fallback();
  346. } // namespace group_local
  347. } // namespace
  348. namespace {
  349. namespace typecvt {
  350. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  351. auto&& op = static_cast<const TypeCvt&>(def);
  352. mgb_assert(inputs.size() == 1);
  353. OperatorNodeConfig config{op.make_name()};
  354. return opr::TypeCvt::make(inputs[0], op.dtype, config);
  355. }
  356. OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
  357. } // namespace typecvt
  358. } // namespace
  359. namespace {
  360. namespace concat {
  361. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  362. auto&& op = static_cast<const Concat&>(def);
  363. cg::OperatorNodeConfig config{op.comp_node};
  364. config.name(op.make_name());
  365. return opr::Concat::make(inputs, op.axis, config);
  366. }
  367. OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback();
  368. } // namespace concat
  369. } // namespace
  370. namespace {
  371. namespace copy {
  372. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  373. auto&& op = static_cast<const Copy&>(def);
  374. mgb_assert(inputs.size() == 1);
  375. cg::OperatorNodeConfig config{op.comp_node};
  376. config.name(op.make_name());
  377. return opr::Copy::make(inputs[0], config);
  378. }
  379. OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback();
  380. } // namespace copy
  381. } // namespace
  382. namespace {
  383. namespace assert_equal {
  384. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  385. auto&& op = def.cast_final<AssertEqual>();
  386. if (inputs.size() == 2) {
  387. return opr::AssertEqual::make(inputs[0], inputs[1], op.param());
  388. } else {
  389. // workaround for MiniGraph, which only allow one opr in the graph
  390. mgb_assert(inputs.size() == 3);
  391. return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {});
  392. }
  393. }
  394. OP_TRAIT_REG(AssertEqual, AssertEqual).apply_on_var_node(apply_on_var_node).fallback();
  395. } // namespace assert_equal
  396. } // namespace
  397. namespace {
  398. namespace correlation {
  399. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  400. auto&& op = static_cast<const Correlation&>(def);
  401. mgb_assert(inputs.size() == 2);
  402. OperatorNodeConfig config{op.make_name()};
  403. return opr::Correlation::make(inputs[0], inputs[1], op.param(), config);
  404. }
  405. OP_TRAIT_REG(Correlation, Correlation).apply_on_var_node(apply_on_var_node).fallback();
  406. } // namespace correlation
  407. } // namespace
  408. #if MGB_CUDA
  409. namespace {
  410. namespace nvof {
  411. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  412. auto&& op = static_cast<const NvOf&>(def);
  413. mgb_assert(inputs.size() == 1);
  414. OperatorNodeConfig config{op.make_name()};
  415. return opr::NvOf::make(inputs[0], op.param(), config);
  416. }
  417. OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback();
  418. } // namespace nvof
  419. } // namespace
  420. #endif
  421. namespace {
  422. namespace linspace {
  423. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  424. auto&& op = static_cast<const Linspace&>(def);
  425. mgb_assert(inputs.size() == 3);
  426. cg::OperatorNodeConfig config{op.comp_node};
  427. config.name(op.make_name());
  428. return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
  429. }
  430. OP_TRAIT_REG(Linspace, Linspace).apply_on_var_node(apply_on_var_node).fallback();
  431. } // namespace linspace
  432. } // namespace
  433. namespace {
  434. namespace eye {
  435. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  436. auto&& op = static_cast<const Eye&>(def);
  437. mgb_assert(inputs.size() == 1);
  438. cg::OperatorNodeConfig config{op.comp_node};
  439. config.name(op.make_name());
  440. opr::Eye::Param param{op.k, op.dtype.enumv()};
  441. return opr::Eye::make(inputs[0], param, config);
  442. }
  443. OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback();
  444. } // namespace eye
  445. } // namespace
  446. namespace {
  447. namespace diag {
  448. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  449. auto&& op = static_cast<const Diag&>(def);
  450. mgb_assert(inputs.size() == 1);
  451. cg::OperatorNodeConfig config{op.make_name()};
  452. opr::Diag::Param param{op.k};
  453. return opr::Diag::make(inputs[0], param, config);
  454. }
  455. OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback();
  456. } // namespace diag
  457. } // namespace
  458. namespace {
  459. namespace remap {
  460. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  461. auto&& op = static_cast<const Remap&>(def);
  462. mgb_assert(inputs.size() == 2);
  463. OperatorNodeConfig config{op.make_name()};
  464. return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
  465. }
  466. OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback();
  467. } // namespace remap
  468. } // namespace
  469. namespace {
  470. auto get_index(
  471. const VarNodeArray& inputs, size_t vidx,
  472. const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
  473. size_t length = mask.size();
  474. opr::Subtensor::IndexDesc ret(length);
  475. for (size_t i = 0; i < length; ++i) {
  476. auto&& [axis, begin, end, step, idx] = mask[i];
  477. ret[i].axis = axis;
  478. if (idx) {
  479. ret[i].idx = inputs[vidx++];
  480. } else {
  481. mgb_assert(begin || end || step);
  482. if (begin)
  483. ret[i].begin = inputs[vidx++];
  484. if (end)
  485. ret[i].end = inputs[vidx++];
  486. if (step)
  487. ret[i].step = inputs[vidx++];
  488. }
  489. }
  490. mgb_assert(vidx == inputs.size());
  491. return ret;
  492. }
  493. #define IN1 inputs[0]
  494. #define IN2 inputs[0], inputs[1]
  495. #define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \
  496. namespace NAME##_impl { \
  497. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { \
  498. auto&& op = static_cast<const NAME&>(def); \
  499. OperatorNodeConfig config{op.make_name()}; \
  500. return opr::NAME::make( \
  501. IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \
  502. } \
  503. OP_TRAIT_REG(NAME, NAME).apply_on_var_node(apply_on_var_node).fallback(); \
  504. }
  505. FANCY_INDEXING_IMPL(Subtensor, 1)
  506. FANCY_INDEXING_IMPL(SetSubtensor, 2)
  507. FANCY_INDEXING_IMPL(IncrSubtensor, 2)
  508. FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
  509. FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
  510. FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
  511. FANCY_INDEXING_IMPL(MeshIndexing, 1)
  512. FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
  513. FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
  514. FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
  515. FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
  516. FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)
  517. #undef FANCY_INDEXING_IMPL
  518. #undef IN1
  519. #undef IN2
  520. } // anonymous namespace
  521. namespace {
  522. namespace fake_quant {
  523. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  524. auto&& op = static_cast<const FakeQuant&>(def);
  525. mgb_assert(inputs.size() == 3);
  526. OperatorNodeConfig config{op.make_name()};
  527. return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config);
  528. }
  529. OP_TRAIT_REG(FakeQuant, FakeQuant).apply_on_var_node(apply_on_var_node).fallback();
  530. } // namespace fake_quant
  531. } // namespace
  532. namespace {
  533. namespace tqt {
  534. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  535. auto&& op = static_cast<const TQT&>(def);
  536. mgb_assert(inputs.size() == 2);
  537. OperatorNodeConfig config{op.make_name()};
  538. return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
  539. }
  540. OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback();
  541. } // namespace tqt
  542. } // namespace
  543. namespace {
  544. namespace elemwise_multi_type {
  545. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  546. auto&& op = static_cast<const ElemwiseMultiType&>(def);
  547. OperatorNodeConfig config{op.dtype};
  548. config.name(op.make_name());
  549. return opr::ElemwiseMultiType::make(inputs, op.param(), config);
  550. }
  551. OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
  552. .apply_on_var_node(apply_on_var_node)
  553. .fallback();
  554. } // namespace elemwise_multi_type
  555. } // namespace
  556. namespace {
  557. namespace svd {
  558. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  559. auto&& op = static_cast<const SVD&>(def);
  560. mgb_assert(inputs.size() == 1);
  561. OperatorNodeConfig config{op.make_name()};
  562. return opr::SVD::make(inputs[0], op.param(), config)[0]
  563. .node()
  564. ->owner_opr()
  565. ->usable_output();
  566. }
  567. OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback();
  568. } // namespace svd
  569. } // namespace
  570. namespace {
  571. namespace images2neibs {
  572. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  573. auto&& op = static_cast<const Images2Neibs&>(def);
  574. OperatorNodeConfig config{op.make_name()};
  575. return opr::Images2Neibs::make(inputs[0], op.param(), config);
  576. }
  577. OP_TRAIT_REG(Images2Neibs, Images2Neibs)
  578. .apply_on_var_node(apply_on_var_node)
  579. .fallback();
  580. } // namespace images2neibs
  581. } // namespace
  582. namespace {
  583. namespace lsq {
  584. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  585. auto&& op = static_cast<const LSQ&>(def);
  586. mgb_assert(inputs.size() == 4);
  587. OperatorNodeConfig config{op.make_name()};
  588. return opr::LSQ::make(
  589. inputs[0], inputs[1], inputs[2], inputs[3], op.param(), config);
  590. }
  591. OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback();
  592. } // namespace lsq
  593. } // namespace
  594. namespace {
  595. namespace sliding_window_transpose {
  596. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  597. auto&& op = static_cast<const SlidingWindowTranspose&>(def);
  598. OperatorNodeConfig config{op.make_name()};
  599. return opr::SlidingWindowTranspose::make(inputs[0], op.param(), config);
  600. }
  601. OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose)
  602. .apply_on_var_node(apply_on_var_node)
  603. .fallback();
  604. } // namespace sliding_window_transpose
  605. } // namespace
  606. namespace lrn {
  607. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  608. auto&& op = static_cast<const LRN&>(def);
  609. mgb_assert(inputs.size() == 1);
  610. return opr::LRN::make(inputs[0], op.param());
  611. }
  612. OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback();
  613. } // namespace lrn
  614. } // namespace mgb::imperative