|
|
|
@@ -42,7 +42,7 @@ class MatMulGpuKernel : public GpuKernel { |
|
|
|
dtype_a_(CUDA_R_32F), |
|
|
|
dtype_b_(CUDA_R_32F), |
|
|
|
dtype_c_(CUDA_R_32F), |
|
|
|
algo_(CUBLAS_GEMM_DEFAULT_TENSOR_OP) {} |
|
|
|
algo_(CUBLAS_GEMM_DEFAULT) {} |
|
|
|
~MatMulGpuKernel() = default; |
|
|
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } |
|
|
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } |
|
|
|
@@ -85,6 +85,10 @@ class MatMulGpuKernel : public GpuKernel { |
|
|
|
dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); |
|
|
|
dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); |
|
|
|
dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); |
|
|
|
if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) { |
|
|
|
MS_LOG(WARNING) << "input and output type is float16, allow to use Tensor Core operations if possible"; |
|
|
|
algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP; |
|
|
|
} |
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); |
|
|
|
is_null_input_ = CHECK_NULL_INPUT(output_shape); |
|
|
|
if (is_null_input_) { |
|
|
|
|