From: @danishnxt Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosmantags/v1.1.0
| @@ -1,30 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| TensorDot, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| TensorDotGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| TensorDot, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| TensorDotGpuKernel, half) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -1,222 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_ | |||||
| #include <cublas_v2.h> | |||||
| #include <cuda_runtime_api.h> | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" | |||||
| #include "utils/convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class TensorDotGpuKernel : public GpuKernel { | |||||
| public: | |||||
| TensorDotGpuKernel() | |||||
| : batch_(0), | |||||
| m_(0), | |||||
| n_(0), | |||||
| k_(0), | |||||
| is_null_input_(false), | |||||
| handle_(nullptr), | |||||
| dtype_a_(CUDA_R_32F), | |||||
| dtype_b_(CUDA_R_32F), | |||||
| dtype_c_(CUDA_R_32F), | |||||
| algo_(CUBLAS_GEMM_DEFAULT) {} | |||||
| ~TensorDotGpuKernel() = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| if (is_null_input_) { | |||||
| return true; | |||||
| } | |||||
| T *x1_input = GetDeviceAddress<T>(inputs, 0); | |||||
| T *x2_input = GetDeviceAddress<T>(inputs, 1); | |||||
| size_t *x1_input_shape = GetDeviceAddress<size_t>(workspace, 0); | |||||
| size_t *x2_input_shape = GetDeviceAddress<size_t>(workspace, 1); | |||||
| size_t *x1_input_trans_axes = GetDeviceAddress<size_t>(workspace, 2); | |||||
| size_t *x2_input_trans_axes = GetDeviceAddress<size_t>(workspace, 3); | |||||
| // transposed interim values moved to workspace, then multiplied | |||||
| T *x1_reshape = GetDeviceAddress<T>(workspace, 4); | |||||
| T *x2_reshape = GetDeviceAddress<T>(workspace, 5); | |||||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||||
| // Transpose X1 | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(x1_input_shape, &x1_input_shape_[0], x1_input_shape_.size() * sizeof(size_t), | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync x1_input_shape failed"); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(x1_input_trans_axes, &x1_transpose_fwd_[0], x1_input_shape_.size() * sizeof(size_t), | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync input_axis_x1 failed"); | |||||
| int size_x1 = SizeToInt(input_size_x1_ / sizeof(T)); | |||||
| CalTranspose(size_x1, x1_input, x1_input_shape, x1_input_trans_axes, SizeToInt(x1_input_shape_.size()), x1_reshape, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| // Transpose X2 | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(x2_input_shape, &x2_input_shape_[0], (x2_input_shape_.size() * sizeof(size_t)), | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync x2_input_shape failed"); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpyAsync(x2_input_trans_axes, &x2_transpose_fwd_[0], (x2_input_shape_.size() * sizeof(size_t)), | |||||
| cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "cudaMemcpyAsync input_axis_x2 failed"); | |||||
| int size_x2 = SizeToInt(input_size_x2_ / sizeof(T)); | |||||
| CalTranspose(size_x2, x2_input, x2_input_shape, x2_input_trans_axes, SizeToInt(x2_input_shape_.size()), x2_reshape, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| // Matrix Mulitply interim transposed values with GEMM | |||||
| const float alpha = 1; // constants for cublas API | |||||
| const float beta = 0; | |||||
| const int lda = SizeToInt(k_); | |||||
| const int ldb = SizeToInt(n_); | |||||
| const int ldc = n_; | |||||
| auto stride_a = SizeToInt(m_ * k_); | |||||
| auto stride_b = SizeToInt(k_ * n_); | |||||
| auto stride_c = SizeToInt(m_ * n_); | |||||
| try { | |||||
| CHECK_CUBLAS_RET_WITH_EXCEPT( | |||||
| cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), | |||||
| &alpha, x2_reshape, dtype_b_, ldb, stride_b, x1_reshape, dtype_a_, lda, stride_a, | |||||
| &beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), | |||||
| "cublasSgemm Call Fail"); | |||||
| } catch (const std::exception &e) { | |||||
| MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx"; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); | |||||
| // checking for FP16 op, switch to Tensor Core if available | |||||
| dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); | |||||
| dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); | |||||
| if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) { | |||||
| MS_LOG(INFO) << "Input and output type is float16, allow to use Tensor Core operations if possible"; | |||||
| algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP; | |||||
| } | |||||
| auto tmp_x1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| auto tmp_x2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| input_size_x1_ = sizeof(T); | |||||
| for (size_t i = 0; i < tmp_x1_shape.size(); i++) { | |||||
| x1_input_shape_.push_back(tmp_x1_shape[i]); | |||||
| input_size_x1_ *= tmp_x1_shape[i]; | |||||
| } | |||||
| input_size_x2_ = sizeof(T); | |||||
| for (size_t i = 0; i < tmp_x2_shape.size(); i++) { | |||||
| x2_input_shape_.push_back(tmp_x2_shape[i]); | |||||
| input_size_x2_ *= tmp_x2_shape[i]; | |||||
| } | |||||
| // holding in temp values to convert to size_t vectors | |||||
| std::vector<int> x1_transpose_fwd_temp; | |||||
| std::vector<int64_t> x1_transpose_me = GetAttr<std::vector<int64_t>>(kernel_node, "x1_transpose_fwd"); | |||||
| (void)std::transform(x1_transpose_me.begin(), x1_transpose_me.end(), std::back_inserter(x1_transpose_fwd_temp), | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||||
| std::vector<int> x2_transpose_fwd_temp; | |||||
| std::vector<int64_t> x2_transpose_me = GetAttr<std::vector<int64_t>>(kernel_node, "x2_transpose_fwd"); | |||||
| (void)std::transform(x2_transpose_me.begin(), x2_transpose_me.end(), std::back_inserter(x2_transpose_fwd_temp), | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||||
| for (size_t i = 0; i < x1_transpose_fwd_temp.size(); i++) { | |||||
| x1_transpose_fwd_.push_back(x1_transpose_fwd_temp[i]); | |||||
| } | |||||
| for (size_t i = 0; i < x2_transpose_fwd_temp.size(); i++) { | |||||
| x2_transpose_fwd_.push_back(x2_transpose_fwd_temp[i]); | |||||
| } | |||||
| // values to decide multiplication call specifics | |||||
| std::vector<int64_t> x1_reshape_me = GetAttr<std::vector<int64_t>>(kernel_node, "x1_reshape_fwd"); | |||||
| (void)std::transform(x1_reshape_me.begin(), x1_reshape_me.end(), std::back_inserter(x1_reshape_fwd_), | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||||
| std::vector<int64_t> x2_reshape_me = GetAttr<std::vector<int64_t>>(kernel_node, "x2_reshape_fwd"); | |||||
| (void)std::transform(x2_reshape_me.begin(), x2_reshape_me.end(), std::back_inserter(x2_reshape_fwd_), | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| output_size_ = sizeof(T); | |||||
| for (size_t i = 0; i < output_shape.size(); i++) { | |||||
| output_size_ *= output_shape[i]; | |||||
| } | |||||
| is_null_input_ = CHECK_NULL_INPUT(output_shape); | |||||
| if (is_null_input_) { | |||||
| MS_LOG(WARNING) << "input is null"; | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| m_ = x1_reshape_fwd_[0]; | |||||
| k_ = x1_reshape_fwd_[1]; | |||||
| n_ = x2_reshape_fwd_[1]; | |||||
| batch_ = 1; // kept as a single multiplication operation | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| size_t size_t_size = sizeof(size_t); | |||||
| input_size_list_.push_back(input_size_x1_); | |||||
| input_size_list_.push_back(input_size_x2_); | |||||
| workspace_size_list_.push_back(x1_input_shape_.size() * size_t_size); | |||||
| workspace_size_list_.push_back(x2_input_shape_.size() * size_t_size); | |||||
| workspace_size_list_.push_back(x1_transpose_fwd_.size() * size_t_size); | |||||
| workspace_size_list_.push_back(x2_transpose_fwd_.size() * size_t_size); | |||||
| workspace_size_list_.push_back(input_size_x1_); | |||||
| workspace_size_list_.push_back(input_size_x2_); | |||||
| output_size_list_.push_back(output_size_); | |||||
| } | |||||
| private: | |||||
| size_t batch_; | |||||
| size_t m_; | |||||
| size_t n_; | |||||
| size_t k_; | |||||
| bool is_null_input_; | |||||
| std::vector<size_t> x1_input_shape_; | |||||
| std::vector<size_t> x2_input_shape_; | |||||
| size_t input_size_x1_; | |||||
| size_t input_size_x2_; | |||||
| size_t output_size_; | |||||
| std::vector<size_t> x1_transpose_fwd_; // For transpose | |||||
| std::vector<size_t> x2_transpose_fwd_; | |||||
| std::vector<int> x1_reshape_fwd_; // For mulitplication shape | |||||
| std::vector<int> x2_reshape_fwd_; | |||||
| cublasHandle_t handle_; | |||||
| cudaDataType_t dtype_a_; | |||||
| cudaDataType_t dtype_b_; | |||||
| cudaDataType_t dtype_c_; | |||||
| cublasGemmAlgo_t algo_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif | |||||
| @@ -156,48 +156,6 @@ def bprop_batchmatmul(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.TensorDot) | |||||
| def bprop_tensordot(self): | |||||
| """Grad definition for `TensorDot` operation.""" | |||||
| mul_op_x1 = P.MatMul(transpose_a=False, transpose_b=True) | |||||
| mul_op_x2 = P.MatMul(transpose_a=True, transpose_b=False) | |||||
| invert_permutation_op = P.InvertPermutation() | |||||
| transpose_op = P.Transpose() | |||||
| reshape_op = P.Reshape() | |||||
| # pull transformation specifics from P.TensorDot class | |||||
| x1_transpose_fwd = tuple(self.x1_transpose_fwd) | |||||
| x2_transpose_fwd = tuple(self.x2_transpose_fwd) | |||||
| x1_reshape_fwd = tuple(self.x1_reshape_fwd) | |||||
| x2_reshape_fwd = tuple(self.x2_reshape_fwd) | |||||
| dout_reshape = (self.x1_reshape_fwd[0], self.x2_reshape_fwd[1]) | |||||
| # precalculated in fwd pass due to easier computation | |||||
| x1_reshape_back = tuple(self.x1_reshape_back) | |||||
| x2_reshape_back = tuple(self.x2_reshape_back) | |||||
| def bprop(x1, x2, out, dout): | |||||
| # reshape dy values to 2D for MatMul | |||||
| dout_reshaped = reshape_op(dout, dout_reshape) | |||||
| # transform inputs to forward pass equivalents | |||||
| x1_transpose = transpose_op(x1, x1_transpose_fwd) | |||||
| x2_transpose = transpose_op(x2, x2_transpose_fwd) | |||||
| x1_reshape = reshape_op(x1_transpose, x1_reshape_fwd) | |||||
| x2_reshape = reshape_op(x2_transpose, x2_reshape_fwd) | |||||
| # calculate dx values for x1 and x2 | |||||
| dx1_interim = mul_op_x1(dout_reshaped, x2_reshape) | |||||
| dx2_interim = mul_op_x2(x1_reshape, dout_reshaped) | |||||
| # reverse transformations on dx values for both inputs | |||||
| dx1_reshape = reshape_op(dx1_interim, x1_reshape_back) | |||||
| dx2_reshape = reshape_op(dx2_interim, x2_reshape_back) | |||||
| dx1_retranspose_axes = invert_permutation_op(x1_transpose_fwd) | |||||
| dx2_retranspose_axes = invert_permutation_op(x2_transpose_fwd) | |||||
| dx1_transpose = transpose_op(dx1_reshape, dx1_retranspose_axes) | |||||
| dx2_transpose = transpose_op(dx2_reshape, dx2_retranspose_axes) | |||||
| return dx1_transpose, dx2_transpose | |||||
| return bprop | |||||
| @bprop_getters.register(P.TensorAdd) | @bprop_getters.register(P.TensorAdd) | ||||
| def get_bprop_tensor_add(self): | def get_bprop_tensor_add(self): | ||||
| """Grad definition for `TensorAdd` operation.""" | """Grad definition for `TensorAdd` operation.""" | ||||
| @@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add | |||||
| from .multitype_ops.ones_like_impl import ones_like | from .multitype_ops.ones_like_impl import ones_like | ||||
| from .multitype_ops.zeros_like_impl import zeros_like | from .multitype_ops.zeros_like_impl import zeros_like | ||||
| from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial | from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial | ||||
| from .math_ops import count_nonzero | |||||
| from .math_ops import count_nonzero, TensorDot | |||||
| __all__ = [ | __all__ = [ | ||||
| @@ -50,4 +50,5 @@ __all__ = [ | |||||
| 'multinomial', | 'multinomial', | ||||
| 'clip_by_value', | 'clip_by_value', | ||||
| 'clip_by_global_norm', | 'clip_by_global_norm', | ||||
| 'count_nonzero'] | |||||
| 'count_nonzero', | |||||
| 'TensorDot'] | |||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """math Operations.""" | """math Operations.""" | ||||
| import numpy as np | |||||
| from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils | from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| @@ -20,7 +21,7 @@ from mindspore.ops.primitive import constexpr | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from .. import operations as P | from .. import operations as P | ||||
| # count_nonzero | |||||
| @constexpr | @constexpr | ||||
| def _check_validate_axis(axis, name): | def _check_validate_axis(axis, name): | ||||
| if isinstance(axis, (tuple, list)): | if isinstance(axis, (tuple, list)): | ||||
| @@ -73,3 +74,139 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32): | |||||
| nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype) | nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype) | ||||
| return nonzero_num | return nonzero_num | ||||
| # TensorDot | |||||
| @constexpr | |||||
| def _int_to_tuple_conv(axes): | |||||
| """ | |||||
| Converts ints to tuples in input axes, expected by most validation checks. | |||||
| """ | |||||
| for x in [0, 1]: | |||||
| if isinstance(axes[x], int): | |||||
| axes[x] = (axes[x],) | |||||
| return axes | |||||
| @constexpr | |||||
| def _check_axes(axes): | |||||
| """ | |||||
| Check for validity and type of axes passed to function. | |||||
| """ | |||||
| validator.check_value_type('axes', axes, [int, tuple, list], "TensorDot") | |||||
| if not isinstance(axes, int): | |||||
| axes = list(axes) # to avoid immutability issues | |||||
| if len(axes) != 2: | |||||
| raise ValueError("Require two axes inputs, given less") | |||||
| axes = _int_to_tuple_conv(axes) # convert before length checks | |||||
| if len(axes[0]) != len(axes[1]): | |||||
| raise ValueError("Axes have to be the same size/length") | |||||
| if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])): | |||||
| raise ValueError("Axes cannot have duplicating values") | |||||
| return axes | |||||
| @constexpr | |||||
| def _typecheck_input(x1_type, x2_type): | |||||
| """ | |||||
| Check input tensor types to be valid and confirm they are the same type. | |||||
| """ | |||||
| const_utils.check_valid_type(x1_type, [mstype.float32, mstype.float16], 'x1') | |||||
| const_utils.check_valid_type(x2_type, [mstype.float32, mstype.float16], 'x2') | |||||
| if x1_type != x2_type: | |||||
| raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ') | |||||
| @constexpr | |||||
| def _validate_input(x1_shape, x2_shape, axes): | |||||
| """ | |||||
| Convert from single int axes to 2d tuple if required and check for validity with inputs. | |||||
| """ | |||||
| if isinstance(axes, int): | |||||
| if axes <= 0: | |||||
| # outer product, no input validation required | |||||
| return ([], []) | |||||
| if axes > len(x1_shape) or axes > len(x2_shape): | |||||
| raise ValueError( | |||||
| "Axes value too high for given input arrays dimensions.") | |||||
| x1_ind = tuple(range(len(x1_shape))[-1 * axes:]) | |||||
| x2_ind = tuple(range(len(x2_shape))[:axes]) | |||||
| axes = tuple((x1_ind, x2_ind)) | |||||
| axes = _int_to_tuple_conv(axes) | |||||
| for i in range(len(axes[0])): # sizes already validated | |||||
| if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]: | |||||
| raise ValueError( | |||||
| "Given Axes are incompatible with given input arrays") | |||||
| return axes | |||||
| @constexpr | |||||
| def _calc_new_shape(shape, axes, position=0): | |||||
| """ | |||||
| Calculate transpose and reshape parameters for input transformations, | |||||
| 'position' refers to whether tensor is first or second in the op. | |||||
| """ | |||||
| contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position]) | |||||
| prod_contraction = int(np.prod([shape[i] for i in contraction_axes])) | |||||
| free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes) | |||||
| free_dims = tuple(shape[i] for i in free_axes) | |||||
| prod_free = int(np.prod(free_dims)) | |||||
| transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes | |||||
| new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction) | |||||
| return new_shape, transpose_perm, free_dims | |||||
| def TensorDot(x1, x2, axes): | |||||
| """ | |||||
| Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`. | |||||
| Contraction allows for the summation of products of elements of `a` and `b` on specified axes. | |||||
| The same number of axes must be specified for both x1 and x2, and values must be within range | |||||
| of number of dims of both `a` and `b`. | |||||
| Selected dims in both inputs must also match. | |||||
| axes = 0 leads to outer product, and axes = 1 leads to normal matrix multiplication. | |||||
| axes = 1 is the same as axes = ((0,),(1,) where length of input shape is 2 for both `a` and `b` | |||||
| axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b` | |||||
| Inputs: | |||||
| - **x1** (Tensor): First tensor in TensorDot op with datatype float16 or float32 | |||||
| - **x2** (Tensor): Second tensor in TensorDot op with datatype float16 or float32 | |||||
| - **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or | |||||
| tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed, | |||||
| automatically picks up first N dims from `a` input shape and last N dims from `b` input shape. | |||||
| Outputs: | |||||
| Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not | |||||
| contracted in both inputs | |||||
| Examples: | |||||
| >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) | |||||
| >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32) | |||||
| >>> output = C.TensorDot(input_x1, input_x2, ((0,1),(1,2))) | |||||
| """ | |||||
| shape_op = P.Shape() | |||||
| reshape_op = P.Reshape() | |||||
| transpose_op = P.Transpose() | |||||
| matmul_op = P.MatMul(False, False) | |||||
| # input validity checks | |||||
| x1_shape = shape_op(x1) | |||||
| x2_shape = shape_op(x2) | |||||
| x1_type = F.dtype(x1) | |||||
| x2_type = F.dtype(x2) | |||||
| axes = _check_axes(axes) | |||||
| _typecheck_input(x1_type, x2_type) | |||||
| # input compability check & axes format update | |||||
| axes = _validate_input(x1_shape, x2_shape, axes) | |||||
| x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) | |||||
| x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) | |||||
| output_shape = x1_ret + x2_ret # combine free axes from both inputs | |||||
| # run TensorDot op | |||||
| x1_transposed = transpose_op(x1, x1_transpose_fwd) | |||||
| x2_transposed = transpose_op(x2, x2_transpose_fwd) | |||||
| x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd) | |||||
| x2_reshaped = reshape_op(x2_transposed, x2_reshape_fwd) | |||||
| mul_result = matmul_op(x1_reshaped, x2_reshaped) | |||||
| final_result = reshape_op(mul_result, output_shape) | |||||
| return final_result | |||||
| @@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||||
| NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | ||||
| Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | ||||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | ||||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorDot) | |||||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) | |||||
| from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | ||||
| RandomCategorical, StandardLaplace, Multinomial) | RandomCategorical, StandardLaplace, Multinomial) | ||||
| @@ -775,127 +775,6 @@ class BatchMatMul(MatMul): | |||||
| 'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}') | 'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}') | ||||
| class TensorDot(PrimitiveWithInfer): | |||||
| """ | |||||
| Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`. | |||||
| Contraction allows for the summation of products of elements of `a` and `b` on specified axes. | |||||
| The same number of axes must be specified for both x1 and x2, and values must be within range | |||||
| of number of dims of both `a` and `b`. | |||||
| Selected dims in both inputs must also match. | |||||
| axes = 0 leads to outer product, and axes = 1 leads to normal matrix multiplication. | |||||
| axes = 1 is the same as axes = ((0,),(1,) where length of input shape is 2 for both `a` and `b` | |||||
| axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b` | |||||
| Args: | |||||
| **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or | |||||
| tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed, | |||||
| automatically picks up first N dims from `a` input shape and last N dims from `b` input shape. | |||||
| Inputs: | |||||
| - **x1** (Tensor): First tensor in TensorDot op with datatype float16 or float32 | |||||
| - **x2** (Tensor): Second tensor in TensorDot op with datatype float16 or float32 | |||||
| Outputs: | |||||
| Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not | |||||
| contracted in both inputs | |||||
| Examples: | |||||
| >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) | |||||
| >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32) | |||||
| >>> tensordot = P.TensorDot(((0,1),(1,2))) | |||||
| >>> output = tensordot(input_x1, input_x2) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, axes): | |||||
| self.axes = axes | |||||
| validator.check_value_type('axes', axes, [int, tuple, list], self.name) | |||||
| if not isinstance(self.axes, int): | |||||
| self.axes = list(self.axes) # to avoid immutability issues | |||||
| if len(self.axes) != 2: | |||||
| raise ValueError("Require two axes inputs, given less") | |||||
| self.int_to_tuple_conv() # convert before length checks | |||||
| if len(self.axes[0]) != len(self.axes[1]): | |||||
| raise ValueError("Axes have to be the same size/length") | |||||
| if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])): | |||||
| raise ValueError("Axes cannot have duplicating values") | |||||
| self.add_prim_attr("axes", self.axes) | |||||
| def int_to_tuple_conv(self): | |||||
| """ | |||||
| Converts ints to tuples in input axes, expected by most validation checks. | |||||
| """ | |||||
| for x in [0, 1]: | |||||
| if isinstance(self.axes[x], int): | |||||
| self.axes[x] = (self.axes[x],) | |||||
| def check_input_axes(self, x1_shape, x2_shape): | |||||
| """ | |||||
| Convert from single int axes to 2d tuple if required and check for validity with inputs. | |||||
| """ | |||||
| if isinstance(self.axes, int): | |||||
| if self.axes <= 0: | |||||
| # outer product, no input validation required | |||||
| self.axes = ([], []) # no axes selected for either | |||||
| return | |||||
| if self.axes > len(x1_shape) or self.axes > len(x2_shape): | |||||
| raise ValueError( | |||||
| "Axes value too high for given input arrays dimensions.") | |||||
| x1_ind = tuple(range(len(x1_shape))[-1 * self.axes:]) | |||||
| x2_ind = tuple(range(len(x2_shape))[:self.axes]) | |||||
| self.axes = tuple((x1_ind, x2_ind)) | |||||
| self.int_to_tuple_conv() | |||||
| for i in range(len(self.axes[0])): # sizes already validated | |||||
| if x1_shape[self.axes[0][i]] != x2_shape[self.axes[1][i]]: | |||||
| raise ValueError( | |||||
| "Given Axes are incompatible with given input arrays") | |||||
| def calc_new_shape(self, shape, position=0): | |||||
| """ | |||||
| Calculate transpose and reshape parameters for input transformations, | |||||
| 'position' refers to whether tensor is first or second in the op. | |||||
| """ | |||||
| contraction_axes = [i if i >= 0 else i + len(shape) for i in self.axes[position]] | |||||
| prod_contraction = int(np.prod([shape[i] for i in contraction_axes])) | |||||
| free_axes = [i for i in range(len(shape)) if i not in contraction_axes] | |||||
| free_dims = [shape[i] for i in free_axes] | |||||
| prod_free = int(np.prod(free_dims)) | |||||
| transpose_perm = list(contraction_axes) + free_axes if position else free_axes + list(contraction_axes) | |||||
| new_shape = [prod_contraction, prod_free] if position else [prod_free, prod_contraction] | |||||
| return new_shape, transpose_perm, free_dims | |||||
| def generate_transform_dims(self, x1_shape, x2_shape): | |||||
| """ | |||||
| Initiate calls for input transform calculations and calculate paramters for output | |||||
| and for backprop tranformations. | |||||
| """ | |||||
| self.x1_reshape_fwd, self.x1_transpose_fwd, x1_ret = self.calc_new_shape(x1_shape, 0) | |||||
| self.x2_reshape_fwd, self.x2_transpose_fwd, x2_ret = self.calc_new_shape(x2_shape, 1) | |||||
| self.output_shape = x1_ret + x2_ret # combine free axes from both inputs | |||||
| self.x1_reshape_back = [x1_shape[x] for x in self.x1_transpose_fwd] | |||||
| self.x2_reshape_back = [x2_shape[x] for x in self.x2_transpose_fwd] | |||||
| def infer_shape(self, x1, x2): | |||||
| self.check_input_axes(x1, x2) | |||||
| self.generate_transform_dims(x1, x2) | |||||
| # processed parameters for reading directly into kernel | |||||
| self.add_prim_attr('x1_transpose_fwd', self.x1_transpose_fwd) | |||||
| self.add_prim_attr('x2_transpose_fwd', self.x2_transpose_fwd) | |||||
| self.add_prim_attr('x1_reshape_fwd', self.x1_reshape_fwd) | |||||
| self.add_prim_attr('x2_reshape_fwd', self.x2_reshape_fwd) | |||||
| return self.output_shape | |||||
| def infer_dtype(self, x1, x2): | |||||
| args = {"x1": x1, "x2": x2} | |||||
| valid_dtypes = [mstype.float16, mstype.float32] | |||||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) | |||||
| return x1 | |||||
| class CumSum(PrimitiveWithInfer): | class CumSum(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Computes the cumulative sum of input tensor along axis. | Computes the cumulative sum of input tensor along axis. | ||||
| @@ -20,17 +20,16 @@ import mindspore | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| class NetTensorDot(nn.Cell): | class NetTensorDot(nn.Cell): | ||||
| def __init__(self, axes): | def __init__(self, axes): | ||||
| super(NetTensorDot, self).__init__() | super(NetTensorDot, self).__init__() | ||||
| self.td = P.TensorDot(axes) | |||||
| self.axes = axes | |||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| return self.td(x, y) | |||||
| return C.TensorDot(x, y, self.axes) | |||||
| class GradNetwork(nn.Cell): | class GradNetwork(nn.Cell): | ||||
| @@ -183,6 +182,7 @@ def test_tensor_dot_outer(): | |||||
| x2_tensor = Tensor(x2, dtype=mindspore.float32) | x2_tensor = Tensor(x2, dtype=mindspore.float32) | ||||
| network = NetTensorDot(axes) | network = NetTensorDot(axes) | ||||
| ms_result_np = network(x1_tensor, x2_tensor).asnumpy() | ms_result_np = network(x1_tensor, x2_tensor).asnumpy() | ||||
| np_result = np.tensordot(x1, x2, axes) | np_result = np.tensordot(x1, x2, axes) | ||||
| np.testing.assert_array_almost_equal(ms_result_np, np_result) | np.testing.assert_array_almost_equal(ms_result_np, np_result) | ||||