| @@ -88,7 +88,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { | |||||
| #if CUDART_VERSION >= 9010 | #if CUDART_VERSION >= 9010 | ||||
| auto io16_c32 = [&]() { | auto io16_c32 = [&]() { | ||||
| #if CUDART_VERSION >= 11000 | |||||
| cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); | |||||
| #else | |||||
| cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | ||||
| #endif | |||||
| auto zero = handle->zero_device(); | auto zero = handle->zero_device(); | ||||
| auto one = handle->one_device(); | auto one = handle->one_device(); | ||||
| cublas_check(cublasGemmBatchedEx( | cublas_check(cublasGemmBatchedEx( | ||||
| @@ -104,7 +108,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { | |||||
| #if CUDART_VERSION >= 9000 | #if CUDART_VERSION >= 9000 | ||||
| auto io16_c16 = [&]() { | auto io16_c16 = [&]() { | ||||
| #if CUDART_VERSION >= 11000 | |||||
| cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); | |||||
| #else | |||||
| cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | ||||
| #endif | |||||
| auto zero = handle->zero_device_h(); | auto zero = handle->zero_device_h(); | ||||
| auto one = handle->one_device_h(); | auto one = handle->one_device_h(); | ||||
| cublas_check(cublasHgemmBatched( | cublas_check(cublasHgemmBatched( | ||||
| @@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const | |||||
| batched_igemm(); | batched_igemm(); | ||||
| } else if (desc.dt_compute == CUBLAS_COMPUTE_16F) { | } else if (desc.dt_compute == CUBLAS_COMPUTE_16F) { | ||||
| batched_hgemm(); | batched_hgemm(); | ||||
| } else if (desc.dt_compute == CUBLAS_COMPUTE_32F) { | |||||
| } else if (desc.dt_compute == CUBLAS_COMPUTE_32F_FAST_TF32) { | |||||
| batched_sgemm(); | batched_sgemm(); | ||||
| } else { | } else { | ||||
| megdnn_throw("compute_type must be int32/float16/float32"); | megdnn_throw("compute_type must be int32/float16/float32"); | ||||
| @@ -49,18 +49,26 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { | |||||
| auto sgemm = [&]() { | auto sgemm = [&]() { | ||||
| auto zero = handle->zero_device(); | auto zero = handle->zero_device(); | ||||
| auto one = handle->one_device(); | auto one = handle->one_device(); | ||||
| #if CUDART_VERSION >= 11000 | |||||
| cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); | |||||
| #endif | |||||
| cublas_check(cublasSgemm( | cublas_check(cublasSgemm( | ||||
| cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, | cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, | ||||
| param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, | param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, | ||||
| args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0], | args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0], | ||||
| args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0], zero, | args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0], zero, | ||||
| args.tensor_c.ptr<dt_float32>(), args.tensor_c.layout.stride[0])); | args.tensor_c.ptr<dt_float32>(), args.tensor_c.layout.stride[0])); | ||||
| #if CUDART_VERSION >= 11000 | |||||
| cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH)); | |||||
| #endif | |||||
| }; | }; | ||||
| auto sgemm_ex = [&]() { | auto sgemm_ex = [&]() { | ||||
| auto zero = handle->zero_device(); | auto zero = handle->zero_device(); | ||||
| auto one = handle->one_device(); | auto one = handle->one_device(); | ||||
| #if CUDART_VERSION >= 9000 | |||||
| #if CUDART_VERSION >= 11000 | |||||
| cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); | |||||
| #elif CUDART_VERSION >= 9000 | |||||
| cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | ||||
| #endif | #endif | ||||
| auto sgemm_ex_err = cublasSgemmEx( | auto sgemm_ex_err = cublasSgemmEx( | ||||
| @@ -78,7 +86,9 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { | |||||
| }; | }; | ||||
| auto hgemm = [&]() { | auto hgemm = [&]() { | ||||
| #if CUDART_VERSION >= 9000 | |||||
| #if CUDART_VERSION >= 11000 | |||||
| cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); | |||||
| #elif CUDART_VERSION >= 9000 | |||||
| cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | ||||
| #endif | #endif | ||||
| auto one_half = handle->one_device_h(); | auto one_half = handle->one_device_h(); | ||||
| @@ -28,7 +28,7 @@ static cublasComputeType_t to_cublas_compute_type(DType tp) { | |||||
| case DTypeEnum::Float16: | case DTypeEnum::Float16: | ||||
| return CUBLAS_COMPUTE_16F; | return CUBLAS_COMPUTE_16F; | ||||
| case DTypeEnum::Float32: | case DTypeEnum::Float32: | ||||
| return CUBLAS_COMPUTE_32F; | |||||
| return CUBLAS_COMPUTE_32F_FAST_TF32; | |||||
| case DTypeEnum::Int32: | case DTypeEnum::Int32: | ||||
| case DTypeEnum::QuantizedS32: | case DTypeEnum::QuantizedS32: | ||||
| return CUBLAS_COMPUTE_32I; | return CUBLAS_COMPUTE_32I; | ||||
| @@ -107,7 +107,7 @@ void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const { | |||||
| case CUBLAS_COMPUTE_16F: | case CUBLAS_COMPUTE_16F: | ||||
| hgemm(); | hgemm(); | ||||
| break; | break; | ||||
| case CUBLAS_COMPUTE_32F: | |||||
| case CUBLAS_COMPUTE_32F_FAST_TF32: | |||||
| sgemm(); | sgemm(); | ||||
| break; | break; | ||||
| case CUBLAS_COMPUTE_32I: | case CUBLAS_COMPUTE_32I: | ||||