From c242166b34c426d90bf0751098230bc799cad3a7 Mon Sep 17 00:00:00 2001 From: yeyunpeng2020 Date: Tue, 27 Apr 2021 17:25:05 +0800 Subject: [PATCH] fix scale CheckSpecs bug --- .../src/runtime/kernel/opencl/kernel/scale.cc | 19 +++++++ .../runtime/kernel/opencl/opencl_fusion.cc | 1 - mindspore/lite/test/models_gpu_fp32.cfg | 49 +++++++++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc index b9be99f3e1..f9bcbe1d42 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc @@ -42,6 +42,25 @@ int ScaleOpenCLKernel::CheckSpecs() { param->activation_type_ != ActType_Relu6) { return RET_ERROR; } + auto *scale_param = reinterpret_cast(op_parameter_); + auto in_tensor = in_tensors_.at(0); + auto in_shape = in_tensor->shape(); + auto scale_tensor = in_tensors_.at(1); + auto scale_shape = scale_tensor->shape(); + auto axis = scale_param->axis_; + if (axis < 0) { + axis += in_shape.size(); + } + bool isBroadCast = scale_shape.size() != in_shape.size(); + if (isBroadCast) { + bool isScalar = scale_tensor->ElementsNum() == 1; + bool isScaleC = (in_shape.size() == 4 && axis == 3) || (in_shape.size() == 2 && axis == 1); + bool isScaleH = in_shape.size() == 4 && axis == 1; + if (isScalar || !(isScaleC || isScaleH)) { + MS_LOG(ERROR) << "unsupported scale axis " << axis << ", in shape " << in_shape << ", scale shape" << scale_shape; + return RET_ERROR; + } + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc index 2079b65bc4..8b67e29d60 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc @@ -135,7 +135,6 @@ std::vector RemoveDuplicationsButKeepOrder(const std::vector &vec) { void Merge(LiteKernel *a, LiteKernel *b, bool remove_a) { MS_ASSERT(a); MS_ASSERT(b); - MS_ASSERT(a->op_parameter()->infer_flag_); MS_ASSERT(b->op_parameter()->infer_flag_); if (remove_a) { // pred->tensor0->a->tensor1->b: remove a tensor1 // update pred out_kernels: a.in_kernels.out_kernels.replace(a,b) diff --git a/mindspore/lite/test/models_gpu_fp32.cfg b/mindspore/lite/test/models_gpu_fp32.cfg index 3a9b8070f0..69a45b227e 100644 --- a/mindspore/lite/test/models_gpu_fp32.cfg +++ b/mindspore/lite/test/models_gpu_fp32.cfg @@ -150,3 +150,52 @@ hiai_iMaxDN_RGB.pb hiai_iMaxSR_RGB.pb hiai_lm_inference_graph.pb hiai_PoseEstimation_Pcm.pb +hiai_model_0909_kd_rot_ps_softmax.tflite +hiai_chinese_english_recognize_model_float32.tflite +hiai_bigmodel_ghost_2_1_no_normalized_no_trans_tflite.tflite +hiai_bigmodel_ghost_5_1_no_normalized_no_trans_tflite.tflite +hiai_detectmodel_desnet_256_128_64_32.tflite +mtk_AADB_HADB_MBV3_model_fp32.tflite +Q888_face_recognition.onnx +mobilenet_v1_0.25_128.tflite +mobilenet_v1_0.5_160.tflite +mobilenet_v1_0.75_192.tflite +mobilenet_v1_1.0_160.tflite +mtk_model_ckpt.tflite +mtk_age_gender.tflite +mtk_model_face_dress.tflite +mtk_face_features_v1.tflite +mtk_isface +mtk_landmark +mtk_pose_tuku +mtk_face_recognition_v1 +mtk_2012_ATLANTA_10class_20190614_v41 +mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified +mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified +mtk_detect_mbv1_640_480_nopostprocess_simplified +densenet.tflite +resnet_v2_101_299.tflite +mnasnet_1.3_224.tflite +deeplabv3_257_mv_gpu.tflite +multi_person_mobilenet_v1_075_float.tflite +ide_label_base.tflite +ml_ei_headpose.tflite +mnist.tflite +mobilenet.tflite +scan_hms_angle1.tflite +scan_hms_detect.tflite +ml_ocr_jk.tflite +nasnet_mobile.tflite +nasnet_large.tflite +model_emotions_0727_nosoftmax.tflite +inception_resnet_v2.tflite +hiai_PoseEstimation_Pcm.tflite +hiai_ssd_mobilenetv2_object.tflite +hiai_cv_focusShootOCRModel_02.tflite +hiai_cv_poseEstimation.tflite +inception_v4.tflite +mtk_model_normalize_object_scene_ps_20200519_f16.tflite +mtk_AADB_HADB_MBV2_model_f16.tflite +mtk_AADB_HADB_MBV3_model_f16.tflite +mtk_model_emotions_0725_fp16.tflite +mtk_face_features_v1_fp16.tflite