Browse Source

!806 GPU update kernel bn

Merge pull request !806 from VectorSL/gpu-fix-bn
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
d5d3bf7565
2 changed files with 4 additions and 3 deletions
  1. +3
    -2
      mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h
  2. +1
    -1
      mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h

+ 3
- 2
mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h View File

@@ -82,6 +82,7 @@ class FusedBatchNormGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
InitResource();
cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5";
@@ -112,11 +113,11 @@ class FusedBatchNormGpuKernel : public GpuKernel {
}

CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch_, channel_, height_, width_),
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set x desc failed");

CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch_, channel_, height_, width_),
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set y desc failed");

CHECK_CUDNN_RET_WITH_EXCEPT(


+ 1
- 1
mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h View File

@@ -110,7 +110,7 @@ class FusedBatchNormGradGpuKernel : public GpuKernel {
cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 1, channel_, 1, 1),
cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
"Set para desc failed");

InitSizeLists();


Loading…
Cancel
Save