| @@ -86,7 +86,7 @@ class MatMulGpuKernel : public GpuKernel { | |||||
| dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); | dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); | ||||
| dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); | dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); | ||||
| if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) { | if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) { | ||||
| MS_LOG(WARNING) << "input and output type is float16, allow to use Tensor Core operations if possible"; | |||||
| MS_LOG(INFO) << "input and output type is float16, allow to use Tensor Core operations if possible"; | |||||
| algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP; | algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP; | ||||
| } | } | ||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||
| @@ -86,6 +86,7 @@ class PoolingGradGpuKernel : public GpuKernel { | |||||
| auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); | auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); | ||||
| auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | ||||
| data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); | data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); | ||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); | is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); | ||||
| if (is_null_input_) { | if (is_null_input_) { | ||||
| MS_LOG(WARNING) << "PoolingGradGpuKernel input is null."; | MS_LOG(WARNING) << "PoolingGradGpuKernel input is null."; | ||||
| @@ -204,7 +205,6 @@ class PoolingGradGpuKernel : public GpuKernel { | |||||
| "cudnnSetPoolingNdDescriptor failed"); | "cudnnSetPoolingNdDescriptor failed"); | ||||
| } | } | ||||
| void SetPoolingMode(const CNodePtr &kernel_node) { | void SetPoolingMode(const CNodePtr &kernel_node) { | ||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| mode_ = AnfAlgo::GetCNodeName(kernel_node); | mode_ = AnfAlgo::GetCNodeName(kernel_node); | ||||
| if (mode_ == "AvgPoolGradGpu") { | if (mode_ == "AvgPoolGradGpu") { | ||||
| pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; | pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; | ||||
| @@ -345,7 +345,6 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { | |||||
| SetGraphKernelInfo(kernel_node, func_graph); | SetGraphKernelInfo(kernel_node, func_graph); | ||||
| return; | return; | ||||
| } | } | ||||
| std::vector<std::string> inputs_format; | std::vector<std::string> inputs_format; | ||||
| std::vector<TypeId> inputs_type; | std::vector<TypeId> inputs_type; | ||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { | ||||
| @@ -368,12 +367,10 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { | |||||
| builder->SetInputsDeviceType(inputs_type); | builder->SetInputsDeviceType(inputs_type); | ||||
| builder->SetOutputsFormat(outputs_format); | builder->SetOutputsFormat(outputs_format); | ||||
| builder->SetOutputsDeviceType(outputs_type); | builder->SetOutputsDeviceType(outputs_type); | ||||
| bool result = false; | bool result = false; | ||||
| if (kernel_type == UNKNOWN_KERNEL_TYPE) { | if (kernel_type == UNKNOWN_KERNEL_TYPE) { | ||||
| result = | result = | ||||
| kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); | kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); | ||||
| if (!result) { | if (!result) { | ||||
| result = SelectAkgKernel(kernel_node, builder->Build()); | result = SelectAkgKernel(kernel_node, builder->Build()); | ||||
| kernel_type = AKG_KERNEL; | kernel_type = AKG_KERNEL; | ||||
| @@ -381,7 +378,6 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { | |||||
| } else if (kernel_type == AKG_KERNEL) { | } else if (kernel_type == AKG_KERNEL) { | ||||
| result = SelectAkgKernel(kernel_node, builder->Build()); | result = SelectAkgKernel(kernel_node, builder->Build()); | ||||
| } | } | ||||
| if (!result) { | if (!result) { | ||||
| PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type); | PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type); | ||||
| return; | return; | ||||