diff --git a/mindspore/lite/nnacl/fp32_grad/layernorm_grad.c b/mindspore/lite/nnacl/fp32_grad/layernorm_grad.c index 38357d39c4..b68010cddc 100644 --- a/mindspore/lite/nnacl/fp32_grad/layernorm_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/layernorm_grad.c @@ -19,14 +19,15 @@ void LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db) { - // var is actually 1/sqrf(var)-> var^0.5 + // var is actually layer_norm forward output var + float eps = 1e-12; const float *var_sqrt_rev = var; for (size_t i = 0; i < param_num; ++i) { float dgamma = 0.0f; float dbeta = 0.0f; for (size_t j = i; j < param_size * param_num; j += param_num) { int norm_shift = (int)(j / block_size); - dgamma += dy[j] * var_sqrt_rev[norm_shift] * (x[j] - mean[norm_shift]); + dgamma += dy[j] * pow(var[norm_shift] + eps, -0.5) * (x[j] - mean[norm_shift]); dbeta += dy[j]; } dg[i] = dgamma; @@ -41,13 +42,14 @@ void LayerNormGrad(const float *x, const float *dy, const float *var, const floa int norm_shift = (int)(j / block_size); float dxm = x[j] - mean[norm_shift]; float dyg = dy[j] * gamma[param_shift]; - sum1 += -0.5f * dyg * dxm * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift]; + sum1 += -0.5f * dyg * dxm * pow(var_sqrt_rev[norm_shift] + eps, -1.5); + sum2 += dyg; sum3 += -2.0f * dxm; } for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { int param_shift = j % param_num; int norm_shift = (int)(j / block_size); - float var_sqrt = var_sqrt_rev[norm_shift]; + float var_sqrt = pow(var_sqrt_rev[norm_shift] + eps, -0.5); float dx1 = dy[j] * gamma[param_shift] * var_sqrt; float dx2 = sum1 * 2.0f / block_size * (x[j] - mean[norm_shift]); float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size);