diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h index 6e11332597..099f9a1f7c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,14 +58,24 @@ class MatMulGpuKernel : public GpuKernel { auto stride_c = SizeToInt(m_ * n_); try { - CHECK_CUBLAS_RET_WITH_EXCEPT( - kernel_node_, - cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), - &alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, - &beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), - "cublasSgemm Call Fail"); + // Use cublasGemmEx to get high performance when batch_ is 1 + if (batch_ == 1) { + CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_, + cublasGemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), + SizeToInt(k_), &alpha, input2_addr, dtype_b_, ldb, input1_addr, + dtype_a_, lda, &beta, output_addr, dtype_c_, ldc, CUDA_R_32F, algo_), + "cublasSgemm Call Fail"); + } else { + CHECK_CUBLAS_RET_WITH_EXCEPT( + kernel_node_, + cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), + &alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, + &beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), + "cublasGemmStridedBatchedEx Call Fail"); + } } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx"; + MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas " + << (batch_ == 1 ? "cublasGemmEx" : "cublasGemmStridedBatchedEx"); } return true; } diff --git a/tests/st/ops/gpu/test_tensordot_op.py b/tests/st/ops/gpu/test_tensordot_op.py index 6262e2706f..2aaeeb971d 100644 --- a/tests/st/ops/gpu/test_tensordot_op.py +++ b/tests/st/ops/gpu/test_tensordot_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -150,7 +150,7 @@ def test_tensor_dot_fp16(): network = NetTensorDot(axes) ms_result_np = network(x1_tensor, x2_tensor).asnumpy() np_result = np.tensordot(x1, x2, axes) - np.testing.assert_array_almost_equal(ms_result_np, np_result) + assert np.allclose(ms_result_np, np_result, rtol=1e-3, atol=1e-3) # 3D shape_x1 = (60, 30, 450) @@ -164,7 +164,7 @@ def test_tensor_dot_fp16(): network = NetTensorDot(axes) ms_result_np = network(x1_tensor, x2_tensor).asnumpy() np_result = np.tensordot(x1, x2, axes) - np.testing.assert_array_almost_equal(ms_result_np, np_result) + assert np.allclose(ms_result_np, np_result, rtol=1e-3, atol=6e0) @pytest.mark.level0 @@ -173,7 +173,7 @@ def test_tensor_dot_fp16(): def test_tensor_dot_outer(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") np.random.seed(2746) - shape_x1 = (1, 2, 3) # incompatable dims for x1 and x2 + shape_x1 = (1, 2, 3) # incompatible dims for x1 and x2 shape_x2 = (4, 5, 6) axes = 0 # outer product does not require multiplicable dims x1 = np.random.random(shape_x1).astype(np.float32)