diff --git a/mindspore/lite/nnacl/fp32_grad/activation_grad.c b/mindspore/lite/nnacl/fp32_grad/activation_grad.c index cb8c7fb16d..55b17f7c27 100644 --- a/mindspore/lite/nnacl/fp32_grad/activation_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/activation_grad.c @@ -20,58 +20,60 @@ #include "nnacl/fp32_grad/activation_grad.h" #include "nnacl/errorcode.h" -inline int ReluGrad(float *src0, float *src1, int length, float *dst) { - for (int i = 0; i < length; ++i) { - dst[i] = src1[i] > 0 ? 1.0f : 0.0f; +inline int ReluGrad(float *src0, float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + if (src1[i] > 0) { + dst[i] = src0[i]; + } else { + dst[i] = 0; + } } - ElementMul(src0, dst, dst, length); return NNACL_OK; } -int Relu6Grad(float *src0, float *src1, int length, float *dst) { - for (int i = 0; i < length; ++i) { - if (src1[i] < 0) { - dst[i] = 0; +int Relu6Grad(float *src0, float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + if (src1[i] > 0.0f && src1[i] <= 6.0f) { + dst[i] = src0[i]; } else { - dst[i] = src1[i] > 6.0f ? 0.0f : 1.0f; + dst[i] = 0.0f; } } - ElementMul(src0, dst, dst, length); return NNACL_OK; } -int LReluGrad(float *src0, float *src1, int length, float *dst, float alpha) { - for (int i = 0; i < length; ++i) { +int LReluGrad(float *src0, float *src1, size_t length, float *dst, float alpha) { + for (size_t i = 0; i < length; ++i) { dst[i] = src1[i] > 0.0f ? 1.0f : alpha; } ElementMul(src0, dst, dst, length); return NNACL_OK; } -int SigmoidGrad(float *src0, float *src1, int length, float *dst) { - for (int i = 0; i < length; ++i) { +int SigmoidGrad(float *src0, float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); } return NNACL_OK; } -int TanhGrad(float *src0, float *src1, int length, float *dst) { - for (int i = 0; i < length; ++i) { +int TanhGrad(float *src0, float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i]; } return NNACL_OK; } -int HSwishGrad(float *src0, float *src1, int length, float *dst) { - for (int i = 0; i < length; ++i) { +int HSwishGrad(float *src0, float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); dst[i] = tmp * src0[i]; } return NNACL_OK; } -int HSigmoidGrad(float *src0, float *src1, int length, float *dst) { - for (int i = 0; i < length; ++i) { +int HSigmoidGrad(float *src0, float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { float tmp = (src1[i] > 3.0f ? 0.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); dst[i] = tmp * src0[i]; } diff --git a/mindspore/lite/nnacl/fp32_grad/activation_grad.h b/mindspore/lite/nnacl/fp32_grad/activation_grad.h index 5863f87284..7aa8f755c4 100644 --- a/mindspore/lite/nnacl/fp32_grad/activation_grad.h +++ b/mindspore/lite/nnacl/fp32_grad/activation_grad.h @@ -30,13 +30,13 @@ typedef struct ActivationGradParameter { extern "C" { #endif -int ReluGrad(float *src0, float *src1, int length, float *dst); -int Relu6Grad(float *src0, float *src1, int length, float *dst); -int LReluGrad(float *src0, float *src1, int length, float *dst, float alpha); -int SigmoidGrad(float *src0, float *src1, int length, float *dst); -int TanhGrad(float *src0, float *src1, int length, float *dst); -int HSwishGrad(float *src0, float *src1, int length, float *dst); -int HSigmoidGrad(float *src0, float *src1, int length, float *dst); +int ReluGrad(float *src0, float *src1, size_t length, float *dst); +int Relu6Grad(float *src0, float *src1, size_t length, float *dst); +int LReluGrad(float *src0, float *src1, size_t length, float *dst, float alpha); +int SigmoidGrad(float *src0, float *src1, size_t length, float *dst); +int TanhGrad(float *src0, float *src1, size_t length, float *dst); +int HSwishGrad(float *src0, float *src1, size_t length, float *dst); +int HSigmoidGrad(float *src0, float *src1, size_t length, float *dst); #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/fp32_grad/pooling_grad.c b/mindspore/lite/nnacl/fp32_grad/pooling_grad.c index a5d702cf38..9ecd19d7a9 100644 --- a/mindspore/lite/nnacl/fp32_grad/pooling_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/pooling_grad.c @@ -34,21 +34,21 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float)); float kk = (float)(win_h * win_w); - for (uint16_t ib = 0; ib < output_batch; ib++) { + for (int ib = 0; ib < output_batch; ib++) { float *out = &output_ptr[(ib * in_h * in_w * channel)]; const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; // iterate over yt - for (uint16_t yh = 0; yh < output_h; yh++) { - for (uint16_t yw = 0; yw < output_w; yw++) { - for (uint16_t ic = 0; ic < channel; ic++) { + for (int yh = 0; yh < output_h; yh++) { + for (int yw = 0; yw < output_w; yw++) { + for (int ic = 0; ic < channel; ic++) { int idx = (yw + yh * output_w) * channel + ic; float delta = inPtr[idx] / kk; - for (int32_t kh = 0; kh < win_h; kh++) { + for (int kh = 0; kh < win_h; kh++) { int xh = yh * stride_h + kh - pad_h; if ((xh < 0) || (xh >= in_h)) { continue; } - for (int32_t kw = 0; kw < win_w; kw++) { + for (int kw = 0; kw < win_w; kw++) { int xw = yw * stride_w + kw - pad_w; if ((xw < 0) || (xw >= in_w)) { continue; @@ -78,25 +78,25 @@ void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy int output_batch = pooling_param->output_batch_; memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float)); - for (uint16_t ib = 0; ib < output_batch; ib++) { + for (int ib = 0; ib < output_batch; ib++) { float *out = &output_ptr[(ib * in_h * in_w * channel)]; const float *inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]); const float *dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]); - for (uint16_t yh = 0; yh < output_h; yh++) { - for (uint16_t yw = 0; yw < output_w; yw++) { - for (uint16_t ic = 0; ic < channel; ic++) { + for (int yh = 0; yh < output_h; yh++) { + for (int yw = 0; yw < output_w; yw++) { + for (int ic = 0; ic < channel; ic++) { int idx = (yw + yh * output_w) * channel + ic; float delta = dyPtr[idx]; float max_val = -FLT_MAX; int max_idx = 0; - for (int32_t kh = 0; kh < win_h; kh++) { + for (int kh = 0; kh < win_h; kh++) { int xh = yh * stride_h + kh - pad_h; if ((xh < 0) || (xh >= in_h)) { continue; } - for (int32_t kw = 0; kw < win_w; kw++) { + for (int kw = 0; kw < win_w; kw++) { int xw = yw * stride_w + kw - pad_w; if ((xw < 0) || (xw >= in_w)) { continue; diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index c24ea68e5a..dd37d496af 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -160,7 +160,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT #endif auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size.at(0); - attr->kernelW = kernel_size.at(1); + attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); auto stride = CastToInt(prim.GetAttr("stride")); attr->strideH = stride.at(2); @@ -240,7 +240,7 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); attr->kernelH = kernel_size.at(0); - attr->kernelW = kernel_size.at(1); + attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); auto stride = CastToInt(prim.GetAttr("stride")); attr->strideH = stride.at(2); diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.cc b/mindspore/lite/src/ops/conv2d_grad_filter.cc index 61574dc337..db51673eab 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.cc +++ b/mindspore/lite/src/ops/conv2d_grad_filter.cc @@ -104,22 +104,22 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vectorformat = schema::Format_NUM_OF_FORMAT; } auto pad_list = CastToInt(prim.GetAttr("pad_list")); - attr->padUp = pad_list[0]; - attr->padDown = pad_list[1]; - attr->padLeft = pad_list[2]; - attr->padRight = pad_list[3]; + attr->padUp = pad_list.at(0); + attr->padDown = pad_list.at(1); + attr->padLeft = pad_list.at(2); + attr->padRight = pad_list.at(3); auto dilation = CastToInt(prim.GetAttr("dilation")); - attr->dilateH = dilation[2]; - attr->dilateW = dilation[3]; + attr->dilateH = dilation.at(2); + attr->dilateW = dilation.at(3); auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); - attr->kernelH = kernel_size[0]; - attr->kernelW = kernel_size[1]; + attr->kernelH = kernel_size.at(0); + attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); auto stride = CastToInt(prim.GetAttr("stride")); - attr->strideH = stride[0]; - attr->strideW = stride[1]; + attr->strideH = stride.at(0); + attr->strideW = stride.at(1); attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); auto pad_mode = GetValue(prim.GetAttr("pad_mode")); diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index 6872c5090e..f323c8ec35 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -105,22 +105,22 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vectorformat = schema::Format_NUM_OF_FORMAT; } auto pad_list = CastToInt(prim.GetAttr("pad_list")); - attr->padUp = pad_list[0]; - attr->padDown = pad_list[1]; - attr->padLeft = pad_list[2]; - attr->padRight = pad_list[3]; + attr->padUp = pad_list.at(0); + attr->padDown = pad_list.at(1); + attr->padLeft = pad_list.at(2); + attr->padRight = pad_list.at(3); auto dilation = CastToInt(prim.GetAttr("dilation")); - attr->dilateH = dilation[2]; - attr->dilateW = dilation[3]; + attr->dilateH = dilation.at(2); + attr->dilateW = dilation.at(3); auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); - attr->kernelH = kernel_size[0]; - attr->kernelW = kernel_size[1]; + attr->kernelH = kernel_size.at(0); + attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); auto stride = CastToInt(prim.GetAttr("stride")); - attr->strideH = stride[0]; - attr->strideW = stride[1]; + attr->strideH = stride.at(0); + attr->strideW = stride.at(1); attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); diff --git a/mindspore/lite/src/ops/pooling_grad.cc b/mindspore/lite/src/ops/pooling_grad.cc index ffe8d8dcec..da24f23cfc 100644 --- a/mindspore/lite/src/ops/pooling_grad.cc +++ b/mindspore/lite/src/ops/pooling_grad.cc @@ -84,10 +84,15 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector } else { attr->format = schema::Format_NUM_OF_FORMAT; } + if (prim.instance_name() == "MaxPoolGrad") { attr->poolingMode = schema::PoolMode_MAX_POOLING; - } else if (prim.instance_name() == "MeanPoolGrad") { + } else if (prim.instance_name() == "AvgPoolGrad") { + attr->poolingMode = schema::PoolMode_MEAN_POOLING; + } else if (prim.instance_name() == "AvgPoolGradGpu") { attr->poolingMode = schema::PoolMode_MEAN_POOLING; + } else { + attr->poolingMode = schema::PoolMode_MAX_POOLING; } auto pad_mode = GetValue(prim.GetAttr("padding")); diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index c838262204..870ef76b78 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -609,7 +609,7 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { return NewPrimitiveC(prim, inputs, quantType); - } else if ((op_type == "MaxPoolGrad") || (op_type == "MeanPoolGrad") || (op_type == "AvgPoolGradGpu")) { + } else if ((op_type == "MaxPoolGrad") || (op_type == "AvgPoolGrad") || (op_type == "AvgPoolGradGpu")) { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Conv2DBackpropFilter") { return NewPrimitiveC(prim, inputs, quantType); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc index 8e62a651e2..6b582ba935 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc @@ -35,10 +35,6 @@ int PoolingGradCPUKernel::Init() { auto in_shape = in_tensors_.at(0)->shape(); auto out_shape = in_tensors_.at(1)->shape(); - if (pool_param->pool_mode_ == PoolMode_AvgPool) { - in_shape = in_tensors_.at(1)->shape(); - out_shape = in_tensors_.at(0)->shape(); - } int input_h = in_shape.at(1); int input_w = in_shape.at(2); @@ -71,6 +67,7 @@ int PoolingGradCPUKernel::Execute(int task_id) { auto dy_ptr = reinterpret_cast(in_tensors_.at(2)->MutableData()); MaxPoolingGrad(input_ptr, dx_ptr, dy_ptr, output_ptr, pool_param, task_id); } else { + input_ptr = reinterpret_cast(in_tensors_.at(2)->MutableData()); AvgPoolingGrad(input_ptr, output_ptr, pool_param, task_id); } return RET_OK; @@ -88,15 +85,6 @@ int PoolingGradImpl(void *cdata, int task_id) { } int PoolingGradCPUKernel::Run() { - // clear output buffer before parallel run - PoolingParameter *pooling_param = reinterpret_cast(op_parameter_); - auto output_ptr = reinterpret_cast(out_tensors_.at(0)->MutableData()); - int size = - pooling_param->input_w_ * pooling_param->input_h_ * pooling_param->input_channel_ * pooling_param->output_batch_; - for (int i = 0; i < size; i++) { - output_ptr[i] = 0.0; - } - int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, 1); if (error_code != RET_OK) { MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; diff --git a/mindspore/lite/src/train/train_model.cc b/mindspore/lite/src/train/train_model.cc index 3594ad3dbe..64f1739ffb 100644 --- a/mindspore/lite/src/train/train_model.cc +++ b/mindspore/lite/src/train/train_model.cc @@ -30,7 +30,7 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { flatbuffers::Verifier verify((const uint8_t *)model_buf, size); int schema_version = VersionVerify(&verify); if (schema_version == -1) { - MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; + MS_LOG(ERROR) << "The model buffer is invalid, cannot get schema version"; return nullptr; } TrainModel *model = new (std::nothrow) TrainModel(); diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index bdb1c2f3cb..4e2941e9c6 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -378,7 +378,7 @@ session::TrainSession *session::TrainSession::CreateSession(const std::string &f ifs.seekg(0, std::ios::end); auto size = ifs.tellg(); - if (size == 0) { + if (size <= 0) { MS_LOG(ERROR) << "Could not read file " << filename; return nullptr; } @@ -391,8 +391,12 @@ session::TrainSession *session::TrainSession::CreateSession(const std::string &f ifs.seekg(0, std::ios::beg); ifs.read(buf.get(), size); + if (!ifs) { + MS_LOG(ERROR) << "only read " << ifs.gcount() << "bytes in " << filename; + ifs.close(); + return nullptr; + } ifs.close(); - return session::TrainSession::CreateSession(buf.get(), size, context, train_mode); } diff --git a/mindspore/lite/test/models_ms_train.cfg b/mindspore/lite/test/models_ms_train.cfg index 3f6e70526d..91ef4d47fd 100644 --- a/mindspore/lite/test/models_ms_train.cfg +++ b/mindspore/lite/test/models_ms_train.cfg @@ -1,5 +1,5 @@ mini_alexnet -# mobilenetv1 +#mobilenetv1 mobilenetv2 mobilenetv3 lenet