You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_inner_ops.py 21 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Inner operators."""
  16. from ..._checkparam import Rel
  17. from ..._checkparam import Validator as validator
  18. from ... import context
  19. from ...common import dtype as mstype
  20. from ..primitive import PrimitiveWithInfer, prim_attr_register
  21. from ..operations.math_ops import _infer_shape_reduce
  22. class ExtractImagePatches(PrimitiveWithInfer):
  23. """
  24. Extracts patches from images.
  25. The input tensor must be a 4-D tensor and the data format is NHWC.
  26. Args:
  27. ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers,
  28. and the format is [1, ksize_row, ksize_col, 1].
  29. strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
  30. must be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
  31. rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension
  32. pixel positions, must be a tuple or a list of integers, and the format is [1, rate_row, rate_col, 1].
  33. padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
  34. not case sensitive. Default: "valid".
  35. - same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
  36. - valid: Means that the taken patch area must be completely covered in the original image.
  37. Inputs:
  38. - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
  39. data type is number.
  40. Outputs:
  41. Tensor, a 4-D tensor whose data type is same as 'input_x',
  42. and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is the same as the in_batch.
  43. """
  44. @prim_attr_register
  45. def __init__(self, ksizes, strides, rates, padding="valid"):
  46. """init"""
  47. def _check_tuple_or_list(arg_name, arg_val, prim_name):
  48. validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
  49. if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
  50. raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
  51. f"{arg_name}_col, 1], but got {arg_val}.")
  52. if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
  53. raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
  54. f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
  55. f"is {arg_val[2]}")
  56. _check_tuple_or_list("ksize", ksizes, self.name)
  57. _check_tuple_or_list("stride", strides, self.name)
  58. _check_tuple_or_list("rate", rates, self.name)
  59. self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
  60. self.add_prim_attr("padding", self.padding)
  61. self.add_prim_attr("io_format", "NHWC")
  62. self.is_ge = context.get_context("enable_ge")
  63. def infer_shape(self, input_x):
  64. """infer shape"""
  65. in_batch, in_depth, in_row, in_col = input_x
  66. if self.is_ge:
  67. in_batch, in_row, in_col, in_depth = input_x
  68. _, ksize_row, ksize_col, _ = self.ksizes
  69. _, stride_row, stride_col, _ = self.strides
  70. _, rate_row, rate_col, _ = self.rates
  71. if len(input_x) != 4:
  72. raise ValueError("The `input_x` should be a 4-D tensor, "
  73. f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
  74. out_batch = in_batch
  75. out_depth = ksize_row * ksize_col * in_depth
  76. if self.padding == "VALID":
  77. out_row = \
  78. (in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1
  79. out_col = \
  80. (in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1
  81. else:
  82. out_row = (in_row - 1) // stride_row + 1
  83. out_col = (in_col - 1) // stride_col + 1
  84. out_shape = [out_batch, out_depth, out_row, out_col]
  85. if self.is_ge:
  86. out_shape = [out_batch, out_row, out_col, out_depth]
  87. return out_shape
  88. def infer_dtype(self, input_x):
  89. """infer dtype"""
  90. validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
  91. return input_x
  92. class Range(PrimitiveWithInfer):
  93. r"""
  94. Creates a sequence of numbers.
  95. Set `input_x` as :math:`x_i` for each element, `output` as follows:
  96. .. math::
  97. \text{output}(x_i) = x_i * \text{delta} + \text{start}
  98. Args:
  99. start (float): If `limit` is `None`, the value acts as limit in the range and first entry
  100. defaults to `0`. Otherwise, it acts as first entry in the range.
  101. limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
  102. while set the first entry of the range to `0`. It can not be equal to `start`.
  103. delta (float): Increment of the range. It can not be equal to zero. Default: 1.0.
  104. Inputs:
  105. - **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32.
  106. Outputs:
  107. Tensor, has the same shape and dtype as `input_x`.
  108. Examples:
  109. >>> range = P.Range(1.0, 8.0, 2.0)
  110. >>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32)
  111. >>> range(x)
  112. [3, 5, 7, 5]
  113. """
  114. @prim_attr_register
  115. def __init__(self, start, limit=None, delta=1.0):
  116. self.init_prim_io_names(inputs=['x'], outputs=['y'])
  117. self.delta = validator.check_value_type("delta", delta, [float], self.name)
  118. validator.check_value_type("start", start, [float], self.name)
  119. if limit is None:
  120. self.start = 0.0
  121. self.limit = start
  122. self.add_prim_attr("start", self.start)
  123. self.add_prim_attr("limit", self.limit)
  124. else:
  125. validator.check_value_type("limit", limit, [float], self.name)
  126. validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name)
  127. if self.delta == 0.0:
  128. raise ValueError("The input of `delta` can not be equal to zero.")
  129. if self.delta > 0.0 and self.start > self.limit:
  130. raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, "
  131. f"but got start:{self.start}, limit:{self.limit}")
  132. if self.delta < 0.0 and self.start < self.limit:
  133. raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, "
  134. f"but got start:{self.start}, limit:{self.limit}")
  135. def infer_shape(self, x_shape):
  136. return x_shape
  137. def infer_dtype(self, x_dtype):
  138. validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name)
  139. return x_dtype
  140. class Quant(PrimitiveWithInfer):
  141. r"""
  142. Returns the quantized value of input_x.
  143. If `sqrt_mode` is False:
  144. .. math::
  145. y = round(scale * x + offset)
  146. If `sqrt_mode` is True:
  147. .. math::
  148. y = round(scale * x * scale + offset)
  149. Note:
  150. This operation only support Ascend 310 inference environment.
  151. Args:
  152. scale (float) : Specifies the scaling ratio.
  153. offset (float): Specifies the offset.
  154. sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False.
  155. round_mode (str): Specifies the way to round. Must be one of ["Round", "Floor", "Ceil", "Trunc"].
  156. Default: "Round".
  157. Inputs:
  158. - **input_x** (Tensor) : Input tensor. Its data type must be mindspore.float16 or mindspore.float32.
  159. Outputs:
  160. - Tensor: The quantized output tensor of type mindspore.int8.
  161. Examples:
  162. >>> input_x = Tensor([100.0, 150.0], mstype.float32)
  163. >>> quant = P.Quant(80.0, 0.0, False, "Round")
  164. >>> y = quant(input_x)
  165. """
  166. @prim_attr_register
  167. def __init__(self, scale, offset, sqrt_mode=False, round_mode="Round"):
  168. self.scale = validator.check_value_type("scale", scale, [float], self.name)
  169. self.offset = validator.check_value_type("offset", offset, [float], self.name)
  170. self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
  171. self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
  172. "round_mode", self.name)
  173. self.add_prim_attr("io_format", "ND")
  174. def infer_shape(self, x_shape):
  175. return x_shape
  176. def infer_dtype(self, x_type):
  177. validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
  178. validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
  179. return mstype.int8
  180. class Dequant(PrimitiveWithInfer):
  181. r"""
  182. Returns the dequantized value of input_x.
  183. This operation will do ReLU to the dequantized value if `relu_flag` is True.
  184. If `sqrt_mode` is False:
  185. .. math::
  186. y = x * deq\_scale
  187. If `sqrt_mode` is True:
  188. .. math::
  189. y = x * deq\_scale * deq\_scale
  190. Note:
  191. This operation only support Ascend 310 inference environment.
  192. Args:
  193. sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False.
  194. relu_flag (bool): Specifies whether to perform ReLU. Default: False.
  195. Inputs:
  196. - **input_x** (Tensor) : Input tensor. Must be mindspore.int32.
  197. - **deq_scale** (Tensor) : Specifies the scaling ratio.
  198. Data type must be mindspore.float16 or mindspore.uint64
  199. Outputs:
  200. - Tensor: The quantized output tensor of type mindspore.float16.
  201. Examples:
  202. >>> input_x = Tensor([100.0, 150.0], mstype.float32)
  203. >>> dequant = P.Dequant(False, False)
  204. >>> y = dequant(input_x)
  205. """
  206. @prim_attr_register
  207. def __init__(self, sqrt_mode=False, relu_flag=False):
  208. self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
  209. self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
  210. self.add_prim_attr("dtype", mstype.float16)
  211. self.add_prim_attr("io_format", "ND")
  212. def infer_shape(self, x_shape, deq_scale_shape):
  213. return x_shape
  214. def infer_dtype(self, x_type, deq_scale_type):
  215. validator.check_subclass("x", x_type, mstype.tensor, self.name)
  216. validator.check_type_name("x", x_type, [mstype.int32], self.name)
  217. validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
  218. return mstype.float16
  219. class LinSpace(PrimitiveWithInfer):
  220. r"""
  221. Generates values in an interval. And return the corresponding interpolation accroding to assist.
  222. Inputs:
  223. - **assist** (Tensor[float32]) - The assist value, With shape of 0-D or 1-D.
  224. - **start** (Tensor[float32]) - The start of interval, With shape of 0-D.
  225. - **stop** (Tensor[float32]) - The end of interval, With shape of 0-D.
  226. - **num** (Tensor[int32]) - ticks number in the interval, the ticks include start and stop value.
  227. With shape of 0-D.
  228. Outputs:
  229. Tensor, has the same shape as `assist`.
  230. Examples:
  231. >>> linspace = P.LinSpace()
  232. >>> assist = Tensor([5, 5.5], mindspore.float32)
  233. >>> start = Tensor(1, mindspore.float32)
  234. >>> stop = Tensor(10, mindspore.float32)
  235. >>> num = Tensor(5, mindspore.int32)
  236. >>> output = linspace(assist, start, stop, num)
  237. [12.25, 13.375]
  238. """
  239. @prim_attr_register
  240. def __init__(self):
  241. pass
  242. def infer_shape(self, assist, start, stop, num):
  243. return assist
  244. def infer_dtype(self, assist, start, stop, num):
  245. args = {"num": num}
  246. validator.check_tensor_type_same(args, (mstype.int32,), self.name)
  247. args = {"assist": assist, "start": start, "stop": stop}
  248. validator.check_tensor_type_same(args, (mstype.float32,), self.name)
  249. return assist
  250. class MatrixDiag(PrimitiveWithInfer):
  251. """
  252. Returns a batched diagonal tensor with a given batched diagonal values.
  253. Inputs:
  254. - **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be one of the following data
  255. types: float32, float16, int32, int8, and uint8.
  256. - **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must greater than or equal to 2 and
  257. it's last dimension must equal to the second to last dimension.
  258. Outputs:
  259. Tensor, has the same type and shape as input `assist`.
  260. Examples:
  261. >>> x = Tensor(np.array([1, -1]), mstype.float32)
  262. >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
  263. >>> matrix_diag = P.MatrixDiag()
  264. >>> result = matrix_diag(x, assist)
  265. [[[-12. 11.]
  266. [-10. 9.]]
  267. [[ -8. 7.]
  268. [ -6. 5.]]
  269. [[ -4. 3.]
  270. [ -2. 1.]]]
  271. """
  272. @prim_attr_register
  273. def __init__(self):
  274. """Initialize MatrixDiag"""
  275. def infer_dtype(self, x_dtype, assist_dtype):
  276. valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
  277. args = {"x": x_dtype, "assist": assist_dtype}
  278. validator.check_tensor_type_same(args, valid_type, self.name)
  279. return x_dtype
  280. def infer_shape(self, x_shape, assist_shape):
  281. validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name)
  282. validator.check('rank of x', len(x_shape)+1,
  283. 'rank of assist', len(assist_shape), Rel.LE, self.name)
  284. validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
  285. assist_shape[-1], Rel.EQ, self.name)
  286. r_end_dim = -len(x_shape)
  287. r_idx = -1
  288. while r_idx >= r_end_dim:
  289. if x_shape[r_idx] != 1:
  290. validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
  291. assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name)
  292. r_idx = r_idx - 1
  293. return assist_shape
  294. class MatrixDiagPart(PrimitiveWithInfer):
  295. r"""
  296. Returns the batched diagonal part of a batched tensor.
  297. Inputs:
  298. - **x** (Tensor) - The batched tensor. It can be one of the following data types:
  299. float32, float16, int32, int8, uint8.
  300. - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
  301. Outputs:
  302. Tensor, data type same as input `x`. The shape must be x.shape[:-2] + [min(x.shape[-2:])].
  303. Examples:
  304. >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
  305. >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
  306. >>> matrix_diag_part = P.MatrixDiagPart()
  307. >>> result = matrix_diag_part(x, assist)
  308. [[12., -9.], [8., -5.], [4., -1.]]
  309. """
  310. @prim_attr_register
  311. def __init__(self):
  312. """Initialize MatrixDiagPart"""
  313. def infer_dtype(self, x_dtype, assist_dtype):
  314. valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
  315. args = {"x": x_dtype, "assist": assist_dtype}
  316. validator.check_tensor_type_same(args, valid_type, self.name)
  317. return x_dtype
  318. def infer_shape(self, x_shape, assist_shape):
  319. validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
  320. validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
  321. if assist_shape[-2] < assist_shape[-1]:
  322. out_shape = assist_shape[:-1]
  323. else:
  324. out_shape = assist_shape[:-2] + assist_shape[-1:]
  325. return out_shape
  326. class MatrixSetDiag(PrimitiveWithInfer):
  327. r"""
  328. Modifies the batched diagonal part of a batched tensor.
  329. Inputs:
  330. - **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
  331. float32, float16, int32, int8, uint8.
  332. - **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
  333. - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
  334. Outputs:
  335. Tensor, data type same as input `x`. The shape same as `x`.
  336. Examples:
  337. >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
  338. >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
  339. >>> matrix_set_diag = P.MatrixSetDiag()
  340. >>> result = matrix_set_diag(x, diagonal)
  341. [[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
  342. """
  343. @prim_attr_register
  344. def __init__(self):
  345. """Initialize MatrixSetDiag"""
  346. def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
  347. valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
  348. args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
  349. validator.check_tensor_type_same(args, valid_type, self.name)
  350. return x_dtype
  351. def infer_shape(self, x_shape, diagonal_shape, assist_shape):
  352. validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
  353. validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
  354. if x_shape[-2] < x_shape[-1]:
  355. validator.check("diagnoal shape", diagonal_shape, "x shape excluding the last dimension",
  356. x_shape[:-1], Rel.EQ, self.name)
  357. else:
  358. validator.check("diagonal shape", diagonal_shape, "x shape excluding the second last dimension",
  359. x_shape[:-2] + x_shape[-1:], Rel.EQ, self.name)
  360. return assist_shape
  361. class ConfusionMulGrad(PrimitiveWithInfer):
  362. """
  363. `output0` is the dot product result of input0 and input1.
  364. `output1` is the dot product result of input0 and input1, then apply the reducesum operation on it.
  365. Args:
  366. axis (Union[int, tuple[int], list[int]]): The dimensions to reduce.
  367. Default:(), reduce all dimensions. Only constant value is allowed.
  368. keep_dims (bool):
  369. - If true, keep these reduced dimensions and the length as 1.
  370. - If false, don't keep these dimensions. Default:False.
  371. Inputs:
  372. - **input_0** (Tensor) - The input Tensor.
  373. - **input_1** (Tensor) - The input Tensor.
  374. - **input_2** (Tensor) - The input Tensor.
  375. Outputs:
  376. - **output_0** (Tensor) - The same shape as `input0`.
  377. - **output_1** (Tensor)
  378. - If axis is (), and keep_dims is false, the output is a 0-D array representing
  379. the sum of all elements in the input array.
  380. - If axis is int, set as 2, and keep_dims is false,
  381. the shape of output is :math:`(x_1,x_3,...,x_R)`.
  382. - If axis is tuple(int), set as (2,3), and keep_dims is false,
  383. the shape of output is :math:`(x_1,x_4,...x_R)`.
  384. Examples:
  385. >>> confusion_mul_grad = P.ConfusionMulGrad()
  386. >>> input_0 = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
  387. >>> input_1 = Tensor(np.random.randint(0, 4, (2, 3)), mindspore.float32)
  388. >>> input_2 = Tensor(np.random.randint(-4, 0, (2, 3)), mindspore.float32)
  389. >>> output_0, output_1 = confusion_mul_grad(input_0, input_1, input_2)
  390. output_0:
  391. [[ 3. 1. 0.]
  392. [-6. 2. -2.]]
  393. output_1:
  394. -3.0
  395. """
  396. @prim_attr_register
  397. def __init__(self, axis=(), keep_dims=False):
  398. self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"])
  399. self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name)
  400. self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
  401. def infer_shape(self, input0_shape, input1_shape, input2_shape):
  402. outshape0 = input0_shape
  403. outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name)
  404. return outshape0, outshape1
  405. def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype):
  406. validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor, self.name)
  407. validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name)
  408. validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name)
  409. return input0_dtype, input1_dtype