| @@ -61,30 +61,30 @@ void CopyMemDevice2Device(const size_t N, const size_t C, float *gamma_addr, flo | |||||
| } | } | ||||
| __global__ void ComputeMeanKernel(const size_t thread_num, const size_t N, const size_t C, | __global__ void ComputeMeanKernel(const size_t thread_num, const size_t N, const size_t C, | ||||
| float *save_mean_addr, float *save_var_addr) { | |||||
| float *dgamma, float *dbeta, const float *ws_dgamma, const float *ws_dbeta) { | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_num; pos += gridDim.x * blockDim.x) { | for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_num; pos += gridDim.x * blockDim.x) { | ||||
| size_t cur_addr = pos / C; | size_t cur_addr = pos / C; | ||||
| size_t cur_local_index = pos % C; | size_t cur_local_index = pos % C; | ||||
| float tmp = 0; | float tmp = 0; | ||||
| if (cur_addr) { | if (cur_addr) { | ||||
| for (size_t i = 0; i < N; i++) { | for (size_t i = 0; i < N; i++) { | ||||
| tmp += save_var_addr[i * C + cur_local_index]; | |||||
| tmp += ws_dgamma[i * C + cur_local_index]; | |||||
| } | } | ||||
| save_var_addr[cur_local_index] = tmp / N; | |||||
| dgamma[cur_local_index] = tmp; | |||||
| } else { | } else { | ||||
| for (size_t i = 0; i < N; i++) { | for (size_t i = 0; i < N; i++) { | ||||
| tmp += save_mean_addr[i * C + cur_local_index]; | |||||
| tmp += ws_dbeta[i * C + cur_local_index]; | |||||
| } | } | ||||
| save_mean_addr[cur_local_index] = tmp / N; | |||||
| dbeta[cur_local_index] = tmp; | |||||
| } | } | ||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| void ComputeMean(const size_t N, const size_t C, | void ComputeMean(const size_t N, const size_t C, | ||||
| float *save_mean_addr, float *save_var_addr, | |||||
| float *dgamma, float *dbeta, const float *ws_dgamma, const float *ws_dbeta, | |||||
| cudaStream_t cuda_stream) { | cudaStream_t cuda_stream) { | ||||
| size_t thread_num = C * 2; | size_t thread_num = C * 2; | ||||
| ComputeMeanKernel<<<GET_BLOCKS(thread_num), GET_THREADS, 0, cuda_stream>>>( | ComputeMeanKernel<<<GET_BLOCKS(thread_num), GET_THREADS, 0, cuda_stream>>>( | ||||
| thread_num, N, C, save_mean_addr, save_var_addr); | |||||
| thread_num, N, C, dgamma, dbeta, ws_dgamma, ws_dbeta); | |||||
| } | } | ||||
| @@ -22,6 +22,6 @@ void CopyMemDevice2Device(const size_t N, const size_t C, | |||||
| float *gamma_addr, float *beta_addr, float *runing_mean_addr, float *runnig_variance_addr, | float *gamma_addr, float *beta_addr, float *runing_mean_addr, float *runnig_variance_addr, | ||||
| float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var, | float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var, | ||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| void ComputeMean(const size_t N, const size_t C, float *save_mean_addr, float *save_var_addr, | |||||
| cudaStream_t cuda_stream); | |||||
| void ComputeMean(const size_t N, const size_t C, float *dgamma, float *dbeta, const float *ws_dgamma, | |||||
| const float *ws_dbeta, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_ | ||||
| @@ -76,9 +76,11 @@ class InstanceNormGradGpuKernel : public GpuKernel { | |||||
| T *dz = nullptr; | T *dz = nullptr; | ||||
| float *ws_gamma = GetDeviceAddress<float>(workspace, 0); | float *ws_gamma = GetDeviceAddress<float>(workspace, 0); | ||||
| float *ws_dgamma = GetDeviceAddress<float>(workspace, 1); | |||||
| float *ws_dbeta = GetDeviceAddress<float>(workspace, 2); | |||||
| void *workspace_addr = nullptr; | void *workspace_addr = nullptr; | ||||
| if (workspace_size_ != 0) { | if (workspace_size_ != 0) { | ||||
| workspace_addr = GetDeviceAddress<T>(workspace, 1); | |||||
| workspace_addr = GetDeviceAddress<T>(workspace, 3); | |||||
| } | } | ||||
| size_t N = input_shape_[0]; | size_t N = input_shape_[0]; | ||||
| @@ -92,14 +94,14 @@ class InstanceNormGradGpuKernel : public GpuKernel { | |||||
| const float alpha_param_diff = 1; | const float alpha_param_diff = 1; | ||||
| const float beta_param_diff = 0; | const float beta_param_diff = 0; | ||||
| float *reserve_addr = nullptr; | float *reserve_addr = nullptr; | ||||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||||
| kernel_node_, | |||||
| cudnnBatchNormalizationBackwardEx( | |||||
| handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x, | |||||
| y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, ws_gamma, beta, dgamma, dbeta, | |||||
| epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0), | |||||
| "Kernel launch failed"); | |||||
| ComputeMean(N, C, dgamma, dbeta, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, | |||||
| cudnnBatchNormalizationBackwardEx( | |||||
| handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, | |||||
| &beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, | |||||
| scale_bias_diff_desc_, ws_gamma, beta, ws_dgamma, ws_dbeta, epsilon_, save_mean, | |||||
| save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0), | |||||
| "Kernel launch failed"); | |||||
| ComputeMean(N, C, dgamma, dbeta, ws_dgamma, ws_dbeta, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -164,10 +166,12 @@ class InstanceNormGradGpuKernel : public GpuKernel { | |||||
| input_size_list_.push_back(para_size_); | input_size_list_.push_back(para_size_); | ||||
| output_size_list_.push_back(x_size_); | output_size_list_.push_back(x_size_); | ||||
| output_size_list_.push_back(para_size_); | |||||
| output_size_list_.push_back(para_size_); | |||||
| output_size_list_.push_back(x_size_); | |||||
| output_size_list_.push_back(x_size_); | |||||
| workspace_size_list_.push_back(para_size_); // ws gamma | workspace_size_list_.push_back(para_size_); // ws gamma | ||||
| workspace_size_list_.push_back(para_size_); // ws dgamma | |||||
| workspace_size_list_.push_back(para_size_); // ws dbeta | |||||
| workspace_size_list_.push_back(workspace_size_); | workspace_size_list_.push_back(workspace_size_); | ||||
| } | } | ||||
| void DestroyResource() noexcept override { | void DestroyResource() noexcept override { | ||||
| @@ -31,6 +31,7 @@ from mindspore._extends import cell_attr_register | |||||
| from mindspore.communication.management import get_group_size, get_rank | from mindspore.communication.management import get_group_size, get_rank | ||||
| from mindspore.communication import management | from mindspore.communication import management | ||||
| from mindspore.ops import _selected_ops | from mindspore.ops import _selected_ops | ||||
| from mindspore.common import dtype as mstype | |||||
| from ..cell import Cell | from ..cell import Cell | ||||
| __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', | __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', | ||||
| @@ -999,7 +1000,7 @@ class InstanceNorm2d(Cell): | |||||
| if not isinstance(val, (Tensor, numbers.Number, str, Initializer)): | if not isinstance(val, (Tensor, numbers.Number, str, Initializer)): | ||||
| raise TypeError(f"[{name}]Supported type for arg {key} is [Tensor, numbers.Number, str, Initializer]," | raise TypeError(f"[{name}]Supported type for arg {key} is [Tensor, numbers.Number, str, Initializer]," | ||||
| f"but got {type(val)}") | f"but got {type(val)}") | ||||
| if isinstance(val, Tensor) and val.dtype is not float: | |||||
| if isinstance(val, Tensor) and val.dtype != mstype.float32: | |||||
| raise TypeError(f"[{name}]The type of arg {key} should be float32, but got {val.dtype}") | raise TypeError(f"[{name}]The type of arg {key} should be float32, but got {val.dtype}") | ||||