| @@ -19,14 +19,15 @@ | |||||
| void LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, | void LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, | ||||
| int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db) { | int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db) { | ||||
| // var is actually 1/sqrf(var)-> var^0.5 | |||||
| // var is actually layer_norm forward output var | |||||
| float eps = 1e-12; | |||||
| const float *var_sqrt_rev = var; | const float *var_sqrt_rev = var; | ||||
| for (size_t i = 0; i < param_num; ++i) { | for (size_t i = 0; i < param_num; ++i) { | ||||
| float dgamma = 0.0f; | float dgamma = 0.0f; | ||||
| float dbeta = 0.0f; | float dbeta = 0.0f; | ||||
| for (size_t j = i; j < param_size * param_num; j += param_num) { | for (size_t j = i; j < param_size * param_num; j += param_num) { | ||||
| int norm_shift = (int)(j / block_size); | int norm_shift = (int)(j / block_size); | ||||
| dgamma += dy[j] * var_sqrt_rev[norm_shift] * (x[j] - mean[norm_shift]); | |||||
| dgamma += dy[j] * pow(var[norm_shift] + eps, -0.5) * (x[j] - mean[norm_shift]); | |||||
| dbeta += dy[j]; | dbeta += dy[j]; | ||||
| } | } | ||||
| dg[i] = dgamma; | dg[i] = dgamma; | ||||
| @@ -41,13 +42,14 @@ void LayerNormGrad(const float *x, const float *dy, const float *var, const floa | |||||
| int norm_shift = (int)(j / block_size); | int norm_shift = (int)(j / block_size); | ||||
| float dxm = x[j] - mean[norm_shift]; | float dxm = x[j] - mean[norm_shift]; | ||||
| float dyg = dy[j] * gamma[param_shift]; | float dyg = dy[j] * gamma[param_shift]; | ||||
| sum1 += -0.5f * dyg * dxm * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift]; | |||||
| sum1 += -0.5f * dyg * dxm * pow(var_sqrt_rev[norm_shift] + eps, -1.5); | |||||
| sum2 += dyg; | |||||
| sum3 += -2.0f * dxm; | sum3 += -2.0f * dxm; | ||||
| } | } | ||||
| for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { | for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { | ||||
| int param_shift = j % param_num; | int param_shift = j % param_num; | ||||
| int norm_shift = (int)(j / block_size); | int norm_shift = (int)(j / block_size); | ||||
| float var_sqrt = var_sqrt_rev[norm_shift]; | |||||
| float var_sqrt = pow(var_sqrt_rev[norm_shift] + eps, -0.5); | |||||
| float dx1 = dy[j] * gamma[param_shift] * var_sqrt; | float dx1 = dy[j] * gamma[param_shift] * var_sqrt; | ||||
| float dx2 = sum1 * 2.0f / block_size * (x[j] - mean[norm_shift]); | float dx2 = sum1 * 2.0f / block_size * (x[j] - mean[norm_shift]); | ||||
| float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size); | float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size); | ||||