You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_inner_ops.py 33 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Inner operators."""
  16. from ..._checkparam import Rel
  17. from ..._checkparam import Validator as validator
  18. from ...common import dtype as mstype
  19. from ..._c_expression import signature_rw as sig_rw
  20. from ..._c_expression import signature_kind as sig_kind
  21. from ..._c_expression import signature_dtype as sig_dtype
  22. from ..primitive import PrimitiveWithInfer, prim_attr_register
  23. class ExtractImagePatches(PrimitiveWithInfer):
  24. """
  25. Extract patches from images.
  26. The input tensor must be a 4-D tensor and the data format is NHWC.
  27. Args:
  28. ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int,
  29. and the format is [1, ksize_row, ksize_col, 1].
  30. strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
  31. should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
  32. rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim
  33. pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1].
  34. padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
  35. not case sensitive. Default: "valid".
  36. - same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
  37. - valid: Means that the patch area taken must be completely contained in the original image.
  38. Inputs:
  39. - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
  40. data type is number.
  41. Outputs:
  42. Tensor, a 4-D tensor whose data type is same as 'input_x',
  43. and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch.
  44. """
  45. @prim_attr_register
  46. def __init__(self, ksizes, strides, rates, padding="valid"):
  47. """init"""
  48. def _check_tuple_or_list(arg_name, arg_val, prim_name):
  49. validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
  50. if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
  51. raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
  52. f"{arg_name}_col, 1], but got {arg_val}.")
  53. if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
  54. raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
  55. f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
  56. f"is {arg_val[2]}")
  57. _check_tuple_or_list("ksize", ksizes, self.name)
  58. _check_tuple_or_list("stride", strides, self.name)
  59. _check_tuple_or_list("rate", rates, self.name)
  60. self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
  61. self.add_prim_attr("padding", self.padding)
  62. def infer_shape(self, input_x):
  63. """infer shape"""
  64. in_batch, in_row, in_col, in_depth = input_x
  65. _, ksize_row, ksize_col, _ = self.ksizes
  66. _, stride_row, stride_col, _ = self.strides
  67. _, rate_row, rate_col, _ = self.rates
  68. if len(input_x) != 4:
  69. raise ValueError("The `input_x` should be a 4-D tensor, "
  70. f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
  71. out_batch = in_batch
  72. out_depth = ksize_row * ksize_col * in_depth
  73. if self.padding == "VALID":
  74. out_row = \
  75. (in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1
  76. out_col = \
  77. (in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1
  78. else:
  79. out_row = (in_row - 1) // stride_row + 1
  80. out_col = (in_col - 1) // stride_col + 1
  81. out_shape = [out_batch, out_row, out_col, out_depth]
  82. return out_shape
  83. def infer_dtype(self, input_x):
  84. """infer dtype"""
  85. validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
  86. return input_x
  87. class Range(PrimitiveWithInfer):
  88. r"""
  89. Creates a sequence of numbers.
  90. Set `input_x` as :math:`x_i` for each element, `output` as follows:
  91. .. math::
  92. \text{output}(x_i) = x_i * \text{delta} + \text{start}
  93. Args:
  94. start (float): If `limit` is `None`, the value acts as limit in the range and first entry
  95. defaults to `0`. Otherwise, it acts as first entry in the range.
  96. limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
  97. while set the first entry of the range to `0`. It can not be equal to `start`.
  98. delta (float): Increment of the range. It can not be equal to zero. Default: 1.0.
  99. Inputs:
  100. - **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32.
  101. Outputs:
  102. Tensor, has the same shape and dtype as `input_x`.
  103. Examples:
  104. >>> range = P.Range(1.0, 8.0, 2.0)
  105. >>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32)
  106. >>> range(x)
  107. [3, 5, 7, 5]
  108. """
  109. @prim_attr_register
  110. def __init__(self, start, limit=None, delta=1.0):
  111. self.init_prim_io_names(inputs=['x'], outputs=['y'])
  112. self.delta = validator.check_value_type("delta", delta, [float], self.name)
  113. validator.check_value_type("start", start, [float], self.name)
  114. if limit is None:
  115. self.start = 0.0
  116. self.limit = start
  117. self.add_prim_attr("start", self.start)
  118. self.add_prim_attr("limit", self.limit)
  119. else:
  120. validator.check_value_type("limit", limit, [float], self.name)
  121. validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name)
  122. if self.delta == 0.0:
  123. raise ValueError("The input of `delta` can not be equal to zero.")
  124. if self.delta > 0.0 and self.start > self.limit:
  125. raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, "
  126. f"but got start:{self.start}, limit:{self.limit}")
  127. if self.delta < 0.0 and self.start < self.limit:
  128. raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, "
  129. f"but got start:{self.start}, limit:{self.limit}")
  130. def infer_shape(self, x_shape):
  131. return x_shape
  132. def infer_dtype(self, x_dtype):
  133. validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name)
  134. return x_dtype
  135. class AscendQuant(PrimitiveWithInfer):
  136. r"""
  137. Returns the quantized value of input_x.
  138. If `sqrt_mode` is False:
  139. .. math::
  140. y = round(scale * x + offset)
  141. If `sqrt_mode` is True:
  142. .. math::
  143. y = round(scale * x * scale + offset)
  144. Note:
  145. This operation only support Ascend 310 inference environment.
  146. Args:
  147. scale (float) : Specifies the scaling ratio.
  148. offset (float): Specifies the offset.
  149. sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False.
  150. round_mode (str): Specifies the way to round. Should be one of ["Round", "Floor", "Ceil", "Trunc"].
  151. Default: "Round".
  152. Inputs:
  153. - **input_x** (Tensor) : Input tensor. Its data type should be mindspore.float16 or mindspore.float32.
  154. Outputs:
  155. - Tensor: The quantized output tensor of type mindspore.int8.
  156. Examples:
  157. >>> input_x = Tensor([100.0, 150.0], mstype.float32)
  158. >>> quant = P.AscendQuant(80.0, 0.0, False, "Round")
  159. >>> y = quant(input_x)
  160. """
  161. @prim_attr_register
  162. def __init__(self, scale, offset, sqrt_mode=False, round_mode="Round"):
  163. self.scale = validator.check_value_type("scale", scale, [float], self.name)
  164. self.offset = validator.check_value_type("offset", offset, [float], self.name)
  165. self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
  166. self.round_mode = validator.check_string("round_mode", round_mode,
  167. ["Round", "Floor", "Ceil", "Trunc"], self.name)
  168. def infer_shape(self, x_shape):
  169. return x_shape
  170. def infer_dtype(self, x_type):
  171. validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
  172. validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
  173. return mstype.int8
  174. class AscendDequant(PrimitiveWithInfer):
  175. r"""
  176. Returns the dequantized value of input_x.
  177. This operation will do ReLU to the dequantized value if `relu_flag` is True.
  178. If `sqrt_mode` is False:
  179. .. math::
  180. y = x * deq\_scale
  181. If `sqrt_mode` is True:
  182. .. math::
  183. y = x * deq\_scale * deq\_scale
  184. Note:
  185. This operation only support Ascend 310 inference environment.
  186. Args:
  187. sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False.
  188. relu_flag (bool): Specifies whether to perform ReLU. Default: False.
  189. Inputs:
  190. - **input_x** (Tensor) : Input tensor. Should be mindspore.int32.
  191. - **deq_scale** (Tensor) : Specifies the scaling ratio.
  192. Data type should be mindspore.float16 or mindspore.uint64
  193. Outputs:
  194. - Tensor: The quantized output tensor of type mindspore.float16.
  195. Examples:
  196. >>> input_x = Tensor([100.0, 150.0], mstype.float32)
  197. >>> dequant = P.AscendDequant(False, False)
  198. >>> y = dequant(input_x)
  199. """
  200. @prim_attr_register
  201. def __init__(self, sqrt_mode=False, relu_flag=False):
  202. self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
  203. self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
  204. def infer_shape(self, x_shape, deq_scale_shape):
  205. return x_shape
  206. def infer_dtype(self, x_type, deq_scale_type):
  207. validator.check_subclass("x", x_type, mstype.tensor, self.name)
  208. validator.check_type_name("x", x_type, [mstype.int32], self.name)
  209. validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
  210. return mstype.float16
  211. class EmbeddingLookup(PrimitiveWithInfer):
  212. """
  213. Returns a slice of input tensor based on the specified indices.
  214. This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs:
  215. `offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices.
  216. Inputs:
  217. - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  218. The Tensor slice, instead of the entire Tensor.
  219. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
  220. Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
  221. and the exceeding part will be filled with 0 in the output.
  222. - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
  223. are equal to `input_indices` minus `offset`.
  224. - **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not.
  225. Only constant value is allowed.
  226. - **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable
  227. is used only if `reduce_scatter_flag` is True. Only constant value is allowed.
  228. Outputs:
  229. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
  230. Examples:
  231. >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
  232. >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
  233. >>> offset = 4
  234. >>> reduce_scatter_flag = False
  235. >>> split_num = 1
  236. >>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num)
  237. [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
  238. """
  239. @prim_attr_register
  240. def __init__(self):
  241. """init index_select"""
  242. self.__setattr_flag__ = True
  243. self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'],
  244. outputs=['output'])
  245. self.add_prim_attr('primitive_target', 'CPU')
  246. def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2):
  247. validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
  248. validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
  249. validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
  250. validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name)
  251. if split_num['value'] < 1:
  252. raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num)
  253. params_shp = params['shape']
  254. out_shape = indices['shape'] + params_shp[1:]
  255. if reduce_scatter_flag is None:
  256. raise ValueError("The value of 'reduce_scatter_flag' is None.")
  257. reduce_scatter_flag_value = reduce_scatter_flag['value']
  258. if split_num is None:
  259. raise ValueError("The value of 'split_num_value' is None.")
  260. split_num_value = split_num['value']
  261. if reduce_scatter_flag_value is True:
  262. # Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by
  263. # (split_num * 8)
  264. if out_shape[0] % (split_num_value * 8) != 0:
  265. raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." %
  266. (out_shape[0], (split_num_value * 8)))
  267. # After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8
  268. out_shape[0] = out_shape[0] // 8
  269. out = {'shape': out_shape,
  270. 'dtype': params['dtype'],
  271. 'value': None}
  272. return out
  273. class SparseApplyFtrlNoReturn(PrimitiveWithInfer):
  274. """
  275. Update relevant entries according to the FTRL-proximal scheme.
  276. Args:
  277. lr (float): The learning rate value, must be positive.
  278. l1 (float): l1 regularization strength, must be greater than or equal to zero.
  279. l2 (float): l2 regularization strength, must be greater than or equal to zero.
  280. lr_power (float): Learning rate power controls how the learning rate decreases during training,
  281. must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
  282. use_locking (bool): Use locks for update operation if True . Default: False.
  283. Inputs:
  284. - **var** (Parameter): The variable to be updated. The data type must be float32.
  285. - **accum** (Parameter): The accum to be updated, must be same type and shape as `var`.
  286. - **linear** (Parameter): The linear to be updated, must be same type and shape as `var`.
  287. - **grad** (Tensor): A tensor of the same type as `var`, for the gradient.
  288. - **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape
  289. of `indices` must be the same as `grad` in first dimension. The type must be int32.
  290. Outputs:
  291. Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless.
  292. - **var** (Tensor) - A Tensor with shape (1,).
  293. - **accum** (Tensor) - A Tensor with shape (1,).
  294. - **linear** (Tensor) - A Tensor with shape (1,).
  295. Examples:
  296. >>> import mindspore
  297. >>> import mindspore.nn as nn
  298. >>> import numpy as np
  299. >>> from mindspore import Parameter
  300. >>> from mindspore import Tensor
  301. >>> from mindspore.ops import operations as P
  302. >>> class SparseApplyFtrlNet(nn.Cell):
  303. >>> def __init__(self):
  304. >>> super(SparseApplyFtrlNet, self).__init__()
  305. >>> self.sparse_apply_ftrl = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
  306. >>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
  307. >>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
  308. >>> self.linear = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="linear")
  309. >>>
  310. >>> def construct(self, grad, indices):
  311. >>> out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
  312. >>> return out
  313. >>>
  314. >>> net = SparseApplyFtrlNet()
  315. >>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
  316. >>> indices = Tensor(np.array([0, 1]).astype(np.int32))
  317. >>> output = net(grad, indices)
  318. """
  319. __mindspore_signature__ = (
  320. ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  321. ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  322. ('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  323. ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  324. ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
  325. )
  326. @prim_attr_register
  327. def __init__(self, lr, l1, l2, lr_power, use_locking=False):
  328. self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
  329. outputs=['output'])
  330. validator.check_value_type("lr", lr, [float], self.name)
  331. validator.check_value_type("l1", l1, [float], self.name)
  332. validator.check_value_type("l2", l2, [float], self.name)
  333. validator.check_value_type("lr_power", lr_power, [float], self.name)
  334. self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
  335. self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
  336. self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
  337. self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
  338. self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
  339. self.add_prim_attr('primitive_target', 'CPU')
  340. def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
  341. validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
  342. validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
  343. if len(var_shape) > 1:
  344. validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
  345. validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
  346. validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
  347. return [1], [1], [1]
  348. def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
  349. args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
  350. "linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
  351. validator.check_tensor_type_same(args, [mstype.float32], self.name)
  352. validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
  353. return var_dtype, accum_dtype, linear_dtype
  354. class SparseApplyProximalAdagradNoReturn(PrimitiveWithInfer):
  355. r"""
  356. Updates relevant entries according to the proximal adagrad algorithm.
  357. .. math::
  358. accum += grad * grad
  359. .. math::
  360. \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}
  361. .. math::
  362. var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
  363. Args:
  364. use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
  365. Inputs:
  366. - **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
  367. - **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
  368. - **lr** (Tensor): The learning rate value. The data type must be float32.
  369. - **l1** (Tensor): l1 regularization strength. The data type must be float32.
  370. - **l2** (Tensor): l2 regularization strength. The data type must be float32.
  371. - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32.
  372. - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type
  373. must be int32.
  374. Outputs:
  375. Tuple of 2 Tensor, this operator will update the input parameters directly, the outputs are useless.
  376. - **var** (Tensor) - A Tensor with shape (1,).
  377. - **accum** (Tensor) - A Tensor with shape (1,).
  378. Examples:
  379. >>> import numpy as np
  380. >>> import mindspore.nn as nn
  381. >>> from mindspore import Tensor, Parameter
  382. >>> from mindspore.ops import operations as P
  383. >>> class Net(nn.Cell):
  384. >>> def __init__(self):
  385. >>> super(Net, self).__init__()
  386. >>> self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagradV2()
  387. >>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
  388. >>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
  389. >>> self.lr = Tensor(0.01, mstype.float32)
  390. >>> self.l1 = Tensor(0.0, mstype.float32)
  391. >>> self.l2 = Tensor(0.0, mstype.float32)
  392. >>> def construct(self, grad, indices):
  393. >>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
  394. >>> self.l2, grad, indices)
  395. >>> return out
  396. >>> net = Net()
  397. >>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
  398. >>> indices = Tensor(np.array([0, 1]).astype(np.int32))
  399. >>> output = net(grad, indices)
  400. """
  401. __mindspore_signature__ = (
  402. ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  403. ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  404. ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  405. ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  406. ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  407. ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  408. ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
  409. )
  410. @prim_attr_register
  411. def __init__(self, use_locking=False):
  412. self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
  413. outputs=['output'])
  414. self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
  415. self.add_prim_attr('primitive_target', 'CPU')
  416. def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
  417. validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
  418. return [1], [1]
  419. def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
  420. args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
  421. validator.check_tensor_type_same(args, [mstype.float32], self.name)
  422. validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float32], self.name)
  423. validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float32], self.name)
  424. validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float32], self.name)
  425. valid_types = [mstype.int16, mstype.int32, mstype.int64,
  426. mstype.uint16, mstype.uint32, mstype.uint64]
  427. validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
  428. return var_dtype, accum_dtype
  429. class LinSpace(PrimitiveWithInfer):
  430. r"""
  431. Generates values in an interval. And return the corresponding interpolation accroding to assist.
  432. Inputs:
  433. - **assist** (Tensor[float32]) - The assist value, With shape of 0-D or 1-D.
  434. - **start** (Tensor[float32]) - The start of interval, With shape of 0-D.
  435. - **stop** (Tensor[float32]) - The end of interval, With shape of 0-D.
  436. - **num** (Tensor[int32]) - ticks number in the interval, the ticks include start and stop value.
  437. With shape of 0-D.
  438. Outputs:
  439. Tensor, has the same shape as `assist`.
  440. Examples:
  441. >>> linspace = P.LinSpace()
  442. >>> assist = Tensor([5, 5.5], mindspore.float32)
  443. >>> start = Tensor(1, mindspore.float32)
  444. >>> stop = Tensor(10, mindspore.float32)
  445. >>> num = Tensor(5, mindspore.int32)
  446. >>> output = linspace(assist, start, stop, num)
  447. [12.25, 13.375]
  448. """
  449. @prim_attr_register
  450. def __init__(self):
  451. pass
  452. def infer_shape(self, assist, start, stop, num):
  453. return assist
  454. def infer_dtype(self, assist, start, stop, num):
  455. args = {"num": num}
  456. validator.check_tensor_type_same(args, (mstype.int32,), self.name)
  457. args = {"assist": assist, "start": start, "stop": stop}
  458. validator.check_tensor_type_same(args, (mstype.float32,), self.name)
  459. return assist
  460. class MatrixDiag(PrimitiveWithInfer):
  461. """
  462. Returns a batched diagonal tensor with a given batched diagonal values.
  463. Inputs:
  464. - **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be of the following data types:
  465. float32, float16, int32, int8, uint8.
  466. - **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must greater than or equal to 2 and
  467. it's last dimension must equal to the second to last dimension.
  468. Outputs:
  469. Tensor, has the same type and shape as input `assist`.
  470. Examples:
  471. >>> x = Tensor(np.array([1, -1]), mstype.float32)
  472. >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
  473. >>> matrix_diag = P.MatrixDiag()
  474. >>> result = matrix_diag(x, assist)
  475. [[[-12. 11.]
  476. [-10. 9.]]
  477. [[ -8. 7.]
  478. [ -6. 5.]]
  479. [[ -4. 3.]
  480. [ -2. 1.]]]
  481. """
  482. @prim_attr_register
  483. def __init__(self):
  484. """init MatrixDiag"""
  485. def infer_dtype(self, x_dtype, assist_dtype):
  486. valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
  487. args = {"x": x_dtype, "assist": assist_dtype}
  488. validator.check_tensor_type_same(args, valid_type, self.name)
  489. return x_dtype
  490. def infer_shape(self, x_shape, assist_shape):
  491. validator.check_integer("assist rank", len(assist_shape), 2, Rel.GE, self.name)
  492. validator.check('rank of x', len(x_shape)+1,
  493. 'rank of assist', len(assist_shape), Rel.LE, self.name)
  494. validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
  495. assist_shape[-1], Rel.EQ, self.name)
  496. r_end_dim = -len(x_shape)
  497. r_idx = -1
  498. while r_idx >= r_end_dim:
  499. if x_shape[r_idx] != 1:
  500. validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
  501. assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name)
  502. r_idx = r_idx - 1
  503. return assist_shape
  504. class MatrixDiagPart(PrimitiveWithInfer):
  505. r"""
  506. Returns the batched diagonal part of a batched tensor.
  507. Inputs:
  508. - **x** (Tensor) - The batched tensor. It can be of the following data types:
  509. float32, float16, int32, int8, uint8.
  510. - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
  511. Outputs:
  512. Tensor, data type same as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])].
  513. Examples:
  514. >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
  515. >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
  516. >>> matrix_diag_part = P.MatrixDiagPart()
  517. >>> result = matrix_diag_part(x, assist)
  518. [[12., -9.], [8., -5.], [4., -1.]]
  519. """
  520. @prim_attr_register
  521. def __init__(self):
  522. """init MatrixDiagPart"""
  523. def infer_dtype(self, x_dtype, assist_dtype):
  524. valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
  525. args = {"x": x_dtype, "assist": assist_dtype}
  526. validator.check_tensor_type_same(args, valid_type, self.name)
  527. return x_dtype
  528. def infer_shape(self, x_shape, assist_shape):
  529. validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
  530. validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
  531. if assist_shape[-2] < assist_shape[-1]:
  532. out_shape = assist_shape[:-1]
  533. else:
  534. out_shape = assist_shape[:-2] + assist_shape[-1:]
  535. return out_shape
  536. class MatrixSetDiag(PrimitiveWithInfer):
  537. r"""
  538. Modify the batched diagonal part of a batched tensor.
  539. Inputs:
  540. - **x** (Tensor) - The batched tensor. It can be of the following data types:
  541. float32, float16, int32, int8, uint8.
  542. - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
  543. - **diagonal** (Tensor) - The diagonal values.
  544. Outputs:
  545. Tensor, data type same as input `x`. The shape same as `x`.
  546. Examples:
  547. >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
  548. >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
  549. >>> matrix_set_diag = P.MatrixSetDiag()
  550. >>> result = matrix_set_diag(x, diagonal)
  551. [[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
  552. """
  553. @prim_attr_register
  554. def __init__(self):
  555. """init MatrixSetDiag"""
  556. def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
  557. valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
  558. args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
  559. validator.check_tensor_type_same(args, valid_type, self.name)
  560. return x_dtype
  561. def infer_shape(self, x_shape, diagonal_shape, assist_shape):
  562. validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
  563. validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
  564. if x_shape[-2] < x_shape[-1]:
  565. validator.check("x shape excluding the last dimension", x_shape[:-1], "diagnoal shape",
  566. diagonal_shape, Rel.EQ, self.name)
  567. else:
  568. validator.check("x shape excluding the second to last dimension", x_shape[:-2]+x_shape[-1:],
  569. "diagonal shape", diagonal_shape, Rel.EQ, self.name)
  570. return assist_shape