|
|
|
@@ -64,11 +64,16 @@ class MatMulGpuKernel : public GpuKernel { |
|
|
|
auto stride_a = SizeToInt(m_ * k_); |
|
|
|
auto stride_b = SizeToInt(k_ * n_); |
|
|
|
auto stride_c = SizeToInt(m_ * n_); |
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT( |
|
|
|
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), |
|
|
|
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, |
|
|
|
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), |
|
|
|
"cublasSgemm Call Fail"); |
|
|
|
|
|
|
|
try { |
|
|
|
CHECK_CUBLAS_RET_WITH_EXCEPT( |
|
|
|
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), |
|
|
|
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, |
|
|
|
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), |
|
|
|
"cublasSgemm Call Fail"); |
|
|
|
} catch (const std::exception &e) { |
|
|
|
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx"; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
bool Init(const CNodePtr &kernel_node) override { |
|
|
|
|