| @@ -218,7 +218,7 @@ class BinaryOpGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| } | } | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetOpTensorDescriptor(opTensor_descriptor_, tensor_op_, cudnn_data_type_, CUDNN_NOT_PROPAGATE_NAN), | |||||
| cudnnSetOpTensorDescriptor(opTensor_descriptor_, tensor_op_, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN), | |||||
| "cudnnSetOpTensorDescriptor failed"); | "cudnnSetOpTensorDescriptor failed"); | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -142,10 +142,14 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||||
| } | } | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_, | cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_, | ||||
| CUDNN_CROSS_CORRELATION, cudnn_data_type_), | |||||
| CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||||
| "cudnnSetConvolution2dDescriptor failed"); | "cudnnSetConvolution2dDescriptor failed"); | ||||
| input_descriptor_real = input_desc_; | input_descriptor_real = input_desc_; | ||||
| } | } | ||||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), | |||||
| "cudnnSetConvolutionMathType failed.") | |||||
| } | |||||
| SelectAlgorithm(input_descriptor_real); | SelectAlgorithm(input_descriptor_real); | ||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| @@ -240,7 +244,7 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||||
| "cudnnSetTensor4dDescriptor failed"); | "cudnnSetTensor4dDescriptor failed"); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_, | cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_, | ||||
| dilation_, dilation_, CUDNN_CROSS_CORRELATION, cudnn_data_type_), | |||||
| dilation_, dilation_, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||||
| "cudnnSetConvolution2dDescriptor failed"); | "cudnnSetConvolution2dDescriptor failed"); | ||||
| } | } | ||||
| @@ -276,6 +280,9 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||||
| "cudnnGetConvolutionForwardAlgorithm_v7 failed"); | "cudnnGetConvolutionForwardAlgorithm_v7 failed"); | ||||
| conv_algorithm_ = perf_results.algo; | conv_algorithm_ = perf_results.algo; | ||||
| } | } | ||||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||||
| conv_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; | |||||
| } | |||||
| } | } | ||||
| cudnnHandle_t cudnn_handle_; | cudnnHandle_t cudnn_handle_; | ||||
| cudnnTensorDescriptor_t input_desc_; | cudnnTensorDescriptor_t input_desc_; | ||||
| @@ -141,10 +141,14 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||||
| } | } | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_, | cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_, | ||||
| CUDNN_CROSS_CORRELATION, cudnn_data_type_), | |||||
| CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||||
| "GetConvolution2dDescriptor failed"); | "GetConvolution2dDescriptor failed"); | ||||
| x_desc_real = x_desc_; | x_desc_real = x_desc_; | ||||
| } | } | ||||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), | |||||
| "cudnnSetConvolutionMathType failed.") | |||||
| } | |||||
| SelectAlgorithm(x_desc_real); | SelectAlgorithm(x_desc_real); | ||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| @@ -239,7 +243,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||||
| "cudnnSetTensor4dDescriptor failed"); | "cudnnSetTensor4dDescriptor failed"); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_, | cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_, | ||||
| dilation_, dilation_, CUDNN_CROSS_CORRELATION, cudnn_data_type_), | |||||
| dilation_, dilation_, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||||
| "cudnnSetConvolution2dDescriptor failed"); | "cudnnSetConvolution2dDescriptor failed"); | ||||
| } | } | ||||
| void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) { | void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) { | ||||
| @@ -258,6 +262,9 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||||
| "GetConvolutionBackwardFilterAlgorithm failed"); | "GetConvolutionBackwardFilterAlgorithm failed"); | ||||
| algo_ = perf_results.algo; | algo_ = perf_results.algo; | ||||
| } | } | ||||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||||
| algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; | |||||
| } | |||||
| } | } | ||||
| void GetFilterShape(const CNodePtr &kernel_node, std::vector<int> *filter_shape) { | void GetFilterShape(const CNodePtr &kernel_node, std::vector<int> *filter_shape) { | ||||
| auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_sizes")->cast<ValueTuplePtr>()->value(); | auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_sizes")->cast<ValueTuplePtr>()->value(); | ||||
| @@ -142,10 +142,14 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||||
| } | } | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_, | cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_, | ||||
| CUDNN_CROSS_CORRELATION, cudnn_data_type_), | |||||
| CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||||
| "cudnnSetConvolution2dDescriptor failed"); | "cudnnSetConvolution2dDescriptor failed"); | ||||
| dx_desc_real = dx_desc_; | dx_desc_real = dx_desc_; | ||||
| } | } | ||||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), | |||||
| "cudnnSetConvolutionMathType failed.") | |||||
| } | |||||
| SelectAlgorithm(dx_desc_real); | SelectAlgorithm(dx_desc_real); | ||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| @@ -239,7 +243,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||||
| "cudnnSetTensor4dDescriptor failed"); | "cudnnSetTensor4dDescriptor failed"); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_, | cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_, | ||||
| dilation_, dilation_, CUDNN_CROSS_CORRELATION, cudnn_data_type_), | |||||
| dilation_, dilation_, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), | |||||
| "cudnnSetConvolution2dDescriptor failed"); | "cudnnSetConvolution2dDescriptor failed"); | ||||
| } | } | ||||
| void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) { | void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) { | ||||
| @@ -258,6 +262,9 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||||
| "cudnnGetConvolutionBackwardDataAlgorithm_v7 failed"); | "cudnnGetConvolutionBackwardDataAlgorithm_v7 failed"); | ||||
| algo_ = perf_results.algo; | algo_ = perf_results.algo; | ||||
| } | } | ||||
| if (cudnn_data_type_ == CUDNN_DATA_HALF) { | |||||
| algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; | |||||
| } | |||||
| } | } | ||||
| void GetInputShape(const CNodePtr &kernel_node, std::vector<int> *input_shape) { | void GetInputShape(const CNodePtr &kernel_node, std::vector<int> *input_shape) { | ||||
| auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast<ValueTuplePtr>()->value(); | auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast<ValueTuplePtr>()->value(); | ||||