From 3c4b5365a51c59fd4b69a1a46f511064f8a3482c Mon Sep 17 00:00:00 2001 From: zengxianglong Date: Sat, 13 Mar 2021 21:27:43 +0800 Subject: [PATCH] fix some bugs and add models to the entrance guard --- mindspore/lite/nnacl/fp32/resize_fp32.c | 14 +++++++++-- .../runtime/kernel/arm/base/reshape_base.cc | 7 ++++++ .../src/runtime/kernel/arm/base/split_base.cc | 2 ++ .../kernel/arm/fp32/arithmetic_fp32.cc | 4 ++-- mindspore/lite/test/models_caffe.cfg | 16 ++++++++++++- mindspore/lite/test/models_caffe_fp16.cfg | 17 +++++++++++++- mindspore/lite/test/models_npu.cfg | 23 +++++++++++++++++-- mindspore/lite/test/models_onnx.cfg | 4 ++++ mindspore/lite/test/models_onnx_fp16.cfg | 4 ++++ mindspore/lite/test/models_tf.cfg | 2 +- .../lite/test/models_with_multiple_inputs.cfg | 2 +- 11 files changed, 85 insertions(+), 10 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/resize_fp32.c b/mindspore/lite/nnacl/fp32/resize_fp32.c index b0a4b8ecb3..32b68fb013 100644 --- a/mindspore/lite/nnacl/fp32/resize_fp32.c +++ b/mindspore/lite/nnacl/fp32/resize_fp32.c @@ -122,6 +122,8 @@ int PrepareCropAndResizeBilinear(const int *input_shape, const float *boxes, con int batch = output_shape[0]; int new_height = output_shape[1]; int new_width = output_shape[2]; + float actual_x; + float actual_y; for (int b = 0; b < batch; b++) { const float *box = boxes + b * 4; @@ -140,11 +142,19 @@ int PrepareCropAndResizeBilinear(const int *input_shape, const float *boxes, con int *x_right = x_rights + b * new_width; float *x_left_weight = x_left_weights + b * new_width; for (int h = 0; h < new_height; h++) { - float actual_y = start_h * (in_h - 1) + h * (end_h - start_h) * (in_h - 1) / (new_height - 1); + if (new_height > 1) { + actual_y = start_h * (in_h - 1) + h * (end_h - start_h) * (in_h - 1) / (new_height - 1); + } else { + actual_y = 0.5 * (end_h + start_h) * (in_h - 1); + } CalculateCoordinate(actual_y, in_h, y_bottom + h, y_top + h, y_bottom_weight + h); } for (int w = 0; w < new_width; w++) { - float actual_x = start_w * (in_w - 1) + w * (end_w - start_w) * (in_w - 1) / (new_width - 1); + if (new_width > 1) { + actual_x = start_w * (in_w - 1) + w * (end_w - start_w) * (in_w - 1) / (new_width - 1); + } else { + actual_x = 0.5 * (end_w + start_w) * (in_w - 1); + } CalculateCoordinate(actual_x, in_w, x_left + w, x_right + w, x_left_weight + w); } } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc index a0a79c9769..d5dbd28fee 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/reshape_base.cc @@ -33,6 +33,13 @@ namespace mindspore::kernel { int ReshapeBaseCPUKernel::Init() { return ReSize(); } int ReshapeBaseCPUKernel::ReSize() { + auto out_tensor = out_tensors_.at(kOutputIndex); + bool is_next_conv = std::any_of(out_kernels_.begin(), out_kernels_.end(), [](LiteKernel *next_kernel) { + return next_kernel->Type() == schema::PrimitiveType_Conv2DFusion; + }); + if (is_next_conv && out_tensor->shape().size() == 4 && out_tensor->format() == schema::Format::Format_NCHW) { + out_tensor->set_format(schema::Format::Format_NHWC); + } int in_data_size = in_tensors_.front()->Size(); int thread_num = context_->thread_num_; cal_max_num_per_thread_ = UP_DIV(in_data_size, thread_num); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc index 9d0dd6a3cf..0dfb02a6a8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/split_base.cc @@ -46,6 +46,8 @@ int SplitBaseCPUKernel::ReSize() { MS_ASSERT(param); MS_ASSERT(input_shape.size() >= 1 && input_shape.size() <= SPLIT_STRIDES_SIZE); + auto split_dim = param->split_dim_; + param->split_dim_ = split_dim >= 0 ? split_dim : in_tensors_.front()->shape().size() + split_dim; param->strides_[input_shape.size() - 1] = 1; for (int i = input_shape.size() - 2; i >= 0; i--) { param->strides_[i] = param->strides_[i + 1] * input_shape.at(i + 1); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index acce89c95d..bc7f35aa34 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -133,7 +133,7 @@ int ArithmeticCPUKernel::ConstTensorBroadCast() { } FreeConstTileBuff(); - if (in_tensors_[0]->data_c() != nullptr && param_->in_elements_num0_ != param_->out_elements_num_) { + if (in_tensors_[0]->IsConst() && param_->in_elements_num0_ != param_->out_elements_num_) { input0_ptr_ = malloc(param_->out_elements_num_ * data_type_len_); if (input0_ptr_ == nullptr) { return RET_ERROR; @@ -144,7 +144,7 @@ int ArithmeticCPUKernel::ConstTensorBroadCast() { param_->in_elements_num0_ = param_->out_elements_num_; param_->broadcasting_ = false; } - if (in_tensors_[1]->data_c() != nullptr && param_->in_elements_num1_ != param_->out_elements_num_) { + if (in_tensors_[1]->IsConst() && param_->in_elements_num1_ != param_->out_elements_num_) { input1_ptr_ = malloc(param_->out_elements_num_ * data_type_len_); if (input1_ptr_ == nullptr) { FreeConstTileBuff(); diff --git a/mindspore/lite/test/models_caffe.cfg b/mindspore/lite/test/models_caffe.cfg index 51178bb4d1..554b045936 100644 --- a/mindspore/lite/test/models_caffe.cfg +++ b/mindspore/lite/test/models_caffe.cfg @@ -72,5 +72,19 @@ ml_video_edit_img_segment ml_video_edit_video_segment_gauss_adaptis_part1 ml_video_edit_Mnet ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145 -ml_video_edit_person_divison_video hdc_Face_Aesthetic_MTI_Aesthetic +hdc_age_medium +hdc_contour_pose_128 +hdc_emotion +hdc_fivembnet +hdc_isface +hdc_mobilenetface +#hdc_retinaface +hdc_resnet +ml_video_edit_detect +ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121 +ml_video_edit_have_imageProcessLayer_interpTo145_20201015 +ml_video_edit_MnetN367_extract_1010_pay +#ml_video_edit_person_divison_pic +ml_video_edit_reid +ml_video_edit_v10_best_model_nomean_20200723 diff --git a/mindspore/lite/test/models_caffe_fp16.cfg b/mindspore/lite/test/models_caffe_fp16.cfg index c4fa87bf7b..fac4d1908b 100644 --- a/mindspore/lite/test/models_caffe_fp16.cfg +++ b/mindspore/lite/test/models_caffe_fp16.cfg @@ -67,4 +67,19 @@ ml_location_scene_division 8 ml_tabel_recog 0.1 ml_text_division 12 ml_video_edit_Mnet 11 # Further analysis in the future -ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145 0.5 \ No newline at end of file +ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145 0.5 +hdc_age_medium 6 +hdc_contour_pose_128 0.5 +hdc_emotion 0.5 +hdc_fivembnet 0.5 +hdc_isface 0.5 +hdc_mobilenetface 7.5 +#hdc_retinaface 14 +hdc_resnet 7 +ml_video_edit_detect 2.5 +ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121 0.5 +ml_video_edit_have_imageProcessLayer_interpTo145_20201015 0.5 +ml_video_edit_MnetN367_extract_1010_pay 1 +#ml_video_edit_person_divison_pic 0.2 +ml_video_edit_reid 1 +ml_video_edit_v10_best_model_nomean_20200723 5 diff --git a/mindspore/lite/test/models_npu.cfg b/mindspore/lite/test/models_npu.cfg index 1aca6948b0..4e61103649 100644 --- a/mindspore/lite/test/models_npu.cfg +++ b/mindspore/lite/test/models_npu.cfg @@ -39,8 +39,8 @@ ml_video_edit_video_segment_gauss_adaptis_part1 2 ml_video_edit_generate_filter.pb 1 ml_video_edit_img_segment_adaptise.pb 0.5 2 ml_video_edit_video_segment_gauss_adaptis_part2.pb 3 2 -ml_video_edit_person_divison_pic 8 2 -ml_video_edit_person_divison_video 0.5 +#ml_video_edit_person_divison_pic 0.5 +#ml_video_edit_person_divison_video 13 2 ml_video_edit_imitate_filter.onnx 230 ml_video_edit_judge.onnx 5 ml_video_edit_vignet.onnx 0.5 @@ -50,3 +50,22 @@ hdc_Face_Landmark5_MTI_Aesthetic.onnx 0.5 hdc_Image_Aesthetic_MTI_Aesthetic.onnx 0.5 hdc_mobilenet_1w_class.onnx 10 hdc_resnet_1w_class.onnx 5 +#hdc_age_medium 6 +hdc_contour_pose_128 4 +hdc_emotion 0.5 +hdc_fivembnet 0.5 +hdc_isface 0.5 +hdc_mobilenetface 4 +#hdc_retinaface #too many subgraphs +hdc_resnet 3 +ml_video_edit_detect 1 +ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121 0.5 +ml_video_edit_have_imageProcessLayer_interpTo145_20201015 0.5 +ml_video_edit_MnetN367_extract_1010_pay 0.5 +ml_video_edit_reid 0.5 +ml_video_edit_v10_best_model_nomean_20200723 8 +#hdc_ocr_attention.onnx 0.5 #too many subgraphs +#hdc_ocr_detect.onnx 30 #too many subgraphs +#ml_edu_kit_hand_detection.onnx 1 +ml_edu_kit_hand_key_position.onnx 2 +#ml_video_edit_oneclick_adaptis.pb diff --git a/mindspore/lite/test/models_onnx.cfg b/mindspore/lite/test/models_onnx.cfg index 08c99ad0c3..00647cab20 100644 --- a/mindspore/lite/test/models_onnx.cfg +++ b/mindspore/lite/test/models_onnx.cfg @@ -52,3 +52,7 @@ hdc_mobilenet_1w_class.onnx hdc_resnet_1w_class.onnx ml_video_edit_imitate_filter.onnx #ml_voice_detect.onnx #Accuracy error: 4.59655%, the result is close to 1.0e-8 except for the last one +hdc_ocr_attention.onnx +hdc_ocr_detect.onnx +#ml_edu_kit_hand_detection.onnx +ml_edu_kit_hand_key_position.onnx diff --git a/mindspore/lite/test/models_onnx_fp16.cfg b/mindspore/lite/test/models_onnx_fp16.cfg index 663568366f..9072b55aa7 100644 --- a/mindspore/lite/test/models_onnx_fp16.cfg +++ b/mindspore/lite/test/models_onnx_fp16.cfg @@ -37,3 +37,7 @@ adversarial_pruning.onnx 3 residual_distill_res34_cifar10_bs_1_update.onnx 2 residual_distill_res50_cifar10_bs_1_update.onnx 2 #ml_voice_detect.onnx #out of float16 range because power op +hdc_ocr_attention.onnx 1 +hdc_ocr_detect.onnx 30 #one of the output has small values +#ml_edu_kit_hand_detection.onnx 2 +ml_edu_kit_hand_key_position.onnx 2 diff --git a/mindspore/lite/test/models_tf.cfg b/mindspore/lite/test/models_tf.cfg index 653a72249e..be21214467 100644 --- a/mindspore/lite/test/models_tf.cfg +++ b/mindspore/lite/test/models_tf.cfg @@ -58,6 +58,6 @@ mtk_model_face_dress.pb 1;1,128,128,3 mtk_model_normalize_object_scene_ps_20200519.pb 1;1,224,224,3 ml_ocr_latin.pb 1 ml_noya_tts_melgan.pb 1;16,16,80 -ml_video_edit_oneclick_adaptis.pb 3 +#ml_video_edit_oneclick_adaptis.pb 3 # Q_hand_0812.pb is not suitable for float16. Out of float16 range. Q_hand_0812.pb diff --git a/mindspore/lite/test/models_with_multiple_inputs.cfg b/mindspore/lite/test/models_with_multiple_inputs.cfg index 88fd55f55a..9dabeb87dd 100644 --- a/mindspore/lite/test/models_with_multiple_inputs.cfg +++ b/mindspore/lite/test/models_with_multiple_inputs.cfg @@ -9,5 +9,5 @@ ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 decoder.onnx;2;1,7,512:1,7 fasterrcnn_crop.pb;1;420,630,3 -ml_video_edit_person_divison_pic;2 +#ml_video_edit_person_divison_video;2 hdc_tb_cn_neg.tflite;3