|
|
|
@@ -82,6 +82,7 @@ class FusedBatchNormGpuKernel : public GpuKernel { |
|
|
|
} |
|
|
|
bool Init(const CNodePtr &kernel_node) override { |
|
|
|
InitResource(); |
|
|
|
cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); |
|
|
|
if (input_num != 5) { |
|
|
|
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5"; |
|
|
|
@@ -112,11 +113,11 @@ class FusedBatchNormGpuKernel : public GpuKernel { |
|
|
|
} |
|
|
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT( |
|
|
|
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch_, channel_, height_, width_), |
|
|
|
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), |
|
|
|
"Set x desc failed"); |
|
|
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT( |
|
|
|
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch_, channel_, height_, width_), |
|
|
|
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), |
|
|
|
"Set y desc failed"); |
|
|
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT( |
|
|
|
|