Browse Source

fix train grad bug

tags/v1.2.0-rc1
zhengjun10 5 years ago
parent
commit
76dc0e4ae8
3 changed files with 18 additions and 5 deletions
  1. +6
    -4
      mindspore/lite/nnacl/fp32_grad/layernorm_grad.c
  2. +8
    -1
      mindspore/lite/nnacl/infer/strided_slice_grad_infer.c
  3. +4
    -0
      mindspore/lite/src/train/train_session.cc

+ 6
- 4
mindspore/lite/nnacl/fp32_grad/layernorm_grad.c View File

@@ -19,14 +19,15 @@

void LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma,
int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db) {
// var is actually 1/sqrf(var)-> var^0.5
// var is actually layer_norm forward output var
float eps = 1e-12;
const float *var_sqrt_rev = var;
for (size_t i = 0; i < param_num; ++i) {
float dgamma = 0.0f;
float dbeta = 0.0f;
for (size_t j = i; j < param_size * param_num; j += param_num) {
int norm_shift = (int)(j / block_size);
dgamma += dy[j] * var_sqrt_rev[norm_shift] * (x[j] - mean[norm_shift]);
dgamma += dy[j] * pow(var[norm_shift] + eps, -0.5) * (x[j] - mean[norm_shift]);
dbeta += dy[j];
}
dg[i] = dgamma;
@@ -41,13 +42,14 @@ void LayerNormGrad(const float *x, const float *dy, const float *var, const floa
int norm_shift = (int)(j / block_size);
float dxm = x[j] - mean[norm_shift];
float dyg = dy[j] * gamma[param_shift];
sum1 += -0.5f * dyg * dxm * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift];
sum1 += -0.5f * dyg * dxm * pow(var_sqrt_rev[norm_shift] + eps, -1.5);
sum2 += dyg;
sum3 += -2.0f * dxm;
}
for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) {
int param_shift = j % param_num;
int norm_shift = (int)(j / block_size);
float var_sqrt = var_sqrt_rev[norm_shift];
float var_sqrt = pow(var_sqrt_rev[norm_shift] + eps, -0.5);
float dx1 = dy[j] * gamma[param_shift] * var_sqrt;
float dx2 = sum1 * 2.0f / block_size * (x[j] - mean[norm_shift]);
float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size);


+ 8
- 1
mindspore/lite/nnacl/infer/strided_slice_grad_infer.c View File

@@ -81,7 +81,14 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size,
ellipsis_mask_[i] = (bool)(param->ellipsisMask_) & (1 << i);
new_axis_mask_[i] = (bool)(param->newAxisMask_) & (1 << i);
}

param->num_axes_ = in_shape_size;
param->in_shape_length_ = in_shape_size;
for (int i = 0; i < ndim_; ++i) {
param->begins_[i] = begins_[i];
param->ends_[i] = ends_[i];
param->strides_[i] = strides_[i];
}
ShapeSet(param->in_shape_, &in_shape_size, input->shape_, input->shape_size_);
// ApplyNewAxisMask();
for (size_t i = 0; i < ndim_; i++) {
if (new_axis_mask_[i]) {


+ 4
- 0
mindspore/lite/src/train/train_session.cc View File

@@ -387,6 +387,10 @@ void TrainSession::CompileOptimizedKernels() {
}

int TrainSession::SetLearningRate(float learning_rate) {
if (learning_rate < 0.0f) {
MS_LOG(ERROR) << "learning rate should more than 0";
return RET_ERROR;
}
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);


Loading…
Cancel
Save