You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

general.h 46 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269
  1. /**
  2. * \file dnn/include/megdnn/oprs/general.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/internal/opr_header_prologue.h"
  13. #include "megdnn/thin/small_vector.h"
  14. namespace megdnn {
  15. /*!
  16. * \brief standard element-wise operator
  17. *
  18. * Inputs must have same dtype, and their shapes must broadcastable into a final
  19. * shape. They can have arbitrary layouts, but non-contiguous and non-broadcast
  20. * layouts may harm performance seriously.
  21. *
  22. * Output dtype is the same as input dtype (note that even for compare oprs this
  23. * is true, e.g. float == float returns value of float). Output layout must be
  24. * contiguous.
  25. */
  26. class ElemwiseForward: public OperatorBase {
  27. DEF_OPR_PARAM(Elemwise);
  28. DEF_OPR_IMPL(ElemwiseForward, OperatorBase, -1, 1);
  29. public:
  30. using Mode = Param::Mode;
  31. //! information about a mode
  32. struct ModeTrait {
  33. uint32_t arity; //!< number of inputs needed
  34. bool commutable; //!< whether arity == 2 and inputs commutable
  35. bool allow_int; //!< whether int inputs allowed
  36. bool allow_float; //!< whether float inputs allowed
  37. const char* name; //!< name of the mode
  38. ModeTrait():
  39. arity(0), commutable(0), allow_int(0), allow_float(0),
  40. name(NULL)
  41. {}
  42. //! get trait from a mode; this function is thread safe
  43. static const ModeTrait& from_mode(Mode mode);
  44. };
  45. //! get trait of current mode
  46. const ModeTrait& mode_trait() const {
  47. return ModeTrait::from_mode(m_param.mode);
  48. }
  49. /**
  50. * \param[in] src input tensor
  51. * \param[out] dst output tensor
  52. *
  53. * src and dst should have the same shape;
  54. * layouts should be contiguous;
  55. * the underlying data pointer can point to the same memory region for
  56. * src and dst.
  57. */
  58. virtual void exec(_megdnn_in const TensorNDArray &src,
  59. _megdnn_tensor_out dst) = 0;
  60. //! deduce output shape (do not check whether arity matches)
  61. static void deduce_shape(
  62. const TensorShapeArray &src,
  63. TensorShape &dst);
  64. static void deduce_format(const TensorFormatArray& src,
  65. TensorFormat& dst);
  66. //! deduce output layout
  67. void deduce_layout(const TensorLayoutArray &src,
  68. TensorLayout &dst);
  69. protected:
  70. //! throw exception if incorrect layout; broadcast input shape to
  71. //! output shape
  72. void check_layout_and_broadcast(
  73. const TensorLayoutPtrArray &src, const TensorLayout &dst);
  74. private:
  75. void check_dtype(DType dtype);
  76. };
  77. using Elemwise = ElemwiseForward;
  78. /*!
  79. * \brief compute ``x**a`` where ``a`` is a constant from the Param
  80. *
  81. * This opr is usually not directly accessible by the end user and it is created
  82. * by mgb optimizer, aiming to work around numerical stability issues with pow.
  83. * For example ``powf(x, 2.f)`` with ``x < 0`` in fast math mode may return NaN.
  84. *
  85. * Like elemwise, this opr supports arbitrary strides. But it should only be
  86. * used with monotone strides. Input and output should have the same
  87. * float-category dtype.
  88. */
  89. class PowC : public OperatorBase {
  90. DEF_OPR_PARAM(PowC);
  91. DEF_OPR_IMPL(PowC, OperatorBase, 1, 1);
  92. public:
  93. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst);
  94. //! compatible API for mgb; workspace is not used
  95. void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  96. _megdnn_workspace) {
  97. return exec(src, dst);
  98. }
  99. size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) {
  100. // the impls should require no workspace; this can be later changed to a
  101. // virtual function if this situation changes
  102. return 0;
  103. }
  104. void deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  105. dst.dtype = src.dtype;
  106. dst.init_contiguous_stride(src);
  107. }
  108. protected:
  109. /*!
  110. * Perform the computing where layouts have been verified.
  111. *
  112. * \p src can have arbitrary layout, and \p dst is contiguous. They have the
  113. * same shape and dtype.
  114. *
  115. * The implementation should not access param(). It should check \p exp_f
  116. * and \p exp_i for the exponent value. Exactly one of them would be
  117. * non-null.
  118. *
  119. * Note: \p exp_f and \p exp_i must be dereferenced before dispatching any
  120. * kernel. They are allocated on the caller's stack.
  121. */
  122. virtual void do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  123. const float* exp_f, const int* exp_i) = 0;
  124. };
  125. /*!
  126. * \brief modify a tensor inplace by adding another tensor to it
  127. *
  128. * dst and delta can have arbitrary layout but must have the same shape.
  129. */
  130. class AddUpdateForward: public OperatorBase {
  131. DEF_OPR_PARAM(AddUpdate);
  132. DEF_OPR_IMPL(AddUpdateForward, OperatorBase, -1, 1);
  133. public:
  134. virtual void exec(
  135. _megdnn_tensor_inout dst, _megdnn_tensor_in delta) = 0;
  136. protected:
  137. void check_exec(const TensorLayout &dst, const TensorLayout &delta);
  138. };
  139. using AddUpdate = AddUpdateForward;
  140. class ReduceForward: public OperatorBase {
  141. DEF_OPR_PARAM(Reduce);
  142. DEF_OPR_IMPL(ReduceForward, OperatorBase, 1, 1);
  143. public:
  144. using Mode = Param::Mode;
  145. using DataType = Param::DataType;
  146. /**
  147. * \param[in] src input tensor
  148. * \param[out] dst output tensor
  149. *
  150. * src and dst should be contiguous.
  151. * src and dst should be of the same shape for all dimensions except
  152. * param().axis.
  153. * the param().axis-th dimension shape for dst should be one.
  154. */
  155. virtual void exec(_megdnn_tensor_in src,
  156. _megdnn_tensor_out dst,
  157. _megdnn_workspace workspace) = 0;
  158. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  159. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  160. const TensorLayout &dst) = 0;
  161. protected:
  162. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  163. size_t workspace_in_bytes);
  164. };
  165. using Reduce = ReduceForward;
  166. class CumsumForward: public OperatorBase {
  167. DEF_OPR_PARAM(Cumsum);
  168. DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1);
  169. public:
  170. /**
  171. * \param[in] src input tensor
  172. * \param[out] dst output tensor
  173. *
  174. * src and dst should be contiguous.
  175. * src and dst should have the same shape.
  176. *
  177. * The exclusive flag specifies whether the current element it taken
  178. * into account when calculating results.
  179. *
  180. * The reverse flag specifies whether cumsum is forward (
  181. * from 0 to n) or backward (from n downto 0).
  182. *
  183. * Example:
  184. * exclusive && reverse:
  185. * dst_i = src_{i+1} + src_{i+2} + ... + src_{n-1}
  186. * exclusive && !reverse
  187. * dst_i = src_0 + src_1 + ... + src_{i-1}
  188. * !exclusive && reverse:
  189. * dst_i = src_i + src_{i+1} + ... + src_{n-1}
  190. * !exclusive && !reverse:
  191. * dst_i = src_0 + src_1 + ... + src_i
  192. */
  193. virtual void exec(_megdnn_tensor_in src,
  194. _megdnn_tensor_out dst,
  195. _megdnn_workspace workspace) = 0;
  196. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  197. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  198. const TensorLayout &dst) = 0;
  199. protected:
  200. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  201. size_t workspace_in_bytes);
  202. };
  203. using Cumsum = CumsumForward;
  204. // mxx can be max or min
  205. class ArgmxxBase: public OperatorBase {
  206. DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase);
  207. DEF_OPR_PARAM(Axis);
  208. protected:
  209. void check_layout_fwd(const TensorLayout &src,
  210. const TensorLayout &dst);
  211. };
  212. class ArgmaxForward: public ArgmxxBase {
  213. DEF_OPR_IMPL(ArgmaxForward, ArgmxxBase, 1, 1);
  214. public:
  215. /**
  216. * \param[in] src input tensor
  217. * \param[out] dst output tensor containing the argmax indices
  218. *
  219. * src and dst should be contiguous.
  220. * src and dst should be of the same shape for all dimensions except
  221. * param().axis.
  222. * the param().axis-th dimension shape for dst should be one.
  223. */
  224. virtual void exec(_megdnn_tensor_in src,
  225. _megdnn_tensor_out dst,
  226. _megdnn_workspace workspace) = 0;
  227. void deduce_layout(const TensorLayout &src,
  228. TensorLayout &dst);
  229. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  230. const TensorLayout &dst) = 0;
  231. protected:
  232. void check_exec(const TensorLayout &src,
  233. const TensorLayout &dst,
  234. size_t workspace_in_bytes);
  235. };
  236. using Argmax = ArgmaxForward;
  237. class ArgminForward: public ArgmxxBase {
  238. DEF_OPR_IMPL(ArgminForward, ArgmxxBase, 1, 1);
  239. public:
  240. /**
  241. * \param[in] src input tensor
  242. * \param[out] dst output tensor containing the argmax indices
  243. *
  244. * src and dst should be contiguous.
  245. * src and dst should be of the same shape for all dimensions except
  246. * param().axis.
  247. * the param().axis-th dimension shape for dst should be one.
  248. */
  249. virtual void exec(_megdnn_tensor_in src,
  250. _megdnn_tensor_out dst,
  251. _megdnn_workspace workspace) = 0;
  252. void deduce_layout(const TensorLayout &src,
  253. TensorLayout &dst);
  254. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  255. const TensorLayout &dst) = 0;
  256. protected:
  257. void check_exec(const TensorLayout &src,
  258. const TensorLayout &dst,
  259. size_t workspace_in_bytes);
  260. };
  261. using Argmin = ArgminForward;
  262. /*!
  263. * \brief take values from input according to given condition
  264. *
  265. * Output two tensors:
  266. * 1. values copied from *data*, with same dtype as *data*
  267. * 2. selected indices with dtype int32; note that it is 1-dimensional and
  268. * based on the flatten input.
  269. *
  270. * Require data and mask to have the same shape and both be contiguous.
  271. */
  272. class CondTake : public OperatorBase {
  273. DEF_OPR_IMPL(CondTake, OperatorBase, 2, 2);
  274. DEF_OPR_PARAM(CondTake);
  275. public:
  276. using Output = std::array<TensorND, 2>;
  277. using OutputDType = std::array<DType, 2>;
  278. OutputDType infer_dtype(DType data, DType mask);
  279. virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0;
  280. virtual Output exec(_megdnn_tensor_in data, _megdnn_tensor_in mask,
  281. _megdnn_workspace workspace,
  282. DynOutMallocPolicyCall malloc_policy) = 0;
  283. protected:
  284. //! check input layouts and get flattened size
  285. size_t check_exec_get_size(const TensorLayout& data,
  286. const TensorLayout& mask,
  287. size_t workspace_in_bytes);
  288. };
  289. class TransposeForward: public OperatorBase {
  290. DEF_OPR_IMPL(TransposeForward, OperatorBase, 1, 1);
  291. DEF_OPR_PARAM(Empty);
  292. public:
  293. /**
  294. * \param[in] src (m, n) stride[0] >= n && stride[1] == 1
  295. * \param[out] dst (n, m) stride[0] >= m && stride[1] == 1
  296. */
  297. virtual void exec(_megdnn_tensor_in src,
  298. _megdnn_tensor_out dst,
  299. _megdnn_workspace workspace) = 0;
  300. void deduce_layout(const TensorLayout &src, TensorLayout &dst);
  301. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  302. const TensorLayout &dst) = 0;
  303. protected:
  304. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  305. size_t workspace_in_bytes);
  306. };
  307. using Transpose = TransposeForward;
  308. /**
  309. * Change a tensor to another layout that has the same dtype and total number of
  310. * elements, and non-overlapping stride.
  311. *
  312. * ON CPU:
  313. * This operator is optimized for some cases(e.g. both dst and last dim of src
  314. * are contiguous)
  315. *
  316. * ON CUDA:
  317. * More contiguous the input/output layouts, higher performance. There is also
  318. * special optimization for broadcast case.
  319. */
  320. class RelayoutForward: public OperatorBase {
  321. DEF_OPR_IMPL(RelayoutForward, OperatorBase, 1, 1);
  322. DEF_OPR_PARAM(Empty);
  323. public:
  324. /*!
  325. * \brief execute relayout opr
  326. *
  327. * This operator should be placed on the same computing device of *dst*.
  328. *
  329. * \param src_handle handle of input tensor; for CUDA d2d copy, the
  330. * src handle can be on a different GPU for copy tensor with
  331. * non-contig dims <= 2
  332. */
  333. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  334. Handle *src_handle = nullptr) = 0;
  335. protected:
  336. //! check layout and collapse contiguous
  337. void check_layout_and_canonize(
  338. TensorLayout &src, TensorLayout &dst);
  339. };
  340. using Relayout = RelayoutForward;
  341. /**
  342. * \brief Base class for Concat and Split operators
  343. */
  344. class ConcatSplitBase: public OperatorBase {
  345. public:
  346. using Param = param::Axis;
  347. ConcatSplitBase(Handle *handle);
  348. const Param &param() const { return m_param; }
  349. Param &param() { return m_param; }
  350. protected:
  351. void check_layout_common(const TensorLayoutArray &srcs,
  352. const TensorLayout &dst);
  353. Param m_param;
  354. /**
  355. * \brief a helper function
  356. *
  357. * A = shape[0] * shape[1] * ... * shape[axis-1]
  358. * B = {srcs[0].shape[axis], srcs[1].shape[axis], ...}
  359. * C = shape[axis+1] * shape[axis+2] * ... * shape[ndim-1]
  360. */
  361. void get_ABC(const TensorShapeArray &srcs,
  362. size_t &A,
  363. size_t *B,
  364. size_t &C);
  365. thin_function<TensorLayout(const TensorND &tensor)> m_get_layout;
  366. thin_function<TensorShape(const TensorLayout &layout)> m_get_shape;
  367. };
  368. class ConcatForward: public ConcatSplitBase {
  369. DEF_OPR_IMPL(ConcatForward, ConcatSplitBase, 1, 1);
  370. public:
  371. /**
  372. * \param[in] srcs a vector containing all inputs to be concatenated
  373. * \param[out] dst the output tensor.
  374. *
  375. * All tensors in srcs and dst should be contiguous.
  376. * All tensors should have the same shape for all axes except
  377. * param().axis.
  378. * For the param().axis-th axis, the axis shape for dst should be the
  379. * sum of corresponding axis shapes for all srcs.
  380. */
  381. virtual void exec(_megdnn_in const TensorNDArray &srcs,
  382. _megdnn_tensor_out dst,
  383. _megdnn_workspace workspace) = 0;
  384. void deduce_layout(const TensorLayoutArray &srcs,
  385. TensorLayout &dst);
  386. virtual size_t get_workspace_in_bytes(
  387. const TensorLayoutArray &srcs,
  388. const TensorLayout &dst) = 0;
  389. protected:
  390. void check_exec(const TensorLayoutArray &srcs,
  391. const TensorLayout &dst,
  392. size_t workspace_in_bytes);
  393. };
  394. using Concat = ConcatForward;
  395. class SplitForward: public ConcatSplitBase {
  396. DEF_OPR_IMPL(SplitForward, ConcatSplitBase, 1, 1);
  397. public:
  398. /**
  399. * \param[in] src input tensor
  400. * \param[out] dsts a vector containing all splitted result
  401. *
  402. * All tensors in src and dsts should be contiguous.
  403. * All tensors should have the same shape for all axes except
  404. * param().axis.
  405. * For the param().axis-th axis, the axis shape for src should be the
  406. * sum of corresponding axis shapes for all dsts.
  407. */
  408. virtual void exec(_megdnn_tensor_in src,
  409. const TensorNDArray &dsts,
  410. _megdnn_workspace workspace) = 0;
  411. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  412. const TensorLayoutArray &dsts) = 0;
  413. protected:
  414. void check_exec(const TensorLayout &src,
  415. const TensorLayoutArray &dsts,
  416. size_t workspace_in_bytes);
  417. };
  418. using Split = SplitForward;
  419. /**
  420. * \brief Base class for ParamPackConcat and ParamPackSplit Operators.
  421. *
  422. * ParamPack oprs act like Concat and Split, but they also are optimized for a
  423. * large number of inputs and can handle alignment requirements. Axis is also
  424. * not supported.
  425. *
  426. * The table can be generated by gen_table(). The \p srcs in ParamPackSplit and
  427. * \p dsts in ParamPackConcat must be on CPU, and must remain valid until the
  428. * execution stream is synchronized.
  429. */
  430. class ParamPackConcatSplitBase : public OperatorBase {
  431. protected:
  432. void check_exec(const TensorLayout& concated, const TensorLayout& table,
  433. const TensorLayout& parts);
  434. public:
  435. using Param = megdnn::param::Empty;
  436. ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {}
  437. //! generate table to be used with ParamPackConcat and ParamPackSplit
  438. static std::vector<dt_int32> gen_table(const TensorShapeArray& shapes,
  439. size_t alignment, size_t dtype_size);
  440. };
  441. /**
  442. * \brief ParamPackConcat, used for calculating gradient of ParamPackSplit
  443. * Combine multiple gradient tensors into a single large tensor, use copy
  444. * strategy due to AddUpdate or other dynamic situation.
  445. */
  446. class ParamPackConcat: public ParamPackConcatSplitBase {
  447. DEF_OPR_IMPL(ParamPackConcat, ParamPackConcatSplitBase, 2, 1);
  448. public:
  449. /*
  450. * \param[in] srcs: TensorND on cpu. srcs[i] corresponding to the
  451. * address of i-th Tensor.
  452. * \param[in] table: with size `2 * srcs.nr_total_elems()`.
  453. * table[addr] corresponding to outer_idx,
  454. * table[addr+srcs.nr_total_elems()] corresponding to
  455. * inner_idx of dsts.
  456. * \param[out] dst: output TensorND, live on cpu or gpu
  457. */
  458. virtual void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table,
  459. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  460. virtual size_t get_workspace_in_bytes(const TensorShapeArray& srcs,
  461. const TensorShape& table,
  462. const TensorShape& dst) = 0;
  463. };
  464. /**
  465. * \brief ParamPackSplit, used for network forwarding.
  466. * Split a single large param into several small tensors, use copy stategy
  467. * either.
  468. */
  469. class ParamPackSplit: public ParamPackConcatSplitBase {
  470. DEF_OPR_IMPL(ParamPackSplit, ParamPackConcatSplitBase, 2, 1);
  471. public:
  472. /*
  473. * \param[in] src: input TensorND, live on cpu or gpu
  474. * \param[in] table: with size `2 * srcs.nr_total_elems()`.
  475. * table[addr] corresponding to outer_idx,
  476. * table[addr+srcs.nr_total_elems()] corresponding to
  477. * inner_idx of dsts.
  478. * \param[out] dsts: TensorND on cpu. dsts[i] corresponding to the address
  479. * of i-th Tensor
  480. */
  481. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in table,
  482. _megdnn_tensor_out dsts, _megdnn_workspace workspace) = 0;
  483. virtual size_t get_workspace_in_bytes(const TensorShape& src,
  484. const TensorShape& table,
  485. const TensorShapeArray& dsts) = 0;
  486. };
  487. /**
  488. * \brief base class for Tile and Repeat
  489. */
  490. class TileRepeatBase: public OperatorBase {
  491. public:
  492. TileRepeatBase(Handle *handle): OperatorBase(handle) {}
  493. struct Param {
  494. TensorShape times;
  495. };
  496. Param &param() { return m_param; }
  497. const Param &param() const { return m_param; }
  498. protected:
  499. void check_layout_fwd(const TensorLayout &src,
  500. const TensorLayout &dst);
  501. void deduce_layout_fwd(const TensorLayout &src,
  502. TensorLayout &dst);
  503. /**
  504. * Assuming src/dst/times are already simplified on entrance.
  505. */
  506. size_t get_workspace_in_bytes_fwd(const TensorShape &src,
  507. const TensorShape &dst,
  508. const TensorShape &times,
  509. DType dtype);
  510. Param m_param;
  511. };
  512. class TileBase: public TileRepeatBase {
  513. public:
  514. TileBase(Handle *handle): TileRepeatBase(handle) {}
  515. protected:
  516. void simplify_shape(const TensorShape &src,
  517. const TensorShape &dst,
  518. const TensorShape &times,
  519. TensorShape &src2,
  520. TensorShape &dst2,
  521. TensorShape &times2);
  522. /**
  523. * This is a helper function that would facilitate other backends'
  524. * implementation.
  525. */
  526. size_t get_workspace_in_bytes_fwd(const TensorLayout &src,
  527. const TensorLayout &dst);
  528. };
  529. class TileForward: public TileBase {
  530. DEF_OPR_IMPL(TileForward, TileBase, 1, 1);
  531. public:
  532. /**
  533. * \brief Tile src times to get dst.
  534. * \param[in] src input tensor
  535. * \param[out] dst output tensor
  536. * \param[out] workspace temporary workspace
  537. *
  538. * src and dst must be contiguous.
  539. * dst.shape should be {src.shape[0]*param().times[0],
  540. * src.shape[1]*param().times[1], ...}
  541. *
  542. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html
  543. *
  544. * Difference between Tile and Repeat:
  545. * Tiling `abc' twice yields `abcabc', whereas repeating `abc' twice
  546. * yields `aabbcc'.
  547. */
  548. virtual void exec(_megdnn_tensor_in src,
  549. _megdnn_tensor_out dst,
  550. _megdnn_workspace workspace) = 0;
  551. void deduce_layout(const TensorLayout &src,
  552. TensorLayout &dst);
  553. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  554. const TensorLayout &dst) = 0;
  555. protected:
  556. void check_exec(const TensorLayout &src, const TensorLayout &dst,
  557. size_t workspace_in_bytes);
  558. };
  559. using Tile = TileForward;
  560. class TileBackward: public TileBase {
  561. DEF_OPR_IMPL(TileBackward, TileBase, 1, 1);
  562. public:
  563. /**
  564. * \param[in] diff the backpropagated gradient wrt. dst
  565. * \param[out] grad the backpropagated gradient wrt. src
  566. * \param[out] workspace temporary workspace
  567. */
  568. virtual void exec(_megdnn_tensor_in diff,
  569. _megdnn_tensor_out grad,
  570. _megdnn_workspace workspace) = 0;
  571. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  572. const TensorLayout &grad) = 0;
  573. protected:
  574. void check_exec(const TensorLayout &diff, const TensorLayout &grad,
  575. size_t workspace_in_bytes);
  576. };
  577. class RepeatBase: public TileRepeatBase {
  578. public:
  579. RepeatBase(Handle *handle): TileRepeatBase(handle) {}
  580. protected:
  581. void simplify_shape(const TensorShape &src,
  582. const TensorShape &dst,
  583. const TensorShape &times,
  584. TensorShape &src2,
  585. TensorShape &dst2,
  586. TensorShape &times2);
  587. /**
  588. * This is a helper function that would facilitate other backends'
  589. * implementation.
  590. */
  591. size_t get_workspace_in_bytes_fwd(const TensorLayout &src,
  592. const TensorLayout &dst);
  593. };
  594. class RepeatForward: public RepeatBase {
  595. DEF_OPR_IMPL(RepeatForward, RepeatBase, 1, 1);
  596. public:
  597. /**
  598. * \brief Repeat src times to get dst.
  599. * \param[in] src input tensor
  600. * \param[out] dst output tensor
  601. * \param[out] workspace temporary workspace
  602. *
  603. * src and dst must be contiguous.
  604. * dst.shape should be {src.shape[0]*param().times[0],
  605. * src.shape[1]*param().times[1], ...}
  606. *
  607. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html
  608. * \see TileForward
  609. */
  610. virtual void exec(_megdnn_tensor_in src,
  611. _megdnn_tensor_out dst,
  612. _megdnn_workspace workspace) = 0;
  613. void deduce_layout(const TensorLayout &src,
  614. TensorLayout &dst);
  615. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  616. const TensorLayout &dst) = 0;
  617. protected:
  618. void check_exec(const TensorLayout &src,
  619. const TensorLayout &dst,
  620. size_t workspace_in_bytes);
  621. };
  622. using Repeat = RepeatForward;
  623. class RepeatBackward: public RepeatBase {
  624. DEF_OPR_IMPL(RepeatBackward, RepeatBase, 1, 1);
  625. public:
  626. /**
  627. * \param[in] diff the backpropagated gradient wrt. dst
  628. * \param[out] grad the backpropagated gradient wrt. src
  629. * \param[out] workspace temporary workspace
  630. */
  631. virtual void exec(_megdnn_tensor_in diff,
  632. _megdnn_tensor_out grad,
  633. _megdnn_workspace workspace) = 0;
  634. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  635. const TensorLayout &grad) = 0;
  636. protected:
  637. void check_exec(const TensorLayout &diff,
  638. const TensorLayout &grad,
  639. size_t workspace_in_bytes);
  640. };
  641. class ArgsortForward: public OperatorBase {
  642. DEF_OPR_IMPL(ArgsortForward, OperatorBase, 1, 2);
  643. DEF_OPR_PARAM(Argsort);
  644. public:
  645. using Order = Param::Order;
  646. /**
  647. * \param[in] src (m, n)
  648. * \param[out] dst (m, n)
  649. * \param[out] indices (m, n)
  650. *
  651. * src, dst and indices should be contiguous.
  652. * Performing m independent sorting on m arrays of length n.
  653. * Sorting arrays and storing the resulting array in `dst',
  654. * and the corresponding indices in `indices'.
  655. *
  656. * Indices range from 0 to n-1.
  657. *
  658. * Note that indices is a TensorND of type int.
  659. */
  660. virtual void exec(_megdnn_tensor_in src,
  661. _megdnn_tensor_out dst,
  662. _megdnn_tensor_out indices,
  663. _megdnn_workspace workspace) = 0;
  664. void deduce_layout(const TensorLayout &src,
  665. TensorLayout &dst,
  666. TensorLayout &indices);
  667. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  668. const TensorLayout &dst,
  669. const TensorLayout &indices) = 0;
  670. protected:
  671. void check_exec(const TensorLayout &src,
  672. const TensorLayout &dst,
  673. const TensorLayout &indices,
  674. size_t workspace_in_bytes);
  675. };
  676. using Argsort = ArgsortForward;
  677. /*!
  678. * \brief backward opr for Argsort
  679. *
  680. * Note: the name is kept for backward compatibility. This opr is actually a
  681. * batched value setter. It is used for gradient computing of Argsort and TopK.
  682. */
  683. class ArgsortBackward : public OperatorBase {
  684. DEF_OPR_IMPL(ArgsortBackward, OperatorBase, 2, 1);
  685. DEF_OPR_PARAM(Empty);
  686. public:
  687. /**
  688. * \param[in] diff (m, k) the backpropagated gradient wrt. dst
  689. * \param[in] indices (m, k) the `indices' parameter in
  690. * ArgsortForward::exec
  691. * \param[out] grad (m, n) the backpropagated gradient wrt. src
  692. *
  693. * Constraint: n >= k. Untouched values would be initialized as zero.
  694. */
  695. virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
  696. _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
  697. virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
  698. const TensorLayout& indices,
  699. const TensorLayout& grad) = 0;
  700. protected:
  701. void check_exec(const TensorLayout& diff, const TensorLayout& indices,
  702. const TensorLayout& grad, size_t workspace_in_bytes);
  703. };
  704. class TopK : public OperatorBase {
  705. DEF_OPR_IMPL(TopK, OperatorBase, 1, 2);
  706. DEF_OPR_PARAM(TopK);
  707. protected:
  708. //! impl exec; inputs have been validated
  709. virtual void do_exec(int k, _megdnn_tensor_in data,
  710. _megdnn_tensor_out values, int32_t* indices,
  711. _megdnn_workspace workspace) = 0;
  712. public:
  713. /*!
  714. * \param[in] k if positive, compute the smallest top-k values; otherwise
  715. * compute the largest top-k values
  716. * \param[in] data (m, n) input data, where top-k is computed on the
  717. * second axis. The second dimension must be contiguous, and the first
  718. * dimension can have arbitrary stride.
  719. * \param[out] values (m, ) or (m, k) output values; its shape depends
  720. * on mode
  721. * \param[out] indices () or (m, ) or (m, k) output values; its shape
  722. * depends on mode
  723. */
  724. void exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
  725. _megdnn_tensor_out indices, _megdnn_workspace workspace);
  726. virtual size_t get_workspace_in_bytes(int k, const TensorLayout& data,
  727. const TensorLayout& values,
  728. const TensorLayout& indices) = 0;
  729. void deduce_layout(int k, const TensorLayout& data, TensorLayout& values,
  730. TensorLayout& indices);
  731. };
  732. /*!
  733. * \brief convert dtype of *src* to match dtype of *dst*; *src* may have
  734. * arbitrary layout and *dst* must be contiguous.
  735. */
  736. class TypeCvtForward: public OperatorBase {
  737. DEF_OPR_PARAM(Empty);
  738. DEF_OPR_IMPL(TypeCvtForward, OperatorBase, 1, 1);
  739. public:
  740. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) = 0;
  741. protected:
  742. void check_exec(const TensorLayout &src, const TensorLayout &dst);
  743. };
  744. using TypeCvt = TypeCvtForward;
  745. class IndexingRemapBase: public OperatorBase {
  746. public:
  747. using Param = param::IndexingRemap;
  748. IndexingRemapBase(Handle *handle): OperatorBase(handle) {}
  749. Param &param() { return m_param; }
  750. const Param &param() const { return m_param; }
  751. protected:
  752. Param m_param;
  753. void check_layout_fwd(const TensorLayout &src,
  754. const TensorLayout &map,
  755. const TensorLayout &dst);
  756. };
  757. class IndexingRemapForward: public IndexingRemapBase {
  758. DEF_OPR_IMPL(IndexingRemapForward, IndexingRemapBase, 2, 1);
  759. public:
  760. /**
  761. * \param[in] src input tensor
  762. * \param[in] map input map
  763. * \param[out] dst output tensor
  764. *
  765. * Suppose:
  766. * the shape of src is \f$(s_0, s_1, ..., s_{m-1}\f$;
  767. * the shape of dst is \f$(d_0, d_1, ..., d_{n-1})\f$;
  768. * then:
  769. * the shape of map must be \f$(d_0, d_1, ..., d_{n-1}, m)\f$.
  770. *
  771. * The last dimension of map indicates the src indices for the
  772. * corresponding dst entry.
  773. *
  774. * src and dst can be non-contiguous in a non-overlapping manner.
  775. */
  776. virtual void exec(_megdnn_tensor_in src,
  777. _megdnn_tensor_in map,
  778. _megdnn_tensor_out dst,
  779. _megdnn_workspace workspace) = 0;
  780. void deduce_layout(const TensorLayout &src,
  781. const TensorLayout &map,
  782. TensorLayout &dst);
  783. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  784. const TensorLayout &map,
  785. const TensorLayout &dst) = 0;
  786. protected:
  787. void check_exec(const TensorLayout &src,
  788. const TensorLayout &map,
  789. const TensorLayout &dst,
  790. size_t workspace_in_bytes);
  791. };
  792. using IndexingRemap = IndexingRemapForward;
  793. // The using directives preserve backward compatibility.
  794. using TensorRemapForward = IndexingRemap;
  795. using TensorRemap = TensorRemapForward;
  796. class IndexingRemapBackward: public IndexingRemapBase {
  797. DEF_OPR_IMPL(IndexingRemapBackward, IndexingRemapBase, 2, 1);
  798. public:
  799. /**
  800. * \param[in] diff the backpropagated gradient wrt. dst
  801. * \param[in] map the `map' parameter in IndexingRemapForward::exec
  802. * \param[out] grad the backpropagated gradient wrt. src
  803. */
  804. virtual void exec(_megdnn_tensor_in diff,
  805. _megdnn_tensor_in map,
  806. _megdnn_tensor_out grad,
  807. _megdnn_workspace workspace) = 0;
  808. virtual size_t get_workspace_in_bytes(const TensorLayout &diff,
  809. const TensorLayout &map,
  810. const TensorLayout &grad) = 0;
  811. protected:
  812. void check_exec(const TensorLayout &diff,
  813. const TensorLayout &map,
  814. const TensorLayout &grad,
  815. size_t workspace_in_bytes);
  816. };
  817. // The using directives preserve backward compatibility.
  818. using TensorRemapBackward = IndexingRemapBackward;
  819. class Linspace: public OperatorBase {
  820. DEF_OPR_IMPL(Linspace, OperatorBase, 0, 1);
  821. DEF_OPR_PARAM(LinspaceFull);
  822. public:
  823. /**
  824. * \param[out] dst must be 1d.
  825. *
  826. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html
  827. */
  828. virtual void exec(_megdnn_tensor_out dst,
  829. _megdnn_workspace workspace) = 0;
  830. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  831. protected:
  832. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  833. };
  834. class Eye: public OperatorBase {
  835. DEF_OPR_IMPL(Eye, OperatorBase, 0, 1);
  836. DEF_OPR_PARAM(Eye);
  837. public:
  838. /**
  839. * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.eye.html
  840. */
  841. virtual void exec(_megdnn_tensor_out dst,
  842. _megdnn_workspace workspace) = 0;
  843. virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0;
  844. protected:
  845. void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
  846. };
  847. class IndexingOneHotBase: public OperatorBase {
  848. DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase);
  849. DEF_OPR_PARAM(Axis);
  850. protected:
  851. void deduce_layout_fwd(const TensorLayout &src,
  852. const TensorLayout &index,
  853. TensorLayout &dst);
  854. void check_layout_fwd(const TensorLayout &src,
  855. const TensorLayout &index,
  856. const TensorLayout &dst);
  857. };
  858. /*!
  859. * \brief Indexing for one-hot encoding
  860. *
  861. * Given src, axis and index,
  862. * for all valid (n-1)-dimensional subscript tuples i iterating through index:
  863. * dst[i[0], ..., i[axis-1], 0, i[axis], ..., i[n-2]] =
  864. * inp[i[0], ..., i[axis-1], index[i], i[axis], ..., i[n-2]]
  865. *
  866. * \param[in] src n-dimensional input data
  867. * \param[in] index (n-1)-dimensional index, must be int
  868. * \param[out] dst n-dimensional output data
  869. */
  870. class IndexingOneHotForward: public IndexingOneHotBase {
  871. DEF_OPR_IMPL(IndexingOneHotForward, IndexingOneHotBase, 2, 1);
  872. public:
  873. void deduce_layout(const TensorLayout &src,
  874. const TensorLayout &index, TensorLayout &dst) {
  875. deduce_layout_fwd(src, index, dst);
  876. }
  877. virtual void exec(_megdnn_tensor_in src,
  878. _megdnn_tensor_in index,
  879. _megdnn_tensor_out dst,
  880. _megdnn_workspace workspace) = 0;
  881. virtual size_t get_workspace_in_bytes(const TensorLayout &src,
  882. const TensorLayout &index,
  883. const TensorLayout &dst) = 0;
  884. protected:
  885. void check_exec(const TensorLayout &src,
  886. const TensorLayout &index, const TensorLayout &dst,
  887. size_t workspace_in_bytes);
  888. };
  889. using IndexingOneHot = IndexingOneHotForward;
  890. /*!
  891. * \brief set-subtensor corresponding to IndexingOneHotForward
  892. *
  893. * \param[in,out] data n-dimensional input and output data, whose sub part
  894. * corresponding to *index* would be replaced by *sub*
  895. * \param[in] index (n-1)-dimensional index, must be int
  896. * \param[in] sub n-dimensional sub tensor to be filled in *data*
  897. */
  898. class IndexingSetOneHotForward: public IndexingOneHotBase {
  899. DEF_OPR_IMPL(IndexingSetOneHotForward, IndexingOneHotBase, -1, 1);
  900. public:
  901. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in index,
  902. _megdnn_tensor_in sub, _megdnn_workspace workspace) = 0;
  903. virtual size_t get_workspace_in_bytes(const TensorLayout &data,
  904. const TensorLayout &index,
  905. const TensorLayout &sub) = 0;
  906. protected:
  907. void check_exec(const TensorLayout &data,
  908. const TensorLayout &index, const TensorLayout &sub,
  909. size_t workspace_in_bytes);
  910. };
  911. using IndexingSetOneHot = IndexingSetOneHotForward;
  912. /*!
  913. * \brief base class for indexing on multiple axes using vector indices
  914. *
  915. * Note that the indexing axes are required to be sorted in ascending order
  916. */
  917. class IndexingMultiAxisVecBase: public OperatorBase {
  918. DEF_OPR_IMPL_CTOR(IndexingMultiAxisVecBase, OperatorBase);
  919. DEF_OPR_PARAM(Empty);
  920. public:
  921. struct AxisIndexer {
  922. size_t axis;
  923. TensorND vec;
  924. };
  925. struct AxisIndexerLayoutOnly {
  926. size_t axis;
  927. TensorLayout layout;
  928. };
  929. using IndexDesc = std::vector<AxisIndexer>;
  930. using IndexDescLayoutOnly = std::vector<AxisIndexerLayoutOnly>;
  931. /*!
  932. * \brief convert IndexDesc to IndexDescLayoutOnly
  933. */
  934. static IndexDescLayoutOnly extract_index_layout(const IndexDesc &index);
  935. /*!
  936. * \brief get the axes on src that are not used in index
  937. * \param[out] out output buffer; suggested size is
  938. * TensorLayout::MAX_NDIM
  939. * \return number of elements written to *out*
  940. */
  941. static size_t get_nonindex_axes(size_t src_ndim, const IndexDesc &index,
  942. size_t *out);
  943. /*!
  944. * \brief get contiguous-collapsed layout for indexing on value
  945. * \param idx_axis indexer axis on value (i.e. ExecInfo::idx_axis)
  946. * \return a tensor layout and an axis to iterate over *value* and also
  947. * access *data*; stride of layout on that axis would be zero, and
  948. * strides on other axes correspond to the strides in *data*
  949. */
  950. static std::pair<TensorLayout, size_t> get_value_iter_optimized_layout(
  951. const TensorLayout &data, const TensorLayout &value,
  952. const IndexDesc &index, size_t idx_axis);
  953. //! helper info for kernel implementation
  954. struct ExecInfo {
  955. //! axis in value used by indexer
  956. size_t idx_axis;
  957. ptrdiff_t value_stride;
  958. void* error_tracker;
  959. megcore::AsyncErrorInfo* error_info;
  960. };
  961. protected:
  962. /*!
  963. * \return axis on dst used by indexer (i.e. ExecInfo::idx_axis)
  964. */
  965. static size_t deduce_layout_fwd(
  966. const TensorLayout &data,
  967. const IndexDescLayoutOnly &index,
  968. TensorLayout &dst);
  969. static ExecInfo check_exec_noworkspace(
  970. const TensorLayout &data, const TensorLayout &value,
  971. const IndexDesc &index, IndexDescLayoutOnly &index_layout);
  972. };
  973. /*!
  974. * \brief compute indexing result, like numpy advanced indexing
  975. *
  976. * src can have arbitrary layout, but dst must be dim1-contig
  977. */
  978. class IndexingMultiAxisVec: public IndexingMultiAxisVecBase {
  979. DEF_OPR_IMPL(IndexingMultiAxisVec, IndexingMultiAxisVecBase, 0, 1);
  980. public:
  981. virtual void exec(_megdnn_tensor_in src,
  982. const IndexDesc &index,
  983. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  984. /*!
  985. * \brief get workspace size based on output shape and indexing axes
  986. */
  987. size_t get_workspace_in_bytes(
  988. const TensorShape &dst,
  989. const size_t *axes, size_t nr_axes);
  990. static void deduce_layout(
  991. const TensorLayout &data,
  992. const IndexDescLayoutOnly &index,
  993. TensorLayout &dst) {
  994. deduce_layout_fwd(data, index, dst);
  995. }
  996. protected:
  997. virtual size_t get_workspace_in_bytes(size_t dst_idx_size) = 0;
  998. ExecInfo check_exec(
  999. const TensorLayout &src,
  1000. const IndexDesc &index,
  1001. const TensorLayout &dst,
  1002. size_t workspace_in_bytes);
  1003. };
  1004. /*!
  1005. * \brief base class for modifying data by given index
  1006. *
  1007. * data can have arbitrary layout, but value must be dim1-contig
  1008. */
  1009. class IndexingModifyMultiAxisVecBase: public IndexingMultiAxisVecBase {
  1010. DEF_OPR_IMPL_CTOR(IndexingModifyMultiAxisVecBase, IndexingMultiAxisVecBase);
  1011. public:
  1012. virtual void exec(
  1013. _megdnn_tensor_inout data, _megdnn_tensor_in value,
  1014. const IndexDesc &index,
  1015. _megdnn_workspace workspace) = 0;
  1016. /*!
  1017. * \brief get workspace size based on shape of value input and indexing
  1018. * axes
  1019. */
  1020. size_t get_workspace_in_bytes(
  1021. const TensorShape &value,
  1022. const size_t *axes, size_t nr_axes);
  1023. protected:
  1024. ExecInfo check_exec(
  1025. const TensorLayout &data, const TensorLayout &value,
  1026. const IndexDesc &index,
  1027. size_t workspace_in_bytes);
  1028. virtual size_t get_workspace_in_bytes(size_t value_idx_size) = 0;
  1029. };
  1030. //! set value to indexed locations; index values must be non-overlapping
  1031. class IndexingSetMultiAxisVec: public IndexingModifyMultiAxisVecBase {
  1032. DEF_OPR_IMPL(IndexingSetMultiAxisVec,
  1033. IndexingModifyMultiAxisVecBase, 0, 0);
  1034. };
  1035. //! add value to indexed locations; index values must be non-overlapping
  1036. class IndexingIncrMultiAxisVec: public IndexingModifyMultiAxisVecBase {
  1037. DEF_OPR_IMPL(IndexingIncrMultiAxisVec,
  1038. IndexingModifyMultiAxisVecBase, 0, 0);
  1039. };
  1040. class MeshBase : public OperatorBase {
  1041. DEF_OPR_PARAM(Empty);
  1042. DEF_OPR_IMPL_CTOR(MeshBase, OperatorBase);
  1043. public:
  1044. using AxisIndexer = IndexingMultiAxisVecBase::AxisIndexer;
  1045. using IndexDesc = IndexingMultiAxisVecBase::IndexDesc;
  1046. using AxisIndexerLayoutOnly =
  1047. IndexingMultiAxisVecBase::AxisIndexerLayoutOnly;
  1048. using IndexDescLayoutOnly = IndexingMultiAxisVecBase::IndexDescLayoutOnly;
  1049. size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t) {
  1050. return 0;
  1051. }
  1052. protected:
  1053. virtual void check_exec(const TensorLayout& origin,
  1054. const TensorLayout& indexed, const IndexDesc& desc);
  1055. };
  1056. class NormalMeshBase : public MeshBase {
  1057. DEF_OPR_IMPL(NormalMeshBase, MeshBase, 0, 0);
  1058. protected:
  1059. virtual void check_exec(const TensorLayout& origin,
  1060. const TensorLayout& indexed,
  1061. const IndexDesc& desc) override final;
  1062. };
  1063. class NormalMeshModifyBase : public NormalMeshBase {
  1064. DEF_OPR_IMPL_CTOR(NormalMeshModifyBase, NormalMeshBase);
  1065. public:
  1066. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value,
  1067. const IndexDesc& desc, _megdnn_workspace workspace) = 0;
  1068. };
  1069. class BatchedMeshBase : public MeshBase {
  1070. DEF_OPR_IMPL_CTOR(BatchedMeshBase, MeshBase);
  1071. protected:
  1072. virtual void check_exec(const TensorLayout& origin,
  1073. const TensorLayout& indexed,
  1074. const IndexDesc& desc) override final;
  1075. };
  1076. class BatchedMeshModifyBase : public BatchedMeshBase {
  1077. DEF_OPR_IMPL_CTOR(BatchedMeshModifyBase, BatchedMeshBase);
  1078. public:
  1079. virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value,
  1080. const IndexDesc& desc, _megdnn_workspace workspace) = 0;
  1081. };
  1082. class MeshIndexing : public NormalMeshBase {
  1083. DEF_OPR_IMPL(MeshIndexing, NormalMeshBase, 0, 0);
  1084. public:
  1085. virtual void exec(_megdnn_tensor_in src, const IndexDesc& desc,
  1086. _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
  1087. static void deduce_layout(const TensorLayout& inp,
  1088. const IndexDescLayoutOnly& layouts,
  1089. TensorLayout& out_layout);
  1090. };
  1091. class IncrMeshIndexing : public NormalMeshModifyBase {
  1092. DEF_OPR_IMPL(IncrMeshIndexing, NormalMeshModifyBase, 0, 0);
  1093. };
  1094. class SetMeshIndexing : public NormalMeshModifyBase {
  1095. DEF_OPR_IMPL(SetMeshIndexing, NormalMeshModifyBase, 0, 0);
  1096. };
  1097. class BatchedMeshIndexing : public BatchedMeshBase {
  1098. DEF_OPR_IMPL(BatchedMeshIndexing, BatchedMeshBase, 0, 0);
  1099. public:
  1100. virtual void exec(_megdnn_tensor_in src, const IndexDesc& desc,
  1101. _megdnn_tensor_out dst,
  1102. _megdnn_workspace workspace) = 0;
  1103. static void deduce_layout(const TensorLayout& inp,
  1104. const IndexDescLayoutOnly& layouts,
  1105. TensorLayout& out_layout);
  1106. };
  1107. class BatchedIncrMeshIndexing : public BatchedMeshModifyBase {
  1108. DEF_OPR_IMPL(BatchedIncrMeshIndexing, BatchedMeshModifyBase, 0, 0);
  1109. };
  1110. class BatchedSetMeshIndexing : public BatchedMeshModifyBase {
  1111. DEF_OPR_IMPL(BatchedSetMeshIndexing, BatchedMeshModifyBase, 0, 0);
  1112. };
  1113. class RelayoutFormat : public OperatorBase {
  1114. DEF_OPR_PARAM(RelayoutFormat);
  1115. DEF_OPR_IMPL(RelayoutFormat, OperatorBase, 1, 1);
  1116. public:
  1117. virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  1118. _megdnn_workspace workspace) = 0;
  1119. void deduce_layout(const TensorLayout& src, TensorLayout& dst);
  1120. void deduce_format(TensorFormat src, TensorFormat& dst);
  1121. virtual size_t get_workspace_in_bytes(const TensorLayout& src,
  1122. const TensorLayout& dst) = 0;
  1123. protected:
  1124. void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
  1125. void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
  1126. void check_exec(const TensorLayout& src, const TensorLayout& dst,
  1127. size_t workspace_in_bytes);
  1128. void deduce_exec_layout(const TensorLayout& src, const TensorLayout& dst,
  1129. TensorLayout& exec_src, TensorLayout& exec_dst);
  1130. };
  1131. } // namespace megdnn
  1132. #include "megdnn/internal/opr_header_epilogue.h"
  1133. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)