| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_MATMUL_GPU_KERNEL_H | |||||
| #define MINDSPORE_MATMUL_GPU_KERNEL_H | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H | |||||
| #include <cublas_v2.h> | #include <cublas_v2.h> | ||||
| #include <cuda_runtime_api.h> | #include <cuda_runtime_api.h> | ||||
| @@ -30,19 +30,7 @@ namespace kernel { | |||||
| template <typename T> | template <typename T> | ||||
| class MatMulGpuKernel : public GpuKernel { | class MatMulGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| MatMulGpuKernel() | |||||
| : batch_(0), | |||||
| m_(0), | |||||
| n_(0), | |||||
| k_(0), | |||||
| is_null_input_(false), | |||||
| transpose_x1_(CUBLAS_OP_N), | |||||
| transpose_x2_(CUBLAS_OP_N), | |||||
| handle_(nullptr), | |||||
| dtype_a_(CUDA_R_32F), | |||||
| dtype_b_(CUDA_R_32F), | |||||
| dtype_c_(CUDA_R_32F), | |||||
| algo_(CUBLAS_GEMM_DEFAULT) {} | |||||
| MatMulGpuKernel() { ResetResource(); } | |||||
| ~MatMulGpuKernel() = default; | ~MatMulGpuKernel() = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||
| @@ -122,6 +110,24 @@ class MatMulGpuKernel : public GpuKernel { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void ResetResource() noexcept override { | |||||
| batch_ = 0; | |||||
| m_ = 0; | |||||
| n_ = 0; | |||||
| k_ = 0; | |||||
| is_null_input_ = false; | |||||
| transpose_x1_ = CUBLAS_OP_N; | |||||
| transpose_x2_ = CUBLAS_OP_N; | |||||
| handle_ = nullptr; | |||||
| dtype_a_ = CUDA_R_32F; | |||||
| dtype_b_ = CUDA_R_32F; | |||||
| dtype_c_ = CUDA_R_32F; | |||||
| algo_ = CUBLAS_GEMM_DEFAULT; | |||||
| input_size_list_.clear(); | |||||
| output_size_list_.clear(); | |||||
| workspace_size_list_.clear(); | |||||
| } | |||||
| protected: | protected: | ||||
| void InitSizeLists() override { | void InitSizeLists() override { | ||||
| size_t unit_size = sizeof(T); | size_t unit_size = sizeof(T); | ||||
| @@ -158,4 +164,4 @@ class MatMulGpuKernel : public GpuKernel { | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H | |||||
| @@ -289,7 +289,10 @@ AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| template <typename T> | template <typename T> | ||||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: a tuple or list or dict. | // Inputs: a tuple or list or dict. | ||||
| @@ -317,6 +317,7 @@ AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| std::make_shared<AbstractTensor>(start->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | std::make_shared<AbstractTensor>(start->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| @@ -326,5 +327,93 @@ AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||||
| auto input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | auto input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | ||||
| return input->Broaden(); | return input->Broaden(); | ||||
| } | } | ||||
| AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 2); | |||||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(x); | |||||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||||
| auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| MS_EXCEPTION_IF_NULL(y); | |||||
| MS_EXCEPTION_IF_NULL(y->shape()); | |||||
| if (x->shape()->shape().size() != 2 || y->shape()->shape().size() != 2) { | |||||
| MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2."; | |||||
| } | |||||
| ValuePtr TAptr = primitive->GetAttr("transpose_a"); | |||||
| ValuePtr TBptr = primitive->GetAttr("transpose_b"); | |||||
| bool TA = GetValue<bool>(TAptr); | |||||
| bool TB = GetValue<bool>(TBptr); | |||||
| ShapeVector x_min_shape = x->shape()->min_shape(); | |||||
| ShapeVector x_max_shape = x->shape()->max_shape(); | |||||
| ShapeVector y_min_shape = y->shape()->min_shape(); | |||||
| ShapeVector y_max_shape = y->shape()->max_shape(); | |||||
| (void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); | |||||
| (void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); | |||||
| ShapeVector ret_shape; | |||||
| ShapeVector ret_min_shape; | |||||
| ShapeVector ret_max_shape; | |||||
| auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { | |||||
| output.push_back(xshp[(TA ? 1 : 0)]); | |||||
| output.push_back(yshp[(TB ? 0 : 1)]); | |||||
| return; | |||||
| }; | |||||
| make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); | |||||
| make_shape(ret_min_shape, x_min_shape, y_min_shape); | |||||
| make_shape(ret_max_shape, x_max_shape, y_max_shape); | |||||
| return std::make_shared<AbstractTensor>(x->element(), | |||||
| std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); | |||||
| } | |||||
| AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 2); | |||||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(x); | |||||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||||
| auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| MS_EXCEPTION_IF_NULL(y); | |||||
| MS_EXCEPTION_IF_NULL(y->shape()); | |||||
| if (x->shape()->shape().size() != y->shape()->shape().size() || x->shape()->shape().size() < 3) { | |||||
| MS_LOG(EXCEPTION) | |||||
| << "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3."; | |||||
| } | |||||
| ValuePtr TAptr = primitive->GetAttr("transpose_a"); | |||||
| ValuePtr TBptr = primitive->GetAttr("transpose_b"); | |||||
| bool TA = GetValue<bool>(TAptr); | |||||
| bool TB = GetValue<bool>(TBptr); | |||||
| ShapeVector x_min_shape = x->shape()->min_shape(); | |||||
| ShapeVector x_max_shape = x->shape()->max_shape(); | |||||
| ShapeVector y_min_shape = y->shape()->min_shape(); | |||||
| ShapeVector y_max_shape = y->shape()->max_shape(); | |||||
| (void)CheckMinMaxShape(x->shape()->shape(), &x_min_shape, &x_max_shape); | |||||
| (void)CheckMinMaxShape(y->shape()->shape(), &y_min_shape, &y_max_shape); | |||||
| ShapeVector ret_shape; | |||||
| ShapeVector ret_min_shape; | |||||
| ShapeVector ret_max_shape; | |||||
| auto make_shape = [&TA, &TB](ShapeVector &output, const ShapeVector xshp, const ShapeVector yshp) -> void { | |||||
| for (size_t i = 0; i < xshp.size() - 2; i++) { | |||||
| if (xshp[i] != yshp[i]) { | |||||
| if (xshp[i] > 0 && yshp[i] > 0) { | |||||
| MS_LOG(EXCEPTION) << "BatchMatMul input x, y are different at index " << i << "."; | |||||
| } | |||||
| output.push_back(Shape::SHP_ANY); | |||||
| } else { | |||||
| output.push_back(xshp[i]); | |||||
| } | |||||
| } | |||||
| size_t offset = xshp.size() - 2; | |||||
| output.push_back(xshp[offset + (TA ? 1 : 0)]); | |||||
| output.push_back(yshp[offset + (TB ? 0 : 1)]); | |||||
| return; | |||||
| }; | |||||
| make_shape(ret_shape, x->shape()->shape(), y->shape()->shape()); | |||||
| make_shape(ret_min_shape, x_min_shape, y_min_shape); | |||||
| make_shape(ret_max_shape, x_max_shape, y_max_shape); | |||||
| return std::make_shared<AbstractTensor>(x->element(), | |||||
| std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -49,6 +49,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, | ||||
| {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | {prim::kPrimLinSpace, {InferImplLinSpace, true}}, | ||||
| {prim::kPrimAddN, {InferImplAddN, true}}, | {prim::kPrimAddN, {InferImplAddN, true}}, | ||||
| {prim::kPrimMatMul, {InferImplMatMul, true}}, | |||||
| {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, | |||||
| // Array | // Array | ||||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | ||||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | ||||
| @@ -689,7 +689,7 @@ class CumProd(PrimitiveWithInfer): | |||||
| raise ValueError(f"For {self.name}, axis must be const.") | raise ValueError(f"For {self.name}, axis must be const.") | ||||
| class MatMul(PrimitiveWithInfer): | |||||
| class MatMul(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| Multiplies matrix `a` and matrix `b`. | Multiplies matrix `a` and matrix `b`. | ||||
| @@ -730,10 +730,10 @@ class MatMul(PrimitiveWithInfer): | |||||
| def check_shape_size(self, x1, x2): | def check_shape_size(self, x1, x2): | ||||
| if len(x1) != 2 or len(x2) != 2: | if len(x1) != 2 or len(x2) != 2: | ||||
| raise ValueError('P.MatMul inputs x1, x2 should has the same dimension size and ' | |||||
| raise ValueError('P.MatMul inputs x1, x2 should have the same dimension size and ' | |||||
| + f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).') | + f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).') | ||||
| def infer_shape(self, x1, x2): | |||||
| def check_shape(self, x1, x2): | |||||
| self.check_shape_size(x1, x2) | self.check_shape_size(x1, x2) | ||||
| cls_name = self.name | cls_name = self.name | ||||
| # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two | # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two | ||||
| @@ -747,23 +747,18 @@ class MatMul(PrimitiveWithInfer): | |||||
| x2_last = x2[-2:] | x2_last = x2[-2:] | ||||
| x1_col = x1_last[not self.transpose_a] | x1_col = x1_last[not self.transpose_a] | ||||
| x2_row = x2_last[self.transpose_b] | x2_row = x2_last[self.transpose_b] | ||||
| if x1_col != x2_row: | |||||
| raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,' | |||||
| + f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})' | |||||
| + f', x2 shape {x2}(transpose_b={self.transpose_b}).') | |||||
| if np.all(np.array(x1) != -1) and np.all(np.array(x2) != -1): | |||||
| if x1_col != x2_row: | |||||
| raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,' | |||||
| + f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})' | |||||
| + f', x2 shape {x2}(transpose_b={self.transpose_b}).') | |||||
| # set attribute | # set attribute | ||||
| self.add_prim_attr('transpose_x1', self.transpose_a) | self.add_prim_attr('transpose_x1', self.transpose_a) | ||||
| self.add_prim_attr('transpose_x2', self.transpose_b) | self.add_prim_attr('transpose_x2', self.transpose_b) | ||||
| ret_dims = x1[: -2] + [x1_last[self.transpose_a], x2_last[not self.transpose_b]] | |||||
| return ret_dims | |||||
| def infer_dtype(self, x1, x2): | |||||
| def check_dtype(self, x1, x2): | |||||
| args = {"x1": x1, "x2": x2} | args = {"x1": x1, "x2": x2} | ||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name) | validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name) | ||||
| if x1.element_type() == mstype.int8: | |||||
| return mstype.tensor_type(mstype.int32) | |||||
| return x1 | |||||
| class BatchMatMul(MatMul): | class BatchMatMul(MatMul): | ||||
| @@ -21,11 +21,9 @@ import mindspore.nn as nn | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| class BatchMatMulNet(nn.Cell): | class BatchMatMulNet(nn.Cell): | ||||
| def __init__(self, transpose_a=False, transpose_b=False): | def __init__(self, transpose_a=False, transpose_b=False): | ||||
| super(BatchMatMulNet, self).__init__() | super(BatchMatMulNet, self).__init__() | ||||
| @@ -34,7 +32,9 @@ class BatchMatMulNet(nn.Cell): | |||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| return self.batch_matmul(x, y) | return self.batch_matmul(x, y) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_4d(): | def test_4d(): | ||||
| input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32) | input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32) | ||||
| input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32) | input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32) | ||||
| @@ -140,3 +140,38 @@ def test_4D_fp16(): | |||||
| [[4340, 4396, 4456, 4510]], | [[4340, 4396, 4456, 4510]], | ||||
| [[5816, 5880, 5948, 6016]]]]).astype(np.float16) | [[5816, 5880, 5948, 6016]]]]).astype(np.float16) | ||||
| assert (output.asnumpy() == expect).all() | assert (output.asnumpy() == expect).all() | ||||
| class BatchMatMul_d(nn.Cell): | |||||
| def __init__(self, transpose_a=False, transpose_b=False): | |||||
| super(BatchMatMul_d, self).__init__() | |||||
| self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b) | |||||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||||
| def construct(self, x, y): | |||||
| x = self.test_dynamic(x) | |||||
| y = self.test_dynamic(y) | |||||
| return self.batch_matmul(x, y) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_batchmatmul_dynamic(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| net = BatchMatMul_d() | |||||
| x1 = np.arange(8).reshape(2, 2, 2).astype(np.float32) | |||||
| y1 = np.arange(28).reshape(2, 2, 7).astype(np.float32) | |||||
| output1 = net(Tensor(x1), Tensor(y1)) | |||||
| expect1 = np.matmul(x1, y1) | |||||
| assert (output1.asnumpy() == expect1).all() | |||||
| x2 = np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3).astype(np.float32) | |||||
| y2 = np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4).astype(np.float32) | |||||
| output2 = net(Tensor(x2), Tensor(y2)) | |||||
| expect2 = np.matmul(x2, y2) | |||||
| assert (output2.asnumpy() == expect2).all() | |||||
| @@ -0,0 +1,54 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| class MatMul_d(nn.Cell): | |||||
| def __init__(self): | |||||
| super(MatMul_d, self).__init__() | |||||
| self.test_dynamic = inner.GpuConvertToDynamicShape() | |||||
| self.matmul = P.MatMul() | |||||
| def construct(self, x, y): | |||||
| x = self.test_dynamic(x) | |||||
| y = self.test_dynamic(y) | |||||
| return self.matmul(x, y) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_MatMul_dynamic(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| net = MatMul_d() | |||||
| x1 = np.arange(2).reshape(1, 2).astype(np.float32) | |||||
| y1 = np.arange(4).reshape(2, 2).astype(np.float32) | |||||
| output1 = net(Tensor(x1), Tensor(y1)) | |||||
| expect1 = np.matmul(x1, y1) | |||||
| np.testing.assert_array_almost_equal(output1.asnumpy(), expect1) | |||||
| x2 = np.arange(102).reshape(34, 3).astype(np.float32) | |||||
| y2 = np.arange(18).reshape(3, 6).astype(np.float32) | |||||
| output2 = net(Tensor(x2), Tensor(y2)) | |||||
| expect2 = np.matmul(x2, y2) | |||||
| np.testing.assert_array_almost_equal(output2.asnumpy(), expect2) | |||||