You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_thor_ops.py 24 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """thor_ops"""
  16. import math
  17. from ..primitive import prim_attr_register, PrimitiveWithInfer
  18. from ...common import dtype as mstype
  19. from ..._checkparam import Validator as validator
  20. from ..._checkparam import Rel
  21. __all__ = ["CusBatchMatMul",
  22. "CusCholeskyTrsm",
  23. "CusFusedAbsMax1",
  24. "CusImg2Col",
  25. "CusMatMulCubeDenseLeft",
  26. "CusMatMulCubeFraczRightMul",
  27. "CusMatMulCube",
  28. "CusMatrixCombine",
  29. "CusTranspose02314",
  30. "CusMatMulCubeDenseRight",
  31. "CusMatMulCubeFraczLeftCast",
  32. ]
  33. def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False):
  34. """
  35. Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements.
  36. """
  37. def _raise_message():
  38. raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
  39. f"{'or four ' if allow_four else ''}positive int numbers, but got {arg_value}")
  40. def _get_return_value():
  41. if isinstance(arg_value, int):
  42. ret = (1, 1, arg_value, arg_value) if ret_four else (arg_value, arg_value)
  43. elif len(arg_value) == 2:
  44. ret = (1, 1, arg_value[0], arg_value[1]) if ret_four else arg_value
  45. elif len(arg_value) == 4:
  46. if not allow_four:
  47. _raise_message()
  48. ret = arg_value if ret_four else (arg_value[2], arg_value[3])
  49. else:
  50. _raise_message()
  51. return ret
  52. validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
  53. ret_value = _get_return_value()
  54. for item in ret_value:
  55. if isinstance(item, int) and item > 0:
  56. continue
  57. _raise_message()
  58. return ret_value
  59. class CusBatchMatMul(PrimitiveWithInfer):
  60. """
  61. Multiplies matrix `a` by matrix `b` in batch.
  62. The rank of input tensors must be `3`.
  63. Inputs:
  64. - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`.
  65. - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. If
  66. `transpose_b` is True.
  67. Outputs:
  68. Tensor, the shape of the output tensor is :math:`(N, D, D)`.
  69. Examples:
  70. >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32)
  71. >>> input_y = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32)
  72. >>> cus_batch_matmul = P.CusBatchMatMul()
  73. >>> output = cus_batch_matmul(input_x, input_y)
  74. """
  75. @prim_attr_register
  76. def __init__(self):
  77. """Initialize CusBatchMatMul"""
  78. self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
  79. from mindspore.ops._op_impl._custom_op.batch_matmul_impl import CusBatchMatMul
  80. def infer_shape(self, data1_shape, data2_shape):
  81. return data1_shape
  82. def infer_dtype(self, data1_dtype, data2_dtype):
  83. return data1_dtype
  84. class CusCholeskyTrsm(PrimitiveWithInfer):
  85. """
  86. L * LT = A.
  87. LT * (LT)^-1 = I.
  88. return (LT)^-1.
  89. Only compute the res of the diag part of input matrix with dim 128.
  90. The rank of input tensors must be `2`.
  91. Inputs:
  92. - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, N)`.
  93. Outputs:
  94. Tensor, the shape of the output tensor is :math:`(N // Split_dim, Split_dim, Split_dim)`.
  95. Examples:
  96. >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float32)
  97. >>> cus_choleskytrsm = P.CusCholeskyTrsm()
  98. >>> output = matmul(input_x)
  99. """
  100. @prim_attr_register
  101. def __init__(self):
  102. """Initialize CusCholeskyTrsm"""
  103. self.init_prim_io_names(inputs=['x1'], outputs=['y'])
  104. from mindspore.ops._op_impl._custom_op.cholesky_trsm_impl import CusCholeskyTrsm
  105. def infer_shape(self, data1_shape):
  106. ll = []
  107. m, _ = data1_shape
  108. if m >= 128:
  109. ll = [m // 128, 128, 128]
  110. else:
  111. ll = [1, 64, 64]
  112. return ll
  113. def infer_dtype(self, data1_dtype):
  114. return data1_dtype
  115. class CusFusedAbsMax1(PrimitiveWithInfer):
  116. """
  117. Computes the abs max of Tensor input.
  118. The rank of input tensors must be `4` or `2`.
  119. Inputs:
  120. - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N0, M0, N1, M1)`
  121. or math:`(32, 64)`.
  122. Outputs:
  123. Tensor, the shape of the output tensor is :math:`(32, 64)` or math:`(1, )`.
  124. Examples:
  125. >>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
  126. >>> cus_fused_abs_max1 = P.CusFusedAbsMax1()
  127. >>> output = cus_fused_abs_max1(input_x)
  128. """
  129. @prim_attr_register
  130. def __init__(self, origin_shape=[-1, -1]):
  131. """Initialize CusFusedAbsMax1"""
  132. self.init_prim_io_names(inputs=['x1'], outputs=['y'])
  133. self.origin_shape = origin_shape
  134. from mindspore.ops._op_impl._custom_op.fused_abs_max1_impl import CusFusedAbsMax1
  135. def infer_shape(self, data1_shape):
  136. ll = []
  137. if len(data1_shape) == 2:
  138. ll = [1,]
  139. else:
  140. ll = [32, 64]
  141. return ll
  142. def infer_dtype(self, data1_dtype):
  143. return data1_dtype
  144. class CusImg2Col(PrimitiveWithInfer):
  145. """
  146. Img2cols the feature map and the result in reorganized in NC1HWC0.
  147. Args:
  148. - **strides** (listInt) - the stride of the ops.
  149. - **ksizes** (listInt) - the kernel size of the ops.
  150. Inputs:
  151. - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C, H, W)`.
  152. Outputs:
  153. Tensor, the shape of the output tensor is :math:`(N * H_O * W_O, C1 * K_W * K_H * C0)`.
  154. Examples:
  155. >>> input_x = Tensor(np.ones(shape=[32, 3, 224, 224]), mindspore.float16)
  156. >>> cusimg2col = P.CusImg2Col()
  157. >>> output = cusimg2col(input_x)
  158. """
  159. @prim_attr_register
  160. def __init__(self, ksizes, strides, dilates=(1, 1, 1, 1), mode="NC1HWC0"):
  161. """Initialize CusImg2Col"""
  162. self.init_prim_io_names(inputs=['x1'], outputs=['y'])
  163. self.ksizes = ksizes
  164. self.strides = strides
  165. self.dilates = dilates
  166. self.mode = mode
  167. from mindspore.ops._op_impl._custom_op.img2col_impl import CusImg2Col
  168. def infer_shape(self, data1_shape):
  169. bs, c, h, w = data1_shape
  170. _, stride_h, stride_w, _ = self.strides
  171. _, k_w, k_h, _ = self.ksizes
  172. # assert m == n
  173. c0 = 16
  174. c1 = c // 16
  175. if c1 == 0:
  176. c1 = 1
  177. shape = [bs * int(h // stride_h) * int(w // stride_w), k_w * k_h * c1 * c0]
  178. return shape
  179. def infer_dtype(self, data1_dtype):
  180. return data1_dtype
  181. class CusMatMulCubeDenseLeft(PrimitiveWithInfer):
  182. """
  183. Multiplies matrix `a` by matrix `b`.
  184. The rank of input_x1 must be `4`, the fractal format of the normal matrix.
  185. The rank of input_x2 must be `2`.
  186. Inputs:
  187. - **input_x1** (Tensor) - The first tensor to be multiplied.
  188. The shape of the tensor is :math:`(N0, M0, N1, M1)`.
  189. - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(M, C)`.
  190. Outputs:
  191. Tensor, the shape of the output tensor is :math:`(N, C)`.
  192. Examples:
  193. >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
  194. >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
  195. >>> matmulcubedenseleft = P.CusMatMulCubeDenseLeft()
  196. >>> output = matmulcubedenseleft(input_x, input_y)
  197. """
  198. @prim_attr_register
  199. def __init__(self):
  200. """Initialize CusMatMulCubeDenseLeft"""
  201. self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
  202. from mindspore.ops._op_impl._custom_op.matmul_cube_dense_left_impl import CusMatMulCubeDenseLeft
  203. def infer_shape(self, data1_shape, data2_shape):
  204. return data2_shape
  205. def infer_dtype(self, data1_dtype, data2_dtype):
  206. return mstype.float16
  207. class CusMatMulCubeFraczRightMul(PrimitiveWithInfer):
  208. """
  209. Multiplies matrix `a` by matrix `b` and muls the result by scalar `c`.
  210. The rank of input_x1 tensors must be `2`.
  211. The rank of input_x2 tensors must be `4`.
  212. Inputs:
  213. - **input_x1** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`.
  214. - **input_x2** (Tensor) - The second tensor to be multiplied.
  215. The shape of the tensor is :math:`(C1, M1, C0, M0)`.
  216. - **input_x3** (Tensor) - The third tensor to be multiplied. The shape of the tensor if :math`(1, )`.
  217. Outputs:
  218. Tensor, the shape of the output tensor is :math:`(N, M)`.
  219. Examples:
  220. >>> input_x1 = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
  221. >>> input_x2 = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
  222. >>> input_x3 = Tensor(np.ones(shape=[1, ]), mindspore.float16)
  223. >>> cusmatmulfraczrightmul = P.CusMatMulCubeFraczRightMul()
  224. >>> output = cusmatmulfraczrightmul(input_x1, input_x2, input_x3)
  225. """
  226. @prim_attr_register
  227. def __init__(self):
  228. """Initialize CusMatMulCubeFraczRightMul"""
  229. self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y'])
  230. from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_right_mul_impl import CusMatMulCubeFraczRightMul
  231. def infer_shape(self, data1_shape, data2_shape, data3_shape):
  232. return data1_shape
  233. def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype):
  234. return mstype.float32
  235. class CusMatMulCube(PrimitiveWithInfer):
  236. """
  237. Multiplies matrix `a` by matrix `b`.
  238. The rank of input tensors must be `2`.
  239. Args:
  240. transpose_a (bool): If true, `a` is transposed before multiplication. Default: False.
  241. transpose_b (bool): If true, `b` is transposed before multiplication. Default: False.
  242. Inputs:
  243. - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`. If
  244. `transpose_a` is True, its shape must be :math:`(N, C)` after transposing.
  245. - **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`. If
  246. `transpose_b` is True, its shape must be :math:`(C, M)` after transpose.
  247. Outputs:
  248. Tensor, the shape of the output tensor is :math:`(N, M)`.
  249. Examples:
  250. >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
  251. >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
  252. >>> cusmatmulcube = P.CusMatMulCube()
  253. >>> output = matmul(input_x, input_y)
  254. """
  255. @prim_attr_register
  256. def __init__(self, transpose_a=False, transpose_b=False):
  257. """Initialize CusMatMulCube"""
  258. self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
  259. self.transpose_a = transpose_a
  260. self.transpose_b = transpose_b
  261. from mindspore.ops._op_impl._custom_op.matmul_cube_impl import CusMatMulCube
  262. def infer_shape(self, data1_shape, data2_shape):
  263. if self.transpose_a:
  264. k1, m = data1_shape
  265. else:
  266. m, k1 = data1_shape
  267. if self.transpose_b:
  268. n, k2 = data2_shape
  269. else:
  270. k2, n = data2_shape
  271. assert k1 == k2
  272. shape = [m, n]
  273. return shape
  274. def infer_dtype(self, data1_dtype, data2_dtype):
  275. return mstype.float32
  276. class CusMatrixCombine(PrimitiveWithInfer):
  277. """
  278. move the batch matrix to result matrix diag part.
  279. The rank of input tensors must be `3`.
  280. Inputs:
  281. - **input_x** (Tensor) - The shape of the tensor is :math:`(N, D, D)`.
  282. Outputs:
  283. Tensor, the shape of the output tensor is :math:`(N * D, N * D)`.
  284. Examples:
  285. >>> input_x = Tensor(np.ones(shape=[2, 128, 128]), mindspore.float32)
  286. >>> cusmatrixcombine = P.CusMatrixCombine()
  287. >>> output = cusmatrixcombine(input_x)
  288. """
  289. @prim_attr_register
  290. def __init__(self):
  291. """Initialize CusMatrixCombine"""
  292. self.init_prim_io_names(inputs=['x'], outputs=['y'])
  293. from mindspore.ops._op_impl._custom_op.matrix_combine_impl import CusMatrixCombine
  294. def infer_shape(self, data_shape):
  295. a, b, c = data_shape
  296. shape = [a * b, a * c]
  297. return shape
  298. def infer_dtype(self, data_dtype):
  299. return data_dtype
  300. class CusTranspose02314(PrimitiveWithInfer):
  301. """
  302. Permute input tensor with perm (0, 2, 3, 1, 4)
  303. The rank of input tensors must be `5` with format NC1HWC0.
  304. Inputs:
  305. - **input_x** (Tensor) - The shape of the tensor is :math:`(N, C1, H, W, C0)`.
  306. Outputs:
  307. Tensor, the shape of the output tensor is :math:`(N, H, W, C1, C0)`.
  308. Examples:
  309. >>> input_x = Tensor(np.ones(shape=[32, 1, 224, 224, 16]), mindspore.float16)
  310. >>> custranspose02314 = P.CusTranspose02314()
  311. >>> output = custranspose02314(input_x)
  312. """
  313. @prim_attr_register
  314. def __init__(self):
  315. """Initialize CusTranspose02314"""
  316. self.init_prim_io_names(inputs=['x1'], outputs=['y'])
  317. from mindspore.ops._op_impl._custom_op.transpose02314_impl import CusTranspose02314
  318. def get_bprop(self):
  319. def bprop(x, out, dout):
  320. return (C.zeros_like(x),)
  321. return bprop
  322. def infer_shape(self, data1_shape):
  323. assert len(data1_shape) == 4
  324. n, c, h, w = data1_shape
  325. c0 = 16
  326. c1 = c // 16
  327. shape = (n * h * w, c1 * c0)
  328. return shape
  329. def infer_dtype(self, data1_dtype):
  330. return data1_dtype
  331. class CusMatMulCubeDenseRight(PrimitiveWithInfer):
  332. """
  333. Multiplies matrix `a` by matrix `b`.
  334. The rank of input_x1 tensor must be `2`.
  335. The rank of input_x2 tensor must be `4`.
  336. Inputs:
  337. - **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, C)`.
  338. - **input_y** (Tensor) - The second tensor to be multiplied.
  339. The shape of the tensor is :math:`(C1, M1, M0, C0)`.
  340. Outputs:
  341. Tensor, the shape of the output tensor is :math:`(N, M)`.
  342. Examples:
  343. >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
  344. >>> input_y = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
  345. >>> cusmatmulcubedenseright = P.CusMatMulCubeDenseRight()
  346. >>> output = cusmatmulcubedenseright(input_x, input_y)
  347. """
  348. @prim_attr_register
  349. def __init__(self):
  350. """Initialize CusMatMulCubeDenseRight"""
  351. self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y'])
  352. from mindspore.ops._op_impl._custom_op.matmul_cube_dense_right_impl import CusMatMulCubeDenseRight
  353. def infer_shape(self, data1_shape, data2_shape, data3_shape):
  354. return data1_shape
  355. def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype):
  356. return mstype.float32
  357. class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer):
  358. """
  359. Multiplies matrix `a` by matrix `b`.
  360. The rank of input_x1 tensor must be `4`.
  361. The rank of input_x2 tensors must be `2`.
  362. Inputs:
  363. - **input_x1** (Tensor) - The first tensor to be multiplied.
  364. The shape of the tensor is :math:`(C1, N1, N0, C0)`.
  365. - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(C, M)`.
  366. Outputs:
  367. Tensor, the shape of the output tensor is :math:`(N, M)`.
  368. Examples:
  369. >>> input_x = Tensor(np.ones(shape=[16, 16, 16, 16]), mindspore.float16)
  370. >>> input_y = Tensor(np.ones(shape=[256, 256]), mindspore.float16)
  371. >>> cusmatmulcubefraczleftcast = P.CusMatMulCubeFraczLeftCast()
  372. >>> output = cusmatmulcubefraczleftcast(input_x, input_y)
  373. """
  374. @prim_attr_register
  375. def __init__(self):
  376. """Initialize CusMatMulCubeFraczLeftCast"""
  377. self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
  378. from mindspore.ops._op_impl._custom_op.matmul_cube_fracz_left_cast_impl import CusMatMulCubeFraczLeftCast
  379. def infer_shape(self, data1_shape, data2_shape):
  380. return data2_shape
  381. def infer_dtype(self, data1_dtype, data2_dtype):
  382. return mstype.float16
  383. class Im2Col(PrimitiveWithInfer):
  384. """
  385. extracts image pathes from image.
  386. The rank of input_x1 must be `4`, data_format is "NCHW".
  387. Inputs:
  388. - **input_x1** (Tensor) - The feature map.
  389. The shape of the tensor is :math:`(N, C, H, W)`.
  390. Outputs:
  391. Tensor.
  392. Examples:
  393. >>> input_x = Tensor(np.random.rand(32, 3, 224, 224).astype(np.float16))
  394. >>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2)
  395. >>> output = img2col(input_x)
  396. """
  397. @prim_attr_register
  398. def __init__(self,
  399. kernel_size,
  400. pad_mode="valid",
  401. pad=0,
  402. stride=1,
  403. dilation=1):
  404. """Initialize Im2Col"""
  405. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  406. self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
  407. self.add_prim_attr('kernel_size', self.kernel_size)
  408. self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
  409. self.add_prim_attr('stride', self.stride)
  410. self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
  411. self.add_prim_attr('dilation', self.dilation)
  412. validator.check_value_type('pad', pad, (int,), self.name)
  413. self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
  414. self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
  415. if self.pad_mode == 'pad':
  416. validator.check_non_negative_int(self.pad, 'pad', self.name)
  417. self.add_prim_attr('data_format', "NCHW")
  418. def infer_shape(self, x_shape):
  419. validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
  420. kernel_size_h = self.kernel_size[0]
  421. kernel_size_w = self.kernel_size[1]
  422. stride_h = self.stride[2]
  423. stride_w = self.stride[3]
  424. dilation_h = self.dilation[2]
  425. dilation_w = self.dilation[3]
  426. if self.pad_mode == "valid":
  427. h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
  428. w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
  429. pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
  430. elif self.pad_mode == "same":
  431. h_out = math.ceil(x_shape[2] / stride_h)
  432. w_out = math.ceil(x_shape[3] / stride_w)
  433. pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2])
  434. pad_top = math.floor(pad_needed_h / 2)
  435. pad_bottom = pad_needed_h - pad_top
  436. pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3])
  437. pad_left = math.floor(pad_needed_w / 2)
  438. pad_right = pad_needed_w - pad_left
  439. elif self.pad_mode == 'pad':
  440. pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad
  441. h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h
  442. w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w
  443. h_out = math.floor(h_out)
  444. w_out = math.floor(w_out)
  445. self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
  446. self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
  447. batch_size = x_shape[0]
  448. channel = x_shape[1]
  449. k_h = kernel_size_h
  450. k_w = kernel_size_w
  451. out_shape = [channel, k_h, k_w, batch_size, h_out, w_out]
  452. return out_shape
  453. def infer_dtype(self, x_dtype):
  454. args = {'x': x_dtype}
  455. valid_types = [mstype.float16, mstype.float32]
  456. validator.check_tensor_type_same(args, valid_types, self.name)
  457. return x_dtype
  458. class UpdateThorGradient(PrimitiveWithInfer):
  459. """
  460. Updates Thor Gradient with Approximate Fisher info matrix(for GPU backend).
  461. The rank of input_x1 must be `3`, which indicates the A matrix.
  462. The rank of input_x2 must be `2`, which indicates the 1st-order gradient.
  463. The rank of input_x3 must be `4`, which indicates the G matrix.
  464. Inputs:
  465. - **input_x1** (Tensor) - The first input is the diag part of the cov matrix of feature map.
  466. Supported dtype [float32].
  467. - **input_x2** (Tensor) - The second input is the corresponding 1st-order grad. Supported dtype [float32].
  468. - **input_x3** (Tensor) - The third input is the diag part of the cov matrix of dout. Supported dtype [float32].
  469. Outputs:
  470. Tensor, the shape is the same as the shape of input_x2, it will be used to update the weights.
  471. Examples:
  472. >>> input_x1 = Tensor(np.random.rand(16, 128, 128).astype(np.float32))
  473. >>> input_x2 = Tensor(np.random.rand(2048, 1024).astype(np.float32))
  474. >>> temp_x3 = np.random.rand(8, 128, 128).astype(np.float32)
  475. >>> input_x3 = np.zeros(16,8,128,128).astype(np.float32)
  476. >>> for i in range(16):
  477. >>> input_x3[i,:,:,:] = temp_x3
  478. >>> input_x3 = Tensor(input_x3)
  479. >>> update_thor_gradient = P.UpdateThorGradient(split_dim=128)
  480. >>> output = update_thor_gradient(input_x1, input_x2, input_x3)
  481. """
  482. @prim_attr_register
  483. def __init__(self, split_dim=0):
  484. """Initialize UpdateThorGradient"""
  485. self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y'])
  486. self.split_dim = split_dim
  487. self.add_prim_attr('split_dim', self.split_dim)
  488. def infer_shape(self, x1_shape, x2_shape, x3_shape):
  489. return x2_shape
  490. def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype):
  491. validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype},
  492. [mstype.float32], self.name)
  493. return x2_dtype
  494. class Cholesky(PrimitiveWithInfer):
  495. """
  496. Inner API for resnet50 THOR GPU backend
  497. """
  498. @prim_attr_register
  499. def __init__(self, split_dim=0):
  500. self.init_prim_io_names(inputs=['x1'], outputs=['y'])
  501. self.split_dim = split_dim
  502. self.add_prim_attr('split_dim', self.split_dim)
  503. def infer_shape(self, x1_shape):
  504. if self.split_dim != 0:
  505. assert len(x1_shape) == 2
  506. height = x1_shape[0]
  507. width = x1_shape[1]
  508. assert height == width
  509. if height <= self.split_dim:
  510. out_shape = [1, height, width]
  511. else:
  512. batch = height // self.split_dim
  513. if height != batch * self.split_dim:
  514. batch += 1
  515. out_shape = [batch, self.split_dim, self.split_dim]
  516. else:
  517. out_shape = x1_shape
  518. return out_shape
  519. def infer_dtype(self, x1_dtype):
  520. validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name)
  521. return x1_dtype