From 206762c9358877f60efe065b87f45f2d990e6d5c Mon Sep 17 00:00:00 2001 From: zhouyuanshen Date: Tue, 16 Mar 2021 19:21:17 +0800 Subject: [PATCH] fix updated gamma beta not same with that of torch --- .../gpu/cuda_impl/instance_norm_impl.cu | 14 +++++----- .../gpu/cuda_impl/instance_norm_impl.cuh | 4 +-- .../gpu/nn/instance_norm_grad_gpu_kernel.h | 26 +++++++++++-------- mindspore/nn/layer/normalization.py | 3 ++- 4 files changed, 26 insertions(+), 21 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cu index 7af68dc46f..11a3337785 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cu @@ -61,30 +61,30 @@ void CopyMemDevice2Device(const size_t N, const size_t C, float *gamma_addr, flo } __global__ void ComputeMeanKernel(const size_t thread_num, const size_t N, const size_t C, - float *save_mean_addr, float *save_var_addr) { + float *dgamma, float *dbeta, const float *ws_dgamma, const float *ws_dbeta) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_num; pos += gridDim.x * blockDim.x) { size_t cur_addr = pos / C; size_t cur_local_index = pos % C; float tmp = 0; if (cur_addr) { for (size_t i = 0; i < N; i++) { - tmp += save_var_addr[i * C + cur_local_index]; + tmp += ws_dgamma[i * C + cur_local_index]; } - save_var_addr[cur_local_index] = tmp / N; + dgamma[cur_local_index] = tmp; } else { for (size_t i = 0; i < N; i++) { - tmp += save_mean_addr[i * C + cur_local_index]; + tmp += ws_dbeta[i * C + cur_local_index]; } - save_mean_addr[cur_local_index] = tmp / N; + dbeta[cur_local_index] = tmp; } } return; } void ComputeMean(const size_t N, const size_t C, - float *save_mean_addr, float *save_var_addr, + float *dgamma, float *dbeta, const float *ws_dgamma, const float *ws_dbeta, cudaStream_t cuda_stream) { size_t thread_num = C * 2; ComputeMeanKernel<<>>( - thread_num, N, C, save_mean_addr, save_var_addr); + thread_num, N, C, dgamma, dbeta, ws_dgamma, ws_dbeta); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh index 053d529cb0..adb0d5895b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh @@ -22,6 +22,6 @@ void CopyMemDevice2Device(const size_t N, const size_t C, float *gamma_addr, float *beta_addr, float *runing_mean_addr, float *runnig_variance_addr, float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var, cudaStream_t cuda_stream); -void ComputeMean(const size_t N, const size_t C, float *save_mean_addr, float *save_var_addr, - cudaStream_t cuda_stream); +void ComputeMean(const size_t N, const size_t C, float *dgamma, float *dbeta, const float *ws_dgamma, + const float *ws_dbeta, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h index 277278ed67..bb4d019ab8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h @@ -76,9 +76,11 @@ class InstanceNormGradGpuKernel : public GpuKernel { T *dz = nullptr; float *ws_gamma = GetDeviceAddress(workspace, 0); + float *ws_dgamma = GetDeviceAddress(workspace, 1); + float *ws_dbeta = GetDeviceAddress(workspace, 2); void *workspace_addr = nullptr; if (workspace_size_ != 0) { - workspace_addr = GetDeviceAddress(workspace, 1); + workspace_addr = GetDeviceAddress(workspace, 3); } size_t N = input_shape_[0]; @@ -92,14 +94,14 @@ class InstanceNormGradGpuKernel : public GpuKernel { const float alpha_param_diff = 1; const float beta_param_diff = 0; float *reserve_addr = nullptr; - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnBatchNormalizationBackwardEx( - handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x, - y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, ws_gamma, beta, dgamma, dbeta, - epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0), - "Kernel launch failed"); - ComputeMean(N, C, dgamma, dbeta, reinterpret_cast(stream_ptr)); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnBatchNormalizationBackwardEx( + handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, + &beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, + scale_bias_diff_desc_, ws_gamma, beta, ws_dgamma, ws_dbeta, epsilon_, save_mean, + save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0), + "Kernel launch failed"); + ComputeMean(N, C, dgamma, dbeta, ws_dgamma, ws_dbeta, reinterpret_cast(stream_ptr)); return true; } @@ -164,10 +166,12 @@ class InstanceNormGradGpuKernel : public GpuKernel { input_size_list_.push_back(para_size_); output_size_list_.push_back(x_size_); - output_size_list_.push_back(para_size_); - output_size_list_.push_back(para_size_); + output_size_list_.push_back(x_size_); + output_size_list_.push_back(x_size_); workspace_size_list_.push_back(para_size_); // ws gamma + workspace_size_list_.push_back(para_size_); // ws dgamma + workspace_size_list_.push_back(para_size_); // ws dbeta workspace_size_list_.push_back(workspace_size_); } void DestroyResource() noexcept override { diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 2b2a1cb229..47bc6cd90a 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -31,6 +31,7 @@ from mindspore._extends import cell_attr_register from mindspore.communication.management import get_group_size, get_rank from mindspore.communication import management from mindspore.ops import _selected_ops +from mindspore.common import dtype as mstype from ..cell import Cell __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', @@ -999,7 +1000,7 @@ class InstanceNorm2d(Cell): if not isinstance(val, (Tensor, numbers.Number, str, Initializer)): raise TypeError(f"[{name}]Supported type for arg {key} is [Tensor, numbers.Number, str, Initializer]," f"but got {type(val)}") - if isinstance(val, Tensor) and val.dtype is not float: + if isinstance(val, Tensor) and val.dtype != mstype.float32: raise TypeError(f"[{name}]The type of arg {key} should be float32, but got {val.dtype}")