You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.td 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. #ifndef MGB_OPS
  2. #define MGB_OPS
  3. include "base.td"
  4. include "param_defs.td"
  5. include "mlir/Interfaces/SideEffectInterfaces.td"
  6. def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
  7. let inputs = (ins Variadic<AnyType>:$input);
  8. let results = (outs AnyType);
  9. let nameFunction = [{
  10. return to_string($_self.mode);
  11. }];
  12. }
  13. def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>{
  14. let extraArguments = (ins
  15. MgbDefaultValuedAttr<MgbBoolAttr, "true">:$keepdim
  16. );
  17. }
  18. def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
  19. let inputs = (ins AnyType:$inputs);
  20. let extraArguments = (ins
  21. TypeAttr:$idtype,
  22. MgbDTypeAttr:$dtype
  23. );
  24. let results = (outs AnyType);
  25. }
  26. def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>;
  27. def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> {
  28. let extraArguments = (ins
  29. MgbUI32Attr:$dimA,
  30. MgbUI32Attr:$dimB
  31. );
  32. }
  33. def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> {
  34. let extraArguments = (ins
  35. MgbUI32Attr:$dimA,
  36. MgbUI32Attr:$dimB
  37. );
  38. }
  39. def Dot: MgbHashableOp<"Dot", [EmptyParam]>;
  40. def SVD: MgbHashableOp<"SVD", [SVDParam]>;
  41. def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;
  42. def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]> {
  43. let extraArguments = (ins
  44. MgbDTypeAttr:$dtype
  45. );
  46. }
  47. def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;
  48. def Convolution3DBackwardData: MgbHashableOp<"Convolution3DBackwardData", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;
  49. def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;
  50. def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>;
  51. def Pooling: MgbHashableOp<"Pooling", [PoolingParam, ExecutionPolicyParamBase<"policy">]>;
  52. def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]> {
  53. let extraArguments = (ins
  54. MgbArrayAttr<MgbI32Attr>:$shape
  55. );
  56. }
  57. def ROIPooling: MgbHashableOp<"ROIPooling", [ROIPoolingParam]>;
  58. def DeformablePSROIPooling : MgbHashableOp<"DeformablePSROIPooling", [DeformablePSROIPoolingParam]>;
  59. def ConvBias : MgbHashableOp<"ConvBias", [ConvBiasParam, ExecutionPolicyParamBase<"policy">]> {
  60. let extraArguments = (ins
  61. MgbDTypeAttr:$dtype
  62. );
  63. }
  64. def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, ExecutionPolicyParamBase<"policy">]> {
  65. let extraArguments = (ins
  66. MgbDTypeAttr:$dtype
  67. );
  68. }
  69. def Images2Neibs : MgbHashableOp<"Images2Neibs", [Images2NeibsParam]>;
  70. def SlidingWindowTranspose : MgbHashableOp<"SlidingWindowTranspose", [SlidingWindowTransposeParam]>;
  71. def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>;
  72. def BatchNormBackward : MgbHashableOp<"BatchNormBackward", [BNParam]>;
  73. def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>;
  74. def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>;
  75. def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>;
  76. def WarpAffine: MgbHashableOp<"WarpAffine", [WarpAffineParam]>;
  77. def Remap: MgbHashableOp<"Remap", [RemapParam]>;
  78. def Resize: MgbHashableOp<"Resize", [ResizeParam]>;
  79. def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]> {
  80. let extraArguments = (ins
  81. MgbI32Attr:$ndim
  82. );
  83. }
  84. def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]> {
  85. let extraArguments = (ins
  86. MgbI32Attr:$ndim
  87. );
  88. }
  89. def Copy: MgbHashableOp<"Copy"> {
  90. let extraArguments = (ins
  91. MgbCompNodeAttr:$comp_node
  92. );
  93. }
  94. def Borrow: MgbHashableOp<"Borrow"> {
  95. let extraArguments = (ins
  96. MgbCompNodeAttr:$comp_node
  97. );
  98. }
  99. def Barrier: MgbHashableOp<"Barrier"> {
  100. let extraArguments = (ins
  101. MgbCompNodeAttr:$comp_node,
  102. MgbUI32Attr:$nr_outputs
  103. );
  104. }
  105. def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>;
  106. def Argmax : MgbHashableOp<"Argmax", [AxisParam]>;
  107. def Argmin : MgbHashableOp<"Argmin", [AxisParam]>;
  108. def CondTake : MgbHashableOp<"CondTake">;
  109. def TopK: MgbHashableOp<"TopK", [TopKParam]>;
  110. def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;
  111. def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> {
  112. let extraArguments = (ins
  113. MgbSizeTAddr:$handle
  114. );
  115. let hashFunction = [{
  116. return mgb::hash_pair_combine(
  117. mgb::hash($_self.dyn_typeinfo()),
  118. mgb::hash_pair_combine(
  119. mgb::hash($_self.handle),
  120. mgb::hash($_self.dtype.enumv())
  121. )
  122. );
  123. }];
  124. let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}];
  125. }
  126. def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> {
  127. let extraArguments = (ins
  128. MgbSizeTAddr:$handle
  129. );
  130. let hashFunction = [{
  131. return mgb::hash_pair_combine(
  132. mgb::hash($_self.dyn_typeinfo()),
  133. mgb::hash_pair_combine(
  134. mgb::hash($_self.handle),
  135. mgb::hash_pair_combine(
  136. mgb::hash($_self.mean),
  137. mgb::hash_pair_combine(
  138. mgb::hash($_self.std),
  139. mgb::hash($_self.dtype.enumv())
  140. )
  141. )
  142. )
  143. );
  144. }];
  145. let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std && $0.dtype == $1.dtype;}];
  146. }
  147. def GammaRNG: MgbHashableOp<"GammaRNG", [GammaRNGParam]> {
  148. let extraArguments = (ins
  149. MgbSizeTAddr:$handle
  150. );
  151. let hashFunction = [{
  152. return mgb::hash_pair_combine(
  153. mgb::hash($_self.dyn_typeinfo()),
  154. mgb::hash($_self.handle)
  155. );
  156. }];
  157. let cmpFunction = [{return $0.handle == $1.handle;}];
  158. }
  159. def PoissonRNG: MgbHashableOp<"PoissonRNG", [PoissonRNGParam]> {
  160. let extraArguments = (ins
  161. MgbSizeTAddr:$handle
  162. );
  163. let hashFunction = [{
  164. return mgb::hash_pair_combine(
  165. mgb::hash($_self.dyn_typeinfo()),
  166. mgb::hash($_self.handle)
  167. );
  168. }];
  169. let cmpFunction = [{return $0.handle == $1.handle;}];
  170. }
  171. def BetaRNG: MgbHashableOp<"BetaRNG", [BetaRNGParam]> {
  172. let extraArguments = (ins
  173. MgbSizeTAddr:$handle
  174. );
  175. let hashFunction = [{
  176. return mgb::hash_pair_combine(
  177. mgb::hash($_self.dyn_typeinfo()),
  178. mgb::hash($_self.handle)
  179. );
  180. }];
  181. let cmpFunction = [{return $0.handle == $1.handle;}];
  182. }
  183. def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> {
  184. let extraArguments = (ins
  185. MgbSizeTAddr:$handle
  186. );
  187. let hashFunction = [{
  188. return mgb::hash_pair_combine(
  189. mgb::hash($_self.dyn_typeinfo()),
  190. mgb::hash_pair_combine(
  191. mgb::hash($_self.handle),
  192. mgb::hash($_self.dtype.enumv())
  193. )
  194. );
  195. }];
  196. let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}];
  197. }
  198. def ShuffleRNG: MgbHashableOp<"ShuffleRNG", [ShuffleRNGParam]> {
  199. let extraArguments = (ins
  200. MgbSizeTAddr:$handle
  201. );
  202. let hashFunction = [{
  203. return mgb::hash_pair_combine(
  204. mgb::hash($_self.dyn_typeinfo()),
  205. mgb::hash($_self.handle)
  206. );
  207. }];
  208. let cmpFunction = [{return $0.handle == $1.handle;}];
  209. }
  210. def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
  211. let extraArguments = (ins
  212. MgbCompNodeAttr:$comp_node
  213. );
  214. }
  215. def Eye: MgbHashableOp<"Eye", [EyeParam]> {
  216. let extraArguments = (ins
  217. MgbCompNodeAttr:$comp_node
  218. );
  219. }
  220. def Diag: MgbHashableOp<"Diag", [DiagParam]>;
  221. def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>;
  222. def Concat: MgbHashableOp<"Concat", [AxisParam]> {
  223. let extraArguments = (ins
  224. MgbCompNodeAttr:$comp_node
  225. );
  226. }
  227. def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]> {
  228. let extraArguments = (ins
  229. MgbArrayAttr<MgbI32Attr>:$shape
  230. );
  231. }
  232. def Identity: MgbHashableOp<"Identity">;
  233. def CollectiveComm : MgbHashableOp<"CollectiveComm", [CollectiveCommParam]> {
  234. let extraArguments = (ins
  235. MgbStringAttr:$key,
  236. MgbUI32Attr:$nr_devices,
  237. MgbUI32Attr:$rank,
  238. MgbBoolAttr:$is_root,
  239. MgbBoolAttr:$local_grad,
  240. MgbStringAttr:$addr,
  241. MgbUI32Attr:$port,
  242. MgbDTypeAttr:$dtype,
  243. MgbStringAttr:$backend,
  244. MgbStringAttr:$comp_node
  245. );
  246. }
  247. def RemoteSend : MgbHashableOp<"RemoteSend"> {
  248. let extraArguments = (ins
  249. MgbStringAttr:$key,
  250. MgbStringAttr:$addr,
  251. MgbUI32Attr:$port,
  252. MgbUI32Attr:$rank_to,
  253. MgbStringAttr:$backend
  254. );
  255. }
  256. def RemoteRecv : MgbHashableOp<"RemoteRecv"> {
  257. let extraArguments = (ins
  258. MgbStringAttr:$key,
  259. MgbStringAttr:$addr,
  260. MgbUI32Attr:$port,
  261. MgbUI32Attr:$rank_from,
  262. MgbCompNodeAttr:$cn,
  263. MgbArrayAttr<MgbI32Attr>:$shape,
  264. MgbDTypeAttr:$dtype,
  265. MgbStringAttr:$backend
  266. );
  267. }
  268. def NMSKeep : MgbHashableOp<"NMSKeep"> {
  269. let extraArguments = (ins
  270. MgbF32Attr:$iou_thresh,
  271. MgbUI32Attr:$max_output
  272. );
  273. }
  274. def ParamPackSplit : MgbHashableOp<"ParamPackSplit"> {
  275. let extraArguments = (ins
  276. MgbArrayAttr<MgbI32Attr>:$offsets,
  277. MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$shapes
  278. );
  279. }
  280. def ParamPackConcat : MgbHashableOp<"ParamPackConcat"> {
  281. let extraArguments = (ins
  282. MgbArrayAttr<MgbI32Attr>:$offsets
  283. );
  284. }
  285. def Dimshuffle: MgbHashableOp<"Dimshuffle"> {
  286. let inputs = (ins AnyMemRef:$input);
  287. let extraArguments = (ins MgbArrayAttr<MgbI32Attr>:$pattern);
  288. let results = (outs AnyMemRef);
  289. }
  290. def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]> {
  291. let extraArguments = (ins
  292. MgbArrayAttr<MgbI32Attr>:$shape
  293. );
  294. }
  295. // TODO: merge Add/Remove Axis into AxisAddRemove as megbrain?
  296. def AddAxis: MgbHashableOp<"AddAxis"> {
  297. let extraArguments = (ins
  298. MgbArrayAttr<MgbI32Attr>:$axis
  299. );
  300. }
  301. def RemoveAxis: MgbHashableOp<"RemoveAxis"> {
  302. let extraArguments = (ins
  303. MgbArrayAttr<MgbI32Attr>:$axis
  304. );
  305. }
  306. class FancyIndexingBase<string name>: MgbHashableOp<name> {
  307. let extraArguments = (ins
  308. MgbArrayAttr<MgbTupleAttr<
  309. [MgbI8Attr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr]>>:$items
  310. );
  311. }
  312. def Subtensor: FancyIndexingBase<"Subtensor">;
  313. def SetSubtensor: FancyIndexingBase<"SetSubtensor">;
  314. def IncrSubtensor: FancyIndexingBase<"IncrSubtensor">;
  315. def IndexingMultiAxisVec: FancyIndexingBase<"IndexingMultiAxisVec">;
  316. def IndexingSetMultiAxisVec: FancyIndexingBase<"IndexingSetMultiAxisVec">;
  317. def IndexingIncrMultiAxisVec: FancyIndexingBase<"IndexingIncrMultiAxisVec">;
  318. def MeshIndexing: FancyIndexingBase<"MeshIndexing">;
  319. def IncrMeshIndexing: FancyIndexingBase<"IncrMeshIndexing">;
  320. def SetMeshIndexing: FancyIndexingBase<"SetMeshIndexing">;
  321. def BatchedMeshIndexing: FancyIndexingBase<"BatchedMeshIndexing">;
  322. def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">;
  323. def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">;
  324. def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>;
  325. def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>;
  326. def TQT: MgbHashableOp<"TQT", [TQTParam]>;
  327. def LSQ: MgbHashableOp<"LSQ", [LSQParam]>;
  328. def Softmax: MgbHashableOp<"Softmax", [SoftmaxParam]>;
  329. def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> {
  330. let extraArguments = (ins
  331. MgbDTypeAttr:$dtype
  332. );
  333. let nameFunction = [{
  334. return to_string($_self.mode);
  335. }];
  336. }
  337. def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;
  338. def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> {
  339. let extraArguments = (ins
  340. MgbStringAttr:$buf,
  341. MgbSizeTAddr:$buf_size
  342. );
  343. }
  344. def AtlasRuntime: MgbHashableOp<"AtlasRuntime"> {
  345. let extraArguments = (ins
  346. MgbStringAttr:$buf,
  347. MgbSizeTAddr:$buf_size
  348. );
  349. }
  350. def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> {
  351. let extraArguments = (ins
  352. MgbStringAttr:$buf,
  353. MgbSizeTAddr:$buf_size,
  354. MgbStringAttr:$symbol,
  355. MgbBoolAttr:$tensor_dim_mutable
  356. );
  357. }
  358. def MagicMindRuntime: MgbHashableOp<"MagicMindRuntime"> {
  359. let extraArguments = (ins
  360. MgbStringAttr:$buf,
  361. MgbSizeTAddr:$buf_size
  362. );
  363. }
  364. def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;
  365. def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [CheckNonFiniteParam]>;
  366. def FastpathCopy: MgbHashableOp<"FastpathCopy">;
  367. def PixelShuffle: MgbHashableOp<"PixelShuffle"> {
  368. let extraArguments = (ins
  369. MgbI32Attr:$factor
  370. );
  371. }
  372. def PixelShuffleBackward: MgbHashableOp<"PixelShuffleBackward"> {
  373. let extraArguments = (ins
  374. MgbI32Attr:$factor
  375. );
  376. }
  377. def ExternOpr: MgbHashableOp<"ExternOpr"> {
  378. let extraArguments = (ins
  379. MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$output_shapes,
  380. MgbStringAttr:$name,
  381. MgbStringAttr:$data,
  382. MgbSizeTAddr:$data_len,
  383. MgbArrayAttr<MgbDTypeAttr>:$output_dtypes
  384. );
  385. let hashFunction = [{
  386. return mgb::hash_pair_combine(
  387. mgb::hash($_self.dyn_typeinfo()),
  388. mgb::hash_pair_combine(
  389. mgb::hash($_self.name),
  390. mgb::hash($_self.data))
  391. );
  392. }];
  393. }
  394. def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>;
  395. def Split: MgbHashableOp<"Split", [EmptyParam]> {
  396. let extraArguments = (ins
  397. MgbI32Attr:$axis,
  398. MgbI32Attr:$nsections
  399. );
  400. }
  401. def Padding: MgbHashableOp<"Padding", [PaddingParam]>;
  402. def LRN: MgbHashableOp<"LRN", [LRNParam]>;
  403. def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>;
  404. def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>;
  405. def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>;
  406. def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>;
  407. def RNN: MgbHashableOp<"RNN", [RNNParam]>;
  408. def LSTM: MgbHashableOp<"LSTM", [LSTMParam]>;
  409. def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> {
  410. let extraArguments = (ins
  411. MgbSizeTAddr:$handle
  412. );
  413. let hashFunction = [{
  414. return mgb::hash_pair_combine(
  415. mgb::hash($_self.dyn_typeinfo()),
  416. mgb::hash_pair_combine(
  417. mgb::hash($_self.drop_prob),
  418. mgb::hash($_self.handle))
  419. );
  420. }];
  421. let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}];
  422. }
  423. #endif // MGB_OPS