Merge pull request !2039 from VectorSL/refactor-datatypetags/v0.5.0-beta
| @@ -81,7 +81,7 @@ class ArrayReduceGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 1) { | if (input_num != 1) { | ||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but reduce op needs 1 inputs."; | MS_LOG(ERROR) << "Input number is " << input_num << ", but reduce op needs 1 inputs."; | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "kernel/kernel.h" | #include "kernel/kernel.h" | ||||
| #include "kernel/gpu/kernel_constants.h" | |||||
| #include "device/gpu/gpu_device_manager.h" | #include "device/gpu/gpu_device_manager.h" | ||||
| #include "device/gpu/gpu_common.h" | #include "device/gpu/gpu_common.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| @@ -79,6 +80,22 @@ class GpuKernel : public KernelMod { | |||||
| "must match the corresponding dimension of outC or must be equal to 1."; | "must match the corresponding dimension of outC or must be equal to 1."; | ||||
| } | } | ||||
| } | } | ||||
| // choose the suitable datatype for cudnn/cublas | |||||
| inline cudnnDataType_t GetCudnnDataType(const std::string &Type) { | |||||
| auto type = kCudnnDtypeMap.find(Type); | |||||
| if (type == kCudnnDtypeMap.end()) { | |||||
| MS_EXCEPTION(TypeError) << Type << " is not supported."; | |||||
| } | |||||
| return type->second; | |||||
| } | |||||
| inline cudaDataType_t GetCudaDataType(const std::string &Type) { | |||||
| auto type = kCudaDtypeMap.find(Type); | |||||
| if (type == kCudaDtypeMap.end()) { | |||||
| MS_EXCEPTION(TypeError) << Type << " is not supported."; | |||||
| } | |||||
| return type->second; | |||||
| } | |||||
| }; | }; | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -60,7 +60,7 @@ class AddNGpuFwdKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| num_input_ = GetAttr<int>(kernel_node, "n"); | num_input_ = GetAttr<int>(kernel_node, "n"); | ||||
| if (IntToSize(num_input_) != input_num) { | if (IntToSize(num_input_) != input_num) { | ||||
| @@ -67,7 +67,7 @@ class BiasAddGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| auto num_dims = x_shape.size(); | auto num_dims = x_shape.size(); | ||||
| is_null_input_ = CHECK_NULL_INPUT(x_shape); | is_null_input_ = CHECK_NULL_INPUT(x_shape); | ||||
| @@ -82,9 +82,9 @@ class MatMulGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); | handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); | ||||
| dtype_a_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| dtype_b_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))]; | |||||
| dtype_c_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))]; | |||||
| dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); | |||||
| dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||
| is_null_input_ = CHECK_NULL_INPUT(output_shape); | is_null_input_ = CHECK_NULL_INPUT(output_shape); | ||||
| if (is_null_input_) { | if (is_null_input_) { | ||||
| @@ -68,7 +68,7 @@ class BiasAddGradGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| auto num_dims = dy_shape.size(); | auto num_dims = dy_shape.size(); | ||||
| if (num_dims < 2) { | if (num_dims < 2) { | ||||
| @@ -191,7 +191,7 @@ class Conv2dGpuFwdKernel : public GpuKernel { | |||||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed"); | CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed"); | ||||
| } | } | ||||
| bool CheckParam(const CNodePtr &kernel_node) { | bool CheckParam(const CNodePtr &kernel_node) { | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 2) { | if (input_num != 2) { | ||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but conv2d needs 2 inputs."; | MS_LOG(ERROR) << "Input number is " << input_num << ", but conv2d needs 2 inputs."; | ||||
| @@ -98,7 +98,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { | |||||
| if (!CheckParam(kernel_node)) { | if (!CheckParam(kernel_node)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | ||||
| is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape); | is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape); | ||||
| @@ -98,7 +98,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { | |||||
| if (!CheckParam(kernel_node)) { | if (!CheckParam(kernel_node)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | ||||
| is_null_input_ = CHECK_NULL_INPUT(dy_shape); | is_null_input_ = CHECK_NULL_INPUT(dy_shape); | ||||
| @@ -82,7 +82,7 @@ class FusedBatchNormGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 5) { | if (input_num != 5) { | ||||
| MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5"; | MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5"; | ||||
| @@ -75,7 +75,7 @@ class FusedBatchNormGradGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 5) { | if (input_num != 5) { | ||||
| MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGradGpuKernel should be 5"; | MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGradGpuKernel should be 5"; | ||||
| @@ -89,7 +89,7 @@ class LstmGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| seq_len_ = SizeToInt(input_shape[0]); | seq_len_ = SizeToInt(input_shape[0]); | ||||
| batch_size_ = SizeToInt(input_shape[1]); | batch_size_ = SizeToInt(input_shape[1]); | ||||
| @@ -105,7 +105,7 @@ class LstmGradDataGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||
| seq_len_ = SizeToInt(input_shape[0]); | seq_len_ = SizeToInt(input_shape[0]); | ||||
| batch_size_ = SizeToInt(input_shape[1]); | batch_size_ = SizeToInt(input_shape[1]); | ||||
| @@ -84,7 +84,7 @@ class LstmGradWeightGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| seq_len_ = SizeToInt(input_shape[0]); | seq_len_ = SizeToInt(input_shape[0]); | ||||
| batch_size_ = SizeToInt(input_shape[1]); | batch_size_ = SizeToInt(input_shape[1]); | ||||
| @@ -88,7 +88,7 @@ class PoolingGpuFwdKernel : public GpuKernel { | |||||
| if (!CheckParam(kernel_node)) { | if (!CheckParam(kernel_node)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| is_null_input_ = CHECK_NULL_INPUT(input_shape); | is_null_input_ = CHECK_NULL_INPUT(input_shape); | ||||
| if (is_null_input_) { | if (is_null_input_) { | ||||
| @@ -239,7 +239,7 @@ class PoolingGradGpuFwdKernel : public GpuKernel { | |||||
| void SetPoolingMode(const CNodePtr &kernel_node) { | void SetPoolingMode(const CNodePtr &kernel_node) { | ||||
| pad_mode_ = GetAttr<std::string>(kernel_node, "padding"); | pad_mode_ = GetAttr<std::string>(kernel_node, "padding"); | ||||
| stride_ = GetAttr<std::vector<int>>(kernel_node, "strides"); | stride_ = GetAttr<std::vector<int>>(kernel_node, "strides"); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| mode_ = AnfAlgo::GetCNodeName(kernel_node); | mode_ = AnfAlgo::GetCNodeName(kernel_node); | ||||
| if (mode_ == "AvgPoolGradGpu") { | if (mode_ == "AvgPoolGradGpu") { | ||||
| pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; | pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; | ||||
| @@ -60,7 +60,7 @@ class ReLUGpuFwdKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 1) { | if (input_num != 1) { | ||||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 1."; | MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 1."; | ||||
| @@ -60,7 +60,7 @@ class ReluGradGpuFwdKernel : public GpuKernel { | |||||
| } | } | ||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 2) { | if (input_num != 2) { | ||||
| MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReluGradGpuFwdKernel needs 2."; | MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReluGradGpuFwdKernel needs 2."; | ||||
| @@ -87,7 +87,7 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { | |||||
| << ", but SoftmaxCrossEntropyWithLogitsGpuKernel needs 2 output."; | << ", but SoftmaxCrossEntropyWithLogitsGpuKernel needs 2 output."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| InferInputOutputSize(kernel_node); | InferInputOutputSize(kernel_node); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, | CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, | ||||
| @@ -95,7 +95,7 @@ class SoftmaxGpuKernel : public GpuKernel { | |||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 1) { | if (input_num != 1) { | ||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax needs 1 input."; | MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax needs 1 input."; | ||||
| @@ -98,7 +98,7 @@ class SoftmaxGradGpuKernel : public GpuKernel { | |||||
| bool Init(const CNodePtr &kernel_node) override { | bool Init(const CNodePtr &kernel_node) override { | ||||
| InitResource(); | InitResource(); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| if (input_num != 2) { | if (input_num != 2) { | ||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax grad needs 2 input."; | MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax grad needs 2 input."; | ||||
| @@ -89,7 +89,7 @@ class SparseSoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { | |||||
| return false; | return false; | ||||
| } | } | ||||
| is_grad_ = GetAttr<bool>(kernel_node, "is_grad"); | is_grad_ = GetAttr<bool>(kernel_node, "is_grad"); | ||||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| InferInputOutputSize(kernel_node); | InferInputOutputSize(kernel_node); | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, | CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, | ||||
| @@ -141,7 +141,7 @@ class BatchNormFoldGpuKernel : public GpuKernel { | |||||
| input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; | input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; | ||||
| output_size_ = sizeof(T) * channel_; | output_size_ = sizeof(T) * channel_; | ||||
| cudnnDataType_t cudnnDataType = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||||
| cudnnDataType_t cudnnDataType = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | CHECK_CUDNN_RET_WITH_EXCEPT( | ||||
| cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, batch_, channel_, height_, width_), | cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, batch_, channel_, height_, width_), | ||||
| "Set x desc failed"); | "Set x desc failed"); | ||||