| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -58,14 +58,24 @@ class MatMulGpuKernel : public GpuKernel { | |||||
| auto stride_c = SizeToInt(m_ * n_); | auto stride_c = SizeToInt(m_ * n_); | ||||
| try { | try { | ||||
| CHECK_CUBLAS_RET_WITH_EXCEPT( | |||||
| kernel_node_, | |||||
| cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), | |||||
| &alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, | |||||
| &beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), | |||||
| "cublasSgemm Call Fail"); | |||||
| // Use cublasGemmEx to get high performance when batch_ is 1 | |||||
| if (batch_ == 1) { | |||||
| CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_, | |||||
| cublasGemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), | |||||
| SizeToInt(k_), &alpha, input2_addr, dtype_b_, ldb, input1_addr, | |||||
| dtype_a_, lda, &beta, output_addr, dtype_c_, ldc, CUDA_R_32F, algo_), | |||||
| "cublasSgemm Call Fail"); | |||||
| } else { | |||||
| CHECK_CUBLAS_RET_WITH_EXCEPT( | |||||
| kernel_node_, | |||||
| cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), | |||||
| &alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, | |||||
| &beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), | |||||
| "cublasGemmStridedBatchedEx Call Fail"); | |||||
| } | |||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx"; | |||||
| MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas " | |||||
| << (batch_ == 1 ? "cublasGemmEx" : "cublasGemmStridedBatchedEx"); | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -150,7 +150,7 @@ def test_tensor_dot_fp16(): | |||||
| network = NetTensorDot(axes) | network = NetTensorDot(axes) | ||||
| ms_result_np = network(x1_tensor, x2_tensor).asnumpy() | ms_result_np = network(x1_tensor, x2_tensor).asnumpy() | ||||
| np_result = np.tensordot(x1, x2, axes) | np_result = np.tensordot(x1, x2, axes) | ||||
| np.testing.assert_array_almost_equal(ms_result_np, np_result) | |||||
| assert np.allclose(ms_result_np, np_result, rtol=1e-3, atol=1e-3) | |||||
| # 3D | # 3D | ||||
| shape_x1 = (60, 30, 450) | shape_x1 = (60, 30, 450) | ||||
| @@ -164,7 +164,7 @@ def test_tensor_dot_fp16(): | |||||
| network = NetTensorDot(axes) | network = NetTensorDot(axes) | ||||
| ms_result_np = network(x1_tensor, x2_tensor).asnumpy() | ms_result_np = network(x1_tensor, x2_tensor).asnumpy() | ||||
| np_result = np.tensordot(x1, x2, axes) | np_result = np.tensordot(x1, x2, axes) | ||||
| np.testing.assert_array_almost_equal(ms_result_np, np_result) | |||||
| assert np.allclose(ms_result_np, np_result, rtol=1e-3, atol=6e0) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @@ -173,7 +173,7 @@ def test_tensor_dot_fp16(): | |||||
| def test_tensor_dot_outer(): | def test_tensor_dot_outer(): | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| np.random.seed(2746) | np.random.seed(2746) | ||||
| shape_x1 = (1, 2, 3) # incompatable dims for x1 and x2 | |||||
| shape_x1 = (1, 2, 3) # incompatible dims for x1 and x2 | |||||
| shape_x2 = (4, 5, 6) | shape_x2 = (4, 5, 6) | ||||
| axes = 0 # outer product does not require multiplicable dims | axes = 0 # outer product does not require multiplicable dims | ||||
| x1 = np.random.random(shape_x1).astype(np.float32) | x1 = np.random.random(shape_x1).astype(np.float32) | ||||