Browse Source

!8386 [MS] Changing TensorDot from P Operations op to Composite op

From: @danishnxt
Reviewed-by: @tom__chen,@robingrosman
Signed-off-by: @robingrosman
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a511e32cf8
8 changed files with 145 additions and 422 deletions
  1. +0
    -30
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.cc
  2. +0
    -222
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h
  3. +0
    -42
      mindspore/ops/_grad/grad_math_ops.py
  4. +3
    -2
      mindspore/ops/composite/__init__.py
  5. +138
    -1
      mindspore/ops/composite/math_ops.py
  6. +1
    -1
      mindspore/ops/operations/__init__.py
  7. +0
    -121
      mindspore/ops/operations/math_ops.py
  8. +3
    -3
      tests/st/ops/gpu/test_tensordot_op.py

+ 0
- 30
mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.cc View File

@@ -1,30 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
TensorDot,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TensorDotGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
TensorDot,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
TensorDotGpuKernel, half)
} // namespace kernel
} // namespace mindspore

+ 0
- 222
mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h View File

@@ -1,222 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
#include "utils/convert_utils.h"
namespace mindspore {
namespace kernel {
template <typename T>
class TensorDotGpuKernel : public GpuKernel {
public:
TensorDotGpuKernel()
: batch_(0),
m_(0),
n_(0),
k_(0),
is_null_input_(false),
handle_(nullptr),
dtype_a_(CUDA_R_32F),
dtype_b_(CUDA_R_32F),
dtype_c_(CUDA_R_32F),
algo_(CUBLAS_GEMM_DEFAULT) {}
~TensorDotGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *x1_input = GetDeviceAddress<T>(inputs, 0);
T *x2_input = GetDeviceAddress<T>(inputs, 1);
size_t *x1_input_shape = GetDeviceAddress<size_t>(workspace, 0);
size_t *x2_input_shape = GetDeviceAddress<size_t>(workspace, 1);
size_t *x1_input_trans_axes = GetDeviceAddress<size_t>(workspace, 2);
size_t *x2_input_trans_axes = GetDeviceAddress<size_t>(workspace, 3);
// transposed interim values moved to workspace, then multiplied
T *x1_reshape = GetDeviceAddress<T>(workspace, 4);
T *x2_reshape = GetDeviceAddress<T>(workspace, 5);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
// Transpose X1
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(x1_input_shape, &x1_input_shape_[0], x1_input_shape_.size() * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync x1_input_shape failed");
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(x1_input_trans_axes, &x1_transpose_fwd_[0], x1_input_shape_.size() * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_axis_x1 failed");
int size_x1 = SizeToInt(input_size_x1_ / sizeof(T));
CalTranspose(size_x1, x1_input, x1_input_shape, x1_input_trans_axes, SizeToInt(x1_input_shape_.size()), x1_reshape,
reinterpret_cast<cudaStream_t>(stream_ptr));
// Transpose X2
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(x2_input_shape, &x2_input_shape_[0], (x2_input_shape_.size() * sizeof(size_t)),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync x2_input_shape failed");
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(x2_input_trans_axes, &x2_transpose_fwd_[0], (x2_input_shape_.size() * sizeof(size_t)),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_axis_x2 failed");
int size_x2 = SizeToInt(input_size_x2_ / sizeof(T));
CalTranspose(size_x2, x2_input, x2_input_shape, x2_input_trans_axes, SizeToInt(x2_input_shape_.size()), x2_reshape,
reinterpret_cast<cudaStream_t>(stream_ptr));
// Matrix Mulitply interim transposed values with GEMM
const float alpha = 1; // constants for cublas API
const float beta = 0;
const int lda = SizeToInt(k_);
const int ldb = SizeToInt(n_);
const int ldc = n_;
auto stride_a = SizeToInt(m_ * k_);
auto stride_b = SizeToInt(k_ * n_);
auto stride_c = SizeToInt(m_ * n_);
try {
CHECK_CUBLAS_RET_WITH_EXCEPT(
cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
&alpha, x2_reshape, dtype_b_, ldb, stride_b, x1_reshape, dtype_a_, lda, stride_a,
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
"cublasSgemm Call Fail");
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx";
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
// checking for FP16 op, switch to Tensor Core if available
dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1)));
dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0)));
if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) {
MS_LOG(INFO) << "Input and output type is float16, allow to use Tensor Core operations if possible";
algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
auto tmp_x1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto tmp_x2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
input_size_x1_ = sizeof(T);
for (size_t i = 0; i < tmp_x1_shape.size(); i++) {
x1_input_shape_.push_back(tmp_x1_shape[i]);
input_size_x1_ *= tmp_x1_shape[i];
}
input_size_x2_ = sizeof(T);
for (size_t i = 0; i < tmp_x2_shape.size(); i++) {
x2_input_shape_.push_back(tmp_x2_shape[i]);
input_size_x2_ *= tmp_x2_shape[i];
}
// holding in temp values to convert to size_t vectors
std::vector<int> x1_transpose_fwd_temp;
std::vector<int64_t> x1_transpose_me = GetAttr<std::vector<int64_t>>(kernel_node, "x1_transpose_fwd");
(void)std::transform(x1_transpose_me.begin(), x1_transpose_me.end(), std::back_inserter(x1_transpose_fwd_temp),
[](const int64_t &value) { return static_cast<int>(value); });
std::vector<int> x2_transpose_fwd_temp;
std::vector<int64_t> x2_transpose_me = GetAttr<std::vector<int64_t>>(kernel_node, "x2_transpose_fwd");
(void)std::transform(x2_transpose_me.begin(), x2_transpose_me.end(), std::back_inserter(x2_transpose_fwd_temp),
[](const int64_t &value) { return static_cast<int>(value); });
for (size_t i = 0; i < x1_transpose_fwd_temp.size(); i++) {
x1_transpose_fwd_.push_back(x1_transpose_fwd_temp[i]);
}
for (size_t i = 0; i < x2_transpose_fwd_temp.size(); i++) {
x2_transpose_fwd_.push_back(x2_transpose_fwd_temp[i]);
}
// values to decide multiplication call specifics
std::vector<int64_t> x1_reshape_me = GetAttr<std::vector<int64_t>>(kernel_node, "x1_reshape_fwd");
(void)std::transform(x1_reshape_me.begin(), x1_reshape_me.end(), std::back_inserter(x1_reshape_fwd_),
[](const int64_t &value) { return static_cast<int>(value); });
std::vector<int64_t> x2_reshape_me = GetAttr<std::vector<int64_t>>(kernel_node, "x2_reshape_fwd");
(void)std::transform(x2_reshape_me.begin(), x2_reshape_me.end(), std::back_inserter(x2_reshape_fwd_),
[](const int64_t &value) { return static_cast<int>(value); });
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
output_size_ = sizeof(T);
for (size_t i = 0; i < output_shape.size(); i++) {
output_size_ *= output_shape[i];
}
is_null_input_ = CHECK_NULL_INPUT(output_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "input is null";
InitSizeLists();
return true;
}
m_ = x1_reshape_fwd_[0];
k_ = x1_reshape_fwd_[1];
n_ = x2_reshape_fwd_[1];
batch_ = 1; // kept as a single multiplication operation
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
size_t size_t_size = sizeof(size_t);
input_size_list_.push_back(input_size_x1_);
input_size_list_.push_back(input_size_x2_);
workspace_size_list_.push_back(x1_input_shape_.size() * size_t_size);
workspace_size_list_.push_back(x2_input_shape_.size() * size_t_size);
workspace_size_list_.push_back(x1_transpose_fwd_.size() * size_t_size);
workspace_size_list_.push_back(x2_transpose_fwd_.size() * size_t_size);
workspace_size_list_.push_back(input_size_x1_);
workspace_size_list_.push_back(input_size_x2_);
output_size_list_.push_back(output_size_);
}
private:
size_t batch_;
size_t m_;
size_t n_;
size_t k_;
bool is_null_input_;
std::vector<size_t> x1_input_shape_;
std::vector<size_t> x2_input_shape_;
size_t input_size_x1_;
size_t input_size_x2_;
size_t output_size_;
std::vector<size_t> x1_transpose_fwd_; // For transpose
std::vector<size_t> x2_transpose_fwd_;
std::vector<int> x1_reshape_fwd_; // For mulitplication shape
std::vector<int> x2_reshape_fwd_;
cublasHandle_t handle_;
cudaDataType_t dtype_a_;
cudaDataType_t dtype_b_;
cudaDataType_t dtype_c_;
cublasGemmAlgo_t algo_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif

+ 0
- 42
mindspore/ops/_grad/grad_math_ops.py View File

@@ -156,48 +156,6 @@ def bprop_batchmatmul(self):
return bprop return bprop




@bprop_getters.register(P.TensorDot)
def bprop_tensordot(self):
"""Grad definition for `TensorDot` operation."""
mul_op_x1 = P.MatMul(transpose_a=False, transpose_b=True)
mul_op_x2 = P.MatMul(transpose_a=True, transpose_b=False)
invert_permutation_op = P.InvertPermutation()
transpose_op = P.Transpose()
reshape_op = P.Reshape()

# pull transformation specifics from P.TensorDot class
x1_transpose_fwd = tuple(self.x1_transpose_fwd)
x2_transpose_fwd = tuple(self.x2_transpose_fwd)
x1_reshape_fwd = tuple(self.x1_reshape_fwd)
x2_reshape_fwd = tuple(self.x2_reshape_fwd)
dout_reshape = (self.x1_reshape_fwd[0], self.x2_reshape_fwd[1])

# precalculated in fwd pass due to easier computation
x1_reshape_back = tuple(self.x1_reshape_back)
x2_reshape_back = tuple(self.x2_reshape_back)

def bprop(x1, x2, out, dout):
# reshape dy values to 2D for MatMul
dout_reshaped = reshape_op(dout, dout_reshape)
# transform inputs to forward pass equivalents
x1_transpose = transpose_op(x1, x1_transpose_fwd)
x2_transpose = transpose_op(x2, x2_transpose_fwd)
x1_reshape = reshape_op(x1_transpose, x1_reshape_fwd)
x2_reshape = reshape_op(x2_transpose, x2_reshape_fwd)
# calculate dx values for x1 and x2
dx1_interim = mul_op_x1(dout_reshaped, x2_reshape)
dx2_interim = mul_op_x2(x1_reshape, dout_reshaped)
# reverse transformations on dx values for both inputs
dx1_reshape = reshape_op(dx1_interim, x1_reshape_back)
dx2_reshape = reshape_op(dx2_interim, x2_reshape_back)
dx1_retranspose_axes = invert_permutation_op(x1_transpose_fwd)
dx2_retranspose_axes = invert_permutation_op(x2_transpose_fwd)
dx1_transpose = transpose_op(dx1_reshape, dx1_retranspose_axes)
dx2_transpose = transpose_op(dx2_reshape, dx2_retranspose_axes)
return dx1_transpose, dx2_transpose
return bprop


@bprop_getters.register(P.TensorAdd) @bprop_getters.register(P.TensorAdd)
def get_bprop_tensor_add(self): def get_bprop_tensor_add(self):
"""Grad definition for `TensorAdd` operation.""" """Grad definition for `TensorAdd` operation."""


+ 3
- 2
mindspore/ops/composite/__init__.py View File

@@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
from .math_ops import count_nonzero
from .math_ops import count_nonzero, TensorDot




__all__ = [ __all__ = [
@@ -50,4 +50,5 @@ __all__ = [
'multinomial', 'multinomial',
'clip_by_value', 'clip_by_value',
'clip_by_global_norm', 'clip_by_global_norm',
'count_nonzero']
'count_nonzero',
'TensorDot']

+ 138
- 1
mindspore/ops/composite/math_ops.py View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""math Operations.""" """math Operations."""
import numpy as np
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
@@ -20,7 +21,7 @@ from mindspore.ops.primitive import constexpr
from mindspore.ops import functional as F from mindspore.ops import functional as F
from .. import operations as P from .. import operations as P


# count_nonzero
@constexpr @constexpr
def _check_validate_axis(axis, name): def _check_validate_axis(axis, name):
if isinstance(axis, (tuple, list)): if isinstance(axis, (tuple, list)):
@@ -73,3 +74,139 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype) nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype)


return nonzero_num return nonzero_num

# TensorDot
@constexpr
def _int_to_tuple_conv(axes):
"""
Converts ints to tuples in input axes, expected by most validation checks.
"""
for x in [0, 1]:
if isinstance(axes[x], int):
axes[x] = (axes[x],)
return axes


@constexpr
def _check_axes(axes):
"""
Check for validity and type of axes passed to function.
"""
validator.check_value_type('axes', axes, [int, tuple, list], "TensorDot")
if not isinstance(axes, int):
axes = list(axes) # to avoid immutability issues
if len(axes) != 2:
raise ValueError("Require two axes inputs, given less")
axes = _int_to_tuple_conv(axes) # convert before length checks
if len(axes[0]) != len(axes[1]):
raise ValueError("Axes have to be the same size/length")
if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])):
raise ValueError("Axes cannot have duplicating values")
return axes


@constexpr
def _typecheck_input(x1_type, x2_type):
"""
Check input tensor types to be valid and confirm they are the same type.
"""
const_utils.check_valid_type(x1_type, [mstype.float32, mstype.float16], 'x1')
const_utils.check_valid_type(x2_type, [mstype.float32, mstype.float16], 'x2')
if x1_type != x2_type:
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')


@constexpr
def _validate_input(x1_shape, x2_shape, axes):
"""
Convert from single int axes to 2d tuple if required and check for validity with inputs.
"""
if isinstance(axes, int):
if axes <= 0:
# outer product, no input validation required
return ([], [])
if axes > len(x1_shape) or axes > len(x2_shape):
raise ValueError(
"Axes value too high for given input arrays dimensions.")
x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
x2_ind = tuple(range(len(x2_shape))[:axes])
axes = tuple((x1_ind, x2_ind))
axes = _int_to_tuple_conv(axes)
for i in range(len(axes[0])): # sizes already validated
if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]:
raise ValueError(
"Given Axes are incompatible with given input arrays")
return axes


@constexpr
def _calc_new_shape(shape, axes, position=0):
"""
Calculate transpose and reshape parameters for input transformations,
'position' refers to whether tensor is first or second in the op.
"""
contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
free_dims = tuple(shape[i] for i in free_axes)
prod_free = int(np.prod(free_dims))

transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
return new_shape, transpose_perm, free_dims


def TensorDot(x1, x2, axes):
"""
Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.

Contraction allows for the summation of products of elements of `a` and `b` on specified axes.
The same number of axes must be specified for both x1 and x2, and values must be within range
of number of dims of both `a` and `b`.

Selected dims in both inputs must also match.

axes = 0 leads to outer product, and axes = 1 leads to normal matrix multiplication.
axes = 1 is the same as axes = ((0,),(1,) where length of input shape is 2 for both `a` and `b`
axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b`

Inputs:
- **x1** (Tensor): First tensor in TensorDot op with datatype float16 or float32
- **x2** (Tensor): Second tensor in TensorDot op with datatype float16 or float32
- **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or
tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
automatically picks up first N dims from `a` input shape and last N dims from `b` input shape.

Outputs:
Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not
contracted in both inputs

Examples:
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
>>> output = C.TensorDot(input_x1, input_x2, ((0,1),(1,2)))
"""
shape_op = P.Shape()
reshape_op = P.Reshape()
transpose_op = P.Transpose()
matmul_op = P.MatMul(False, False)
# input validity checks
x1_shape = shape_op(x1)
x2_shape = shape_op(x2)
x1_type = F.dtype(x1)
x2_type = F.dtype(x2)
axes = _check_axes(axes)
_typecheck_input(x1_type, x2_type)
# input compability check & axes format update
axes = _validate_input(x1_shape, x2_shape, axes)
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
output_shape = x1_ret + x2_ret # combine free axes from both inputs
# run TensorDot op
x1_transposed = transpose_op(x1, x1_transpose_fwd)
x2_transposed = transpose_op(x2, x2_transpose_fwd)
x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd)
x2_reshaped = reshape_op(x2_transposed, x2_reshape_fwd)
mul_result = matmul_op(x1_reshaped, x2_reshaped)
final_result = reshape_op(mul_result, output_shape)
return final_result

+ 1
- 1
mindspore/ops/operations/__init__.py View File

@@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorDot)
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)


from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial) RandomCategorical, StandardLaplace, Multinomial)


+ 0
- 121
mindspore/ops/operations/math_ops.py View File

@@ -775,127 +775,6 @@ class BatchMatMul(MatMul):
'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}') 'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}')




class TensorDot(PrimitiveWithInfer):
"""
Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.

Contraction allows for the summation of products of elements of `a` and `b` on specified axes.
The same number of axes must be specified for both x1 and x2, and values must be within range
of number of dims of both `a` and `b`.

Selected dims in both inputs must also match.

axes = 0 leads to outer product, and axes = 1 leads to normal matrix multiplication.
axes = 1 is the same as axes = ((0,),(1,) where length of input shape is 2 for both `a` and `b`
axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b`

Args:
**axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or
tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
automatically picks up first N dims from `a` input shape and last N dims from `b` input shape.

Inputs:
- **x1** (Tensor): First tensor in TensorDot op with datatype float16 or float32
- **x2** (Tensor): Second tensor in TensorDot op with datatype float16 or float32

Outputs:
Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not
contracted in both inputs

Examples:
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
>>> tensordot = P.TensorDot(((0,1),(1,2)))
>>> output = tensordot(input_x1, input_x2)
"""

@prim_attr_register
def __init__(self, axes):
self.axes = axes
validator.check_value_type('axes', axes, [int, tuple, list], self.name)
if not isinstance(self.axes, int):
self.axes = list(self.axes) # to avoid immutability issues
if len(self.axes) != 2:
raise ValueError("Require two axes inputs, given less")
self.int_to_tuple_conv() # convert before length checks
if len(self.axes[0]) != len(self.axes[1]):
raise ValueError("Axes have to be the same size/length")
if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])):
raise ValueError("Axes cannot have duplicating values")
self.add_prim_attr("axes", self.axes)

def int_to_tuple_conv(self):
"""
Converts ints to tuples in input axes, expected by most validation checks.
"""
for x in [0, 1]:
if isinstance(self.axes[x], int):
self.axes[x] = (self.axes[x],)

def check_input_axes(self, x1_shape, x2_shape):
"""
Convert from single int axes to 2d tuple if required and check for validity with inputs.
"""
if isinstance(self.axes, int):
if self.axes <= 0:
# outer product, no input validation required
self.axes = ([], []) # no axes selected for either
return
if self.axes > len(x1_shape) or self.axes > len(x2_shape):
raise ValueError(
"Axes value too high for given input arrays dimensions.")
x1_ind = tuple(range(len(x1_shape))[-1 * self.axes:])
x2_ind = tuple(range(len(x2_shape))[:self.axes])
self.axes = tuple((x1_ind, x2_ind))
self.int_to_tuple_conv()
for i in range(len(self.axes[0])): # sizes already validated
if x1_shape[self.axes[0][i]] != x2_shape[self.axes[1][i]]:
raise ValueError(
"Given Axes are incompatible with given input arrays")

def calc_new_shape(self, shape, position=0):
"""
Calculate transpose and reshape parameters for input transformations,
'position' refers to whether tensor is first or second in the op.
"""
contraction_axes = [i if i >= 0 else i + len(shape) for i in self.axes[position]]
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
free_axes = [i for i in range(len(shape)) if i not in contraction_axes]
free_dims = [shape[i] for i in free_axes]
prod_free = int(np.prod(free_dims))

transpose_perm = list(contraction_axes) + free_axes if position else free_axes + list(contraction_axes)
new_shape = [prod_contraction, prod_free] if position else [prod_free, prod_contraction]
return new_shape, transpose_perm, free_dims

def generate_transform_dims(self, x1_shape, x2_shape):
"""
Initiate calls for input transform calculations and calculate paramters for output
and for backprop tranformations.
"""
self.x1_reshape_fwd, self.x1_transpose_fwd, x1_ret = self.calc_new_shape(x1_shape, 0)
self.x2_reshape_fwd, self.x2_transpose_fwd, x2_ret = self.calc_new_shape(x2_shape, 1)
self.output_shape = x1_ret + x2_ret # combine free axes from both inputs
self.x1_reshape_back = [x1_shape[x] for x in self.x1_transpose_fwd]
self.x2_reshape_back = [x2_shape[x] for x in self.x2_transpose_fwd]

def infer_shape(self, x1, x2):
self.check_input_axes(x1, x2)
self.generate_transform_dims(x1, x2)
# processed parameters for reading directly into kernel
self.add_prim_attr('x1_transpose_fwd', self.x1_transpose_fwd)
self.add_prim_attr('x2_transpose_fwd', self.x2_transpose_fwd)
self.add_prim_attr('x1_reshape_fwd', self.x1_reshape_fwd)
self.add_prim_attr('x2_reshape_fwd', self.x2_reshape_fwd)
return self.output_shape

def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2}
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
return x1


class CumSum(PrimitiveWithInfer): class CumSum(PrimitiveWithInfer):
""" """
Computes the cumulative sum of input tensor along axis. Computes the cumulative sum of input tensor along axis.


+ 3
- 3
tests/st/ops/gpu/test_tensordot_op.py View File

@@ -20,17 +20,16 @@ import mindspore
from mindspore import Tensor from mindspore import Tensor
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C




class NetTensorDot(nn.Cell): class NetTensorDot(nn.Cell):
def __init__(self, axes): def __init__(self, axes):
super(NetTensorDot, self).__init__() super(NetTensorDot, self).__init__()
self.td = P.TensorDot(axes)
self.axes = axes


def construct(self, x, y): def construct(self, x, y):
return self.td(x, y)
return C.TensorDot(x, y, self.axes)




class GradNetwork(nn.Cell): class GradNetwork(nn.Cell):
@@ -183,6 +182,7 @@ def test_tensor_dot_outer():
x2_tensor = Tensor(x2, dtype=mindspore.float32) x2_tensor = Tensor(x2, dtype=mindspore.float32)


network = NetTensorDot(axes) network = NetTensorDot(axes)

ms_result_np = network(x1_tensor, x2_tensor).asnumpy() ms_result_np = network(x1_tensor, x2_tensor).asnumpy()
np_result = np.tensordot(x1, x2, axes) np_result = np.tensordot(x1, x2, axes)
np.testing.assert_array_almost_equal(ms_result_np, np_result) np.testing.assert_array_almost_equal(ms_result_np, np_result)


Loading…
Cancel
Save