You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math_ops.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """math Operations."""
  16. import numpy as np
  17. from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
  18. from mindspore.common import dtype as mstype
  19. from mindspore._checkparam import Validator as validator
  20. from mindspore.ops.primitive import constexpr
  21. from mindspore.ops import functional as F
  22. from .. import operations as P
  23. # count_nonzero
  24. @constexpr
  25. def _check_validate_axis(axis, name):
  26. if isinstance(axis, (tuple, list)):
  27. for idx, item in enumerate(axis):
  28. validator.check_value_type("axis[%d]" % idx, item, [int], name)
  29. axis = validator.check_value_type('axis', axis, [int, tuple, list], name)
  30. return axis
  31. @constexpr
  32. def _check_validate_keepdims(keep_dims, name):
  33. keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
  34. return keep_dims
  35. def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
  36. r"""
  37. Count number of nonzero elements across axis of input tensor
  38. Args:
  39. x (Tensor): Input data is used to count non-zero numbers.
  40. axis (Union[int, tuple(int), list(int)]): The dimensions to reduce. Only constant value is allowed.
  41. Default: (), reduce all dimensions.
  42. keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
  43. If false, don't keep these dimensions. Default: False.
  44. dtype (Union[Number, mstype.bool\_]): The data type of the output tensor. Only constant value is allowed.
  45. Default: mstype.int32
  46. Returns:
  47. Tensor, number of nonzero element. The data type is dtype.
  48. Supported Platforms:
  49. ``Ascend`` ``GPU``
  50. Examples:
  51. >>> input_x = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
  52. >>> nonzero_num = count_nonzero(x=input_x, axis=[0, 1], keep_dims=True, dtype=mstype.int32)
  53. >>> print(nonzero_num)
  54. [[3]]
  55. """
  56. const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
  57. axis = _check_validate_axis(axis, "count_nonzero")
  58. keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero")
  59. const_utils.check_type_valid(dtype, mstype.number_type + (mstype.bool_,), 'dtype')
  60. not_equal = P.NotEqual()
  61. cast = P.Cast()
  62. reduce_sum = P.ReduceSum(keep_dims)
  63. nonzero_bool = not_equal(x, 0)
  64. # ReduceSum only support float16 or float32 tensor.
  65. nonzero_val = cast(nonzero_bool, mstype.float16)
  66. nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype)
  67. return nonzero_num
  68. # tensor dot
  69. @constexpr
  70. def _int_to_tuple_conv(axes):
  71. """
  72. Converts ints to tuples in input axes, expected by most validation checks.
  73. """
  74. for x in [0, 1]:
  75. if isinstance(axes[x], int):
  76. axes[x] = (axes[x],)
  77. return axes
  78. @constexpr
  79. def _check_axes(axes):
  80. """
  81. Check for validity and type of axes passed to function.
  82. """
  83. validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
  84. if not isinstance(axes, int):
  85. axes = list(axes) # to avoid immutability issues
  86. if len(axes) != 2:
  87. raise ValueError("Require two axes inputs, given less")
  88. axes = _int_to_tuple_conv(axes) # convert before length checks
  89. if len(axes[0]) != len(axes[1]):
  90. raise ValueError("Axes have to be the same size/length")
  91. if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])):
  92. raise ValueError("Axes cannot have duplicating values")
  93. return axes
  94. @constexpr
  95. def _typecheck_input(x1_type, x2_type):
  96. """
  97. Check input tensor types to be valid and confirm they are the same type.
  98. """
  99. const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1')
  100. const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2')
  101. if x1_type != x2_type:
  102. raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')
  103. @constexpr
  104. def _axes_int_check(x1_shape, x2_shape, axes):
  105. """
  106. Convert from single int axes to 2d tuple if required
  107. """
  108. if isinstance(axes, int):
  109. if axes < 0:
  110. raise ValueError(f"axes must be at least 0 for tensor dot, got {axes}")
  111. if axes == 0:
  112. # outer product, no input validation required
  113. return ([], [])
  114. if axes > len(x1_shape) or axes > len(x2_shape):
  115. raise ValueError(
  116. "Axes value too high for given input arrays dimensions.")
  117. x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
  118. x2_ind = tuple(range(len(x2_shape))[:axes])
  119. axes = tuple((x1_ind, x2_ind))
  120. axes = _int_to_tuple_conv(axes)
  121. return axes
  122. @constexpr
  123. def _validate_axes(x1_shape, x2_shape, axes):
  124. """
  125. Checks for axes having the correct length according to input, for any value in axis
  126. being out of range with given shape and also checking for compatible axes values
  127. with given inputs.
  128. """
  129. shapes = [x1_shape, x2_shape]
  130. # axis length check
  131. for ix_input, x_axes in enumerate(axes):
  132. axes_len = len(x_axes)
  133. shape_dim_len = len(shapes[ix_input])
  134. if axes_len > shape_dim_len:
  135. raise ValueError(f"axes for input: {ix_input + 1} are of length: {axes_len} "
  136. f"can only be max: {shape_dim_len} due to input shape.")
  137. # axis values range check
  138. for ix_input, x_axes in enumerate(axes):
  139. comp_shape = shapes[ix_input]
  140. max_val = len(comp_shape) - 1
  141. min_val = -1 * len(comp_shape)
  142. for _, x_value in enumerate(x_axes):
  143. if not min_val <= x_value <= max_val:
  144. raise ValueError(f"axes for input: {ix_input + 1} contains index: "
  145. f"{x_value}, but range is: [{min_val}, {max_val}]")
  146. # check axis value with input shape - both ways for axis valid
  147. invalid_a = False
  148. invalid_b = False
  149. for i in range(len(axes[0])): # sizes already validated
  150. if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]:
  151. invalid_a = True
  152. if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0])-1-i]]:
  153. invalid_b = True
  154. if invalid_a and invalid_b:
  155. raise ValueError("Given Axes are incompatible with given input arrays")
  156. @constexpr
  157. def _calc_new_shape(shape, axes, position=0):
  158. """
  159. Calculate transpose and reshape parameters for input transformations,
  160. 'position' refers to whether tensor is first or second in the op.
  161. """
  162. contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
  163. prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
  164. free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
  165. free_dims = tuple(shape[i] for i in free_axes)
  166. prod_free = int(np.prod(free_dims))
  167. transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
  168. new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
  169. return new_shape, transpose_perm, free_dims
  170. def tensor_dot(x1, x2, axes):
  171. """
  172. Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
  173. Contraction allows for the summation of products of elements of `a` and `b` on specified axes.
  174. The same number of axes must be specified for both x1 and x2, and values must be within range
  175. of number of dims of both `a` and `b`.
  176. Selected dims in both inputs must also match.
  177. axes = 0 leads to outer product
  178. axes = 1 leads to normal matrix multiplication when inputs both 2D.
  179. axes = 1 is the same as axes = ((1,),(0,) where both `a` and `b` are 2D.
  180. axes = 2 is the same as axes = ((1,2),(0,1)) where both `a` and `b` are 3D.
  181. Inputs:
  182. - **x1** (Tensor) - First tensor in tensor_dot with datatype float16 or float32
  183. - **x2** (Tensor) - Second tensor in tensor_dot with datatype float16 or float32
  184. - **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]) - Single value or
  185. tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
  186. automatically picks up last N dims from `a` input shape and first N dims from `b` input shape in order
  187. as axes for each respectively.
  188. Outputs:
  189. Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not
  190. contracted in both inputs
  191. Supported Platforms:
  192. ``Ascend`` ``GPU`` ``CPU``
  193. Examples:
  194. >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
  195. >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
  196. >>> output = C.tensor_dot(input_x1, input_x2, ((0,1),(1,2)))
  197. >>> print(output)
  198. [[2. 2. 2]
  199. [2. 2. 2]
  200. [2. 2. 2]]
  201. """
  202. shape_op = P.Shape()
  203. reshape_op = P.Reshape()
  204. transpose_op = P.Transpose()
  205. matmul_op = P.MatMul(False, False)
  206. # input validity checks
  207. x1_shape = shape_op(x1)
  208. x2_shape = shape_op(x2)
  209. x1_type = F.dtype(x1)
  210. x2_type = F.dtype(x2)
  211. axes = _check_axes(axes)
  212. _typecheck_input(x1_type, x2_type)
  213. # input compatibility check & axes format update
  214. axes = _axes_int_check(x1_shape, x2_shape, axes)
  215. _validate_axes(x1_shape, x2_shape, axes)
  216. x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
  217. x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
  218. output_shape = x1_ret + x2_ret # combine free axes from both inputs
  219. # run tensor_dot op
  220. x1_transposed = transpose_op(x1, x1_transpose_fwd)
  221. x2_transposed = transpose_op(x2, x2_transpose_fwd)
  222. x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd)
  223. x2_reshaped = reshape_op(x2_transposed, x2_reshape_fwd)
  224. mul_result = matmul_op(x1_reshaped, x2_reshaped)
  225. final_result = reshape_op(mul_result, output_shape)
  226. return final_result
  227. @constexpr
  228. def _check_invalid_input(x1_shape, x2_shape):
  229. if len(x1_shape) < 2 or len(x2_shape) < 2:
  230. raise ValueError('C.dot inputs x1, x2 should has dimension >= 2,'
  231. + f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).')
  232. @constexpr
  233. def _get_transpose_shape(x2_shape):
  234. x2_shape_range = tuple(range(len(x2_shape)))
  235. x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
  236. return x2_shape_transpose
  237. def dot(x1, x2):
  238. """
  239. Computation a dot product between samples in two tensors.
  240. Inputs:
  241. - **x1** (Tensor) - First tensor in Dot op with datatype float16 or float32
  242. - **x2** (Tensor) - Second tensor in Dot op with datatype float16 or float32
  243. Outputs:
  244. Tensor, dot product of x1 and x2.
  245. Supported Platforms:
  246. ``Ascend`` ``GPU`` ``CPU``
  247. Examples:
  248. >>> input_x1 = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
  249. >>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
  250. >>> output = C.dot(input_x1, input_x2)
  251. >>> print(output)
  252. [[[3. 3.]]
  253. [[3. 3.]]]
  254. """
  255. shape_op = P.Shape()
  256. reshape_op = P.Reshape()
  257. transpose_op = P.Transpose()
  258. matmul_op = P.MatMul(False, False)
  259. x1_shape = shape_op(x1)
  260. x2_shape = shape_op(x2)
  261. _check_invalid_input(x1_shape, x2_shape)
  262. if len(x1_shape) > 2 or len(x2_shape) > 2:
  263. x2_shape_transpose = _get_transpose_shape(x2_shape)
  264. x2_transpose = transpose_op(x2, x2_shape_transpose)
  265. x1_reshape = reshape_op(x1, (-1, x1_shape[-1]))
  266. x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1))
  267. mul_result = matmul_op(x1_reshape, x2_reshape)
  268. return reshape_op(mul_result, x1_shape[:-1] + x2_shape[:-2] + x2_shape[-1:])
  269. return matmul_op(x1, x2)
  270. @constexpr
  271. def _get_batch_size(x1_shape, x2_shape):
  272. """
  273. Get batch sizes from two inputs
  274. """
  275. if len(x1_shape) < 2 or len(x2_shape) < 2:
  276. raise ValueError("Require both inputs with rank >= 2.")
  277. return x1_shape[0], x2_shape[0]
  278. @constexpr
  279. def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
  280. """
  281. Check whether axes are valid and cast axes from tuple to list
  282. """
  283. if axes is None:
  284. if len(x2_shape) == 2:
  285. axes = [len(x1_shape) - 1, len(x2_shape) - 1]
  286. else:
  287. axes = [len(x1_shape) - 1, len(x2_shape) - 2]
  288. if isinstance(axes, (list, tuple)):
  289. if 0 in axes:
  290. raise ValueError("Batch dim cannot be used as in axes.")
  291. if len(axes) != 2:
  292. raise ValueError("Require two axes inputs, given less")
  293. if isinstance(axes, tuple):
  294. axes = list(axes)
  295. for sub_axes in axes:
  296. if isinstance(sub_axes, (list, tuple)):
  297. raise ValueError("Require dimension to be in any of those: None, int, (int, int).")
  298. # Reverse if axis < 0
  299. if axes[0] < 0:
  300. axes[0] += len(x1_shape)
  301. if axes[1] < 0:
  302. axes[1] += len(x2_shape)
  303. elif isinstance(axes, int):
  304. if axes == 0:
  305. raise ValueError("Batch dim cannot be used as in axes.")
  306. if axes < 0:
  307. axes = [axes + len(x1_shape), axes + len(x2_shape)]
  308. elif axes > len(x1_shape) or axes > len(x2_shape):
  309. raise ValueError(
  310. "Axes value too high for given input arrays dimensions.")
  311. else:
  312. axes = [axes, axes]
  313. else:
  314. raise ValueError(
  315. "Axes type must be one of those: int, tuple(int), list(int).")
  316. return axes
  317. @constexpr
  318. def _calc_new_shape_batchdot(shape, axes, position=0):
  319. """
  320. Calculate transpose and reshape parameters for input transformations,
  321. 'position' refers to whether tensor is first or second in the op.
  322. """
  323. axis = axes[position]
  324. contraction_axes = tuple([axis])
  325. prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
  326. free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
  327. free_dims = tuple(shape[i] for i in free_axes)
  328. prod_free = int(np.prod(free_dims))
  329. transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
  330. transpose_perm = tuple([0]) + transpose_perm
  331. new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
  332. new_shape = tuple([shape[0]]) + new_shape
  333. return new_shape, transpose_perm, free_dims
  334. @constexpr
  335. def _check_batch_size(x1_batch_size, x2_batch_size):
  336. """
  337. Check whether batch size of two inputs are the same
  338. """
  339. if x1_batch_size != x2_batch_size:
  340. raise ValueError("Require both inputs with the same batch sizes.")
  341. @constexpr
  342. def _get_output_shape(batch_size, x1_ret, x2_ret):
  343. """
  344. Compute output shape for batch dot
  345. """
  346. output_shape = tuple([batch_size]) + x1_ret + x2_ret
  347. return output_shape
  348. def batch_dot(x1, x2, axes=None):
  349. """
  350. Computation of batch dot product between samples in two tensors containing batch dims.
  351. Inputs:
  352. - **x1** (Tensor) - First tensor in Batch Dot op with datatype float16 or float32
  353. - **x2** (Tensor) - Second tensor in Batch Dot op with datatype float16 or float32. x2's datatype should
  354. be same as x1's.
  355. - **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions
  356. specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
  357. `a` input shape and last N dims from `b` input shape in order as axes for each respectively.
  358. Outputs:
  359. Tensor, batch dot product of x1 and x2.
  360. Supported Platforms:
  361. ``Ascend`` ``GPU`` ``CPU``
  362. Examples:
  363. >>> input_x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
  364. >>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
  365. >>> axes = (-1, -2)
  366. >>> output = C.batch_dot(input_x1, input_x2, axes)
  367. >>> print(output)
  368. [[[3. 3.]
  369. [3. 3.]]
  370. [[3. 3.]
  371. [3. 3.]]]
  372. """
  373. transpose_op = P.Transpose()
  374. batch_matmul_op = P.BatchMatMul()
  375. squeeze_one_op = P.Squeeze(1)
  376. squeeze_minus_one_op = P.Squeeze(-1)
  377. # input validity checks
  378. x1_shape = F.shape(x1)
  379. x2_shape = F.shape(x2)
  380. x1_dim_num = len(x1_shape)
  381. x2_dim_num = len(x2_shape)
  382. x1_type = F.dtype(x1)
  383. x2_type = F.dtype(x2)
  384. x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape)
  385. _typecheck_input(x1_type, x2_type)
  386. _check_batch_size(x1_batch_size, x2_batch_size)
  387. axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes)
  388. if x1_dim_num == 2:
  389. x1 = F.expand_dims(x1, 1)
  390. axes[0] += 1
  391. if x2_dim_num == 2:
  392. x2 = F.expand_dims(x2, 2)
  393. x1_shape = F.shape(x1)
  394. x2_shape = F.shape(x2)
  395. x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0)
  396. x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1)
  397. output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret)
  398. x1_transposed = transpose_op(x1, x1_transpose_fwd)
  399. x2_transposed = transpose_op(x2, x2_transpose_fwd)
  400. x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd)
  401. x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd)
  402. # Batch matmal op part
  403. mul_result = batch_matmul_op(x1_reshaped, x2_reshaped)
  404. final_result = F.reshape(mul_result, output_shape)
  405. # if the original dims are expanded, restore them from 3 to 2
  406. if x1_dim_num == 2:
  407. final_result = squeeze_one_op(final_result)
  408. elif x2_dim_num == 2:
  409. final_result = squeeze_minus_one_op(final_result)
  410. return final_result