From 9dba65c4247085618d2aef60986a4d283a5281c0 Mon Sep 17 00:00:00 2001 From: danishnxt Date: Mon, 9 Nov 2020 17:25:10 -0500 Subject: [PATCH] TensorDot Conv: P -> C --- .../gpu/math/tensordot_gpu_kernel.cc | 30 --- .../gpu/math/tensordot_gpu_kernel.h | 222 ------------------ mindspore/ops/_grad/grad_math_ops.py | 42 ---- mindspore/ops/composite/__init__.py | 5 +- mindspore/ops/composite/math_ops.py | 139 ++++++++++- mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/math_ops.py | 121 ---------- tests/st/ops/gpu/test_tensordot_op.py | 6 +- 8 files changed, 145 insertions(+), 422 deletions(-) delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.cc delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.cc deleted file mode 100644 index 62acfb46fd..0000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - TensorDot, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - TensorDotGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - TensorDot, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - TensorDotGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h deleted file mode 100644 index 1d1de79d0d..0000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h +++ /dev/null @@ -1,222 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_ - -#include -#include -#include -#include -#include "backend/kernel_compiler/gpu/gpu_kernel.h" -#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" -#include "backend/kernel_compiler/gpu/kernel_constants.h" -#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace kernel { -template -class TensorDotGpuKernel : public GpuKernel { - public: - TensorDotGpuKernel() - : batch_(0), - m_(0), - n_(0), - k_(0), - is_null_input_(false), - handle_(nullptr), - dtype_a_(CUDA_R_32F), - dtype_b_(CUDA_R_32F), - dtype_c_(CUDA_R_32F), - algo_(CUBLAS_GEMM_DEFAULT) {} - ~TensorDotGpuKernel() = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *x1_input = GetDeviceAddress(inputs, 0); - T *x2_input = GetDeviceAddress(inputs, 1); - size_t *x1_input_shape = GetDeviceAddress(workspace, 0); - size_t *x2_input_shape = GetDeviceAddress(workspace, 1); - size_t *x1_input_trans_axes = GetDeviceAddress(workspace, 2); - size_t *x2_input_trans_axes = GetDeviceAddress(workspace, 3); - // transposed interim values moved to workspace, then multiplied - T *x1_reshape = GetDeviceAddress(workspace, 4); - T *x2_reshape = GetDeviceAddress(workspace, 5); - T *output_addr = GetDeviceAddress(outputs, 0); - - // Transpose X1 - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(x1_input_shape, &x1_input_shape_[0], x1_input_shape_.size() * sizeof(size_t), - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync x1_input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(x1_input_trans_axes, &x1_transpose_fwd_[0], x1_input_shape_.size() * sizeof(size_t), - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_axis_x1 failed"); - int size_x1 = SizeToInt(input_size_x1_ / sizeof(T)); - CalTranspose(size_x1, x1_input, x1_input_shape, x1_input_trans_axes, SizeToInt(x1_input_shape_.size()), x1_reshape, - reinterpret_cast(stream_ptr)); - - // Transpose X2 - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(x2_input_shape, &x2_input_shape_[0], (x2_input_shape_.size() * sizeof(size_t)), - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync x2_input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(x2_input_trans_axes, &x2_transpose_fwd_[0], (x2_input_shape_.size() * sizeof(size_t)), - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_axis_x2 failed"); - int size_x2 = SizeToInt(input_size_x2_ / sizeof(T)); - CalTranspose(size_x2, x2_input, x2_input_shape, x2_input_trans_axes, SizeToInt(x2_input_shape_.size()), x2_reshape, - reinterpret_cast(stream_ptr)); - - // Matrix Mulitply interim transposed values with GEMM - const float alpha = 1; // constants for cublas API - const float beta = 0; - const int lda = SizeToInt(k_); - const int ldb = SizeToInt(n_); - const int ldc = n_; - auto stride_a = SizeToInt(m_ * k_); - auto stride_b = SizeToInt(k_ * n_); - auto stride_c = SizeToInt(m_ * n_); - try { - CHECK_CUBLAS_RET_WITH_EXCEPT( - cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), - &alpha, x2_reshape, dtype_b_, ldb, stride_b, x1_reshape, dtype_a_, lda, stride_a, - &beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), - "cublasSgemm Call Fail"); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx"; - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); - // checking for FP16 op, switch to Tensor Core if available - dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); - dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); - if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) { - MS_LOG(INFO) << "Input and output type is float16, allow to use Tensor Core operations if possible"; - algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP; - } - - auto tmp_x1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto tmp_x2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - input_size_x1_ = sizeof(T); - for (size_t i = 0; i < tmp_x1_shape.size(); i++) { - x1_input_shape_.push_back(tmp_x1_shape[i]); - input_size_x1_ *= tmp_x1_shape[i]; - } - input_size_x2_ = sizeof(T); - for (size_t i = 0; i < tmp_x2_shape.size(); i++) { - x2_input_shape_.push_back(tmp_x2_shape[i]); - input_size_x2_ *= tmp_x2_shape[i]; - } - - // holding in temp values to convert to size_t vectors - std::vector x1_transpose_fwd_temp; - std::vector x1_transpose_me = GetAttr>(kernel_node, "x1_transpose_fwd"); - (void)std::transform(x1_transpose_me.begin(), x1_transpose_me.end(), std::back_inserter(x1_transpose_fwd_temp), - [](const int64_t &value) { return static_cast(value); }); - std::vector x2_transpose_fwd_temp; - std::vector x2_transpose_me = GetAttr>(kernel_node, "x2_transpose_fwd"); - (void)std::transform(x2_transpose_me.begin(), x2_transpose_me.end(), std::back_inserter(x2_transpose_fwd_temp), - [](const int64_t &value) { return static_cast(value); }); - for (size_t i = 0; i < x1_transpose_fwd_temp.size(); i++) { - x1_transpose_fwd_.push_back(x1_transpose_fwd_temp[i]); - } - - for (size_t i = 0; i < x2_transpose_fwd_temp.size(); i++) { - x2_transpose_fwd_.push_back(x2_transpose_fwd_temp[i]); - } - - // values to decide multiplication call specifics - std::vector x1_reshape_me = GetAttr>(kernel_node, "x1_reshape_fwd"); - (void)std::transform(x1_reshape_me.begin(), x1_reshape_me.end(), std::back_inserter(x1_reshape_fwd_), - [](const int64_t &value) { return static_cast(value); }); - std::vector x2_reshape_me = GetAttr>(kernel_node, "x2_reshape_fwd"); - (void)std::transform(x2_reshape_me.begin(), x2_reshape_me.end(), std::back_inserter(x2_reshape_fwd_), - [](const int64_t &value) { return static_cast(value); }); - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - output_size_ = sizeof(T); - for (size_t i = 0; i < output_shape.size(); i++) { - output_size_ *= output_shape[i]; - } - is_null_input_ = CHECK_NULL_INPUT(output_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "input is null"; - InitSizeLists(); - return true; - } - m_ = x1_reshape_fwd_[0]; - k_ = x1_reshape_fwd_[1]; - n_ = x2_reshape_fwd_[1]; - batch_ = 1; // kept as a single multiplication operation - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - size_t size_t_size = sizeof(size_t); - input_size_list_.push_back(input_size_x1_); - input_size_list_.push_back(input_size_x2_); - workspace_size_list_.push_back(x1_input_shape_.size() * size_t_size); - workspace_size_list_.push_back(x2_input_shape_.size() * size_t_size); - workspace_size_list_.push_back(x1_transpose_fwd_.size() * size_t_size); - workspace_size_list_.push_back(x2_transpose_fwd_.size() * size_t_size); - workspace_size_list_.push_back(input_size_x1_); - workspace_size_list_.push_back(input_size_x2_); - output_size_list_.push_back(output_size_); - } - - private: - size_t batch_; - size_t m_; - size_t n_; - size_t k_; - bool is_null_input_; - std::vector x1_input_shape_; - std::vector x2_input_shape_; - size_t input_size_x1_; - size_t input_size_x2_; - size_t output_size_; - std::vector x1_transpose_fwd_; // For transpose - std::vector x2_transpose_fwd_; - std::vector x1_reshape_fwd_; // For mulitplication shape - std::vector x2_reshape_fwd_; - cublasHandle_t handle_; - cudaDataType_t dtype_a_; - cudaDataType_t dtype_b_; - cudaDataType_t dtype_c_; - cublasGemmAlgo_t algo_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 1267681c9b..dbd0df22b0 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -156,48 +156,6 @@ def bprop_batchmatmul(self): return bprop -@bprop_getters.register(P.TensorDot) -def bprop_tensordot(self): - """Grad definition for `TensorDot` operation.""" - mul_op_x1 = P.MatMul(transpose_a=False, transpose_b=True) - mul_op_x2 = P.MatMul(transpose_a=True, transpose_b=False) - invert_permutation_op = P.InvertPermutation() - transpose_op = P.Transpose() - reshape_op = P.Reshape() - - # pull transformation specifics from P.TensorDot class - x1_transpose_fwd = tuple(self.x1_transpose_fwd) - x2_transpose_fwd = tuple(self.x2_transpose_fwd) - x1_reshape_fwd = tuple(self.x1_reshape_fwd) - x2_reshape_fwd = tuple(self.x2_reshape_fwd) - dout_reshape = (self.x1_reshape_fwd[0], self.x2_reshape_fwd[1]) - - # precalculated in fwd pass due to easier computation - x1_reshape_back = tuple(self.x1_reshape_back) - x2_reshape_back = tuple(self.x2_reshape_back) - - def bprop(x1, x2, out, dout): - # reshape dy values to 2D for MatMul - dout_reshaped = reshape_op(dout, dout_reshape) - # transform inputs to forward pass equivalents - x1_transpose = transpose_op(x1, x1_transpose_fwd) - x2_transpose = transpose_op(x2, x2_transpose_fwd) - x1_reshape = reshape_op(x1_transpose, x1_reshape_fwd) - x2_reshape = reshape_op(x2_transpose, x2_reshape_fwd) - # calculate dx values for x1 and x2 - dx1_interim = mul_op_x1(dout_reshaped, x2_reshape) - dx2_interim = mul_op_x2(x1_reshape, dout_reshaped) - # reverse transformations on dx values for both inputs - dx1_reshape = reshape_op(dx1_interim, x1_reshape_back) - dx2_reshape = reshape_op(dx2_interim, x2_reshape_back) - dx1_retranspose_axes = invert_permutation_op(x1_transpose_fwd) - dx2_retranspose_axes = invert_permutation_op(x2_transpose_fwd) - dx1_transpose = transpose_op(dx1_reshape, dx1_retranspose_axes) - dx2_transpose = transpose_op(dx2_reshape, dx2_retranspose_axes) - return dx1_transpose, dx2_transpose - return bprop - - @bprop_getters.register(P.TensorAdd) def get_bprop_tensor_add(self): """Grad definition for `TensorAdd` operation.""" diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index 498f14f660..d29e9c075b 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.zeros_like_impl import zeros_like from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial -from .math_ops import count_nonzero +from .math_ops import count_nonzero, TensorDot __all__ = [ @@ -50,4 +50,5 @@ __all__ = [ 'multinomial', 'clip_by_value', 'clip_by_global_norm', - 'count_nonzero'] + 'count_nonzero', + 'TensorDot'] diff --git a/mindspore/ops/composite/math_ops.py b/mindspore/ops/composite/math_ops.py index 03b8ef6f27..169a008604 100644 --- a/mindspore/ops/composite/math_ops.py +++ b/mindspore/ops/composite/math_ops.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """math Operations.""" +import numpy as np from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils from mindspore.common import dtype as mstype from mindspore._checkparam import Validator as validator @@ -20,7 +21,7 @@ from mindspore.ops.primitive import constexpr from mindspore.ops import functional as F from .. import operations as P - +# count_nonzero @constexpr def _check_validate_axis(axis, name): if isinstance(axis, (tuple, list)): @@ -73,3 +74,139 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32): nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype) return nonzero_num + +# TensorDot +@constexpr +def _int_to_tuple_conv(axes): + """ + Converts ints to tuples in input axes, expected by most validation checks. + """ + for x in [0, 1]: + if isinstance(axes[x], int): + axes[x] = (axes[x],) + return axes + + +@constexpr +def _check_axes(axes): + """ + Check for validity and type of axes passed to function. + """ + validator.check_value_type('axes', axes, [int, tuple, list], "TensorDot") + if not isinstance(axes, int): + axes = list(axes) # to avoid immutability issues + if len(axes) != 2: + raise ValueError("Require two axes inputs, given less") + axes = _int_to_tuple_conv(axes) # convert before length checks + if len(axes[0]) != len(axes[1]): + raise ValueError("Axes have to be the same size/length") + if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])): + raise ValueError("Axes cannot have duplicating values") + return axes + + +@constexpr +def _typecheck_input(x1_type, x2_type): + """ + Check input tensor types to be valid and confirm they are the same type. + """ + const_utils.check_valid_type(x1_type, [mstype.float32, mstype.float16], 'x1') + const_utils.check_valid_type(x2_type, [mstype.float32, mstype.float16], 'x2') + if x1_type != x2_type: + raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ') + + +@constexpr +def _validate_input(x1_shape, x2_shape, axes): + """ + Convert from single int axes to 2d tuple if required and check for validity with inputs. + """ + if isinstance(axes, int): + if axes <= 0: + # outer product, no input validation required + return ([], []) + if axes > len(x1_shape) or axes > len(x2_shape): + raise ValueError( + "Axes value too high for given input arrays dimensions.") + x1_ind = tuple(range(len(x1_shape))[-1 * axes:]) + x2_ind = tuple(range(len(x2_shape))[:axes]) + axes = tuple((x1_ind, x2_ind)) + axes = _int_to_tuple_conv(axes) + for i in range(len(axes[0])): # sizes already validated + if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]: + raise ValueError( + "Given Axes are incompatible with given input arrays") + return axes + + +@constexpr +def _calc_new_shape(shape, axes, position=0): + """ + Calculate transpose and reshape parameters for input transformations, + 'position' refers to whether tensor is first or second in the op. + """ + contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position]) + prod_contraction = int(np.prod([shape[i] for i in contraction_axes])) + free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes) + free_dims = tuple(shape[i] for i in free_axes) + prod_free = int(np.prod(free_dims)) + + transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes + new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction) + return new_shape, transpose_perm, free_dims + + +def TensorDot(x1, x2, axes): + """ + Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`. + + Contraction allows for the summation of products of elements of `a` and `b` on specified axes. + The same number of axes must be specified for both x1 and x2, and values must be within range + of number of dims of both `a` and `b`. + + Selected dims in both inputs must also match. + + axes = 0 leads to outer product, and axes = 1 leads to normal matrix multiplication. + axes = 1 is the same as axes = ((0,),(1,) where length of input shape is 2 for both `a` and `b` + axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b` + + Inputs: + - **x1** (Tensor): First tensor in TensorDot op with datatype float16 or float32 + - **x2** (Tensor): Second tensor in TensorDot op with datatype float16 or float32 + - **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or + tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed, + automatically picks up first N dims from `a` input shape and last N dims from `b` input shape. + + Outputs: + Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not + contracted in both inputs + + Examples: + >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) + >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32) + >>> output = C.TensorDot(input_x1, input_x2, ((0,1),(1,2))) + """ + shape_op = P.Shape() + reshape_op = P.Reshape() + transpose_op = P.Transpose() + matmul_op = P.MatMul(False, False) + # input validity checks + x1_shape = shape_op(x1) + x2_shape = shape_op(x2) + x1_type = F.dtype(x1) + x2_type = F.dtype(x2) + axes = _check_axes(axes) + _typecheck_input(x1_type, x2_type) + # input compability check & axes format update + axes = _validate_input(x1_shape, x2_shape, axes) + x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) + x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) + output_shape = x1_ret + x2_ret # combine free axes from both inputs + # run TensorDot op + x1_transposed = transpose_op(x1, x1_transpose_fwd) + x2_transposed = transpose_op(x2, x2_transpose_fwd) + x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd) + x2_reshaped = reshape_op(x2_transposed, x2_reshape_fwd) + mul_result = matmul_op(x1_reshaped, x2_reshaped) + final_result = reshape_op(mul_result, output_shape) + return final_result diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 738ae9c3c2..c28103360c 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, - Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorDot) + Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, RandomCategorical, StandardLaplace, Multinomial) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 0d751f76cd..89590f879e 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -766,127 +766,6 @@ class BatchMatMul(MatMul): 'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}') -class TensorDot(PrimitiveWithInfer): - """ - Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`. - - Contraction allows for the summation of products of elements of `a` and `b` on specified axes. - The same number of axes must be specified for both x1 and x2, and values must be within range - of number of dims of both `a` and `b`. - - Selected dims in both inputs must also match. - - axes = 0 leads to outer product, and axes = 1 leads to normal matrix multiplication. - axes = 1 is the same as axes = ((0,),(1,) where length of input shape is 2 for both `a` and `b` - axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b` - - Args: - **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or - tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed, - automatically picks up first N dims from `a` input shape and last N dims from `b` input shape. - - Inputs: - - **x1** (Tensor): First tensor in TensorDot op with datatype float16 or float32 - - **x2** (Tensor): Second tensor in TensorDot op with datatype float16 or float32 - - Outputs: - Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not - contracted in both inputs - - Examples: - >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) - >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32) - >>> tensordot = P.TensorDot(((0,1),(1,2))) - >>> output = tensordot(input_x1, input_x2) - """ - - @prim_attr_register - def __init__(self, axes): - self.axes = axes - validator.check_value_type('axes', axes, [int, tuple, list], self.name) - if not isinstance(self.axes, int): - self.axes = list(self.axes) # to avoid immutability issues - if len(self.axes) != 2: - raise ValueError("Require two axes inputs, given less") - self.int_to_tuple_conv() # convert before length checks - if len(self.axes[0]) != len(self.axes[1]): - raise ValueError("Axes have to be the same size/length") - if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])): - raise ValueError("Axes cannot have duplicating values") - self.add_prim_attr("axes", self.axes) - - def int_to_tuple_conv(self): - """ - Converts ints to tuples in input axes, expected by most validation checks. - """ - for x in [0, 1]: - if isinstance(self.axes[x], int): - self.axes[x] = (self.axes[x],) - - def check_input_axes(self, x1_shape, x2_shape): - """ - Convert from single int axes to 2d tuple if required and check for validity with inputs. - """ - if isinstance(self.axes, int): - if self.axes <= 0: - # outer product, no input validation required - self.axes = ([], []) # no axes selected for either - return - if self.axes > len(x1_shape) or self.axes > len(x2_shape): - raise ValueError( - "Axes value too high for given input arrays dimensions.") - x1_ind = tuple(range(len(x1_shape))[-1 * self.axes:]) - x2_ind = tuple(range(len(x2_shape))[:self.axes]) - self.axes = tuple((x1_ind, x2_ind)) - self.int_to_tuple_conv() - for i in range(len(self.axes[0])): # sizes already validated - if x1_shape[self.axes[0][i]] != x2_shape[self.axes[1][i]]: - raise ValueError( - "Given Axes are incompatible with given input arrays") - - def calc_new_shape(self, shape, position=0): - """ - Calculate transpose and reshape parameters for input transformations, - 'position' refers to whether tensor is first or second in the op. - """ - contraction_axes = [i if i >= 0 else i + len(shape) for i in self.axes[position]] - prod_contraction = int(np.prod([shape[i] for i in contraction_axes])) - free_axes = [i for i in range(len(shape)) if i not in contraction_axes] - free_dims = [shape[i] for i in free_axes] - prod_free = int(np.prod(free_dims)) - - transpose_perm = list(contraction_axes) + free_axes if position else free_axes + list(contraction_axes) - new_shape = [prod_contraction, prod_free] if position else [prod_free, prod_contraction] - return new_shape, transpose_perm, free_dims - - def generate_transform_dims(self, x1_shape, x2_shape): - """ - Initiate calls for input transform calculations and calculate paramters for output - and for backprop tranformations. - """ - self.x1_reshape_fwd, self.x1_transpose_fwd, x1_ret = self.calc_new_shape(x1_shape, 0) - self.x2_reshape_fwd, self.x2_transpose_fwd, x2_ret = self.calc_new_shape(x2_shape, 1) - self.output_shape = x1_ret + x2_ret # combine free axes from both inputs - self.x1_reshape_back = [x1_shape[x] for x in self.x1_transpose_fwd] - self.x2_reshape_back = [x2_shape[x] for x in self.x2_transpose_fwd] - - def infer_shape(self, x1, x2): - self.check_input_axes(x1, x2) - self.generate_transform_dims(x1, x2) - # processed parameters for reading directly into kernel - self.add_prim_attr('x1_transpose_fwd', self.x1_transpose_fwd) - self.add_prim_attr('x2_transpose_fwd', self.x2_transpose_fwd) - self.add_prim_attr('x1_reshape_fwd', self.x1_reshape_fwd) - self.add_prim_attr('x2_reshape_fwd', self.x2_reshape_fwd) - return self.output_shape - - def infer_dtype(self, x1, x2): - args = {"x1": x1, "x2": x2} - valid_dtypes = [mstype.float16, mstype.float32] - validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) - return x1 - - class CumSum(PrimitiveWithInfer): """ Computes the cumulative sum of input tensor along axis. diff --git a/tests/st/ops/gpu/test_tensordot_op.py b/tests/st/ops/gpu/test_tensordot_op.py index 2a0e5991c8..878d995739 100644 --- a/tests/st/ops/gpu/test_tensordot_op.py +++ b/tests/st/ops/gpu/test_tensordot_op.py @@ -20,17 +20,16 @@ import mindspore from mindspore import Tensor import mindspore.nn as nn import mindspore.context as context -from mindspore.ops import operations as P from mindspore.ops import composite as C class NetTensorDot(nn.Cell): def __init__(self, axes): super(NetTensorDot, self).__init__() - self.td = P.TensorDot(axes) + self.axes = axes def construct(self, x, y): - return self.td(x, y) + return C.TensorDot(x, y, self.axes) class GradNetwork(nn.Cell): @@ -183,6 +182,7 @@ def test_tensor_dot_outer(): x2_tensor = Tensor(x2, dtype=mindspore.float32) network = NetTensorDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() np_result = np.tensordot(x1, x2, axes) np.testing.assert_array_almost_equal(ms_result_np, np_result)