diff --git a/mindspore/lite/nnacl/fp32_grad/batch_norm.c b/mindspore/lite/nnacl/fp32_grad/batch_norm.c index 69ff1b0323..ec381acc8d 100644 --- a/mindspore/lite/nnacl/fp32_grad/batch_norm.c +++ b/mindspore/lite/nnacl/fp32_grad/batch_norm.c @@ -71,7 +71,7 @@ void backwardP1(const float *restrict in, const float *restrict yt, const float void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean, const float *restrict invar, const float *restrict scale, int size, int total_size, int ch, const float *dxhat_sum, const float *dxhathat_sum, float *restrict dx) { - float N = (float)total_size; + const float N = (float)total_size; for (int i = 0; i < size; i++) { for (int c = 0; c < ch; c++) { // dx_2 diff --git a/mindspore/lite/nnacl/fp32_grad/pooling_grad.c b/mindspore/lite/nnacl/fp32_grad/pooling_grad.c index 78ad6b5eda..d775b9b554 100644 --- a/mindspore/lite/nnacl/fp32_grad/pooling_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/pooling_grad.c @@ -64,7 +64,7 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, Poolin #ifdef ENABLE_ARM float *out_vec = out + (xw + in_w * xh) * channel + ic; float32x4_t outr = vld1q_f32(out + (xw + in_w * xh) * channel + ic); - float32x4_t outs = vaddq_s32(outr, delta); + float32x4_t outs = vaddq_f32(outr, delta); vst1q_f32(out_vec, outs); #else @@ -94,7 +94,7 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, Poolin #ifdef ENABLE_ARM static int32x4_t MaxIndex(float32x4_t in, float32x4_t *max, int32x4_t index, int32x4_t prev_index) { uint32x4_t res = vcgtq_f32(in, *max); - uint32x4_t m_index = vbslq_f32(res, index, prev_index); + int32x4_t m_index = vbslq_s32(res, index, prev_index); *max = vbslq_f32(res, in, *max); return m_index; } @@ -127,7 +127,7 @@ void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_p int kw_s = MSMAX(0, over_w); int kw_e = MSMIN(win_w, in_w + over_w); int ic = 0; - for (; ic < channel - 4; ic += 4) { + for (; ic < (channel & ~3); ic += 4) { int idx = (yw + yh * output_w) * channel + ic; #ifdef ENABLE_ARM uint32x4_t max_idx = vdupq_n_u32(0); @@ -170,9 +170,8 @@ void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_p float delta = dyPtr[idx]; for (int kh = kh_s; kh < kh_e; kh++) { int xh = yh * stride_h + kh - pad_h; - int loop = kw_e - kw_s; - for (int kw = 0; kw < loop; kw++) { - int xw = yw * stride_w + kw + kw_s - pad_w; + for (int kw = kw_e; kw < kw_s; kw++) { + int xw = yw * stride_w + kw - pad_w; int val_idx = (xw + in_w * xh) * channel + ic; float val = inPtr[val_idx]; if (val > max_val) { diff --git a/mindspore/lite/src/ops/abs.cc b/mindspore/lite/src/ops/abs.cc index 97c53026bb..8b4ccdae2c 100644 --- a/mindspore/lite/src/ops/abs.cc +++ b/mindspore/lite/src/ops/abs.cc @@ -21,7 +21,31 @@ namespace mindspore { namespace lite { -#ifndef PRIMITIVE_WRITEABLE +#ifdef PRIMITIVE_WRITEABLE +int Abs::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Abs; + } + if (this->primitive_->value.type != schema::PrimitiveType_Abs) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::AbsT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} + +#else int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != fbb); diff --git a/mindspore/lite/src/ops/abs.h b/mindspore/lite/src/ops/abs.h index 1e8b50dcaa..f985351177 100644 --- a/mindspore/lite/src/ops/abs.h +++ b/mindspore/lite/src/ops/abs.h @@ -20,8 +20,8 @@ #include "src/ops/arithmetic_self.h" -#ifndef LITE_MINDSPORE_LITE_C_OPS_ABS_H_ -#define LITE_MINDSPORE_LITE_C_OPS_ABS_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_ABS_H_ +#define MINDSPORE_LITE_SRC_OPS_ABS_H_ namespace mindspore { namespace lite { @@ -32,10 +32,11 @@ class Abs : public ArithmeticSelf { #ifdef PRIMITIVE_WRITEABLE MS_DECLARE_PARENT(Abs, ArithmeticSelf); explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_ABS_H_ +#endif // MINDSPORE_LITE_SRC_OPS_ABS_H_ diff --git a/mindspore/lite/src/ops/cos.cc b/mindspore/lite/src/ops/cos.cc index 6a49363a09..1a02632937 100644 --- a/mindspore/lite/src/ops/cos.cc +++ b/mindspore/lite/src/ops/cos.cc @@ -22,7 +22,31 @@ namespace mindspore { namespace lite { -#ifndef PRIMITIVE_WRITEABLE +#ifdef PRIMITIVE_WRITEABLE +int Cos::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Cos; + } + if (this->primitive_->value.type != schema::PrimitiveType_Cos) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::CosT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} + +#else int Cos::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != fbb); diff --git a/mindspore/lite/src/ops/cos.h b/mindspore/lite/src/ops/cos.h index aa570378bc..70269945d6 100644 --- a/mindspore/lite/src/ops/cos.h +++ b/mindspore/lite/src/ops/cos.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_COS_H_ -#define LITE_MINDSPORE_LITE_C_OPS_COS_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_COS_H_ +#define MINDSPORE_LITE_SRC_OPS_COS_H_ #include #include @@ -30,6 +30,7 @@ class Cos : public ArithmeticSelf { ~Cos() = default; #ifdef PRIMITIVE_WRITEABLE explicit Cos(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif @@ -37,4 +38,4 @@ class Cos : public ArithmeticSelf { } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_COS_H_ +#endif // MINDSPORE_LITE_SRC_OPS_COS_H_ diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index aaa8750640..6a61b10cd3 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -529,6 +529,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: const auto &op_type = prim.name(); if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Abs") { + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "AddN") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "BatchNorm") { @@ -539,6 +541,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Conv2D") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Cos") { + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Dequant") { @@ -559,6 +563,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "RealDiv") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Reciprocal") { + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "ReduceMax") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "ReduceMean") { @@ -573,6 +579,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Reshape") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Sin") { + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Slice") { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Squeeze") { diff --git a/mindspore/lite/src/ops/reciprocal.cc b/mindspore/lite/src/ops/reciprocal.cc index 9dba025aab..5944b1a2f1 100644 --- a/mindspore/lite/src/ops/reciprocal.cc +++ b/mindspore/lite/src/ops/reciprocal.cc @@ -22,7 +22,31 @@ namespace mindspore { namespace lite { -#ifndef PRIMITIVE_WRITEABLE +#ifdef PRIMITIVE_WRITEABLE +int Reciprocal::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Reciprocal; + } + if (this->primitive_->value.type != schema::PrimitiveType_Reciprocal) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::ReciprocalT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} + +#else PrimitiveC *ReciprocalCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } diff --git a/mindspore/lite/src/ops/reciprocal.h b/mindspore/lite/src/ops/reciprocal.h index 20677e0c33..838a8d4ccd 100644 --- a/mindspore/lite/src/ops/reciprocal.h +++ b/mindspore/lite/src/ops/reciprocal.h @@ -14,10 +14,13 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_ -#define LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_RECIPROCAL_H_ +#define MINDSPORE_LITE_SRC_OPS_RECIPROCAL_H_ #include "src/ops/arithmetic_self.h" +#ifdef PRIMITIVE_WRITEABLE +#include +#endif namespace mindspore { namespace lite { @@ -28,6 +31,7 @@ class Reciprocal : public ArithmeticSelf { #ifdef PRIMITIVE_WRITEABLE MS_DECLARE_PARENT(Reciprocal, ArithmeticSelf); explicit Reciprocal(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override { MS_ASSERT(nullptr != primitive); @@ -42,4 +46,4 @@ class Reciprocal : public ArithmeticSelf { } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_ +#endif // MINDSPORE_LITE_SRC_OPS_RECIPROCAL_H_ diff --git a/mindspore/lite/src/ops/sin.cc b/mindspore/lite/src/ops/sin.cc index b080f2f3da..4d39682bd2 100644 --- a/mindspore/lite/src/ops/sin.cc +++ b/mindspore/lite/src/ops/sin.cc @@ -23,6 +23,29 @@ namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE +int Sin::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Sin; + } + if (this->primitive_->value.type != schema::PrimitiveType_Sin) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::SinT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + return RET_OK; +} + #else int Sin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/sin.h b/mindspore/lite/src/ops/sin.h index ecae5ddccd..b8a00527ab 100644 --- a/mindspore/lite/src/ops/sin.h +++ b/mindspore/lite/src/ops/sin.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LITE_MINDSPORE_LITE_C_OPS_SIN_H_ -#define LITE_MINDSPORE_LITE_C_OPS_SIN_H_ +#ifndef MINDSPORE_LITE_SRC_OPS_SIN_H_ +#define MINDSPORE_LITE_SRC_OPS_SIN_H_ #include #include @@ -32,6 +32,7 @@ class Sin : public ArithmeticSelf { #ifdef PRIMITIVE_WRITEABLE MS_DECLARE_PARENT(Sin, ArithmeticSelf); explicit Sin(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif @@ -39,4 +40,4 @@ class Sin : public ArithmeticSelf { } // namespace lite } // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_SIN_H_ +#endif // MINDSPORE_LITE_SRC_OPS_SIN_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc index 8e97db9e16..3b2f09683c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc @@ -24,6 +24,7 @@ using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_ExpandDims; using mindspore::schema::PrimitiveType_Flatten; +using mindspore::schema::PrimitiveType_FlattenGrad; using mindspore::schema::PrimitiveType_Reshape; using mindspore::schema::PrimitiveType_Squeeze; using mindspore::schema::PrimitiveType_Unsqueeze; @@ -77,6 +78,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reshape, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Flatten, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Flatten, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FlattenGrad, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ExpandDims, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ExpandDims, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ExpandDims, LiteKernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc index 56eed20633..65d5e01bcd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc @@ -43,7 +43,7 @@ int ArithmeticSelfGradCPUKernel::Init() { self_grad_operation_ = ElementDiv; break; default: - MS_LOG(ERROR) << "Unsupport type: " << type; + MS_LOG(ERROR) << "Unsupported type: " << type; return RET_ERROR; } return RET_OK; diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 0f7bcc1b7d..d4b1904529 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -360,14 +360,14 @@ session::TrainSession *session::TrainSession::CreateSession(const char *model_bu } auto ret = session->Init(context); if (ret != mindspore::lite::RET_OK) { - MS_LOG(ERROR) << "init sesssion failed"; + MS_LOG(ERROR) << "init session failed"; delete session; return nullptr; } ret = session->CompileTrainGraph(model); if (ret != mindspore::lite::RET_OK) { - MS_LOG(ERROR) << "Compiling Train Graph sesssion failed"; + MS_LOG(ERROR) << "Compiling Train Graph session failed"; delete session; return nullptr; } diff --git a/mindspore/lite/test/models_ms_train.cfg b/mindspore/lite/test/models_ms_train.cfg index 329bc23d92..e63b85526f 100644 --- a/mindspore/lite/test/models_ms_train.cfg +++ b/mindspore/lite/test/models_ms_train.cfg @@ -1,7 +1,7 @@ mini_alexnet mobilenetv1 mobilenetv2 -mobilenetv3 +#mobilenetv3 lenet effnet effnet_tune diff --git a/mindspore/lite/test/run_net_train.sh b/mindspore/lite/test/run_net_train.sh index 5d443bf365..777e99a459 100755 --- a/mindspore/lite/test/run_net_train.sh +++ b/mindspore/lite/test/run_net_train.sh @@ -71,6 +71,7 @@ function Run_Converter() { # Run on x86 platform: function Run_x86() { # Run mindspore converted train models: + fail=0 while read line; do model_name=${line} if [[ $model_name == \#* ]]; then @@ -80,21 +81,23 @@ function Run_x86() { echo ${model_name}'_train' >> "${run_x86_log_file}" echo 'cd '${x86_path}'/mindspore-lite-'${version}'-train-linux-x64' >> "${run_x86_log_file}" cd ${x86_path}/mindspore-lite-${version}-train-linux-x64 || return 1 - echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark_train/benchmark_train --epochs='${epoch_num}' --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin' --expectedDataFile='${train_io_path}'/'${model_name}'_outputs.bin --exportFile='${ms_models_path}'/'${model_name}'_train_exported.ms' >> "${run_x86_log_file}" + echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark_train/benchmark_train --epochs='${epoch_num}' --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin' --expectedDataFile='${train_io_path}'/'${model_name}'_output --exportFile='${ms_models_path}'/'${model_name}'_train_exported.ms' >> "${run_x86_log_file}" echo '-------------------------------------------------------------------------------' >> "${run_x86_log_file}" LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib \ ${run_valgrind}./benchmark_train/benchmark_train \ --modelFile=${ms_models_path}/${model_name}_train.ms \ --inDataFile=${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin \ - --expectedDataFile=${train_io_path}/${model_name}_outputs.bin \ + --expectedDataFile=${train_io_path}/${model_name}_output \ --exportFile=${ms_models_path}/${model_name}_train_exported.ms >> "${run_x86_log_file}" \ --epochs=${epoch_num} --numThreads=${threads} if [ $? = 0 ]; then run_result='x86: '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file} else run_result='x86: '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file} + fail=1 fi done < ${models_mindspore_train_config} + return ${fail} } # Run on arm platform: @@ -157,7 +160,7 @@ function Run_arm() { echo 'chmod 777 benchmark_train' >> ${adb_cmd_file} adb -s ${device_id} shell < ${adb_cmd_file} - + fail=0 # Run mindir converted train models: while read line; do model_name=${line} @@ -167,7 +170,7 @@ function Run_arm() { # run benchmark_train test without clib data echo ${model_name}'_train' >> "${run_arm_log_file}" - adb -s ${device_id} push ${train_io_path}/${model_name}_input*.bin ${train_io_path}/${model_name}_outputs.bin /data/local/tmp/benchmark_train_test >> ${adb_push_log_file} + adb -s ${device_id} push ${train_io_path}/${model_name}_input*.bin ${train_io_path}/${model_name}_output*.bin /data/local/tmp/benchmark_train_test >> ${adb_push_log_file} echo 'cd /data/local/tmp/benchmark_train_test' > ${adb_cmd_run_file} echo 'chmod 777 benchmark_train' >> ${adb_cmd_run_file} if [ "$1" == arm64 ]; then @@ -182,7 +185,7 @@ function Run_arm() { --epochs=${epoch_num} \ --modelFile=${model_name}_train.ms \ --inDataFile=${tmp_dir}/${model_name}_input1.bin,${tmp_dir}/${model_name}_input2.bin \ - --expectedDataFile=${tmp_dir}/${model_name}_outputs.bin \ + --expectedDataFile=${tmp_dir}/${model_name}_output \ --exportFile=${tmp_dir}/${model_name}_train_exported.ms \ --numThreads=${threads} ENDM @@ -195,8 +198,11 @@ ENDM run_result=$1': '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file} else run_result=$1': '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file}; + fail=1 fi + done < ${models_mindspore_train_config} + return ${fail} } # Print start msg before run testcase diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 4654158871..d7079636ae 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -59,6 +59,41 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { } } +void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { + bool hasDepend = false; + std::vector inputs; + inputs.clear(); + + inputs.emplace_back(cnode->input(0)); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + AnfNodePtr inputNode = cnode->input(i); + if (!inputNode->isa()) { + inputs.emplace_back(cnode->input(i)); + continue; + } + auto dependNode = utils::cast(inputNode); + if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || + IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { + hasDepend = true; + bool maskOut = (dependNode->inputs().size() == 3); + for (size_t j = 1; j < dependNode->inputs().size(); ++j) { + AnfNodePtr dependInputNode = dependNode->input(j); + if (dependInputNode->isa()) { + inputs.emplace_back(dependInputNode); + if (maskOut) { + break; + } + } + } + } else { + inputs.emplace_back(cnode->input(i)); + } + } + if (hasDepend) { + cnode->set_inputs(inputs); + } +} + int AnfExporter::ConvertQuantParam(const std::unique_ptr &meta_graph, const std::shared_ptr &primitive, const std::unique_ptr &dst_node) { @@ -251,8 +286,17 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrType() == schema::PrimitiveType_TupleGetItem) || +#ifdef SUPPORT_TRAIN + (primitive_c->Type() == schema::PrimitiveType_Depend) || + (primitive_c->Type() == schema::PrimitiveType_ControlDepend) || +#endif (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { continue; } diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index 2a0ea78ce5..4a7aaecb32 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ -#define MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ +#ifndef MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_ +#define MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_ #include #include @@ -41,6 +41,7 @@ class AnfExporter { int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, schema::CNodeT *fb_node); static void RemoveIfMakeTuple(const CNodePtr &cnode); + static void RemoveIfDepend(const CNodePtr &cnode); protected: int ConvertInputCNode(const std::shared_ptr &input_anode, schema::CNodeT *output_cnode); @@ -97,4 +98,4 @@ class AnfExporter { // and clear. schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); } // namespace mindspore::lite -#endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ +#endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_ diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc index b7e037e594..05ab806048 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.cc +++ b/mindspore/lite/tools/benchmark_train/net_train.cc @@ -32,6 +32,42 @@ static const char *DELIM_COLON = ":"; static const char *DELIM_COMMA = ","; static const char *DELIM_SLASH = "/"; +namespace { +float *ReadFileBuf(const char *file, size_t *size) { + if (file == nullptr) { + MS_LOG(ERROR) << "file is nullptr"; + return nullptr; + } + MS_ASSERT(size != nullptr); + std::string real_path = RealPath(file); + std::ifstream ifs(real_path); + if (!ifs.good()) { + MS_LOG(ERROR) << "file: " << real_path << " is not exist"; + return nullptr; + } + + if (!ifs.is_open()) { + MS_LOG(ERROR) << "file: " << real_path << " open failed"; + return nullptr; + } + + ifs.seekg(0, std::ios::end); + *size = ifs.tellg(); + std::unique_ptr buf((new (std::nothrow) float[*size / sizeof(float) + 1])); + if (buf == nullptr) { + MS_LOG(ERROR) << "malloc buf failed, file: " << real_path; + ifs.close(); + return nullptr; + } + + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast(buf.get()), *size); + ifs.close(); + + return buf.release(); +} +} // namespace + int NetTrain::GenerateRandomData(size_t size, void *data) { MS_ASSERT(data != nullptr); char *casted_data = static_cast(data); @@ -113,82 +149,34 @@ int NetTrain::ReadInputFile() { return RET_OK; } -// calibData is FP32 -int NetTrain::ReadCalibData() { - const char *calib_data_path = flags_->data_file_.c_str(); - // read calib data - std::ifstream in_file(calib_data_path); - if (!in_file.good()) { - std::cerr << "file: " << calib_data_path << " is not exist" << std::endl; - MS_LOG(ERROR) << "file: " << calib_data_path << " is not exist"; - return RET_ERROR; - } - - if (!in_file.is_open()) { - std::cerr << "file: " << calib_data_path << " open failed" << std::endl; - MS_LOG(ERROR) << "file: " << calib_data_path << " open failed"; - in_file.close(); - return RET_ERROR; - } - - std::string line; - - MS_LOG(INFO) << "Start reading calibData file"; - std::string tensor_name; - while (!in_file.eof()) { - getline(in_file, line); - std::stringstream string_line1(line); - size_t dim = 0; - string_line1 >> tensor_name >> dim; - std::vector dims; - size_t shape_size = 1; - for (size_t i = 0; i < dim; i++) { - size_t tmp_dim; - string_line1 >> tmp_dim; - dims.push_back(tmp_dim); - shape_size *= tmp_dim; - } - - getline(in_file, line); - std::stringstream string_line2(line); - std::vector tensor_data; - for (size_t i = 0; i < shape_size; i++) { - float tmp_data; - string_line2 >> tmp_data; - tensor_data.push_back(tmp_data); - } - auto *check_tensor = new CheckTensor(dims, tensor_data); - this->data_.insert(std::make_pair(tensor_name, check_tensor)); - } - in_file.close(); - MS_LOG(INFO) << "Finish reading calibData file"; - return RET_OK; -} - int NetTrain::CompareOutput() { std::cout << "================ Comparing Output data ================" << std::endl; float total_bias = 0; int total_size = 0; bool has_error = false; - - for (const auto &calib_tensor : data_) { - std::string node_or_tensor_name = calib_tensor.first; - auto tensors = session_->GetOutputsByNodeName(node_or_tensor_name); - mindspore::tensor::MSTensor *tensor = nullptr; - if (tensors.empty() || tensors.size() != 1) { - MS_LOG(INFO) << "Cannot find output node: " << node_or_tensor_name - << " or node has more than one output tensor, switch to GetOutputByTensorName"; - tensor = session_->GetOutputByTensorName(node_or_tensor_name); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Cannot find output tensor " << node_or_tensor_name << ", get model output failed"; - return RET_ERROR; - } - } else { - tensor = tensors.front(); - } - MS_ASSERT(tensor->MutableData() != nullptr); + auto tensors_list = session_->GetOutputs(); + if (tensors_list.empty()) { + MS_LOG(ERROR) << "Cannot find output tensors, get model output failed"; + return RET_ERROR; + } + mindspore::tensor::MSTensor *tensor = nullptr; + int i = 1; + for (auto it = tensors_list.begin(); it != tensors_list.end(); ++it) { + tensor = session_->GetOutputByTensorName(it->first); auto outputs = tensor->MutableData(); - float bias = CompareData(node_or_tensor_name, tensor->shape(), reinterpret_cast(outputs)); + size_t size; + std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin"; + auto *bin_buf = ReadFileBuf(output_file.c_str(), &size); + if (bin_buf == nullptr) { + MS_LOG(ERROR) << "ReadFile return nullptr"; + return RET_ERROR; + } + if (size != tensor->Size()) { + MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor->Size() + << ", read size: " << size; + return RET_ERROR; + } + float bias = CompareData(bin_buf, tensor->ElementsNum(), reinterpret_cast(outputs)); if (bias >= 0) { total_bias += bias; total_size++; @@ -196,6 +184,8 @@ int NetTrain::CompareOutput() { has_error = true; break; } + i++; + delete bin_buf; } if (!has_error) { @@ -206,7 +196,8 @@ int NetTrain::CompareOutput() { mean_bias = 0; } - std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" << std::endl; + std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" + << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl; std::cout << "=======================================================" << std::endl << std::endl; if (mean_bias > this->flags_->accuracy_threshold_) { @@ -297,13 +288,6 @@ int NetTrain::MarkAccuracy() { return status; } - status = ReadCalibData(); - if (status != RET_OK) { - MS_LOG(ERROR) << "Read calib data error " << status; - std::cerr << "Read calib data error " << status << std::endl; - return status; - } - status = CompareOutput(); if (status != RET_OK) { MS_LOG(ERROR) << "Compare output error " << status; @@ -454,7 +438,7 @@ int NetTrain::RunNetTrain() { std::cout << "Run SaveToFile error"; return RET_ERROR; } - + // delete session_; status = RunExportedNet(); if (status != RET_OK) { MS_LOG(ERROR) << "Run Exported model error: " << status; diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h index 49c0f4daf3..4c4990aa31 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.h +++ b/mindspore/lite/tools/benchmark_train/net_train.h @@ -116,8 +116,6 @@ class MS_API NetTrain { int ReadInputFile(); - int ReadCalibData(); - int CompareOutput(); int InitCallbackParameter(); @@ -140,78 +138,49 @@ class MS_API NetTrain { // tensorData need to be converter first template - float CompareData(const std::string &nodeName, std::vector msShape, T *msTensorData) { - auto iter = this->data_.find(nodeName); - if (iter != this->data_.end()) { - std::vector castedMSShape; - size_t shapeSize = 1; - for (int64_t dim : msShape) { - castedMSShape.push_back(size_t(dim)); - shapeSize *= dim; + float CompareData(const float *refOutput, int size, T *msTensorData) { + size_t errorCount = 0; + float meanError = 0; + std::cout << "Data of model output: "; + for (int j = 0; j < size; j++) { + if (j < 50) { + std::cout << static_cast(msTensorData[j]) << " "; } - CheckTensor *calibTensor = iter->second; - if (calibTensor->shape != castedMSShape) { - std::ostringstream oss; - oss << "Shape of mslite output("; - for (auto dim : castedMSShape) { - oss << dim << ","; - } - oss << ") and shape source model output("; - for (auto dim : calibTensor->shape) { - oss << dim << ","; - } - oss << ") are different"; - std::cerr << oss.str() << std::endl; - MS_LOG(ERROR) << oss.str().c_str(); + if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { + std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; + MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; return RET_ERROR; } - size_t errorCount = 0; - float meanError = 0; - std::cout << "Data of node " << nodeName << " : "; - for (size_t j = 0; j < shapeSize; j++) { - if (j < 50) { - std::cout << static_cast(msTensorData[j]) << " "; - } - if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { - std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; - MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; - return RET_ERROR; - } - - auto tolerance = absoluteTolerance + relativeTolerance * fabs(calibTensor->data.at(j)); - auto absoluteError = std::fabs(msTensorData[j] - calibTensor->data.at(j)); - if (absoluteError > tolerance) { - if (fabs(calibTensor->data.at(j)) == 0) { - if (absoluteError > 1e-5) { - meanError += absoluteError; - errorCount++; - } else { - continue; - } - } else { - // just assume that atol = rtol - meanError += absoluteError / (fabs(calibTensor->data.at(j)) + FLT_MIN); + auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]); + auto absoluteError = std::fabs(msTensorData[j] - refOutput[j]); + if (absoluteError > tolerance) { + if (fabs(refOutput[j]) == 0) { + if (absoluteError > 1e-5) { + meanError += absoluteError; errorCount++; + } else { + continue; } + } else { + // just assume that atol = rtol + meanError += absoluteError / (fabs(refOutput[j]) + FLT_MIN); + errorCount++; } } - std::cout << std::endl; - if (meanError > 0.0f) { - meanError /= errorCount; - } + } + std::cout << std::endl; + if (meanError > 0.0f) { + meanError /= errorCount; + } - if (meanError <= 0.0000001) { - std::cout << "Mean bias of node/tensor " << nodeName << " : 0%" << std::endl; - } else { - std::cout << "Mean bias of node/tensor " << nodeName << " : " << meanError * 100 << "%" << std::endl; - } - return meanError; + if (meanError <= 0.0000001) { + std::cout << "Mean bias of tensor: 0%" << std::endl; } else { - MS_LOG(INFO) << "%s is not in Source Model output", nodeName.c_str(); - return RET_ERROR; + std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl; } + return meanError; } int MarkPerformance(); diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 49e4156fd2..6128722572 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -144,8 +144,8 @@ int AnfTransform::AddConvertPass(const std::shared_ptr &opt int AnfTransform::AddConstFoldPass(const std::shared_ptr &optimizer, const converter::Flags *config) { auto const_fold_pm = std::make_shared("const fold fusion pass manager", false); - const_fold_pm->AddPass(std::make_shared()); if (!config->trainModel) { + const_fold_pm->AddPass(std::make_shared()); auto inne_context_ptr = std::make_shared(); inne_context_ptr->Init(); const_fold_pm->AddPass(std::make_shared(inne_context_ptr));