You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic.py 40 kB

4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """basic"""
  16. import math
  17. import numpy as np
  18. import mindspore.common.dtype as mstype
  19. from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
  20. from mindspore.common.seed import _get_graph_seed
  21. from mindspore.common.tensor import Tensor
  22. from mindspore.common.initializer import initializer
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import functional as F
  25. from mindspore.ops.functional import identity
  26. from mindspore.ops.operations import _inner_ops as inner
  27. from mindspore.ops.primitive import constexpr, Primitive
  28. from mindspore.common.parameter import Parameter
  29. from mindspore._extends import cell_attr_register
  30. from mindspore._checkparam import Rel, Validator
  31. from ..cell import Cell
  32. from .activation import get_activation
  33. __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold',
  34. 'Tril', 'Triu', 'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer']
  35. class L1Regularizer(Cell):
  36. r"""
  37. Apply l1 regularization to weights
  38. l1 regularization makes weights sparsity
  39. .. math::
  40. \text{loss}=\lambda * \text{reduce_sum}(\text{abs}(\omega))
  41. Note:
  42. scale(regularization factor) should be a number which greater than 0
  43. Args:
  44. scale (int, float): l1 regularization factor which greater than 0.
  45. Inputs:
  46. - **weights** (Tensor) - The input tensor
  47. Outputs:
  48. Tensor, which dtype is higher precision data type between mindspore.float32 and weights dtype,
  49. and Tensor shape is ()
  50. Raises:
  51. TypeError: If `scale` is neither an int nor float.
  52. ValueError: If `scale` is not greater than 0.
  53. ValueError: If `scale` is math.inf or math.nan.
  54. Supported Platforms:
  55. ``Ascend`` ``GPU`` ``CPU``
  56. Examples:
  57. >>> scale = 0.5
  58. >>> net = nn.L1Regularizer(scale)
  59. >>> weights = Tensor(np.array([[1.0, -2.0], [-3.0, 4.0]]).astype(np.float32))
  60. >>> output = net(weights)
  61. >>> print(output.asnumpy())
  62. 5.0
  63. """
  64. def __init__(self, scale):
  65. super(L1Regularizer, self).__init__()
  66. Validator.check_value_type("scale", scale, [int, float], self.cls_name)
  67. if scale <= 0:
  68. raise ValueError("scale should be a number which greater than 0")
  69. if math.isinf(scale) or math.isnan(scale):
  70. raise ValueError("scale can not be INF or NAN")
  71. self.abs = P.Abs()
  72. self.reduce_sum = P.ReduceSum()
  73. self.scale = Tensor(scale, dtype=mstype.float32)
  74. def construct(self, weights):
  75. const_utils.check_type_valid(F.dtype(weights), mstype.number_type, 'weights')
  76. l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
  77. return l1_regularization
  78. class Dropout(Cell):
  79. r"""
  80. Dropout layer for the input.
  81. Randomly set some elements of the input tensor to zero with probability :math:`1 - keep\_prob` during training
  82. using samples from a Bernoulli distribution.
  83. The outputs are scaled by a factor of :math:`\frac{1}{keep\_prob}` during training so
  84. that the output layer remains at a similar scale. During inference, this
  85. layer returns the same tensor as the input.
  86. This technique is proposed in paper `Dropout: A Simple Way to Prevent Neural Networks from Overfitting
  87. <http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ and proved to be effective to reduce
  88. over-fitting and prevents neurons from co-adaptation. See more details in `Improving neural networks by
  89. preventing co-adaptation of feature detectors
  90. <https://arxiv.org/pdf/1207.0580.pdf>`_.
  91. Note:
  92. Each channel will be zeroed out independently on every construct call.
  93. Args:
  94. keep_prob (float): The keep rate, greater than 0 and less equal than 1. E.g. rate=0.9,
  95. dropping out 10% of input units. Default: 0.5.
  96. dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
  97. Inputs:
  98. - **input** (Tensor) - The input of Dropout with data type of float16 or float32.
  99. Outputs:
  100. Tensor, output tensor with the same shape as the input.
  101. Raises:
  102. TypeError: If `keep_prob` is not a float.
  103. TypeError: If dtype of `input` is not neither float16 nor float32.
  104. ValueError: If `keep_prob` is not in range (0, 1].
  105. ValueError: If length of shape of `input` is less than 1.
  106. Supported Platforms:
  107. ``Ascend`` ``GPU`` ``CPU``
  108. Examples:
  109. >>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
  110. >>> net = nn.Dropout(keep_prob=0.8)
  111. >>> net.set_train()
  112. Dropout<keep_prob=0.8>
  113. >>> output = net(x)
  114. >>> print(output.shape)
  115. (2, 2, 3)
  116. """
  117. def __init__(self, keep_prob=0.5, dtype=mstype.float32):
  118. super(Dropout, self).__init__()
  119. if keep_prob <= 0 or keep_prob > 1:
  120. raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
  121. Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
  122. Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
  123. self.keep_prob = keep_prob
  124. seed0, seed1 = _get_graph_seed(0, "dropout")
  125. self.seed0 = seed0
  126. self.seed1 = seed1
  127. self.dropout = P.Dropout(keep_prob, seed0, seed1)
  128. def construct(self, x):
  129. if not self.training:
  130. return x
  131. if self.keep_prob == 1:
  132. return x
  133. out, _ = self.dropout(x)
  134. return out
  135. def extend_repr(self):
  136. return 'keep_prob={}'.format(self.keep_prob)
  137. class Flatten(Cell):
  138. r"""
  139. Flatten layer for the input.
  140. Flattens a tensor without changing dimension of batch size on the 0-th axis.
  141. Inputs:
  142. - **input** (Tensor) - Tensor of shape :math:`(N, \ldots)` to be flattened.
  143. Outputs:
  144. Tensor, the shape of the output tensor is :math:`(N, X)`, where :math:`X` is
  145. the product of the remaining dimensions.
  146. Raises:
  147. TypeError: If `input` is not a subclass of Tensor.
  148. Supported Platforms:
  149. ``Ascend`` ``GPU`` ``CPU``
  150. Examples:
  151. >>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
  152. >>> net = nn.Flatten()
  153. >>> output = net(input)
  154. >>> print(output)
  155. [[1.2 1.2 2.1 2.1]
  156. [2.2 2.2 3.2 3.2]]
  157. """
  158. def __init__(self):
  159. super(Flatten, self).__init__()
  160. def construct(self, x):
  161. return F.reshape(x, (F.shape(x)[0], -1))
  162. class Dense(Cell):
  163. r"""
  164. The dense connected layer.
  165. Applies dense connected layer for the input. This layer implements the operation as:
  166. .. math::
  167. \text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
  168. where :math:`\text{activation}` is the activation function passed as the activation
  169. argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
  170. data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
  171. with the same data type as the inputs created by the layer (only if has_bias is True).
  172. Args:
  173. in_channels (int): The number of channels in the input space.
  174. out_channels (int): The number of channels in the output space.
  175. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  176. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  177. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  178. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  179. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  180. activation (Union[str, Cell, Primitive]): activate function applied to the output of the fully connected layer,
  181. eg. 'ReLU'.Default: None.
  182. Inputs:
  183. - **input** (Tensor) - Tensor of shape :math:`(*, in\_channels)`.
  184. Outputs:
  185. Tensor of shape :math:`(*, out\_channels)`.
  186. Raises:
  187. TypeError: If `in_channels` or `out_channels` is not an int.
  188. TypeError: If `has_bias` is not a bool.
  189. TypeError: If `activation` is not one of str, Cell, Primitive, None.
  190. ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
  191. is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`.
  192. ValueError: If length of shape of `bias_init` is not equal to 1
  193. or shape[0] of `bias_init` is not equal to `out_channels`.
  194. Supported Platforms:
  195. ``Ascend`` ``GPU`` ``CPU``
  196. Examples:
  197. >>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
  198. >>> net = nn.Dense(3, 4)
  199. >>> output = net(input)
  200. >>> print(output.shape)
  201. (2, 4)
  202. """
  203. @cell_attr_register(attrs=['has_bias', 'activation'])
  204. def __init__(self,
  205. in_channels,
  206. out_channels,
  207. weight_init='normal',
  208. bias_init='zeros',
  209. has_bias=True,
  210. activation=None):
  211. super(Dense, self).__init__()
  212. self.in_channels = Validator.check_positive_int(in_channels)
  213. self.out_channels = Validator.check_positive_int(out_channels)
  214. self.has_bias = Validator.check_bool(has_bias)
  215. self.reshape = P.Reshape()
  216. self.shape_op = P.Shape()
  217. if isinstance(weight_init, Tensor):
  218. if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
  219. weight_init.shape[1] != in_channels:
  220. raise ValueError("Weight init shape error.")
  221. self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
  222. self.bias = None
  223. if self.has_bias:
  224. if isinstance(bias_init, Tensor):
  225. if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
  226. raise ValueError("Bias init shape error.")
  227. self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
  228. self.bias_add = P.BiasAdd()
  229. self.matmul = P.MatMul(transpose_b=True)
  230. self.activation = get_activation(activation) if isinstance(activation, str) else activation
  231. if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
  232. raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
  233. self.activation_flag = self.activation is not None
  234. def construct(self, x):
  235. x_shape = self.shape_op(x)
  236. if len(x_shape) != 2:
  237. x = self.reshape(x, (-1, x_shape[-1]))
  238. x = self.matmul(x, self.weight)
  239. if self.has_bias:
  240. x = self.bias_add(x, self.bias)
  241. if self.activation_flag:
  242. x = self.activation(x)
  243. if len(x_shape) != 2:
  244. out_shape = x_shape[:-1] + (-1,)
  245. x = self.reshape(x, out_shape)
  246. return x
  247. def extend_repr(self):
  248. s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
  249. if self.has_bias:
  250. s += ', has_bias={}'.format(self.has_bias)
  251. if self.activation_flag:
  252. s += ', activation={}'.format(self.activation)
  253. return s
  254. @constexpr
  255. def _is_equal_one(x):
  256. if x is None:
  257. return False
  258. return bool(x.asnumpy().mean() == 1.0)
  259. @constexpr
  260. def _dtype_check(x_dtype):
  261. if x_dtype not in [mstype.float32, mstype.float16]:
  262. raise TypeError("The input type must be float32 or float16.")
  263. @constexpr
  264. def _is_float_dtype(dtype):
  265. if dtype in [mstype.float32, mstype.float16]:
  266. return True
  267. return False
  268. @constexpr
  269. def _need_reduce_all(axis):
  270. if axis == ():
  271. return True
  272. return False
  273. class ClipByNorm(Cell):
  274. r"""
  275. Clips tensor values to a maximum :math:`L_2`-norm.
  276. The output of this layer remains the same if the :math:`L_2`-norm of the input tensor
  277. is not greater than the argument clip_norm. Otherwise the tensor will be normalized as:
  278. .. math::
  279. \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
  280. where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
  281. Args:
  282. axis (Union[None, int, tuple(int)]): Compute the L2-norm along the Specific dimension.
  283. Default: None, all dimensions to calculate.
  284. Inputs:
  285. - **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16.
  286. - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
  287. Or a tensor shape can be broadcast to input shape.
  288. Outputs:
  289. Tensor, clipped tensor with the same shape as the input, whose type is float32.
  290. Raises:
  291. TypeError: If `axis` is not one of None, int, tuple.
  292. TypeError: If dtype of `input` is neither float32 nor float16.
  293. Supported Platforms:
  294. ``Ascend`` ``GPU``
  295. Examples:
  296. >>> net = nn.ClipByNorm()
  297. >>> input = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
  298. >>> clip_norm = Tensor(np.array([100]).astype(np.float32))
  299. >>> output = net(input, clip_norm)
  300. >>> print(output.shape)
  301. (4, 16)
  302. """
  303. def __init__(self, axis=None):
  304. super(ClipByNorm, self).__init__()
  305. if axis is None:
  306. axis = ()
  307. if isinstance(axis, tuple):
  308. for idx, item in enumerate(axis):
  309. Validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name)
  310. self.axis = Validator.check_value_type('axis', axis, [int, tuple], self.cls_name)
  311. self.reduce_sum = P.ReduceSum(keep_dims=True)
  312. self.select_ = P.Select()
  313. self.greater_ = P.Greater()
  314. self.cast = P.Cast()
  315. self.sqrt = P.Sqrt()
  316. self.max_op = P.Maximum()
  317. self.shape = P.Shape()
  318. self.reshape = P.Reshape()
  319. self.fill = P.Fill()
  320. self.expand_dims = P.ExpandDims()
  321. self.dtype = P.DType()
  322. def construct(self, x, clip_norm):
  323. mul_x = F.square(x)
  324. l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32)
  325. cond = self.greater_(l2sum, 0)
  326. ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
  327. l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
  328. l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
  329. _dtype_check(self.dtype(x))
  330. if _is_equal_one(clip_norm):
  331. intermediate = x
  332. else:
  333. intermediate = x * clip_norm
  334. max_norm = self.max_op(l2norm, clip_norm)
  335. if _need_reduce_all(self.axis):
  336. max_norm = self.expand_dims(max_norm, -1)
  337. values_clip = self.cast(intermediate, mstype.float32) / max_norm
  338. values_clip = self.reshape(values_clip, self.shape(x))
  339. values_clip = identity(values_clip)
  340. return values_clip
  341. class Norm(Cell):
  342. r"""
  343. Computes the norm of vectors, currently including Euclidean norm, i.e., :math:`L_2`-norm.
  344. .. math::
  345. norm(x) = \sqrt{\sum_{i=1}^{n} (x_i^2)}
  346. Args:
  347. axis (Union[tuple, int]): The axis over which to compute vector norms. Default: ().
  348. keep_dims (bool): If true, the axis indicated in `axis` are kept with size 1. Otherwise,
  349. the dimensions in `axis` are removed from the output shape. Default: False.
  350. Inputs:
  351. - **input** (Tensor) - Tensor which is not empty.
  352. Outputs:
  353. Tensor, output tensor with dimensions in 'axis' reduced to 1 will be returned if 'keep_dims' is True;
  354. otherwise a Tensor with dimensions in 'axis' removed is returned.
  355. Raises:
  356. TypeError: If `axis` is neither an int nor tuple.
  357. TypeError: If `keep_dims` is not a bool.
  358. Supported Platforms:
  359. ``Ascend`` ``GPU``
  360. Examples:
  361. >>> net = nn.Norm(axis=0)
  362. >>> input = Tensor(np.array([[4, 4, 9, 1], [2, 1, 3, 6]]), mindspore.float32)
  363. >>> output = net(input)
  364. >>> print(output)
  365. [4.472136 4.1231055 9.486833 6.0827627]
  366. """
  367. def __init__(self, axis=(), keep_dims=False):
  368. super(Norm, self).__init__()
  369. Validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name)
  370. self.axis = axis
  371. self.keep_dims = keep_dims
  372. self.reduce_sum = P.ReduceSum(True)
  373. self.sqrt = P.Sqrt()
  374. self.squeeze = P.Squeeze(self.axis)
  375. def construct(self, x):
  376. x = self.sqrt(self.reduce_sum(F.square(x), self.axis))
  377. if not self.keep_dims:
  378. x = self.squeeze(x)
  379. return x
  380. def extend_repr(self):
  381. return 'axis={}, keep_dims={}'.format(self.axis, self.keep_dims)
  382. class OneHot(Cell):
  383. """
  384. Returns a one-hot tensor.
  385. The locations represented by indices in argument `indices` take value on_value,
  386. while all other locations take value off_value.
  387. Note:
  388. If the input indices is rank :math:`N`, the output will have rank :math:`N+1`. The new
  389. axis is created at dimension `axis`.
  390. If `indices` is a scalar, the output shape will be a vector of length `depth`.
  391. If `indices` is a vector of length `features`, the output shape will be:
  392. .. code-block::
  393. features * depth if axis == -1
  394. depth * features if axis == 0
  395. If `indices` is a matrix with shape `[batch, features]`, the output shape will be:
  396. .. code-block::
  397. batch * features * depth if axis == -1
  398. batch * depth * features if axis == 1
  399. depth * batch * features if axis == 0
  400. Args:
  401. axis (int): Features x depth if axis is -1, depth x features
  402. if axis is 0. Default: -1.
  403. depth (int): A scalar defining the depth of the one hot dimension. Default: 1.
  404. on_value (float): A scalar defining the value to fill in output[i][j]
  405. when indices[j] = i. Default: 1.0.
  406. off_value (float): A scalar defining the value to fill in output[i][j]
  407. when indices[j] != i. Default: 0.0.
  408. dtype (:class:`mindspore.dtype`): Data type of 'on_value' and 'off_value', not the
  409. data type of indices. Default: mindspore.float32.
  410. Inputs:
  411. - **indices** (Tensor) - A tensor of indices with data type of int32 or int64 and arbitrary shape.
  412. Outputs:
  413. Tensor, the one-hot tensor of data type `dtype` with dimension at `axis` expanded to `depth` and filled with
  414. on_value and off_value.
  415. Raises:
  416. TypeError: If `axis` or `depth` is not an int.
  417. TypeError: If dtype of `indices` is neither int32 nor int64.
  418. ValueError: If `axis` is not in range [-1, len(indices_shape)].
  419. ValueError: If `depth` is less than 0.
  420. Supported Platforms:
  421. ``Ascend`` ``GPU`` ``CPU``
  422. Examples:
  423. >>> net = nn.OneHot(depth=4, axis=1)
  424. >>> indices = Tensor([[1, 3], [0, 2]], dtype=mindspore.int32)
  425. >>> output = net(indices)
  426. >>> print(output)
  427. [[[0. 0.]
  428. [1. 0.]
  429. [0. 0.]
  430. [0. 1.]]
  431. [[1. 0.]
  432. [0. 0.]
  433. [0. 1.]
  434. [0. 0.]]]
  435. """
  436. def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, dtype=mstype.float32):
  437. super(OneHot, self).__init__()
  438. self.onehot = P.OneHot(axis)
  439. self.depth = depth
  440. self.dtype = dtype
  441. self.on_value = on_value
  442. self.off_value = off_value
  443. def construct(self, indices):
  444. return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype))
  445. class Pad(Cell):
  446. """
  447. Pads the input tensor according to the paddings and mode.
  448. Args:
  449. paddings (tuple): The shape of parameter `paddings` is (N, 2). N is the rank of input data. All elements of
  450. paddings are int type. For `D` th dimension of input, paddings[D, 0] indicates how many sizes to be
  451. extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates how many sizes to
  452. be extended behind of the `D` th dimension of the input tensor. The padded size of each dimension D of the
  453. output is:
  454. .. code-block::
  455. paddings[D, 0] + input_x.dim_size(D) + paddings[D, 1]
  456. mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
  457. Default: "CONSTANT".
  458. Inputs:
  459. - **input_x** (Tensor) - The input tensor.
  460. Outputs:
  461. Tensor, the tensor after padding.
  462. - If `mode` is "CONSTANT", it fills the edge with 0, regardless of the values of the `input_x`.
  463. If the `input_x` is [[1,2,3], [4,5,6], [7,8,9]] and `paddings` is [[1,1], [2,2]], then the
  464. Outputs is [[0,0,0,0,0,0,0], [0,0,1,2,3,0,0], [0,0,4,5,6,0,0], [0,0,7,8,9,0,0], [0,0,0,0,0,0,0]].
  465. - If `mode` is "REFLECT", it uses a way of symmetrical copying through the axis of symmetry to fill in.
  466. If the `input_x` is [[1,2,3], [4,5,6], [7,8,9]] and `paddings` is [[1,1], [2,2]], then the
  467. Outputs is [[6,5,4,5,6,5,4], [3,2,1,2,3,2,1], [6,5,4,5,6,5,4], [9,8,7,8,9,8,7], [6,5,4,5,6,5,4]].
  468. - If `mode` is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied
  469. according to the symmetry axis, except that it includes the symmetry axis. If the `input_x`
  470. is [[1,2,3], [4,5,6], [7,8,9]] and `paddings` is [[1,1], [2,2]], then the Outputs is
  471. [[2,1,1,2,3,3,2], [2,1,1,2,3,3,2], [5,4,4,5,6,6,5], [8,7,7,8,9,9,8], [8,7,7,8,9,9,8]].
  472. Raises:
  473. TypeError: If `paddings` is not a tuple.
  474. ValueError: If length of `paddings` is more than 4 or its shape is not (n, 2).
  475. ValueError: If `mode` is not one of 'CONSTANT', 'REFLECT', 'SYMMETRIC'.
  476. Supported Platforms:
  477. ``Ascend`` ``GPU``
  478. Examples:
  479. >>> from mindspore import Tensor
  480. >>> from mindspore.ops import operations as P
  481. >>> import mindspore.nn as nn
  482. >>> import numpy as np
  483. >>> class Net(nn.Cell):
  484. ... def __init__(self):
  485. ... super(Net, self).__init__()
  486. ... self.pad = nn.Pad(paddings=((1, 1), (2, 2)), mode="CONSTANT")
  487. ... def construct(self, x):
  488. ... return self.pad(x)
  489. >>> x = np.array([[0.3, 0.5, 0.2], [0.5, 0.7, 0.3]], dtype=np.float32)
  490. >>> pad = Net()
  491. >>> output = pad(Tensor(x))
  492. >>> print(output)
  493. [[0. 0. 0. 0. 0. 0. 0. ]
  494. [0. 0. 0.3 0.5 0.2 0. 0. ]
  495. [0. 0. 0.5 0.7 0.3 0. 0. ]
  496. [0. 0. 0. 0. 0. 0. 0. ]]
  497. """
  498. def __init__(self, paddings, mode="CONSTANT"):
  499. super(Pad, self).__init__()
  500. self.mode = mode
  501. self.paddings = paddings
  502. Validator.check_string(self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
  503. if not isinstance(paddings, tuple):
  504. raise TypeError('Paddings must be tuple type.')
  505. for item in paddings:
  506. if len(item) != 2:
  507. raise ValueError('The shape of paddings must be (n, 2).')
  508. if len(paddings) > 4:
  509. raise ValueError('Only padding up to 4 dims is supported')
  510. if mode == "CONSTANT":
  511. self.pad = P.Pad(self.paddings)
  512. else:
  513. self.paddings = Tensor(np.array(self.paddings))
  514. self.pad = P.MirrorPad(mode=mode)
  515. def construct(self, x):
  516. if self.mode == "CONSTANT":
  517. x = self.pad(x)
  518. else:
  519. x = self.pad(x, self.paddings)
  520. return x
  521. @constexpr
  522. def bilinear(shape, size, scale, align_corners):
  523. """Check input and calculate shape"""
  524. if not isinstance(align_corners, bool):
  525. raise TypeError("align_corners should be type boolean")
  526. if size is None and scale is None:
  527. raise ValueError("size and scale both none")
  528. if size is not None and scale is not None:
  529. raise ValueError("size and scale both not none")
  530. if size is not None:
  531. if not isinstance(size, (tuple, list)):
  532. raise ValueError("size must be tuple or list")
  533. Validator.check_int(len(size), 2, Rel.EQ, "size", "bilinear")
  534. Validator.check_int(size[0], 1, Rel.GE, "size[0]", "bilinear")
  535. Validator.check_int(size[1], 1, Rel.GE, "size[1]", "bilinear")
  536. return size
  537. Validator.check_int(scale, 1, Rel.GE, "scale factor", "bilinear")
  538. ret = (scale * shape[2], scale * shape[3])
  539. return ret
  540. class ResizeBilinear(Cell):
  541. r"""
  542. Samples the input tensor to the given size or scale_factor by using bilinear interpolate.
  543. Inputs:
  544. - **x** (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape:
  545. math:`(batch, channels, height, width)`, with data type of float16 or float32.
  546. - **size** (Union[tuple[int], list[int]]): A tuple or list of 2 int elements '(new_height, new_width)',
  547. the new size of the tensor. One and only one of size and scale_factor can be set to None. Default: None.
  548. - **scale_factor** (int): The scale factor of new size of the tensor. The value should be positive integer.
  549. One and only one of size and scale_factor can be set to None. Default: None.
  550. - **align_corners** (bool): If true, rescale input by '(new_height - 1) / (height - 1)', which exactly aligns
  551. the 4 corners of images and resized images. If false, rescale by 'new_height / height'. Default: False.
  552. Outputs:
  553. Resized tensor.
  554. If size is set, the result is 4-D tensor with shape:math:`(batch, channels, new_height, new_width)`
  555. in float32.
  556. If scale is set, the result is 4-D tensor with shape:math:`(batch, channels, scale_factor * height,
  557. scale_factor * width)` in float32
  558. Raises:
  559. TypeError: If `size` is not one of tuple, list, None.
  560. TypeError: If `scale_factor` is neither int nor None.
  561. TypeError: If `align_corners` is not a bool.
  562. TypeError: If dtype of `x` is neither float16 nor float32.
  563. ValueError: If `size` and `scale_factor` are both None or not None.
  564. ValueError: If length of shape of `x` is not equal to 4.
  565. ValueError: If `scale_factor` is an int which is less than 1.
  566. ValueError: If `size` is a list or tuple whose length is not equal to 2.
  567. Supported Platforms:
  568. ``Ascend``
  569. Examples:
  570. >>> tensor = Tensor([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], mindspore.float32)
  571. >>> resize_bilinear = nn.ResizeBilinear()
  572. >>> result = resize_bilinear(tensor, size=(5,5))
  573. >>> print(result.shape)
  574. (1, 1, 5, 5)
  575. """
  576. def __init__(self):
  577. super(ResizeBilinear, self).__init__()
  578. def construct(self, x, size=None, scale_factor=None, align_corners=False):
  579. shape = bilinear(x.shape, size, scale_factor, align_corners)
  580. resize_bilinear = P.ResizeBilinear(shape, align_corners)
  581. return resize_bilinear(x)
  582. class Unfold(Cell):
  583. """
  584. Extract patches from images.
  585. The input tensor must be a 4-D tensor and the data format is NCHW.
  586. Args:
  587. ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers,
  588. and the format is [1, ksize_row, ksize_col, 1].
  589. strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
  590. must be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
  591. rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension
  592. pixel positions, must be a tuple or a list of integers, and the format is [1, rate_row, rate_col, 1].
  593. padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", not case sensitive.
  594. Default: "valid".
  595. - same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
  596. - valid: Means that the taken patch area must be completely covered in the original image.
  597. Inputs:
  598. - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_depth, in_row, in_col] and
  599. data type is number.
  600. Outputs:
  601. Tensor, a 4-D tensor whose data type is same as `input_x`,
  602. and the shape is [out_batch, out_depth, out_row, out_col] where `out_batch` is the same as the `in_batch`.
  603. .. code-block::
  604. out_depth = ksize_row * ksize_col * in_depth
  605. out_row = (in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1
  606. out_col = (in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1
  607. Raises:
  608. TypeError: If `ksizes`, `strides` or `rates` is neither a tuple nor list.
  609. ValueError: If shape of `ksizes`, `strides` or `rates` is not (1, x_row, x_col, 1).
  610. ValueError: If the second and third element of `ksizes`, `strides` or `rates` is less than 1.
  611. Supported Platforms:
  612. ``Ascend``
  613. Examples:
  614. >>> net = Unfold(ksizes=[1, 2, 2, 1], strides=[1, 2, 2, 1], rates=[1, 2, 2, 1])
  615. >>> image = Tensor(np.ones([2, 3, 6, 6]), dtype=mstype.float16)
  616. >>> output = net(image)
  617. >>> print(output.shape)
  618. (2, 12, 2, 2)
  619. """
  620. def __init__(self, ksizes, strides, rates, padding="valid"):
  621. super(Unfold, self).__init__()
  622. def _check_tuple_or_list(arg_name, arg_val, prim_name):
  623. Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.cls_name)
  624. if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
  625. raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
  626. f"{arg_name}_col, 1], but got {arg_val}.")
  627. if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
  628. raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
  629. f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
  630. f"is {arg_val[2]}")
  631. _check_tuple_or_list("ksize", ksizes, self.cls_name)
  632. _check_tuple_or_list("stride", strides, self.cls_name)
  633. _check_tuple_or_list("rate", rates, self.cls_name)
  634. ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2]
  635. strides = strides[0], strides[3], strides[1], strides[2]
  636. rates = rates[0], rates[3], rates[1], rates[2]
  637. self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding)
  638. def construct(self, input_x):
  639. result = self.extract_image_patches(input_x)
  640. return result
  641. @constexpr
  642. def tril(x_shape, x_dtype, k):
  643. Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "tril")
  644. Validator.check_is_int(k, "k value", "tril")
  645. mask = np.tril(np.ones(x_shape), k)
  646. return Tensor(mask, x_dtype)
  647. class Tril(Cell):
  648. """
  649. Returns a tensor with elements above the kth diagonal zeroed.
  650. Inputs:
  651. - **x** (Tensor) - The input tensor.
  652. - **k** (Int) - The index of diagonal. Default: 0
  653. Outputs:
  654. Tensor, has the same type as input `x`.
  655. Raises:
  656. TypeError: If `k` is not an int.
  657. ValueError: If length of shape of `x` is less than 1.
  658. Supported Platforms:
  659. ``Ascend`` ``GPU`` ``CPU``
  660. Examples:
  661. >>> x = Tensor(np.array([[1, 2], [3, 4]]))
  662. >>> tril = nn.Tril()
  663. >>> result = tril(x)
  664. >>> print(result)
  665. [[1 0]
  666. [3 4]]
  667. """
  668. def __init__(self):
  669. super(Tril, self).__init__()
  670. self.dtype = P.DType()
  671. self.mul = P.Mul()
  672. self.cast = P.Cast()
  673. def construct(self, x, k=0):
  674. assist = tril(x.shape, self.dtype(x), k)
  675. result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
  676. return self.cast(result, self.dtype(x))
  677. @constexpr
  678. def triu(x_shape, x_dtype, k):
  679. Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "triu")
  680. Validator.check_is_int(k, "k value", "triu")
  681. mask = np.triu(np.ones(x_shape), k)
  682. return Tensor(mask, x_dtype)
  683. class Triu(Cell):
  684. """
  685. Returns a tensor with elements below the kth diagonal zeroed.
  686. Inputs:
  687. - **x** (Tensor) - The input tensor.
  688. - **k** (Int) - The index of diagonal. Default: 0
  689. Outputs:
  690. Tensor, has the same type as input `x`.
  691. Raises:
  692. TypeError: If `k` is not an int.
  693. ValueError: If length of shape of `x` is less than 1.
  694. Supported Platforms:
  695. ``Ascend`` ``GPU`` ``CPU``
  696. Examples:
  697. >>> x = Tensor(np.array([[1, 2], [3, 4]]))
  698. >>> triu = nn.Triu()
  699. >>> result = triu(x)
  700. >>> print(result)
  701. [[1 2]
  702. [0 4]]
  703. """
  704. def __init__(self):
  705. super(Triu, self).__init__()
  706. self.dtype = P.DType()
  707. self.mul = P.Mul()
  708. self.cast = P.Cast()
  709. def construct(self, x, k=0):
  710. assist = triu(x.shape, self.dtype(x), k)
  711. result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
  712. return self.cast(result, self.dtype(x))
  713. @constexpr
  714. def _get_matrix_diag_assist(x_shape, x_dtype):
  715. Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist")
  716. base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1)
  717. assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],))
  718. return Tensor(assist, x_dtype)
  719. @constexpr
  720. def _get_matrix_diag_part_assist(x_shape, x_dtype):
  721. Validator.check_int(len(x_shape), 2, Rel.GE, "x rank", "_get_matrix_diag_part_assist")
  722. base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1)
  723. assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape)
  724. return Tensor(assist, x_dtype)
  725. class MatrixDiag(Cell):
  726. r"""
  727. Returns a batched diagonal tensor with a given batched diagonal values.
  728. Assume `x` has :math:`k` dimensions :math:`[I, J, K, ..., N]`, then the output is a tensor of rank
  729. :math:`k+1` with dimensions :math:`[I, J, K, ..., N, N]` where:
  730. .. code-block::
  731. output[i, j, k, ..., m, n] = 1{m=n} * x[i, j, k, ..., n]
  732. Inputs:
  733. - **x** (Tensor) - The diagonal values. It can be one of the following data types:
  734. float32, float16, int32, int8, and uint8.
  735. Outputs:
  736. Tensor, has the same type as input `x`. The shape must be x.shape + (x.shape[-1], ).
  737. Raises:
  738. TypeError: If dtype of `x` is not one of float32, float16, int32, int8 or uint8.
  739. Supported Platforms:
  740. ``Ascend``
  741. Examples:
  742. >>> x = Tensor(np.array([1, -1]), mstype.float32)
  743. >>> matrix_diag = nn.MatrixDiag()
  744. >>> output = matrix_diag(x)
  745. >>> print(output)
  746. [[ 1. 0.]
  747. [ 0. -1.]]
  748. """
  749. def __init__(self):
  750. super(MatrixDiag, self).__init__()
  751. self.matrix_diag = inner.MatrixDiag()
  752. self.dtype = P.DType()
  753. def construct(self, input_x):
  754. x_shape = F.shape(input_x)
  755. x_dtype = self.dtype(input_x)
  756. assist = _get_matrix_diag_assist(x_shape, x_dtype)
  757. out_matrix_diag = self.matrix_diag(input_x, assist)
  758. return out_matrix_diag
  759. class MatrixDiagPart(Cell):
  760. r"""
  761. Returns the batched diagonal part of a batched tensor.
  762. Assume `x` has :math:`k` dimensions :math:`[I, J, K, ..., M, N]`, then the output is a tensor of rank
  763. :math:`k-1` with dimensions :math:`[I, J, K, ..., min(M, N)]` where:
  764. .. code-block::
  765. output[i, j, k, ..., n] = x[i, j, k, ..., n, n]
  766. Inputs:
  767. - **x** (Tensor) - The batched tensor. It can be one of the following data types:
  768. float32, float16, int32, int8, and uint8.
  769. Outputs:
  770. Tensor, has the same type as input `x`. The shape must be x.shape[:-2] + [min(x.shape[-2:])].
  771. Raises:
  772. TypeError: If dtype of `x` is not one of float32, float16, int32, int8 or uint8.
  773. Supported Platforms:
  774. ``Ascend``
  775. Examples:
  776. >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
  777. >>> matrix_diag_part = nn.MatrixDiagPart()
  778. >>> output = matrix_diag_part(x)
  779. >>> print(output)
  780. [[-1. 1.]
  781. [-1. 1.]
  782. [-1. 1.]]
  783. """
  784. def __init__(self):
  785. super(MatrixDiagPart, self).__init__()
  786. self.matrix_diag_part = inner.MatrixDiagPart()
  787. self.dtype = P.DType()
  788. def construct(self, input_x):
  789. x_shape = F.shape(input_x)
  790. x_dtype = self.dtype(input_x)
  791. assist = _get_matrix_diag_part_assist(x_shape, x_dtype)
  792. out_matrix_diag_part = self.matrix_diag_part(input_x, assist)
  793. return out_matrix_diag_part
  794. class MatrixSetDiag(Cell):
  795. r"""
  796. Modifies the batched diagonal part of a batched tensor.
  797. Assume `x` has :math:`k+1` dimensions :math:`[I, J, K, ..., M, N]` and `diagonal` has :math:`k`
  798. dimensions :math:`[I, J, K, ..., min(M, N)]`. Then the output is a tensor of rank :math:`k+1` with dimensions
  799. :math:`[I, J, K, ..., M, N]` where:
  800. .. code-block::
  801. output[i, j, k, ..., m, n] = diagnoal[i, j, k, ..., n] for m == n
  802. output[i, j, k, ..., m, n] = x[i, j, k, ..., m, n] for m != n
  803. Inputs:
  804. - **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
  805. float32, float16, int32, int8, and uint8.
  806. - **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
  807. Outputs:
  808. Tensor, has the same type and shape as input `x`.
  809. Raises:
  810. TypeError: If dtype of `x` or `diagonal` is not one of float32, float16, int32, int8 or uint8.
  811. ValueError: If length of shape of `x` is less than 2.
  812. ValueError: If x_shape[-2] < x_shape[-1] and x_shape[:-1] != diagonal_shape.
  813. ValueError: If x_shape[-2] >= x_shape[-1] and x_shape[:-2] + x_shape[-1:] != diagonal_shape.
  814. Supported Platforms:
  815. ``Ascend``
  816. Examples:
  817. >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
  818. >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
  819. >>> matrix_set_diag = nn.MatrixSetDiag()
  820. >>> output = matrix_set_diag(x, diagonal)
  821. >>> print(output)
  822. [[[-1. 0.]
  823. [ 0. 2.]]
  824. [[-1. 0.]
  825. [ 0. 1.]]
  826. [[-1. 0.]
  827. [ 0. 1.]]]
  828. """
  829. def __init__(self):
  830. super(MatrixSetDiag, self).__init__()
  831. self.matrix_set_diag = inner.MatrixSetDiag()
  832. self.dtype = P.DType()
  833. def construct(self, input_x, diagonal):
  834. x_shape = F.shape(input_x)
  835. x_dtype = self.dtype(input_x)
  836. assist = _get_matrix_diag_part_assist(x_shape, x_dtype)
  837. out_matrix_set_diag = self.matrix_set_diag(input_x, diagonal, assist)
  838. return out_matrix_set_diag