From dfd2f8f2c1af1796ce3d0f9d3675bc92eeb51cc2 Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Fri, 19 Mar 2021 15:41:01 +0800 Subject: [PATCH] fix train issue --- mindspore/lite/nnacl/infer/strided_slice_grad_infer.c | 9 ++++++++- mindspore/lite/src/train/train_session.cc | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/nnacl/infer/strided_slice_grad_infer.c b/mindspore/lite/nnacl/infer/strided_slice_grad_infer.c index cd3127d7ab..69fb638b26 100644 --- a/mindspore/lite/nnacl/infer/strided_slice_grad_infer.c +++ b/mindspore/lite/nnacl/infer/strided_slice_grad_infer.c @@ -81,7 +81,14 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, ellipsis_mask_[i] = (bool)(param->ellipsisMask_) & (1 << i); new_axis_mask_[i] = (bool)(param->newAxisMask_) & (1 << i); } - + param->num_axes_ = in_shape_size; + param->in_shape_length_ = in_shape_size; + for (int i = 0; i < ndim_; ++i) { + param->begins_[i] = begins_[i]; + param->ends_[i] = ends_[i]; + param->strides_[i] = strides_[i]; + } + ShapeSet(param->in_shape_, &in_shape_size, input->shape_, input->shape_size_); // ApplyNewAxisMask(); for (size_t i = 0; i < ndim_; i++) { if (new_axis_mask_[i]) { diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 80e599d73b..34d5ef2bc6 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -387,6 +387,10 @@ void TrainSession::CompileOptimizedKernels() { } int TrainSession::SetLearningRate(float learning_rate) { + if (learning_rate < 0.0f) { + MS_LOG(ERROR) << "learning rate should more than 0"; + return RET_ERROR; + } for (auto kernel : this->train_kernels_) { if (IsOptimizer(kernel)) { auto optimizer = reinterpret_cast(kernel);