|
|
|
@@ -76,9 +76,11 @@ class InstanceNormGradGpuKernel : public GpuKernel { |
|
|
|
T *dz = nullptr; |
|
|
|
|
|
|
|
float *ws_gamma = GetDeviceAddress<float>(workspace, 0); |
|
|
|
float *ws_dgamma = GetDeviceAddress<float>(workspace, 1); |
|
|
|
float *ws_dbeta = GetDeviceAddress<float>(workspace, 2); |
|
|
|
void *workspace_addr = nullptr; |
|
|
|
if (workspace_size_ != 0) { |
|
|
|
workspace_addr = GetDeviceAddress<T>(workspace, 1); |
|
|
|
workspace_addr = GetDeviceAddress<T>(workspace, 3); |
|
|
|
} |
|
|
|
|
|
|
|
size_t N = input_shape_[0]; |
|
|
|
@@ -92,14 +94,14 @@ class InstanceNormGradGpuKernel : public GpuKernel { |
|
|
|
const float alpha_param_diff = 1; |
|
|
|
const float beta_param_diff = 0; |
|
|
|
float *reserve_addr = nullptr; |
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT( |
|
|
|
kernel_node_, |
|
|
|
cudnnBatchNormalizationBackwardEx( |
|
|
|
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x, |
|
|
|
y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, ws_gamma, beta, dgamma, dbeta, |
|
|
|
epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0), |
|
|
|
"Kernel launch failed"); |
|
|
|
ComputeMean(N, C, dgamma, dbeta, reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, |
|
|
|
cudnnBatchNormalizationBackwardEx( |
|
|
|
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, |
|
|
|
&beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, |
|
|
|
scale_bias_diff_desc_, ws_gamma, beta, ws_dgamma, ws_dbeta, epsilon_, save_mean, |
|
|
|
save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0), |
|
|
|
"Kernel launch failed"); |
|
|
|
ComputeMean(N, C, dgamma, dbeta, ws_dgamma, ws_dbeta, reinterpret_cast<cudaStream_t>(stream_ptr)); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -164,10 +166,12 @@ class InstanceNormGradGpuKernel : public GpuKernel { |
|
|
|
input_size_list_.push_back(para_size_); |
|
|
|
|
|
|
|
output_size_list_.push_back(x_size_); |
|
|
|
output_size_list_.push_back(para_size_); |
|
|
|
output_size_list_.push_back(para_size_); |
|
|
|
output_size_list_.push_back(x_size_); |
|
|
|
output_size_list_.push_back(x_size_); |
|
|
|
|
|
|
|
workspace_size_list_.push_back(para_size_); // ws gamma |
|
|
|
workspace_size_list_.push_back(para_size_); // ws dgamma |
|
|
|
workspace_size_list_.push_back(para_size_); // ws dbeta |
|
|
|
workspace_size_list_.push_back(workspace_size_); |
|
|
|
} |
|
|
|
void DestroyResource() noexcept override { |
|
|
|
|