Browse Source

fix train issue

pull/13599/head
zhengjun10 4 years ago
parent
commit
dfd2f8f2c1
2 changed files with 12 additions and 1 deletions
  1. +8
    -1
      mindspore/lite/nnacl/infer/strided_slice_grad_infer.c
  2. +4
    -0
      mindspore/lite/src/train/train_session.cc

+ 8
- 1
mindspore/lite/nnacl/infer/strided_slice_grad_infer.c View File

@@ -81,7 +81,14 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size,
ellipsis_mask_[i] = (bool)(param->ellipsisMask_) & (1 << i);
new_axis_mask_[i] = (bool)(param->newAxisMask_) & (1 << i);
}

param->num_axes_ = in_shape_size;
param->in_shape_length_ = in_shape_size;
for (int i = 0; i < ndim_; ++i) {
param->begins_[i] = begins_[i];
param->ends_[i] = ends_[i];
param->strides_[i] = strides_[i];
}
ShapeSet(param->in_shape_, &in_shape_size, input->shape_, input->shape_size_);
// ApplyNewAxisMask();
for (size_t i = 0; i < ndim_; i++) {
if (new_axis_mask_[i]) {


+ 4
- 0
mindspore/lite/src/train/train_session.cc View File

@@ -387,6 +387,10 @@ void TrainSession::CompileOptimizedKernels() {
}

int TrainSession::SetLearningRate(float learning_rate) {
if (learning_rate < 0.0f) {
MS_LOG(ERROR) << "learning rate should more than 0";
return RET_ERROR;
}
for (auto kernel : this->train_kernels_) {
if (IsOptimizer(kernel)) {
auto optimizer = reinterpret_cast<kernel::OptimizerKernel *>(kernel);


Loading…
Cancel
Save