You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math_ops.py 36 kB

5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
optimize the comment and log description 修改: ops/operations/_inner_ops.py 修改: ops/operations/_quant_ops.py 修改: ops/operations/array_ops.py 修改: ops/operations/comm_ops.py 修改: ops/operations/math_ops.py 修改: ops/operations/quantum_ops.py 修改: ops/operations/rl_ops.py 修改: ops/operations/sponge_ops.py 修改: ops/operations/sponge_update_ops.py 修改: train/__init__.py 修改: common/tensor.py 修改: train/serialization.py 修改: ccsrc/pipeline/jit/parse/parse.h 修改: explainer/benchmark/_attribution/metric.py 修改: ops/composite/multitype_ops/_constexpr_utils.py 修改: ops/operations/comm_ops.py 修改: RELEASE.md 修改: mindspore/_extends/parse/standard_method.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/concat_offset_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_shape_cpu_kernel.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc 修改: mindspore/ccsrc/frontend/parallel/strategy.h 修改: mindspore/common/tensor.py 修改: mindspore/core/abstract/prim_arrays.cc 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/core/ops/conv2d.cc 修改: mindspore/core/ops/logical_and.h 修改: mindspore/core/ops/logical_not.h 修改: mindspore/core/ops/logical_or.h 修改: mindspore/core/ops/reduce_all.h 修改: mindspore/core/ops/reduce_any.h 修改: mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc 修改: mindspore/nn/layer/quant.py 修改: mindspore/nn/optim/sgd.py 修改: mindspore/nn/sparse/sparse.py 修改: mindspore/numpy/array_creations.py 修改: mindspore/numpy/array_ops.py 修改: mindspore/numpy/logic_ops.py 修改: mindspore/numpy/math_ops.py 修改: mindspore/ops/operations/_inner_ops.py 修改: mindspore/ops/operations/array_ops.py 修改: mindspore/ops/operations/rl_ops.py 修改: mindspore/train/_utils.py 修改: tests/ut/python/model/test_lenet_core_after_exception.py 修改: mindspore/_extends/parse/standard_method.py 修改: mindspore/ops/operations/rl_ops.py 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/core/ops/conv2d.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ctcloss_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/rolling_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_arithmetic_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h 修改: mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h 修改: mindspore/ccsrc/fl/server/server.cc 修改: mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc 修改: mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h 修改: mindspore/ccsrc/frontend/optimizer/irpass/inline.h 修改: mindspore/ccsrc/minddata/dataset/core/device_tensor.cc 修改: mindspore/ccsrc/minddata/dataset/core/tensor.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/emnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/datasetops/source/qmnist_op.cc 修改: mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.cc 修改: mindspore/ccsrc/minddata/dataset/engine/opt/pre/epoch_ctrl_pass.cc 修改: mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc 修改: mindspore/ccsrc/pipeline/jit/action.cc 修改: mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc 修改: mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_adapter.cc 修改: mindspore/compression/quant/quant_utils.py 修改: mindspore/core/abstract/prim_nn.cc 修改: mindspore/dataset/engine/validators.py 修改: mindspore/lite/micro/coder/opcoders/nnacl/fp32/affine_fp32_coder.cc 修改: mindspore/lite/micro/coder/opcoders/nnacl/int8/affine_int8_coder.cc 修改: mindspore/lite/src/runtime/kernel/ascend310/src/custom_kernel.cc 修改: mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc 修改: mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.cc 修改: mindspore/lite/tools/common/graph_util.h 修改: mindspore/lite/tools/optimizer/fisson/fisson_util.cc 修改: mindspore/ops/composite/math_ops.py 修改: mindspore/ops/operations/_inner_ops.py 修改: mindspore/ops/operations/array_ops.py 修改: mindspore/ops/operations/math_ops.py 修改: mindspore/ops/operations/other_ops.py 修改: mindspore/boost/boost_cell_wrapper.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc 修改: mindspore/ccsrc/common/trans.cc 修改: mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc 修改: mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc 修改: mindspore/lite/src/common/log_util.h 修改: mindspore/nn/wrap/loss_scale.py 修改: mindspore/parallel/nn/moe.py 修改: tests/mindspore_test_framework/mindspore_test.py 修改: mindspore/ccsrc/backend/kernel_compiler/cpu/split_cpu_kernel.cc 修改: mindspore/lite/tools/common/graph_util.h 修改: mindspore/ccsrc/frontend/parallel/ops_info/gather_info.cc 修改: mindspore/core/ops/conv2d.cc 修改: tests/ut/python/model/test_lenet_core_after_exception.py
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """math Operations."""
  16. from itertools import zip_longest
  17. from collections import deque
  18. import numpy as np
  19. from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
  20. from mindspore.common import dtype as mstype
  21. from mindspore._checkparam import Validator as validator
  22. from mindspore.ops.operations import _inner_ops as inner
  23. from mindspore.ops.primitive import constexpr
  24. from mindspore.ops import functional as F
  25. from .. import operations as P
  26. @constexpr
  27. def _check_validate_axis(axis, name):
  28. if isinstance(axis, (tuple, list)):
  29. for idx, item in enumerate(axis):
  30. validator.check_value_type("axis[%d]" % idx, item, [int], name)
  31. axis = validator.check_value_type('axis', axis, [int, tuple, list], name)
  32. return axis
  33. @constexpr
  34. def _check_validate_keepdims(keep_dims, name):
  35. keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
  36. return keep_dims
  37. def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
  38. r"""
  39. Count number of nonzero elements across axis of input tensor
  40. Args:
  41. x (Tensor): Input data is used to count non-zero numbers.
  42. :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
  43. axis (Union[int, tuple(int), list(int)]): The dimensions to reduce. Only constant value is allowed.
  44. Default: (), reduce all dimensions.
  45. keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
  46. If false, don't keep these dimensions. Default: False.
  47. dtype (Union[Number, mindspore.bool\_]): The data type of the output tensor. Only constant value is allowed.
  48. Default: mindspore.int32
  49. Returns:
  50. Tensor, number of nonzero element. The data type is `dtype`.
  51. Supported Platforms:
  52. ``Ascend`` ``GPU`` ``CPU``
  53. Examples:
  54. >>> # case 1: each value specified.
  55. >>> x = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
  56. >>> nonzero_num = ops.count_nonzero(x=x, axis=[0, 1], keep_dims=True, dtype=mindspore.int32)
  57. >>> print(nonzero_num)
  58. [[3]]
  59. >>> # case 2: all value is default.
  60. >>> nonzero_num = ops.count_nonzero(x=x)
  61. >>> print(nonzero_num)
  62. 3
  63. >>> # case 3: axis value was specified 0.
  64. >>> nonzero_num = ops.count_nonzero(x=x, axis=[0,])
  65. >>> print(nonzero_num)
  66. [1 2 0]
  67. >>> # case 4: axis value was specified 1.
  68. >>> nonzero_num = ops.count_nonzero(x=x, axis=[1,])
  69. >>> print(nonzero_num)
  70. [1 2]
  71. >>> # case 5: keep_dims value was specified.
  72. >>> nonzero_num = ops.count_nonzero(x=x, keep_dims=True)
  73. >>> print(nonzero_num)
  74. [[3]]
  75. >>> # case 6: keep_dims and axis value was specified.
  76. >>> nonzero_num = ops.count_nonzero(x=x, axis=[0,], keep_dims=True)
  77. >>> print(nonzero_num)
  78. [[1 2 0]]
  79. """
  80. const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
  81. axis = _check_validate_axis(axis, "count_nonzero")
  82. keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero")
  83. const_utils.check_type_valid(dtype, mstype.number_type + (mstype.bool_,), 'dtype')
  84. not_equal = P.NotEqual()
  85. cast = P.Cast()
  86. reduce_sum = P.ReduceSum(keep_dims)
  87. nonzero_bool = not_equal(x, 0)
  88. # ReduceSum only support float16 or float32 tensor.
  89. nonzero_val = cast(nonzero_bool, mstype.float32)
  90. nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype)
  91. return nonzero_num
  92. @constexpr
  93. def _int_to_tuple_conv(axes):
  94. """
  95. Converts ints to tuples in input axes, expected by most validation checks.
  96. """
  97. for x in [0, 1]:
  98. if isinstance(axes[x], int):
  99. axes[x] = (axes[x],)
  100. return axes
  101. @constexpr
  102. def _check_axes(axes, prim_name=None):
  103. """
  104. Check for validity and type of axes passed to function.
  105. """
  106. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  107. validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
  108. if not isinstance(axes, int):
  109. axes = list(axes) # to avoid immutability issues
  110. if len(axes) != 2:
  111. raise ValueError(f"{msg_prefix} dimension of 'axes' should be 2, but got 'axes': {axes}.")
  112. axes = _int_to_tuple_conv(axes) # convert before length checks
  113. if len(axes[0]) != len(axes[1]):
  114. raise ValueError(f"{msg_prefix} first and second dim of 'axes' have to be the same size/length, "
  115. f"but got 'axes': {axes}.")
  116. if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])):
  117. raise ValueError(f"{msg_prefix} 'axes' cannot have duplicating values, but got {axes}.")
  118. return axes
  119. @constexpr
  120. def _typecheck_input(x1_type, x2_type, prim_name=None):
  121. """
  122. Check input tensor types to be valid and confirm they are the same type.
  123. """
  124. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  125. const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1')
  126. const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2')
  127. if x1_type != x2_type:
  128. raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} "
  129. f"and x2_type: {x2_type}.")
  130. @constexpr
  131. def _axes_int_check(x1_shape, x2_shape, axes, prim_name=None):
  132. """
  133. Convert from single int axes to 2d tuple if required
  134. """
  135. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  136. if isinstance(axes, int):
  137. if axes < 0:
  138. raise ValueError(f"{msg_prefix} 'axes' must be at least 0, but got {axes}.")
  139. if axes == 0:
  140. # outer product, no input validation required
  141. return [], []
  142. if axes > len(x1_shape) or axes > len(x2_shape):
  143. raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
  144. f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
  145. x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
  146. x2_ind = tuple(range(len(x2_shape))[:axes])
  147. axes = tuple((x1_ind, x2_ind))
  148. axes = _int_to_tuple_conv(axes)
  149. return axes
  150. @constexpr
  151. def _validate_axes(x1_shape, x2_shape, axes, prim_name=None):
  152. """
  153. Checks for axes having the correct length according to input, for any value in axis
  154. being out of range with given shape and also checking for compatible axes values
  155. with given inputs.
  156. """
  157. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  158. shapes = [x1_shape, x2_shape]
  159. # axis length check
  160. for ix_input, x_axes in enumerate(axes):
  161. axes_len = len(x_axes)
  162. shape_dim_len = len(shapes[ix_input])
  163. if axes_len > shape_dim_len:
  164. raise ValueError(f"{msg_prefix} length of element {x_axes} in 'axes' should be less than or equal to "
  165. f"{shape_dim_len}, but got {axes_len}.")
  166. # axis values range check
  167. for ix_input, x_axes in enumerate(axes):
  168. comp_shape = shapes[ix_input]
  169. max_val = len(comp_shape) - 1
  170. min_val = -1 * len(comp_shape)
  171. for _, x_value in enumerate(x_axes):
  172. if not min_val <= x_value <= max_val:
  173. raise ValueError(f"{msg_prefix} value in 'axes' should be in range: [{min_val}, {max_val}], "
  174. f"but got {x_value}.")
  175. # check axis value with input shape - both ways for axis valid
  176. invalid_a = False
  177. invalid_b = False
  178. for i in range(len(axes[0])): # sizes already validated
  179. if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]:
  180. invalid_a = True
  181. if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0]) - 1 - i]]:
  182. invalid_b = True
  183. if invalid_a and invalid_b:
  184. raise ValueError(f"{msg_prefix} 'i' should exist such that 'x1_shape[axes[0][i]]' is equal to "
  185. f"'x2_shape[axes[1][i]]' or 'x2_shape[axes[1][len(axes[0])-1-i]]', but got "
  186. f"'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}, 'axes': {axes}.")
  187. @constexpr
  188. def _calc_new_shape(shape, axes, position=0):
  189. """
  190. Calculate transpose and reshape parameters for input transformations,
  191. 'position' refers to whether tensor is first or second in the op.
  192. """
  193. contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
  194. prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
  195. free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
  196. free_dims = tuple(shape[i] for i in free_axes)
  197. prod_free = int(np.prod(free_dims))
  198. transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
  199. new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
  200. return new_shape, transpose_perm, free_dims
  201. def tensor_dot(x1, x2, axes):
  202. """
  203. Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
  204. Contraction allows for the summation of products of elements of `a` and `b` on specified axes.
  205. The same number of axes must be specified for both x1 and x2, and values must be within range
  206. of number of dims of both `a` and `b`.
  207. Selected dims in both inputs must also match.
  208. axes = 0 leads to outer product
  209. axes = 1 leads to normal matrix multiplication when inputs both 2D.
  210. axes = 1 is the same as axes = ((1,),(0,)) where both `a` and `b` are 2D.
  211. axes = 2 is the same as axes = ((1,2),(0,1)) where both `a` and `b` are 3D.
  212. Args:
  213. x1 (Tensor): First tensor in tensor_dot with datatype float16 or float32
  214. x2 (Tensor): Second tensor in tensor_dot with datatype float16 or float32
  215. axes (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or
  216. tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
  217. automatically picks up last N dims from `a` input shape and first N dims from `b` input shape in order
  218. as axes for each respectively.
  219. Inputs:
  220. - **x1** (Tensor) - First tensor in tensor_dot with datatype float16 or float32
  221. - **x2** (Tensor) - Second tensor in tensor_dot with datatype float16 or float32
  222. - **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]) - Single value or
  223. tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
  224. automatically picks up last N dims from `a` input shape and first N dims from `b` input shape in order
  225. as axes for each respectively.
  226. Outputs:
  227. Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not
  228. contracted in both inputs
  229. Raises:
  230. TypeError: If `x1` or `x2` is not a Tensor.
  231. TypeError: If `axes` is not one of the following: int, tuple, list.
  232. Supported Platforms:
  233. ``Ascend`` ``GPU`` ``CPU``
  234. Examples:
  235. >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
  236. >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
  237. >>> output = ops.tensor_dot(input_x1, input_x2, ((0,1),(1,2)))
  238. >>> print(output)
  239. [[2. 2. 2]
  240. [2. 2. 2]
  241. [2. 2. 2]]
  242. """
  243. shape_op = P.Shape()
  244. reshape_op = P.Reshape()
  245. transpose_op = P.Transpose()
  246. matmul_op = P.MatMul(False, False)
  247. # input validity checks
  248. x1_shape = shape_op(x1)
  249. x2_shape = shape_op(x2)
  250. x1_type = F.dtype(x1)
  251. x2_type = F.dtype(x2)
  252. axes = _check_axes(axes, 'tensor_dot')
  253. _typecheck_input(x1_type, x2_type, 'tensor_dot')
  254. # input compatibility check & axes format update
  255. axes = _axes_int_check(x1_shape, x2_shape, axes, 'tensor_dot')
  256. _validate_axes(x1_shape, x2_shape, axes, 'tensor_dot')
  257. x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
  258. x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
  259. output_shape = x1_ret + x2_ret # combine free axes from both inputs
  260. # run tensor_dot op
  261. x1_transposed = transpose_op(x1, x1_transpose_fwd)
  262. x2_transposed = transpose_op(x2, x2_transpose_fwd)
  263. x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd)
  264. x2_reshaped = reshape_op(x2_transposed, x2_reshape_fwd)
  265. mul_result = matmul_op(x1_reshaped, x2_reshaped)
  266. final_result = reshape_op(mul_result, output_shape)
  267. return final_result
  268. @constexpr
  269. def _check_invalid_input(x1_shape, x2_shape, prim_name=None):
  270. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  271. if len(x1_shape) < 2 or len(x2_shape) < 2:
  272. raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2',"
  273. f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
  274. @constexpr
  275. def _typecheck_input_dot(x1_type, x2_type, prim_name=None):
  276. """
  277. Check input tensor types to be valid and confirm they are the same type for dot and batch dot ops.
  278. """
  279. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  280. const_utils.check_type_valid(x1_type, [mstype.float16, mstype.float32], 'x1')
  281. const_utils.check_type_valid(x2_type, [mstype.float16, mstype.float32], 'x2')
  282. if x1_type != x2_type:
  283. raise TypeError(f"{msg_prefix} inputs must be the same type, but got "
  284. f"x1_type: {x1_type} and x2_type: {x2_type}.")
  285. @constexpr
  286. def _get_transpose_shape(x2_shape):
  287. x2_shape_range = tuple(range(len(x2_shape)))
  288. x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
  289. return x2_shape_transpose
  290. def dot(x1, x2):
  291. """
  292. Computation a dot product between samples in two tensors.
  293. Args:
  294. x1 (Tensor): First tensor in Dot op with datatype float16 or float32,
  295. The rank must be greater than or equal to 2.
  296. x2 (Tensor): Second tensor in Dot op with datatype float16 or float32,
  297. The rank must be greater than or equal to 2.
  298. Inputs:
  299. - **x1** (Tensor) - First tensor in Dot op with datatype float16 or float32
  300. The rank must be greater than or equal to 2.
  301. - **x2** (Tensor) - Second tensor in Dot op with datatype float16 or float32
  302. The rank must be greater than or equal to 2.
  303. Outputs:
  304. Tensor, dot product of x1 and x2.
  305. Raises:
  306. TypeError: If type of x1 and x2 are not the same.
  307. TypeError: If dtype of x1 or x2 is not float16 or float32.
  308. ValueError: If rank of x1 or x2 less than 2.
  309. Supported Platforms:
  310. ``Ascend`` ``GPU`` ``CPU``
  311. Examples:
  312. >>> input_x1 = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
  313. >>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
  314. >>> output = ops.dot(input_x1, input_x2)
  315. >>> print(output)
  316. [[[3. 3.]]
  317. [[3. 3.]]]
  318. >>> print(output.shape)
  319. (2, 1, 2)
  320. >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
  321. >>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
  322. >>> output = ops.dot(input_x1, input_x2)
  323. >>> print(output)
  324. [[[[3. 3.]]
  325. [[3. 3.]]]]
  326. >>> print(output.shape)
  327. (1, 2, 1, 2)
  328. >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
  329. >>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
  330. >>> output = ops.dot(input_x1, input_x2)
  331. >>> print(output)
  332. [[[[3. 3.]
  333. [3. 3.]]
  334. [[3. 3.]
  335. [3. 3.]]]]
  336. >>> print(output.shape)
  337. (1, 2, 2, 2)
  338. >>> input_x1 = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32)
  339. >>> input_x2 = Tensor(np.ones(shape=[2, 1, 3, 2]), mindspore.float32)
  340. >>> output = ops.dot(input_x1, input_x2)
  341. >>> print(output)
  342. [[[[[3. 3.]]
  343. [[3. 3.]]]
  344. [[[3. 3.]]
  345. [[3. 3.]]]]
  346. [[[[3. 3.]]
  347. [[3. 3.]]]
  348. [[[3. 3.]]
  349. [[3. 3.]]]]
  350. [[[[3. 3.]]
  351. [[3. 3.]]]
  352. [[[3. 3.]]
  353. [[3. 3.]]]]]
  354. >>> print(output.shape)
  355. (3, 2, 2, 1, 2)
  356. """
  357. shape_op = P.Shape()
  358. reshape_op = P.Reshape()
  359. transpose_op = P.Transpose()
  360. matmul_op = P.MatMul(False, False)
  361. x1_shape = shape_op(x1)
  362. x2_shape = shape_op(x2)
  363. x1_type = F.dtype(x1)
  364. x2_type = F.dtype(x2)
  365. _typecheck_input_dot(x1_type, x2_type, 'dot')
  366. _check_invalid_input(x1_shape, x2_shape, 'dot')
  367. if len(x1_shape) > 2 or len(x2_shape) > 2:
  368. x2_shape_transpose = _get_transpose_shape(x2_shape)
  369. x2_transpose = transpose_op(x2, x2_shape_transpose)
  370. x1_reshape = reshape_op(x1, (-1, x1_shape[-1]))
  371. x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1))
  372. mul_result = matmul_op(x1_reshape, x2_reshape)
  373. return reshape_op(mul_result, x1_shape[:-1] + x2_shape[:-2] + x2_shape[-1:])
  374. return matmul_op(x1, x2)
  375. @constexpr
  376. def _get_batch_size(x1_shape, x2_shape, prim_name=None):
  377. """
  378. Get batch sizes from two inputs
  379. """
  380. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  381. if len(x1_shape) < 2 or len(x2_shape) < 2:
  382. raise ValueError(f"{msg_prefix} inputs x1, x2 should have 'dimension >= 2', "
  383. f"but got 'len(x1_shape)': ({len(x1_shape)}) and 'len(x2_shape)': ({len(x2_shape)}).")
  384. return x1_shape[0], x2_shape[0]
  385. @constexpr
  386. def _typecheck_input_batch_dot(x1_type, x2_type, prim_name=None):
  387. """
  388. Check input tensor types to be valid and confirm they are the same type for batch dot ops.
  389. """
  390. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  391. const_utils.check_type_valid(x1_type, [mstype.float32], 'x1')
  392. const_utils.check_type_valid(x2_type, [mstype.float32], 'x2')
  393. if x1_type != x2_type:
  394. raise TypeError(f"{msg_prefix} inputs must be the same type, but got x1_type: {x1_type} and "
  395. f"x2_type: {x2_type}.")
  396. @constexpr
  397. def _check_axes_for_batch_dot(x1_shape, x2_shape, axes, prim_name=None):
  398. """
  399. Check whether axes are valid and cast axes from tuple to list
  400. """
  401. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  402. if axes is None:
  403. if len(x2_shape) == 2:
  404. axes = [len(x1_shape) - 1, len(x2_shape) - 1]
  405. else:
  406. axes = [len(x1_shape) - 1, len(x2_shape) - 2]
  407. if isinstance(axes, (list, tuple)):
  408. if 0 in axes:
  409. raise ValueError(f"{msg_prefix} 'axes' cannot contain 0, but got axes: {axes}.")
  410. if len(axes) != 2:
  411. raise ValueError(f"{msg_prefix} length of 'axes' must be equal to 2, but got {len(axes)}.")
  412. if isinstance(axes, tuple):
  413. axes = list(axes)
  414. validator.check_value_type('axes[0]', axes[0], [int], 'batch_dot')
  415. validator.check_value_type('axes[1]', axes[1], [int], 'batch_dot')
  416. # Reverse if axis < 0
  417. if axes[0] < 0:
  418. axes[0] += len(x1_shape)
  419. if axes[1] < 0:
  420. axes[1] += len(x2_shape)
  421. validator.check_non_negative_int(axes[0], 'reversed axes[0]', 'batch_dot')
  422. validator.check_non_negative_int(axes[1], 'reversed axes[1]', 'batch_dot')
  423. if axes[0] > len(x1_shape) or axes[1] > len(x2_shape):
  424. raise ValueError(f"{msg_prefix} axes[0] must be less than or equal to len(x1_shape), "
  425. f"and axes[1] must be less than or equal to len(x2_shape)."
  426. f"But got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
  427. elif isinstance(axes, int):
  428. if axes == 0:
  429. raise ValueError(f"{msg_prefix} 'axes' should not be equal to 0, but got {axes}.")
  430. if axes < 0:
  431. axes = [axes + len(x1_shape), axes + len(x2_shape)]
  432. validator.check_non_negative_int(axes[0], 'reversed axes', 'batch_dot')
  433. elif axes > len(x1_shape) or axes > len(x2_shape):
  434. raise ValueError(f"{msg_prefix} 'axes' cannot be greater than the length of 'x1_shape' and 'x2_shape', "
  435. f"but got 'axes': {axes}, 'x1_shape': {x1_shape}, 'x2_shape': {x2_shape}.")
  436. else:
  437. axes = [axes, axes]
  438. else:
  439. raise ValueError(f"{msg_prefix} type of 'axes' must be one of those: int, tuple(int), list(int), "
  440. f"but got {type(axes).__name__}.")
  441. return axes
  442. @constexpr
  443. def _calc_new_shape_batchdot(shape, axes, position=0):
  444. """
  445. Calculate transpose and reshape parameters for input transformations,
  446. 'position' refers to whether tensor is first or second in the op.
  447. """
  448. axis = axes[position]
  449. contraction_axes = tuple([axis])
  450. prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
  451. free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
  452. free_dims = tuple(shape[i] for i in free_axes)
  453. prod_free = int(np.prod(free_dims))
  454. transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
  455. transpose_perm = tuple([0]) + transpose_perm
  456. new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
  457. new_shape = tuple([shape[0]]) + new_shape
  458. return new_shape, transpose_perm, free_dims
  459. @constexpr
  460. def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None):
  461. """
  462. Check whether batch size of two inputs are the same
  463. """
  464. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  465. if x1_batch_size != x2_batch_size:
  466. raise ValueError(f"{msg_prefix} inputs 'x1', 'x2' should have the same batch sizes, but got "
  467. f"'x1_batch_size': {x1_batch_size} and 'x2_batch_size': {x2_batch_size}.")
  468. @constexpr
  469. def _get_output_shape(batch_size, x1_ret, x2_ret):
  470. """
  471. Compute output shape for batch dot
  472. """
  473. output_shape = tuple([batch_size]) + x1_ret + x2_ret
  474. return output_shape
  475. def batch_dot(x1, x2, axes=None):
  476. """
  477. Computation of batch dot product between samples in two tensors containing batch dims.
  478. .. math::
  479. output = x1[batch, :] * x2[batch, :]
  480. Args:
  481. x1 (Tensor): First tensor in Batch Dot op with datatype float32 and the rank of `x1` must be greater
  482. than or equal to 2.
  483. x2 (Tensor): Second tensor in Batch Dot op with datatype float32. The datatype of `x2` should
  484. be same as `x1` and the rank of `x2` must be greater than or equal to 2.
  485. axes (Union[int, tuple(int), list(int)]): Single value or tuple/list of length 2 with dimensions
  486. specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
  487. `a` input shape and last N dimensions from `b` input shape in order as axes for each respectively.
  488. Default: None.
  489. Outputs:
  490. Tensor, batch dot product of `x1` and `x2`. For example: The Shape of output
  491. for input `x1` shapes (batch, d1, axes, d2) and `x2` shapes (batch, d3, axes, d4) is (batch, d1, d2, d3, d4),
  492. where d1 and d2 means any number.
  493. Raises:
  494. TypeError: If type of x1 and x2 are not the same.
  495. TypeError: If dtype of x1 or x2 is not float32.
  496. ValueError: If rank of x1 or x2 less than 2.
  497. ValueError: If batch dim used in axes.
  498. ValueError: If len(axes) less than 2.
  499. ValueError: If axes is not one of those: None, int, (int, int).
  500. ValueError: If axes reversed from negative int is too low for dimensions of input arrays.
  501. ValueError: If axes value is too high for dimensions of input arrays.
  502. ValueError: If batch size of x1 and x2 are not the same.
  503. Supported Platforms:
  504. ``Ascend`` ``GPU`` ``CPU``
  505. Examples:
  506. >>> x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
  507. >>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
  508. >>> axes = (-1, -2)
  509. >>> output = ops.batch_dot(x1, x2, axes)
  510. >>> print(output)
  511. [[[3. 3.]
  512. [3. 3.]]
  513. [[3. 3.]
  514. [3. 3.]]]
  515. >>> x1 = Tensor(np.ones(shape=[2, 2]), mindspore.float32)
  516. >>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
  517. >>> axes = (1, 2)
  518. >>> output = ops.batch_dot(x1, x2, axes)
  519. >>> print(output)
  520. [[2. 2. 2.]
  521. [2. 2. 2.]]
  522. >>> print(output.shape)
  523. (2, 3)
  524. >>> x1 = Tensor(np.ones(shape=[6, 2, 3, 4]), mindspore.float32)
  525. >>> x2 = Tensor(np.ones(shape=[6, 5, 4, 8]), mindspore.float32)
  526. >>> output = ops.batch_dot(x1, x2)
  527. >>> print(output.shape)
  528. (6, 2, 3, 5, 8)
  529. >>> x1 = Tensor(np.ones(shape=[2, 2, 4]), mindspore.float32)
  530. >>> x2 = Tensor(np.ones(shape=[2, 5, 4, 5]), mindspore.float32)
  531. >>> output = ops.batch_dot(x1, x2)
  532. >>> print(output.shape)
  533. (2, 2, 5, 5)
  534. """
  535. transpose_op = P.Transpose()
  536. batch_matmul_op = P.BatchMatMul()
  537. squeeze_one_op = P.Squeeze(1)
  538. squeeze_minus_one_op = P.Squeeze(-1)
  539. # input validity checks
  540. x1_shape = F.shape(x1)
  541. x2_shape = F.shape(x2)
  542. x1_dim_num = len(x1_shape)
  543. x2_dim_num = len(x2_shape)
  544. x1_type = F.dtype(x1)
  545. x2_type = F.dtype(x2)
  546. x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape, 'batch_dot')
  547. _typecheck_input_batch_dot(x1_type, x2_type, 'batch_dot')
  548. _check_batch_size(x1_batch_size, x2_batch_size, 'batch_dot')
  549. axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes, 'batch_dot')
  550. if x1_dim_num == 2:
  551. x1 = F.expand_dims(x1, 1)
  552. axes[0] += 1
  553. if x2_dim_num == 2:
  554. x2 = F.expand_dims(x2, 2)
  555. x1_shape = F.shape(x1)
  556. x2_shape = F.shape(x2)
  557. x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0)
  558. x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1)
  559. output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret)
  560. x1_transposed = transpose_op(x1, x1_transpose_fwd)
  561. x2_transposed = transpose_op(x2, x2_transpose_fwd)
  562. x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd)
  563. x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd)
  564. # Batch matmal op part
  565. mul_result = batch_matmul_op(x1_reshaped, x2_reshaped)
  566. final_result = F.reshape(mul_result, output_shape)
  567. # if the original dims are expanded, restore them from 3 to 2
  568. if x1_dim_num == 2:
  569. final_result = squeeze_one_op(final_result)
  570. elif x2_dim_num == 2:
  571. final_result = squeeze_minus_one_op(final_result)
  572. return final_result
  573. @constexpr
  574. def _check_same_type(dtype1, dtype2):
  575. return dtype1 == dtype2
  576. @constexpr
  577. def _max(*args):
  578. """Returns the maximum value."""
  579. return max(*args)
  580. @constexpr
  581. def _min(*args):
  582. """Returns the minimum value."""
  583. return min(*args)
  584. @constexpr
  585. def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
  586. """Infers the shape of the last two dimensions after performing matmul."""
  587. shape_rem = []
  588. if ndim1 >= 2:
  589. shape_rem.append(shape1[-2])
  590. if transpose_b:
  591. if ndim2 >= 2:
  592. shape_rem.append(shape2[-2])
  593. else:
  594. if ndim1 >= 1:
  595. shape_rem.append(shape2[-1])
  596. return tuple(shape_rem)
  597. @constexpr
  598. def _check_matmul_shapes(shape1, shape2, prim_name=None):
  599. """Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting."""
  600. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  601. ndim1, ndim2 = len(shape1), len(shape2)
  602. if ndim1 < 1 or ndim2 < 1:
  603. raise ValueError(f"{msg_prefix} dimension of input operands must be at least 1, but got "
  604. f"the length of shape1: {ndim1}, the length of shape2: {ndim2}.")
  605. if ndim2 >= 2 and shape1[-1] != shape2[-2]:
  606. raise ValueError(f"{msg_prefix} shape1[-1] should be equal to shape2[-2] when the length of shape2 "
  607. f"is greater than or equal to 2, but got shape1[-1]: {shape1[-1]}, "
  608. f"shape2[-2]: {shape2[-2]}.")
  609. shape_out = deque()
  610. for items in zip_longest(reversed(shape1[:-2]), reversed(shape2[:-2]), fillvalue=1):
  611. max_size = max(items)
  612. if any(item not in (1, max_size) for item in items):
  613. raise ValueError(f"{msg_prefix} operands could not be broadcast together with shape1 {shape1} and "
  614. f"shape2 {shape2}.")
  615. shape_out.appendleft(max_size)
  616. return tuple(shape_out)
  617. @constexpr
  618. def _tile_size(shape, out_shape, ndim):
  619. """Returns tile_size such that shape*tile_size = out_shape"""
  620. size = [1] * ndim
  621. for idx, (i, j) in enumerate(zip(shape, out_shape)):
  622. if i != j:
  623. size[idx] = j
  624. return tuple(size)
  625. @constexpr
  626. def _check_need_broadcast(shape1, shape2):
  627. """Returns True if broadcast is necessary for batchmatmul."""
  628. return shape1[:-2] != shape2[:-2]
  629. def _expand(x, ndim):
  630. """Expand x to ndim from axis, which can be 0 or -1."""
  631. while F.rank(x) < ndim:
  632. x = F.expand_dims(x, 0)
  633. return x
  634. def _broadcast_to(x, shape_cur, shape_to, ndim_to):
  635. """Broadcasts x from shape_cur to shape_to."""
  636. size = _tile_size(shape_cur, shape_to, ndim_to)
  637. return F.tile(x, size)
  638. def matmul(x1, x2, dtype=None):
  639. """
  640. Returns the matrix product of two arrays.
  641. Note:
  642. Numpy arguments `out`, `casting`, `order`, `subok`, `signature`, and `extobj` are
  643. not supported.
  644. On GPU, the supported dtypes are np.float16 and np.float32.
  645. On CPU, the supported dtypes are np.float16 and np.float32.
  646. Args:
  647. x1 (Tensor): Input tensor, scalar not allowed.
  648. The last dimension of `x1` must be the same size as the second last dimension of `x2`.
  649. And the shape of x1 and x2 could be broadcast.
  650. x2 (Tensor): Input tensor, scalar not allowed.
  651. The last dimension of `x1` must be the same size as the second last dimension of `x2`.
  652. And the shape of x1 and x2 could be broadcast.
  653. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
  654. output Tensor.
  655. Returns:
  656. Tensor or scalar, the matrix product of the inputs. This is a scalar only
  657. when both `x1`, `x2` are 1-d vectors.
  658. Raises:
  659. ValueError: If the last dimension of `x1` is not the same size as the
  660. second-to-last dimension of `x2`, or if a scalar value is passed in.
  661. ValueError: If the shape of `x1` and `x2` could not broadcast together。
  662. Supported Platforms:
  663. ``Ascend`` ``GPU`` ``CPU``
  664. Examples:
  665. >>> # case 1 : Reasonable application of broadcast mechanism
  666. >>> x1 = Tensor(np.arange(2*3*4).reshape(2, 3, 4), mindspore.float32)
  667. >>> x2 = Tensor(np.arange(4*5).reshape(4, 5), mindspore.float32)
  668. >>> output = ops.matmul(x1, x2)
  669. >>> print(output)
  670. [[[ 70. 76. 82. 88. 94.]
  671. [ 190. 212. 234. 256. 278.]
  672. [ 310. 348. 386. 424. 462.]]
  673. [[ 430. 484. 538. 592. 646.]
  674. [ 550. 620. 690. 760. 830.]
  675. [ 670. 756. 842. 928. 1014.]]]
  676. >>> print(output.shape)
  677. (2, 3, 5)
  678. >>> # case 2 : the rank of `x1` is 1
  679. >>> x1 = Tensor(np.ones([1, 2]), mindspore.float32)
  680. >>> x2 = Tensor(np.ones([2,]), mindspore.float32)
  681. >>> output = ops.matmul(x1, x2)
  682. >>> print(output)
  683. [2.]
  684. >>> print(output.shape)
  685. (1,)
  686. """
  687. # performs type promotion
  688. dtype1 = F.dtype(x1)
  689. dtype2 = F.dtype(x2)
  690. if not _check_same_type(dtype1, dtype2):
  691. x1 = x1.astype(mstype.float32)
  692. x2 = x2.astype(mstype.float32)
  693. ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2)
  694. shape1_orig, shape2_orig = F.shape(x1), F.shape(x2)
  695. transpose_b = ndim2_orig == 1
  696. shape_backbone = _check_matmul_shapes(shape1_orig, shape2_orig, 'matmul')
  697. # infers the shape of the output
  698. shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig,
  699. ndim1_orig, ndim2_orig, transpose_b)
  700. x1 = _expand(x1, 2)
  701. x2 = _expand(x2, 2)
  702. if F.rank(x2) == 2:
  703. if F.rank(x1) > 2:
  704. x1 = F.reshape(x1, (-1, shape1_orig[-1]))
  705. res = P.MatMul(False, transpose_b)(x1, x2)
  706. else:
  707. # broadcasts x1.shape[:-2] with x2.shape[:-2]
  708. ndim_aligned = _max(ndim1_orig, ndim2_orig)
  709. x1 = _expand(x1, ndim_aligned)
  710. x2 = _expand(x2, ndim_aligned)
  711. shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2)
  712. x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_backbone, ndim_aligned)
  713. x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_backbone, ndim_aligned)
  714. res = P.BatchMatMul(False, transpose_b)(x1, x2)
  715. if dtype is not None:
  716. res = res.astype(dtype)
  717. return F.reshape(res, shape_out)
  718. @constexpr
  719. def _create_cummin_perm(axis, x_shape):
  720. """Insure axis is in [-len(x_shape),len(s_shape)-1]"""
  721. len_axis = len(x_shape)
  722. if not isinstance(axis, int):
  723. raise TypeError(f"The date type of 'axis' should be Int, but got {axis}.")
  724. if axis < -len_axis or axis > len_axis:
  725. raise ValueError(f"The value of axis should be in [{-len_axis}, {len_axis}], but got {axis}.")
  726. prem = [i for i in range(len_axis)]
  727. if axis < 0:
  728. axis = axis + len_axis
  729. prem[0], prem[axis] = axis, 0
  730. prem = tuple(prem)
  731. return prem
  732. def cummin(x, axis):
  733. r"""
  734. Computation of the cumulative minimum of elements of 'x' in the dimension axis,
  735. and the index location of each maximum value found in the dimension 'axis'.
  736. It returns the cumulative minimum of elements and the index.
  737. ..math::
  738. y{i} = min(x{1}, x{2}, ... , x{i})
  739. Args:
  740. x (Tensor): The input tensor, rank of `input_x` > 0.
  741. axis (Int): The dimension to do the operation, The axis is in the range from -len(`input_x`.shape)
  742. to len(`input_x`.shape) - 1. When it's in the range from 0 to len(`input_x`.shape) - 1, it means starting
  743. from the first dimension and counting forwards, When it's less than 0, it means we're counting backwards
  744. from the last dimension. for example, -1 means the last dimension.
  745. Outputs:
  746. - **output** (Tensor) - The output tensor of the cumulative minimum of elements.
  747. - **indices** (Tensor) - The result tensor of the index of each minimum value been found.
  748. Raises:
  749. TypeError: If `input_x` is not a Tensor.
  750. TypeError: If 'axis' is not a int.
  751. ValueError:If 'axis' is out the range of [-len(`input_x`.shape) to len(`input_x`.shape) - 1]
  752. Supported Platforms:
  753. ``Ascend``
  754. Examples:
  755. >>> a = Tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220], mindspore.float32)
  756. >>> output = ops.cummin(a, axis=0)
  757. >>> print(output[0])
  758. [-0.2284 -0.6628 -0.6628 -0.6628 -1.3298 -1.3298]
  759. >>> print(output[1])
  760. [0 1 1 1 4 4]
  761. """
  762. cummin_op = inner.Cummin(axis=0)
  763. if axis == 0:
  764. out1, out2 = cummin_op(x)
  765. else:
  766. transpose = P.Transpose()
  767. x_shape = P.Shape()(x)
  768. prem = _create_cummin_perm(axis, x_shape)
  769. x = transpose(x, prem)
  770. out1, out2 = cummin_op(x)
  771. out1 = transpose(out1, prem)
  772. out2 = transpose(out2, prem)
  773. return [out1, out2]