You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic.py 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """basic"""
  16. import numpy as np
  17. import mindspore.common.dtype as mstype
  18. from mindspore.common.seed import _get_graph_seed
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.common.initializer import initializer
  21. from mindspore.ops import operations as P
  22. from mindspore.ops import functional as F
  23. from mindspore.ops.functional import identity
  24. from mindspore.ops.operations import _inner_ops as inner
  25. from mindspore.ops.primitive import constexpr, Primitive
  26. from mindspore.common.parameter import Parameter
  27. from mindspore._extends import cell_attr_register
  28. from mindspore._checkparam import Rel, Validator
  29. from mindspore.common.api import ms_function
  30. from mindspore import context
  31. from ..cell import Cell
  32. from .activation import get_activation
  33. __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold',
  34. 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag']
  35. class Dropout(Cell):
  36. r"""
  37. Dropout layer for the input.
  38. Randomly set some elements of the input tensor to zero with probability :math:`1 - keep\_prob` during training
  39. using samples from a Bernoulli distribution.
  40. Note:
  41. Each channel will be zeroed out independently on every construct call.
  42. The outputs are scaled by a factor of :math:`\frac{1}{keep\_prob}` during training so
  43. that the output layer remains at a similar scale. During inference, this
  44. layer returns the same tensor as the input.
  45. This technique is proposed in paper `Dropout: A Simple Way to Prevent Neural Networks from Overfitting
  46. <http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ and proved to be effective to reduce
  47. over-fitting and prevents neurons from co-adaptation. See more details in `Improving neural networks by
  48. preventing co-adaptation of feature detectors
  49. <https://arxiv.org/pdf/1207.0580.pdf>`_.
  50. Args:
  51. keep_prob (float): The keep rate, greater than 0 and less equal than 1. E.g. rate=0.9,
  52. dropping out 10% of input units. Default: 0.5.
  53. dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
  54. Raises:
  55. ValueError: If `keep_prob` is not in range (0, 1].
  56. Inputs:
  57. - **input** (Tensor) - The input tensor.
  58. Outputs:
  59. Tensor, output tensor with the same shape as the input.
  60. Examples:
  61. >>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
  62. >>> net = nn.Dropout(keep_prob=0.8)
  63. >>> net(x)
  64. [[[1.0, 1.0, 1.0],
  65. [1.0, 1.0, 1.0]],
  66. [[1.0, 1.0, 1.0],
  67. [1.0, 1.0, 1.0]]]
  68. """
  69. def __init__(self, keep_prob=0.5, dtype=mstype.float32):
  70. super(Dropout, self).__init__()
  71. if keep_prob <= 0 or keep_prob > 1:
  72. raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
  73. Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
  74. Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
  75. self.keep_prob = keep_prob
  76. seed0, seed1 = _get_graph_seed(0, "dropout")
  77. self.seed0 = seed0
  78. self.seed1 = seed1
  79. self.dtype = dtype
  80. self.get_shape = P.Shape()
  81. self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
  82. self.dropout_do_mask = P.DropoutDoMask()
  83. self.cast = P.Cast()
  84. self.is_gpu = context.get_context('device_target') in ["GPU"]
  85. self.dropout = P.Dropout(keep_prob)
  86. def construct(self, x):
  87. if not self.training:
  88. return x
  89. if self.is_gpu:
  90. out, _ = self.dropout(x)
  91. return out
  92. if self.keep_prob == 1:
  93. return x
  94. shape = self.get_shape(x)
  95. dtype = P.DType()(x)
  96. if _is_float_dtype(dtype):
  97. keep_prob = self.cast(self.keep_prob, dtype)
  98. else:
  99. keep_prob = self.cast(self.keep_prob, mstype.float16)
  100. output = self.dropout_gen_mask(shape, keep_prob)
  101. return self.dropout_do_mask(x, output, keep_prob)
  102. def extend_repr(self):
  103. str_info = 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
  104. return str_info
  105. class Flatten(Cell):
  106. r"""
  107. Flatten layer for the input.
  108. Flattens a tensor without changing dimension of batch size on the 0-th axis.
  109. Inputs:
  110. - **input** (Tensor) - Tensor of shape :math:`(N, \ldots)` to be flattened.
  111. Outputs:
  112. Tensor, the shape of the output tensor is :math:`(N, X)`, where :math:`X` is
  113. the product of the remaining dimensions.
  114. Examples:
  115. >>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
  116. >>> net = nn.Flatten()
  117. >>> net(input)
  118. [[1.2 1.2 2.1 2.1]
  119. [2.2 2.2 3.2 3.2]]
  120. """
  121. def __init__(self):
  122. super(Flatten, self).__init__()
  123. def construct(self, x):
  124. return F.reshape(x, (F.shape(x)[0], -1))
  125. class Dense(Cell):
  126. r"""
  127. The dense connected layer.
  128. Applies dense connected layer for the input. This layer implements the operation as:
  129. .. math::
  130. \text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
  131. where :math:`\text{activation}` is the activation function passed as the activation
  132. argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
  133. data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
  134. with the same data type as the inputs created by the layer (only if has_bias is True).
  135. Args:
  136. in_channels (int): The number of channels in the input space.
  137. out_channels (int): The number of channels in the output space.
  138. weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
  139. is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
  140. bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
  141. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
  142. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
  143. activation (Union[str, Cell, Primitive]): activate function applied to the output of the fully connected layer,
  144. eg. 'ReLU'.Default: None.
  145. Raises:
  146. ValueError: If weight_init or bias_init shape is incorrect.
  147. Inputs:
  148. - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
  149. Outputs:
  150. Tensor of shape :math:`(N, out\_channels)`.
  151. Examples:
  152. >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
  153. >>> net = nn.Dense(3, 4)
  154. >>> net(input)
  155. [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
  156. [ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
  157. """
  158. @cell_attr_register(attrs=['has_bias', 'activation'])
  159. def __init__(self,
  160. in_channels,
  161. out_channels,
  162. weight_init='normal',
  163. bias_init='zeros',
  164. has_bias=True,
  165. activation=None):
  166. super(Dense, self).__init__()
  167. self.in_channels = Validator.check_positive_int(in_channels)
  168. self.out_channels = Validator.check_positive_int(out_channels)
  169. self.has_bias = Validator.check_bool(has_bias)
  170. if isinstance(weight_init, Tensor):
  171. if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
  172. weight_init.shape[1] != in_channels:
  173. raise ValueError("Weight init shape error.")
  174. self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
  175. self.bias = None
  176. if self.has_bias:
  177. if isinstance(bias_init, Tensor):
  178. if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
  179. raise ValueError("Bias init shape error.")
  180. self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
  181. self.bias_add = P.BiasAdd()
  182. self.matmul = P.MatMul(transpose_b=True)
  183. self.activation = get_activation(activation) if isinstance(activation, str) else activation
  184. if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
  185. raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
  186. self.activation_flag = self.activation is not None
  187. def construct(self, x):
  188. x = self.matmul(x, self.weight)
  189. if self.has_bias:
  190. x = self.bias_add(x, self.bias)
  191. if self.activation_flag:
  192. x = self.activation(x)
  193. return x
  194. def extend_repr(self):
  195. s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
  196. if self.has_bias:
  197. s += ', has_bias={}'.format(self.has_bias)
  198. if self.activation_flag:
  199. s += ', activation={}'.fomat(self.activation)
  200. return s
  201. @constexpr
  202. def _is_equal_one(x):
  203. if x is None:
  204. return False
  205. return bool(x.asnumpy().mean() == 1.0)
  206. @constexpr
  207. def _dtype_check(x_dtype):
  208. if x_dtype not in [mstype.float32, mstype.float16]:
  209. raise TypeError("The input type must be float32 or float16.")
  210. @constexpr
  211. def _is_float_dtype(dtype):
  212. if dtype in [mstype.float32, mstype.float16]:
  213. return True
  214. return False
  215. class ClipByNorm(Cell):
  216. r"""
  217. Clips tensor values to a maximum :math:`L_2`-norm.
  218. The output of this layer remains the same if the :math:`L_2`-norm of the input tensor
  219. is not greater than the argument clip_norm. Otherwise the tensor will be normalized as:
  220. .. math::
  221. \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
  222. where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
  223. Inputs:
  224. - **input** (Tensor) - Tensor of shape N-D. The type must be float32 or float16.
  225. - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
  226. Outputs:
  227. Tensor, clipped tensor with the same shape as the input, whose type is float32.
  228. Examples:
  229. >>> net = nn.ClipByNorm()
  230. >>> input = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
  231. >>> clip_norm = Tensor(np.array([100]).astype(np.float32))
  232. >>> net(input, clip_norm)
  233. """
  234. def __init__(self):
  235. super(ClipByNorm, self).__init__()
  236. self.reduce_sum = P.ReduceSum(keep_dims=True)
  237. self.select_ = P.Select()
  238. self.greater_ = P.Greater()
  239. self.cast = P.Cast()
  240. self.sqrt = P.Sqrt()
  241. self.max_op = P.Maximum()
  242. self.shape = P.Shape()
  243. self.reshape = P.Reshape()
  244. self.fill = P.Fill()
  245. self.expand_dims = P.ExpandDims()
  246. self.dtype = P.DType()
  247. @ms_function
  248. def construct(self, x, clip_norm):
  249. """add ms_function decorator for pynative mode"""
  250. mul_x = F.square(x)
  251. l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
  252. cond = self.greater_(l2sum, 0)
  253. ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
  254. l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
  255. l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
  256. _dtype_check(self.dtype(x))
  257. if _is_equal_one(clip_norm):
  258. intermediate = x
  259. else:
  260. intermediate = x * clip_norm
  261. max_norm = self.max_op(l2norm, clip_norm)
  262. values_clip = self.cast(intermediate, mstype.float32) / self.expand_dims(max_norm, -1)
  263. values_clip = self.reshape(values_clip, self.shape(x))
  264. values_clip = identity(values_clip)
  265. return values_clip
  266. class Norm(Cell):
  267. """
  268. Computes the norm of vectors, currently including Euclidean norm, i.e., :math:`L_2`-norm.
  269. Args:
  270. axis (Union[tuple, int]): The axis over which to compute vector norms. Default: ().
  271. keep_dims (bool): If true, the axis indicated in `axis` are kept with size 1. Otherwise,
  272. the dimensions in `axis` are removed from the output shape. Default: False.
  273. Inputs:
  274. - **input** (Tensor) - Tensor which is not empty.
  275. Outputs:
  276. Tensor, output tensor with dimensions in 'axis' reduced to 1 will be returned if 'keep_dims' is True;
  277. otherwise a Tensor with dimensions in 'axis' removed is returned.
  278. Examples:
  279. >>> net = nn.Norm(axis=0)
  280. >>> input = Tensor(np.random.randint(0, 10, [2, 4]), mindspore.float32)
  281. >>> net(input)
  282. [2.236068 9.848858 4. 5.656854]
  283. """
  284. def __init__(self, axis=(), keep_dims=False):
  285. super(Norm, self).__init__()
  286. Validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name)
  287. self.axis = axis
  288. self.keep_dims = keep_dims
  289. self.reduce_sum = P.ReduceSum(True)
  290. self.sqrt = P.Sqrt()
  291. self.squeeze = P.Squeeze(self.axis)
  292. def construct(self, x):
  293. x = self.sqrt(self.reduce_sum(F.square(x), self.axis))
  294. if not self.keep_dims:
  295. x = self.squeeze(x)
  296. return x
  297. def extend_repr(self):
  298. str_info = 'axis={}, keep_dims={}'.format(self.axis, self.keep_dims)
  299. return str_info
  300. class OneHot(Cell):
  301. """
  302. Returns a one-hot tensor.
  303. The locations represented by indices in argument 'indices' take value on_value,
  304. while all other locations take value off_value.
  305. Note:
  306. If the input indices is rank :math:`N`, the output will have rank :math:`N+1`. The new
  307. axis is created at dimension `axis`.
  308. Args:
  309. axis (int): Features x depth if axis is -1, depth x features
  310. if axis is 0. Default: -1.
  311. depth (int): A scalar defining the depth of the one hot dimension. Default: 1.
  312. on_value (float): A scalar defining the value to fill in output[i][j]
  313. when indices[j] = i. Default: 1.0.
  314. off_value (float): A scalar defining the value to fill in output[i][j]
  315. when indices[j] != i. Default: 0.0.
  316. dtype (:class:`mindspore.dtype`): Data type of 'on_value' and 'off_value', not the
  317. data type of indices. Default: mindspore.float32.
  318. Inputs:
  319. - **indices** (Tensor) - A tensor of indices of data type mindspore.int32 and arbitrary shape.
  320. Outputs:
  321. Tensor, the one-hot tensor of data type 'dtype' with dimension at 'axis' expanded to 'depth' and filled with
  322. on_value and off_value.
  323. Examples:
  324. >>> net = nn.OneHot(depth=4, axis=1)
  325. >>> indices = Tensor([[1, 3], [0, 2]], dtype=mindspore.int32)
  326. >>> net(indices)
  327. [[[0. 0.]
  328. [1. 0.]
  329. [0. 0.]
  330. [0. 1.]]
  331. [[1. 0.]
  332. [0. 0.]
  333. [0. 1.]
  334. [0. 0.]]]
  335. """
  336. def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, dtype=mstype.float32):
  337. super(OneHot, self).__init__()
  338. self.onehot = P.OneHot(axis)
  339. self.depth = depth
  340. self.dtype = dtype
  341. self.on_value = on_value
  342. self.off_value = off_value
  343. def construct(self, indices):
  344. return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype))
  345. class Pad(Cell):
  346. """
  347. Pads the input tensor according to the paddings and mode.
  348. Args:
  349. paddings (tuple): The shape of parameter `paddings` is (N, 2). N is the rank of input data. All elements of
  350. paddings are int type. For `D` th dimension of input, paddings[D, 0] indicates how many sizes to be
  351. extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates how many sizes to
  352. be extended behind of the `D` th dimension of the input tensor.
  353. mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
  354. Default: "CONSTANT".
  355. Inputs:
  356. - **input_x** (Tensor) - The input tensor.
  357. Outputs:
  358. Tensor, the tensor after padding.
  359. - If `mode` is "CONSTANT", it fills the edge with 0, regardless of the values of the `input_x`.
  360. If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the
  361. Outputs is [[0,0,0,0,0,0,0],[0,0,1,2,3,0,0],[0,0,4,5,6,0,0],[0,0,7,8,9,0,0],[0,0,0,0,0,0,0]].
  362. - If `mode` is "REFLECT", it uses a way of symmetrical copying throught the axis of symmetry to fill in.
  363. If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the
  364. Outputs is [[6,5,4,5,6,5,4],[3,2,1,2,3,2,1],[6,5,4,5,6,5,4],[9,8,7,8,9,8,7],[6,5,4,5,6,5,4]].
  365. - If `mode` is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied
  366. according to the symmetry axis, except that it includes the symmetry axis. If the `input_x`
  367. is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the Outputs is
  368. [[2,1,1,2,3,3,2],[2,1,1,2,3,3,2],[5,4,4,5,6,6,5],[8,7,7,8,9,9,8],[8,7,7,8,9,9,8]].
  369. Examples:
  370. >>> from mindspore import Tensor
  371. >>> from mindspore.ops import operations as P
  372. >>> import mindspore.nn as nn
  373. >>> import numpy as np
  374. >>> class Net(nn.Cell):
  375. >>> def __init__(self):
  376. >>> super(Net, self).__init__()
  377. >>> self.pad = nn.Pad(paddings=((1,1),(2,2)), mode="CONSTANT")
  378. >>> def construct(self, x):
  379. >>> return self.pad(x)
  380. >>> x = np.random.random(size=(2, 3)).astype(np.float32)
  381. >>> pad = Net()
  382. >>> ms_output = pad(Tensor(x))
  383. """
  384. def __init__(self, paddings, mode="CONSTANT"):
  385. super(Pad, self).__init__()
  386. self.mode = mode
  387. self.paddings = paddings
  388. Validator.check_string(self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
  389. if not isinstance(paddings, tuple):
  390. raise TypeError('Paddings must be tuple type.')
  391. for item in paddings:
  392. if len(item) != 2:
  393. raise ValueError('The shape of paddings must be (n, 2).')
  394. if len(paddings) > 4:
  395. raise ValueError('Only padding up to 4 dims is supported')
  396. if mode == "CONSTANT":
  397. self.pad = P.Pad(self.paddings)
  398. else:
  399. self.paddings = Tensor(np.array(self.paddings))
  400. self.pad = P.MirrorPad(mode=mode)
  401. def construct(self, x):
  402. if self.mode == "CONSTANT":
  403. x = self.pad(x)
  404. else:
  405. x = self.pad(x, self.paddings)
  406. return x
  407. class Unfold(Cell):
  408. """
  409. Extract patches from images.
  410. The input tensor must be a 4-D tensor and the data format is NCHW.
  411. Args:
  412. ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers,
  413. and the format is [1, ksize_row, ksize_col, 1].
  414. strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
  415. must be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
  416. rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension
  417. pixel positions, must be a tuple or a list of integers, and the format is [1, rate_row, rate_col, 1].
  418. padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
  419. not case sensitive. Default: "valid".
  420. - same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
  421. - valid: Means that the taken patch area must be completely covered in the original image.
  422. Inputs:
  423. - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_depth, in_row, in_col] and
  424. data type is number.
  425. Outputs:
  426. Tensor, a 4-D tensor whose data type is same as 'input_x',
  427. and the shape is [out_batch, out_depth, out_row, out_col], the out_batch is the same as the in_batch.
  428. Examples:
  429. >>> net = Unfold(ksizes=[1, 2, 2, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1])
  430. >>> image = Tensor(np.ones([1, 1, 3, 3]), dtype=mstype.float16)
  431. >>> net(image)
  432. Tensor ([[[[1, 1] [1, 1]] [[1, 1], [1, 1]] [[1, 1] [1, 1]], [[1, 1], [1, 1]]]],
  433. shape=(1, 4, 2, 2), dtype=mstype.float16)
  434. """
  435. def __init__(self, ksizes, strides, rates, padding="valid"):
  436. super(Unfold, self).__init__()
  437. self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding)
  438. self.transpose = P.Transpose()
  439. self.format_NHWC = (0, 2, 3, 1)
  440. self.format_NCHW = (0, 3, 1, 2)
  441. self.is_ge = context.get_context("enable_ge")
  442. def construct(self, input_x):
  443. if self.is_ge:
  444. x_transpose = self.transpose(input_x, self.format_NHWC)
  445. ret = self.extract_image_patches(x_transpose)
  446. result = self.transpose(ret, self.format_NCHW)
  447. else:
  448. result = self.extract_image_patches(input_x)
  449. return result
  450. @constexpr
  451. def _get_matrix_diag_assist(x_shape, x_dtype):
  452. Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist")
  453. base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1)
  454. assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],))
  455. return Tensor(assist, x_dtype)
  456. @constexpr
  457. def _get_matrix_diag_part_assist(x_shape, x_dtype):
  458. Validator.check_int(len(x_shape), 2, Rel.GE, "x rank", "_get_matrix_diag_part_assist")
  459. base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1)
  460. assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape)
  461. return Tensor(assist, x_dtype)
  462. class MatrixDiag(Cell):
  463. """
  464. Returns a batched diagonal tensor with a given batched diagonal values.
  465. Inputs:
  466. - **x** (Tensor) - The diagonal values. It can be one of the following data types:
  467. float32, float16, int32, int8, and uint8.
  468. Outputs:
  469. Tensor, has the same type as input `x`. The shape must be x.shape + (x.shape[-1], ).
  470. Examples:
  471. >>> x = Tensor(np.array([1, -1]), mstype.float32)
  472. >>> matrix_diag = nn.MatrixDiag()
  473. >>> result = matrix_diag(x)
  474. [[1. 0.]
  475. [0. -1.]]
  476. """
  477. def __init__(self):
  478. super(MatrixDiag, self).__init__()
  479. self.matrix_diag = inner.MatrixDiag()
  480. self.dtype = P.DType()
  481. def construct(self, input_x):
  482. x_shape = F.shape(input_x)
  483. x_dtype = self.dtype(input_x)
  484. assist = _get_matrix_diag_assist(x_shape, x_dtype)
  485. out_matrix_diag = self.matrix_diag(input_x, assist)
  486. return out_matrix_diag
  487. class MatrixDiagPart(Cell):
  488. r"""
  489. Returns the batched diagonal part of a batched tensor.
  490. Inputs:
  491. - **x** (Tensor) - The batched tensor. It can be one of the following data types:
  492. float32, float16, int32, int8, and uint8.
  493. Outputs:
  494. Tensor, has the same type as input `x`. The shape must be x.shape[:-2] + [min(x.shape[-2:])].
  495. Examples:
  496. >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
  497. >>> matrix_diag_part = nn.MatrixDiagPart()
  498. >>> result = matrix_diag_part(x)
  499. [[-1., 1.], [-1., 1.], [-1., 1.]]
  500. """
  501. def __init__(self):
  502. super(MatrixDiagPart, self).__init__()
  503. self.matrix_diag_part = inner.MatrixDiagPart()
  504. self.dtype = P.DType()
  505. def construct(self, input_x):
  506. x_shape = F.shape(input_x)
  507. x_dtype = self.dtype(input_x)
  508. assist = _get_matrix_diag_part_assist(x_shape, x_dtype)
  509. out_matrix_diag_part = self.matrix_diag_part(input_x, assist)
  510. return out_matrix_diag_part
  511. class MatrixSetDiag(Cell):
  512. r"""
  513. Modify the batched diagonal part of a batched tensor.
  514. Inputs:
  515. - **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
  516. float32, float16, int32, int8, and uint8.
  517. - **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
  518. Outputs:
  519. Tensor, has the same type and shape as input `x`.
  520. Examples:
  521. >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
  522. >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
  523. >>> matrix_set_diag = nn.MatrixSetDiag()
  524. >>> result = matrix_set_diag(x, diagonal)
  525. [[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
  526. """
  527. def __init__(self):
  528. super(MatrixSetDiag, self).__init__()
  529. self.matrix_set_diag = inner.MatrixSetDiag()
  530. self.dtype = P.DType()
  531. def construct(self, input_x, diagonal):
  532. x_shape = F.shape(input_x)
  533. x_dtype = self.dtype(input_x)
  534. assist = _get_matrix_diag_part_assist(x_shape, x_dtype)
  535. out_matrix_set_diag = self.matrix_set_diag(input_x, diagonal, assist)
  536. return out_matrix_set_diag