| @@ -71,7 +71,7 @@ void backwardP1(const float *restrict in, const float *restrict yt, const float | |||
| void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean, | |||
| const float *restrict invar, const float *restrict scale, int size, int total_size, int ch, | |||
| const float *dxhat_sum, const float *dxhathat_sum, float *restrict dx) { | |||
| float N = (float)total_size; | |||
| const float N = (float)total_size; | |||
| for (int i = 0; i < size; i++) { | |||
| for (int c = 0; c < ch; c++) { | |||
| // dx_2 | |||
| @@ -64,7 +64,7 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, Poolin | |||
| #ifdef ENABLE_ARM | |||
| float *out_vec = out + (xw + in_w * xh) * channel + ic; | |||
| float32x4_t outr = vld1q_f32(out + (xw + in_w * xh) * channel + ic); | |||
| float32x4_t outs = vaddq_s32(outr, delta); | |||
| float32x4_t outs = vaddq_f32(outr, delta); | |||
| vst1q_f32(out_vec, outs); | |||
| #else | |||
| @@ -94,7 +94,7 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, Poolin | |||
| #ifdef ENABLE_ARM | |||
| static int32x4_t MaxIndex(float32x4_t in, float32x4_t *max, int32x4_t index, int32x4_t prev_index) { | |||
| uint32x4_t res = vcgtq_f32(in, *max); | |||
| uint32x4_t m_index = vbslq_f32(res, index, prev_index); | |||
| int32x4_t m_index = vbslq_s32(res, index, prev_index); | |||
| *max = vbslq_f32(res, in, *max); | |||
| return m_index; | |||
| } | |||
| @@ -127,7 +127,7 @@ void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_p | |||
| int kw_s = MSMAX(0, over_w); | |||
| int kw_e = MSMIN(win_w, in_w + over_w); | |||
| int ic = 0; | |||
| for (; ic < channel - 4; ic += 4) { | |||
| for (; ic < (channel & ~3); ic += 4) { | |||
| int idx = (yw + yh * output_w) * channel + ic; | |||
| #ifdef ENABLE_ARM | |||
| uint32x4_t max_idx = vdupq_n_u32(0); | |||
| @@ -170,9 +170,8 @@ void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_p | |||
| float delta = dyPtr[idx]; | |||
| for (int kh = kh_s; kh < kh_e; kh++) { | |||
| int xh = yh * stride_h + kh - pad_h; | |||
| int loop = kw_e - kw_s; | |||
| for (int kw = 0; kw < loop; kw++) { | |||
| int xw = yw * stride_w + kw + kw_s - pad_w; | |||
| for (int kw = kw_e; kw < kw_s; kw++) { | |||
| int xw = yw * stride_w + kw - pad_w; | |||
| int val_idx = (xw + in_w * xh) * channel + ic; | |||
| float val = inPtr[val_idx]; | |||
| if (val > max_val) { | |||
| @@ -21,7 +21,31 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Abs::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Abs; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Abs) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::AbsT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| @@ -20,8 +20,8 @@ | |||
| #include "src/ops/arithmetic_self.h" | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_ABS_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_ABS_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_ABS_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_ABS_H_ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -32,10 +32,11 @@ class Abs : public ArithmeticSelf { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Abs, ArithmeticSelf); | |||
| explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_ABS_H_ | |||
| #endif // MINDSPORE_LITE_SRC_OPS_ABS_H_ | |||
| @@ -22,7 +22,31 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Cos::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Cos; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Cos) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::CosT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Cos::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_COS_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_COS_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_COS_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_COS_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| @@ -30,6 +30,7 @@ class Cos : public ArithmeticSelf { | |||
| ~Cos() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| explicit Cos(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -37,4 +38,4 @@ class Cos : public ArithmeticSelf { | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_COS_H_ | |||
| #endif // MINDSPORE_LITE_SRC_OPS_COS_H_ | |||
| @@ -529,6 +529,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| const auto &op_type = prim.name(); | |||
| if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") { | |||
| return NewPrimitiveC<Activation>(prim, inputs, quantType); | |||
| } else if (op_type == "Abs") { | |||
| return NewPrimitiveC<Abs>(prim, inputs, quantType); | |||
| } else if (op_type == "AddN") { | |||
| return NewPrimitiveC<AddN>(prim, inputs, quantType); | |||
| } else if (op_type == "BatchNorm") { | |||
| @@ -539,6 +541,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Concat>(prim, inputs, quantType); | |||
| } else if (op_type == "Conv2D") { | |||
| return NewPrimitiveC<Conv2D>(prim, inputs, quantType); | |||
| } else if (op_type == "Cos") { | |||
| return NewPrimitiveC<Cos>(prim, inputs, quantType); | |||
| } else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") { | |||
| return NewPrimitiveC<DepthwiseConv2D>(prim, inputs, quantType); | |||
| } else if (op_type == "Dequant") { | |||
| @@ -559,6 +563,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Quant>(prim, inputs, quantType); | |||
| } else if (op_type == "RealDiv") { | |||
| return NewPrimitiveC<RealDiv>(prim, inputs, quantType); | |||
| } else if (op_type == "Reciprocal") { | |||
| return NewPrimitiveC<Reciprocal>(prim, inputs, quantType); | |||
| } else if (op_type == "ReduceMax") { | |||
| return NewPrimitiveC<Reduce>(prim, inputs, quantType); | |||
| } else if (op_type == "ReduceMean") { | |||
| @@ -573,6 +579,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||
| return NewPrimitiveC<Reduce>(prim, inputs, quantType); | |||
| } else if (op_type == "Reshape") { | |||
| return NewPrimitiveC<Reshape>(prim, inputs, quantType); | |||
| } else if (op_type == "Sin") { | |||
| return NewPrimitiveC<Sin>(prim, inputs, quantType); | |||
| } else if (op_type == "Slice") { | |||
| return NewPrimitiveC<Slice>(prim, inputs, quantType); | |||
| } else if (op_type == "Squeeze") { | |||
| @@ -22,7 +22,31 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Reciprocal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Reciprocal; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Reciprocal) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::ReciprocalT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| PrimitiveC *ReciprocalCreator(const schema::Primitive *primitive) { | |||
| return PrimitiveC::NewPrimitiveC<Reciprocal>(primitive); | |||
| } | |||
| @@ -14,10 +14,13 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_RECIPROCAL_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_RECIPROCAL_H_ | |||
| #include "src/ops/arithmetic_self.h" | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include <vector> | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -28,6 +31,7 @@ class Reciprocal : public ArithmeticSelf { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Reciprocal, ArithmeticSelf); | |||
| explicit Reciprocal(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -42,4 +46,4 @@ class Reciprocal : public ArithmeticSelf { | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_RECIPROCAL_H_ | |||
| #endif // MINDSPORE_LITE_SRC_OPS_RECIPROCAL_H_ | |||
| @@ -23,6 +23,29 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Sin::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||
| if (this->primitive_ == nullptr) { | |||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||
| if (this->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| this->primitive_->value.type = schema::PrimitiveType_Sin; | |||
| } | |||
| if (this->primitive_->value.type != schema::PrimitiveType_Sin) { | |||
| MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| if (this->primitive_->value.value == nullptr) { | |||
| this->primitive_->value.value = new (std::nothrow) schema::SinT(); | |||
| if (this->primitive_->value.value == nullptr) { | |||
| MS_LOG(ERROR) << "new primitiveT value failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #else | |||
| int Sin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_SIN_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_SIN_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_OPS_SIN_H_ | |||
| #define MINDSPORE_LITE_SRC_OPS_SIN_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| @@ -32,6 +32,7 @@ class Sin : public ArithmeticSelf { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Sin, ArithmeticSelf); | |||
| explicit Sin(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} | |||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| @@ -39,4 +40,4 @@ class Sin : public ArithmeticSelf { | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_SIN_H_ | |||
| #endif // MINDSPORE_LITE_SRC_OPS_SIN_H_ | |||
| @@ -24,6 +24,7 @@ using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_ExpandDims; | |||
| using mindspore::schema::PrimitiveType_Flatten; | |||
| using mindspore::schema::PrimitiveType_FlattenGrad; | |||
| using mindspore::schema::PrimitiveType_Reshape; | |||
| using mindspore::schema::PrimitiveType_Squeeze; | |||
| using mindspore::schema::PrimitiveType_Unsqueeze; | |||
| @@ -77,6 +78,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reshape, LiteKernelCreator<Re | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Reshape, LiteKernelCreator<ReshapeBaseCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Flatten, LiteKernelCreator<ReshapeBaseCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Flatten, LiteKernelCreator<ReshapeBaseCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FlattenGrad, LiteKernelCreator<ReshapeBaseCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ExpandDims, LiteKernelCreator<ReshapeBaseCPUKernel>) | |||
| @@ -43,7 +43,7 @@ int ArithmeticSelfGradCPUKernel::Init() { | |||
| self_grad_operation_ = ElementDiv; | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupport type: " << type; | |||
| MS_LOG(ERROR) << "Unsupported type: " << type; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| @@ -360,14 +360,14 @@ session::TrainSession *session::TrainSession::CreateSession(const char *model_bu | |||
| } | |||
| auto ret = session->Init(context); | |||
| if (ret != mindspore::lite::RET_OK) { | |||
| MS_LOG(ERROR) << "init sesssion failed"; | |||
| MS_LOG(ERROR) << "init session failed"; | |||
| delete session; | |||
| return nullptr; | |||
| } | |||
| ret = session->CompileTrainGraph(model); | |||
| if (ret != mindspore::lite::RET_OK) { | |||
| MS_LOG(ERROR) << "Compiling Train Graph sesssion failed"; | |||
| MS_LOG(ERROR) << "Compiling Train Graph session failed"; | |||
| delete session; | |||
| return nullptr; | |||
| } | |||
| @@ -1,7 +1,7 @@ | |||
| mini_alexnet | |||
| mobilenetv1 | |||
| mobilenetv2 | |||
| mobilenetv3 | |||
| #mobilenetv3 | |||
| lenet | |||
| effnet | |||
| effnet_tune | |||
| @@ -71,6 +71,7 @@ function Run_Converter() { | |||
| # Run on x86 platform: | |||
| function Run_x86() { | |||
| # Run mindspore converted train models: | |||
| fail=0 | |||
| while read line; do | |||
| model_name=${line} | |||
| if [[ $model_name == \#* ]]; then | |||
| @@ -80,21 +81,23 @@ function Run_x86() { | |||
| echo ${model_name}'_train' >> "${run_x86_log_file}" | |||
| echo 'cd '${x86_path}'/mindspore-lite-'${version}'-train-linux-x64' >> "${run_x86_log_file}" | |||
| cd ${x86_path}/mindspore-lite-${version}-train-linux-x64 || return 1 | |||
| echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark_train/benchmark_train --epochs='${epoch_num}' --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin' --expectedDataFile='${train_io_path}'/'${model_name}'_outputs.bin --exportFile='${ms_models_path}'/'${model_name}'_train_exported.ms' >> "${run_x86_log_file}" | |||
| echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark_train/benchmark_train --epochs='${epoch_num}' --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin' --expectedDataFile='${train_io_path}'/'${model_name}'_output --exportFile='${ms_models_path}'/'${model_name}'_train_exported.ms' >> "${run_x86_log_file}" | |||
| echo '-------------------------------------------------------------------------------' >> "${run_x86_log_file}" | |||
| LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib \ | |||
| ${run_valgrind}./benchmark_train/benchmark_train \ | |||
| --modelFile=${ms_models_path}/${model_name}_train.ms \ | |||
| --inDataFile=${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin \ | |||
| --expectedDataFile=${train_io_path}/${model_name}_outputs.bin \ | |||
| --expectedDataFile=${train_io_path}/${model_name}_output \ | |||
| --exportFile=${ms_models_path}/${model_name}_train_exported.ms >> "${run_x86_log_file}" \ | |||
| --epochs=${epoch_num} --numThreads=${threads} | |||
| if [ $? = 0 ]; then | |||
| run_result='x86: '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file} | |||
| else | |||
| run_result='x86: '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file} | |||
| fail=1 | |||
| fi | |||
| done < ${models_mindspore_train_config} | |||
| return ${fail} | |||
| } | |||
| # Run on arm platform: | |||
| @@ -157,7 +160,7 @@ function Run_arm() { | |||
| echo 'chmod 777 benchmark_train' >> ${adb_cmd_file} | |||
| adb -s ${device_id} shell < ${adb_cmd_file} | |||
| fail=0 | |||
| # Run mindir converted train models: | |||
| while read line; do | |||
| model_name=${line} | |||
| @@ -167,7 +170,7 @@ function Run_arm() { | |||
| # run benchmark_train test without clib data | |||
| echo ${model_name}'_train' >> "${run_arm_log_file}" | |||
| adb -s ${device_id} push ${train_io_path}/${model_name}_input*.bin ${train_io_path}/${model_name}_outputs.bin /data/local/tmp/benchmark_train_test >> ${adb_push_log_file} | |||
| adb -s ${device_id} push ${train_io_path}/${model_name}_input*.bin ${train_io_path}/${model_name}_output*.bin /data/local/tmp/benchmark_train_test >> ${adb_push_log_file} | |||
| echo 'cd /data/local/tmp/benchmark_train_test' > ${adb_cmd_run_file} | |||
| echo 'chmod 777 benchmark_train' >> ${adb_cmd_run_file} | |||
| if [ "$1" == arm64 ]; then | |||
| @@ -182,7 +185,7 @@ function Run_arm() { | |||
| --epochs=${epoch_num} \ | |||
| --modelFile=${model_name}_train.ms \ | |||
| --inDataFile=${tmp_dir}/${model_name}_input1.bin,${tmp_dir}/${model_name}_input2.bin \ | |||
| --expectedDataFile=${tmp_dir}/${model_name}_outputs.bin \ | |||
| --expectedDataFile=${tmp_dir}/${model_name}_output \ | |||
| --exportFile=${tmp_dir}/${model_name}_train_exported.ms \ | |||
| --numThreads=${threads} | |||
| ENDM | |||
| @@ -195,8 +198,11 @@ ENDM | |||
| run_result=$1': '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file} | |||
| else | |||
| run_result=$1': '${model_name}'_train failed'; echo ${run_result} >> ${run_benchmark_train_result_file}; | |||
| fail=1 | |||
| fi | |||
| done < ${models_mindspore_train_config} | |||
| return ${fail} | |||
| } | |||
| # Print start msg before run testcase | |||
| @@ -59,6 +59,41 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | |||
| } | |||
| } | |||
| void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||
| bool hasDepend = false; | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.clear(); | |||
| inputs.emplace_back(cnode->input(0)); | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| AnfNodePtr inputNode = cnode->input(i); | |||
| if (!inputNode->isa<CNode>()) { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| continue; | |||
| } | |||
| auto dependNode = utils::cast<CNodePtr>(inputNode); | |||
| if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || | |||
| IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { | |||
| hasDepend = true; | |||
| bool maskOut = (dependNode->inputs().size() == 3); | |||
| for (size_t j = 1; j < dependNode->inputs().size(); ++j) { | |||
| AnfNodePtr dependInputNode = dependNode->input(j); | |||
| if (dependInputNode->isa<CNode>()) { | |||
| inputs.emplace_back(dependInputNode); | |||
| if (maskOut) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| } | |||
| } | |||
| if (hasDepend) { | |||
| cnode->set_inputs(inputs); | |||
| } | |||
| } | |||
| int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<PrimitiveC> &primitive, | |||
| const std::unique_ptr<schema::CNodeT> &dst_node) { | |||
| @@ -251,8 +286,17 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc | |||
| break; | |||
| } | |||
| } | |||
| RemoveIfMakeTuple(cnode); | |||
| #ifdef SUPPORT_TRAIN | |||
| RemoveIfDepend(cnode); | |||
| #endif | |||
| if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) || | |||
| #ifdef SUPPORT_TRAIN | |||
| (primitive_c->Type() == schema::PrimitiveType_Depend) || | |||
| (primitive_c->Type() == schema::PrimitiveType_ControlDepend) || | |||
| #endif | |||
| (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { | |||
| continue; | |||
| } | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| #ifndef MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| #include <map> | |||
| #include <string> | |||
| @@ -41,6 +41,7 @@ class AnfExporter { | |||
| int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *fb_node); | |||
| static void RemoveIfMakeTuple(const CNodePtr &cnode); | |||
| static void RemoveIfDepend(const CNodePtr &cnode); | |||
| protected: | |||
| int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode); | |||
| @@ -97,4 +98,4 @@ class AnfExporter { | |||
| // and clear. | |||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); | |||
| } // namespace mindspore::lite | |||
| #endif // MINDSPORE_LITE_TOOLS_COMMON_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| #endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_ | |||
| @@ -32,6 +32,42 @@ static const char *DELIM_COLON = ":"; | |||
| static const char *DELIM_COMMA = ","; | |||
| static const char *DELIM_SLASH = "/"; | |||
| namespace { | |||
| float *ReadFileBuf(const char *file, size_t *size) { | |||
| if (file == nullptr) { | |||
| MS_LOG(ERROR) << "file is nullptr"; | |||
| return nullptr; | |||
| } | |||
| MS_ASSERT(size != nullptr); | |||
| std::string real_path = RealPath(file); | |||
| std::ifstream ifs(real_path); | |||
| if (!ifs.good()) { | |||
| MS_LOG(ERROR) << "file: " << real_path << " is not exist"; | |||
| return nullptr; | |||
| } | |||
| if (!ifs.is_open()) { | |||
| MS_LOG(ERROR) << "file: " << real_path << " open failed"; | |||
| return nullptr; | |||
| } | |||
| ifs.seekg(0, std::ios::end); | |||
| *size = ifs.tellg(); | |||
| std::unique_ptr<float[]> buf((new (std::nothrow) float[*size / sizeof(float) + 1])); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "malloc buf failed, file: " << real_path; | |||
| ifs.close(); | |||
| return nullptr; | |||
| } | |||
| ifs.seekg(0, std::ios::beg); | |||
| ifs.read(reinterpret_cast<char *>(buf.get()), *size); | |||
| ifs.close(); | |||
| return buf.release(); | |||
| } | |||
| } // namespace | |||
| int NetTrain::GenerateRandomData(size_t size, void *data) { | |||
| MS_ASSERT(data != nullptr); | |||
| char *casted_data = static_cast<char *>(data); | |||
| @@ -113,82 +149,34 @@ int NetTrain::ReadInputFile() { | |||
| return RET_OK; | |||
| } | |||
| // calibData is FP32 | |||
| int NetTrain::ReadCalibData() { | |||
| const char *calib_data_path = flags_->data_file_.c_str(); | |||
| // read calib data | |||
| std::ifstream in_file(calib_data_path); | |||
| if (!in_file.good()) { | |||
| std::cerr << "file: " << calib_data_path << " is not exist" << std::endl; | |||
| MS_LOG(ERROR) << "file: " << calib_data_path << " is not exist"; | |||
| return RET_ERROR; | |||
| } | |||
| if (!in_file.is_open()) { | |||
| std::cerr << "file: " << calib_data_path << " open failed" << std::endl; | |||
| MS_LOG(ERROR) << "file: " << calib_data_path << " open failed"; | |||
| in_file.close(); | |||
| return RET_ERROR; | |||
| } | |||
| std::string line; | |||
| MS_LOG(INFO) << "Start reading calibData file"; | |||
| std::string tensor_name; | |||
| while (!in_file.eof()) { | |||
| getline(in_file, line); | |||
| std::stringstream string_line1(line); | |||
| size_t dim = 0; | |||
| string_line1 >> tensor_name >> dim; | |||
| std::vector<size_t> dims; | |||
| size_t shape_size = 1; | |||
| for (size_t i = 0; i < dim; i++) { | |||
| size_t tmp_dim; | |||
| string_line1 >> tmp_dim; | |||
| dims.push_back(tmp_dim); | |||
| shape_size *= tmp_dim; | |||
| } | |||
| getline(in_file, line); | |||
| std::stringstream string_line2(line); | |||
| std::vector<float> tensor_data; | |||
| for (size_t i = 0; i < shape_size; i++) { | |||
| float tmp_data; | |||
| string_line2 >> tmp_data; | |||
| tensor_data.push_back(tmp_data); | |||
| } | |||
| auto *check_tensor = new CheckTensor(dims, tensor_data); | |||
| this->data_.insert(std::make_pair(tensor_name, check_tensor)); | |||
| } | |||
| in_file.close(); | |||
| MS_LOG(INFO) << "Finish reading calibData file"; | |||
| return RET_OK; | |||
| } | |||
| int NetTrain::CompareOutput() { | |||
| std::cout << "================ Comparing Output data ================" << std::endl; | |||
| float total_bias = 0; | |||
| int total_size = 0; | |||
| bool has_error = false; | |||
| for (const auto &calib_tensor : data_) { | |||
| std::string node_or_tensor_name = calib_tensor.first; | |||
| auto tensors = session_->GetOutputsByNodeName(node_or_tensor_name); | |||
| mindspore::tensor::MSTensor *tensor = nullptr; | |||
| if (tensors.empty() || tensors.size() != 1) { | |||
| MS_LOG(INFO) << "Cannot find output node: " << node_or_tensor_name | |||
| << " or node has more than one output tensor, switch to GetOutputByTensorName"; | |||
| tensor = session_->GetOutputByTensorName(node_or_tensor_name); | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "Cannot find output tensor " << node_or_tensor_name << ", get model output failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| tensor = tensors.front(); | |||
| } | |||
| MS_ASSERT(tensor->MutableData() != nullptr); | |||
| auto tensors_list = session_->GetOutputs(); | |||
| if (tensors_list.empty()) { | |||
| MS_LOG(ERROR) << "Cannot find output tensors, get model output failed"; | |||
| return RET_ERROR; | |||
| } | |||
| mindspore::tensor::MSTensor *tensor = nullptr; | |||
| int i = 1; | |||
| for (auto it = tensors_list.begin(); it != tensors_list.end(); ++it) { | |||
| tensor = session_->GetOutputByTensorName(it->first); | |||
| auto outputs = tensor->MutableData(); | |||
| float bias = CompareData<float>(node_or_tensor_name, tensor->shape(), reinterpret_cast<float *>(outputs)); | |||
| size_t size; | |||
| std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin"; | |||
| auto *bin_buf = ReadFileBuf(output_file.c_str(), &size); | |||
| if (bin_buf == nullptr) { | |||
| MS_LOG(ERROR) << "ReadFile return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (size != tensor->Size()) { | |||
| MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor->Size() | |||
| << ", read size: " << size; | |||
| return RET_ERROR; | |||
| } | |||
| float bias = CompareData<float>(bin_buf, tensor->ElementsNum(), reinterpret_cast<float *>(outputs)); | |||
| if (bias >= 0) { | |||
| total_bias += bias; | |||
| total_size++; | |||
| @@ -196,6 +184,8 @@ int NetTrain::CompareOutput() { | |||
| has_error = true; | |||
| break; | |||
| } | |||
| i++; | |||
| delete bin_buf; | |||
| } | |||
| if (!has_error) { | |||
| @@ -206,7 +196,8 @@ int NetTrain::CompareOutput() { | |||
| mean_bias = 0; | |||
| } | |||
| std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" << std::endl; | |||
| std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" | |||
| << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl; | |||
| std::cout << "=======================================================" << std::endl << std::endl; | |||
| if (mean_bias > this->flags_->accuracy_threshold_) { | |||
| @@ -297,13 +288,6 @@ int NetTrain::MarkAccuracy() { | |||
| return status; | |||
| } | |||
| status = ReadCalibData(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Read calib data error " << status; | |||
| std::cerr << "Read calib data error " << status << std::endl; | |||
| return status; | |||
| } | |||
| status = CompareOutput(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Compare output error " << status; | |||
| @@ -454,7 +438,7 @@ int NetTrain::RunNetTrain() { | |||
| std::cout << "Run SaveToFile error"; | |||
| return RET_ERROR; | |||
| } | |||
| // delete session_; | |||
| status = RunExportedNet(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Run Exported model error: " << status; | |||
| @@ -116,8 +116,6 @@ class MS_API NetTrain { | |||
| int ReadInputFile(); | |||
| int ReadCalibData(); | |||
| int CompareOutput(); | |||
| int InitCallbackParameter(); | |||
| @@ -140,78 +138,49 @@ class MS_API NetTrain { | |||
| // tensorData need to be converter first | |||
| template <typename T> | |||
| float CompareData(const std::string &nodeName, std::vector<int> msShape, T *msTensorData) { | |||
| auto iter = this->data_.find(nodeName); | |||
| if (iter != this->data_.end()) { | |||
| std::vector<size_t> castedMSShape; | |||
| size_t shapeSize = 1; | |||
| for (int64_t dim : msShape) { | |||
| castedMSShape.push_back(size_t(dim)); | |||
| shapeSize *= dim; | |||
| float CompareData(const float *refOutput, int size, T *msTensorData) { | |||
| size_t errorCount = 0; | |||
| float meanError = 0; | |||
| std::cout << "Data of model output: "; | |||
| for (int j = 0; j < size; j++) { | |||
| if (j < 50) { | |||
| std::cout << static_cast<float>(msTensorData[j]) << " "; | |||
| } | |||
| CheckTensor *calibTensor = iter->second; | |||
| if (calibTensor->shape != castedMSShape) { | |||
| std::ostringstream oss; | |||
| oss << "Shape of mslite output("; | |||
| for (auto dim : castedMSShape) { | |||
| oss << dim << ","; | |||
| } | |||
| oss << ") and shape source model output("; | |||
| for (auto dim : calibTensor->shape) { | |||
| oss << dim << ","; | |||
| } | |||
| oss << ") are different"; | |||
| std::cerr << oss.str() << std::endl; | |||
| MS_LOG(ERROR) << oss.str().c_str(); | |||
| if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { | |||
| std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; | |||
| MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; | |||
| return RET_ERROR; | |||
| } | |||
| size_t errorCount = 0; | |||
| float meanError = 0; | |||
| std::cout << "Data of node " << nodeName << " : "; | |||
| for (size_t j = 0; j < shapeSize; j++) { | |||
| if (j < 50) { | |||
| std::cout << static_cast<float>(msTensorData[j]) << " "; | |||
| } | |||
| if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { | |||
| std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; | |||
| MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; | |||
| return RET_ERROR; | |||
| } | |||
| auto tolerance = absoluteTolerance + relativeTolerance * fabs(calibTensor->data.at(j)); | |||
| auto absoluteError = std::fabs(msTensorData[j] - calibTensor->data.at(j)); | |||
| if (absoluteError > tolerance) { | |||
| if (fabs(calibTensor->data.at(j)) == 0) { | |||
| if (absoluteError > 1e-5) { | |||
| meanError += absoluteError; | |||
| errorCount++; | |||
| } else { | |||
| continue; | |||
| } | |||
| } else { | |||
| // just assume that atol = rtol | |||
| meanError += absoluteError / (fabs(calibTensor->data.at(j)) + FLT_MIN); | |||
| auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]); | |||
| auto absoluteError = std::fabs(msTensorData[j] - refOutput[j]); | |||
| if (absoluteError > tolerance) { | |||
| if (fabs(refOutput[j]) == 0) { | |||
| if (absoluteError > 1e-5) { | |||
| meanError += absoluteError; | |||
| errorCount++; | |||
| } else { | |||
| continue; | |||
| } | |||
| } else { | |||
| // just assume that atol = rtol | |||
| meanError += absoluteError / (fabs(refOutput[j]) + FLT_MIN); | |||
| errorCount++; | |||
| } | |||
| } | |||
| std::cout << std::endl; | |||
| if (meanError > 0.0f) { | |||
| meanError /= errorCount; | |||
| } | |||
| } | |||
| std::cout << std::endl; | |||
| if (meanError > 0.0f) { | |||
| meanError /= errorCount; | |||
| } | |||
| if (meanError <= 0.0000001) { | |||
| std::cout << "Mean bias of node/tensor " << nodeName << " : 0%" << std::endl; | |||
| } else { | |||
| std::cout << "Mean bias of node/tensor " << nodeName << " : " << meanError * 100 << "%" << std::endl; | |||
| } | |||
| return meanError; | |||
| if (meanError <= 0.0000001) { | |||
| std::cout << "Mean bias of tensor: 0%" << std::endl; | |||
| } else { | |||
| MS_LOG(INFO) << "%s is not in Source Model output", nodeName.c_str(); | |||
| return RET_ERROR; | |||
| std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl; | |||
| } | |||
| return meanError; | |||
| } | |||
| int MarkPerformance(); | |||
| @@ -144,8 +144,8 @@ int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &opt | |||
| int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, | |||
| const converter::Flags *config) { | |||
| auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); | |||
| const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>()); | |||
| if (!config->trainModel) { | |||
| const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>()); | |||
| auto inne_context_ptr = std::make_shared<lite::InnerContext>(); | |||
| inne_context_ptr->Init(); | |||
| const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr)); | |||