diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_reserve_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_reserve_infer.c index 3a7381f3ac..49efd7fd83 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_reserve_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_reserve_infer.c @@ -16,6 +16,7 @@ #include "nnacl/infer/tensorlist_reserve_infer.h" #include "nnacl/infer/infer_register.h" +#include "nnacl/tensorlist_parameter.h" int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { @@ -26,6 +27,7 @@ int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size } #endif + TensorListParameter *reserve_param = (TensorListParameter *)parameter; const TensorC *input0 = inputs[0]; int ele_shape_type = input0->data_type_; if (ele_shape_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) { @@ -35,6 +37,7 @@ int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size TensorListC *output = (TensorListC *)(outputs[0]); output->data_type_ = kObjectTypeTensorType; output->format_ = Format_NHWC; + output->tensors_data_type_ = reserve_param->element_dtype_; if (input0->data_ == NULL) { return NNACL_INFER_INVALID; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c index 13f1ec740e..270c0085dc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c @@ -39,6 +39,9 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor TensorC *output = outputs[0]; SetDataTypeFormat(output, input); + if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) { + output->data_type_ = kNumberTypeFloat32; + } if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index 25c4cad323..508bcaefdf 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -179,31 +179,35 @@ bool KernelRegistry::SupportKernel(const KernelKey &key) { return kernel_creator != nullptr; } -kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector &in_tensors, - const std::vector &out_tensors, const InnerContext *ctx, - const kernel::KernelKey &key, OpParameter *parameter, - const void *primitive) { +int KernelRegistry::GetKernel(const std::vector &in_tensors, const std::vector &out_tensors, + const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *parameter, + kernel::LiteKernel **kernel, const void *primitive) { MS_ASSERT(ctx != nullptr); + MS_ASSERT(kernel != nullptr); if (key.vendor == kBuiltin) { auto creator = GetCreator(key); if (creator != nullptr) { - auto kernel = creator(in_tensors, out_tensors, parameter, ctx, key); - if (kernel != nullptr) { - kernel->set_desc(key); - return kernel; + *kernel = creator(in_tensors, out_tensors, parameter, ctx, key); + if (*kernel != nullptr) { + (*kernel)->set_desc(key); + return RET_OK; } + return RET_ERROR; } } else { auto creator = GetDelegateCreator(key); - if (creator == nullptr) { - return nullptr; + if (creator != nullptr) { + std::vector tensors_in; + Tensor2MSTensor(std::move(in_tensors), &tensors_in); + std::vector tensors_out; + Tensor2MSTensor(std::move(out_tensors), &tensors_out); + *kernel = creator(tensors_in, tensors_out, static_cast(primitive), ctx); + if (*kernel != nullptr) { + return RET_OK; + } + return RET_ERROR; } - std::vector tensors_in; - Tensor2MSTensor(std::move(in_tensors), &tensors_in); - std::vector tensors_out; - Tensor2MSTensor(std::move(out_tensors), &tensors_out); - return creator(tensors_in, tensors_out, static_cast(primitive), ctx); } - return nullptr; + return RET_NOT_SUPPORT; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/kernel_registry.h b/mindspore/lite/src/kernel_registry.h index 12ab8b55c8..1d4064bcc0 100644 --- a/mindspore/lite/src/kernel_registry.h +++ b/mindspore/lite/src/kernel_registry.h @@ -48,9 +48,9 @@ class KernelRegistry { kernel::CreateKernel creator); bool Merge(const std::unordered_map &newCreators); bool SupportKernel(const kernel::KernelKey &key); - kernel::LiteKernel *GetKernel(const std::vector &in_tensors, const std::vector &out_tensors, - const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter, - const void *primitive = nullptr); + int GetKernel(const std::vector &in_tensors, const std::vector &out_tensors, + const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *op_parameter, + kernel::LiteKernel **kernel, const void *primitive = nullptr); protected: static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1}; diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index ab061e1fa1..bb0883791e 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -151,6 +151,12 @@ lite::Tensor *LiteSession::ConvertTensor(const schema::Tensor &src_tensor) { lite::Tensor *dst_tensor = nullptr; if (TypeId(src_tensor.dataType()) == kObjectTypeTensorType) { dst_tensor = new (std::nothrow) TensorList(shape, std::vector(), src_category); + // set tensor list datatype + auto tensor_list = reinterpret_cast(dst_tensor); + if (src_tensor.data() != nullptr) { + auto tensor_data_type = TypeId(reinterpret_cast(src_tensor.data()->data())[0]); + tensor_list->set_tensors_data_type(tensor_data_type); + } } else { dst_tensor = new (std::nothrow) Tensor(TypeId(src_tensor.dataType()), shape, src_tensor.format(), src_category); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index 78eeedde44..d183b7a287 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -425,6 +425,7 @@ int ArithmeticCPUKernel::Run() { REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MulFusion, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_MulFusion, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_AddFusion, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_AddFusion, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_AddFusion, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SubFusion, LiteKernelCreator) diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index eb621edb94..3bf0abdb3d 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -340,26 +340,26 @@ inline void RestoreTensorData(std::map *restored_origin_tens } } // namespace -kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector &in_tensors, - const std::vector &out_tensors, OpParameter *op_parameter, - const kernel::KernelKey &desc, TypeId kernel_data_type) { +int Scheduler::FindCpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, + OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type, + kernel::LiteKernel **kernel) { MS_ASSERT(op_parameter != nullptr); auto op_type = op_parameter->type_; if (!KernelRegistry::GetInstance()->SupportKernel(desc)) { - return nullptr; + return RET_NOT_SUPPORT; } kernel::KernelKey cpu_desc = desc; if (kernel_data_type == kNumberTypeFloat16) { if (!context_->IsCpuFloat16Enabled() || (cpu_desc.data_type != kNumberTypeFloat32 && cpu_desc.data_type != kNumberTypeFloat16)) { - return nullptr; + return RET_NOT_SUPPORT; } cpu_desc.data_type = kNumberTypeFloat16; } auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type); if (ret != RET_OK) { MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret; - return nullptr; + return RET_NOT_SUPPORT; } std::map restored_origin_tensors; @@ -367,28 +367,27 @@ kernel::LiteKernel *Scheduler::FindCpuKernel(const std::vector &in_ten ret = CastConstTensorsData(in_tensors, &restored_origin_tensors, kernel_data_type); if (ret != RET_OK) { MS_LOG(DEBUG) << "CastConstTensorsData failed: " << ret; - return nullptr; + return RET_NOT_SUPPORT; } // we don't need to restore tensor for copy data ret = CopyConstTensorData(in_tensors, op_type); if (ret != RET_OK) { MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret; - return nullptr; + return RET_NOT_SUPPORT; } } - auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter); - if (kernel != nullptr) { + ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, cpu_desc, op_parameter, kernel); + if (ret == RET_OK) { MS_LOG(DEBUG) << "Get TypeId(" << kernel_data_type << ") op success: " << PrimitiveCurVersionTypeName(op_type); FreeRestoreTensors(&restored_origin_tensors); } else { RestoreTensorData(&restored_origin_tensors); } - return kernel; + return ret; } // namespace mindspore::lite -kernel::LiteKernel *Scheduler::FindGpuKernel(const std::vector &in_tensors, - const std::vector &out_tensors, OpParameter *op_parameter, - const kernel::KernelKey &desc) { +int Scheduler::FindGpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, + OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel) { MS_ASSERT(op_parameter != nullptr); if (context_->IsGpuEnabled()) { @@ -402,30 +401,27 @@ kernel::LiteKernel *Scheduler::FindGpuKernel(const std::vector &in_ten auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32); if (ret != RET_OK) { MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret; - return nullptr; + return RET_NOT_SUPPORT; } - // we don't need to restore tensor for copy data ret = CopyConstTensorData(in_tensors, op_parameter->type_); if (ret != RET_OK) { MS_LOG(DEBUG) << "CopyConstTensorsData failed: " << ret; - return nullptr; + return RET_NOT_SUPPORT; } - auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, gpu_desc, op_parameter); - if (kernel != nullptr) { + ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, gpu_desc, op_parameter, kernel); + if (ret == RET_OK) { MS_LOG(DEBUG) << "Get gpu op success: " << PrimitiveCurVersionTypeName(gpu_desc.type); } else { MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(gpu_desc.type); } - return kernel; - } else { - return nullptr; + return ret; } + return RET_NOT_SUPPORT; } -kernel::LiteKernel *Scheduler::FindNpuKernel(const std::vector &in_tensors, - const std::vector &out_tensors, OpParameter *op_parameter, - const kernel::KernelKey &desc) { +int Scheduler::FindNpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, + OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel) { MS_ASSERT(op_parameter != nullptr); kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; if (context_->IsNpuEnabled()) { @@ -435,23 +431,22 @@ kernel::LiteKernel *Scheduler::FindNpuKernel(const std::vector &in_ten auto ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32); if (ret != RET_OK) { MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret; - return nullptr; + return RET_NOT_SUPPORT; } for (auto tensor : in_tensors) { if (tensor->data_type() == kNumberTypeFloat16) { tensor->set_data_type(kNumberTypeFloat32); } } - auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter); - if (kernel != nullptr) { + ret = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, context_, npu_desc, op_parameter, kernel); + if (ret == RET_OK) { MS_LOG(DEBUG) << "Get npu op success: " << PrimitiveCurVersionTypeName(npu_desc.type); } else { MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(npu_desc.type); } - return kernel; - } else { - return nullptr; + return ret; } + return RET_NOT_SUPPORT; } kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in_tensors, @@ -469,55 +464,62 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in bool infer_shape_interrupt = !op_parameter->infer_flag_; kernel::KernelKey desc{kCPU, data_type, static_cast(op_parameter->type_)}; kernel::LiteKernel *kernel = nullptr; + int status; #ifdef SUPPORT_GPU // if (node->device_type_ == DT_GPU || node->device_type_ == DEFAULT) { - kernel = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc); - if (kernel != nullptr) { + status = FindGpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel); + if (status == RET_OK) { return kernel; } else { MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " " << node->name_; - auto ret = InferNodeShape(node, &infer_shape_interrupt); - if (ret == RET_INFER_INVALID || ret == RET_OK) { - op_parameter = op_parameters_[node->output_indices_.at(0)]; - } else { - MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_; - return nullptr; + if (status == RET_ERROR) { + auto ret = InferNodeShape(node, &infer_shape_interrupt); + if (ret == RET_INFER_INVALID || ret == RET_OK) { + op_parameter = op_parameters_[node->output_indices_.at(0)]; + } else { + MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_; + return nullptr; + } } } // } #endif #ifdef SUPPORT_NPU // if (node->device_type_ == DT_NPU || node->device_type_ == DEFAULT) { - kernel = FindNpuKernel(in_tensors, out_tensors, op_parameter, desc); - if (kernel != nullptr) { + status = FindNpuKernel(in_tensors, out_tensors, op_parameter, desc, &kernel); + if (status == RET_OK) { return kernel; } else { MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " " << node->name_; - auto ret = InferNodeShape(node, &infer_shape_interrupt); - if (ret == RET_INFER_INVALID || ret == RET_OK) { - op_parameter = op_parameters_[node->output_indices_.at(0)]; - } else { - MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_; - return nullptr; + if (status == RET_ERROR) { + auto ret = InferNodeShape(node, &infer_shape_interrupt); + if (ret == RET_INFER_INVALID || ret == RET_OK) { + op_parameter = op_parameters_[node->output_indices_.at(0)]; + } else { + MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_; + return nullptr; + } } } // } #endif if (prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) { - kernel = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16); - if (kernel != nullptr) { + status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16, &kernel); + if (status == RET_OK) { return kernel; } else { MS_LOG(DEBUG) << "Get fp16 op failed, scheduler to cpu: " << PrimitiveCurVersionTypeName(desc.type) << " " << node->name_; - auto ret = InferNodeShape(node, &infer_shape_interrupt); - if (ret == RET_INFER_INVALID || ret == RET_OK) { - op_parameter = op_parameters_[node->output_indices_.at(0)]; - } else { - MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_; - return nullptr; + if (status == RET_ERROR) { + auto ret = InferNodeShape(node, &infer_shape_interrupt); + if (ret == RET_INFER_INVALID || ret == RET_OK) { + op_parameter = op_parameters_[node->output_indices_.at(0)]; + } else { + MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_; + return nullptr; + } } } } @@ -526,10 +528,10 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in desc.data_type = kNumberTypeFloat32; } if (prefer_data_type == kNumberTypeFloat32 || prefer_data_type == kTypeUnknown) { - kernel = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32); - if (kernel != nullptr) { + status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat32, &kernel); + if (status == RET_OK) { return kernel; - } else { + } else if (status == RET_ERROR) { auto ret = InferNodeShape(node, &infer_shape_interrupt); if (!(ret == RET_INFER_INVALID || ret == RET_OK)) { MS_LOG(ERROR) << "Try repeat infer fail: " << node->name_; diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index 1d69d4519e..67fe1c6159 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -60,12 +60,13 @@ class Scheduler { kernel::LiteKernel *FindBackendKernel(const std::vector &in_tensors, const std::vector &out_tensors, const Model::Node *node, TypeId prefer_data_type = kTypeUnknown); - kernel::LiteKernel *FindCpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, - OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type); - kernel::LiteKernel *FindGpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, - OpParameter *op_parameter, const kernel::KernelKey &desc); - kernel::LiteKernel *FindNpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, - OpParameter *op_parameter, const kernel::KernelKey &desc); + int FindCpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, + OpParameter *op_parameter, const kernel::KernelKey &desc, TypeId kernel_data_type, + kernel::LiteKernel **kernel); + int FindGpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, + OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel); + int FindNpuKernel(const std::vector &in_tensors, const std::vector &out_tensors, + OpParameter *op_parameter, const kernel::KernelKey &desc, kernel::LiteKernel **kernel); // schedule a partial node to a subgraph_kernel kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node); // schedule a node to a kernel diff --git a/mindspore/lite/test/models_caffe_fp16.cfg b/mindspore/lite/test/models_caffe_fp16.cfg index cce066f536..f4a98f3fb5 100644 --- a/mindspore/lite/test/models_caffe_fp16.cfg +++ b/mindspore/lite/test/models_caffe_fp16.cfg @@ -95,7 +95,7 @@ ml_video_edit_img_segment 3 ml_video_edit_video_segment_gauss_adaptis_part1 5 # When the input range is [-1,1], the precision is poor, and the output value is very small (10e-5). If the input range is adjusted to [0,255], the precision will decrease to 15.5415%, and the rest is cumulative error. ml_handpose 175 -hdc_Face_Aesthetic_MTI_Aesthetic 22 +hdc_Face_Aesthetic_MTI_Aesthetic 0.5 ml_face_compare 5.5 ml_face_tracking 2.5 ml_face_beard 0.5 diff --git a/mindspore/lite/test/models_onnx.cfg b/mindspore/lite/test/models_onnx.cfg index 4c5f0e313b..91b2316350 100644 --- a/mindspore/lite/test/models_onnx.cfg +++ b/mindspore/lite/test/models_onnx.cfg @@ -73,7 +73,7 @@ mtk_face_recognition_v3.onnx mtk_face_recognition_v2.onnx ml_2012_ocr_detection_tmp.onnx ml_video_edit_enhance_update_tmp.onnx -#Harmony_Voiceprint_resnet18.onnx +Harmony_Voiceprint_resnet18.onnx;1,150,40,1 bloom_hongmo_detection_tmp.onnx Q_face_recognition.onnx Q888_face_recognition.onnx diff --git a/mindspore/lite/test/models_onnx_fp16.cfg b/mindspore/lite/test/models_onnx_fp16.cfg index 12ca98d70b..d4da1babc8 100644 --- a/mindspore/lite/test/models_onnx_fp16.cfg +++ b/mindspore/lite/test/models_onnx_fp16.cfg @@ -79,7 +79,7 @@ mtk_face_features_v2.onnx;1,256,192,3 0.5 mtk_face_recognition_v3.onnx 0.5 mtk_face_recognition_v2.onnx 2.5 ml_2012_ocr_detection_tmp.onnx 0.5 -#Harmony_Voiceprint_resnet18.onnx;1,1,200,40 4.5 +Harmony_Voiceprint_resnet18.onnx;1,150,40,1 4.5 bloom_hongmo_detection_tmp.onnx 0.5 Q_face_recognition.onnx 2 ml_video_edit_enhance_update_tmp.onnx 0.5 diff --git a/mindspore/lite/test/models_tflite.cfg b/mindspore/lite/test/models_tflite.cfg index 1fb9767391..015d38e1ed 100644 --- a/mindspore/lite/test/models_tflite.cfg +++ b/mindspore/lite/test/models_tflite.cfg @@ -36,7 +36,6 @@ mnasnet_1.3_224.tflite inception_v3.tflite deeplabv3_257_mv_gpu.tflite multi_person_mobilenet_v1_075_float.tflite -#hiai_vad.tflite ide_label_base.tflite ide_label_retrained.tflite ml_ei_headpose.tflite @@ -164,8 +163,6 @@ hiai_detectmodel_desnet_256_128_64_32.tflite lite-model_aiy_vision_classifier_food_V1_1.tflite lite-model_disease-classification_1.tflite lite-model_models_mushroom-identification_v1_1.tflite -#lite-model_albert_lite_base_squadv1_metadata_1.tflite -#lite-model_mobilebert_1_metadata_1.tflite smartreply_1_default_1.tflite text_classification.tflite Q_detect_fpn_add_inception-1448650.tflite @@ -183,3 +180,8 @@ Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite Q888_face_emo_dress_mv3_orderd.tflite Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite Q_iMaxSR_RGB_385_p_pb2tflite.tflite +bloom_new_detect.tflite +bloom_model_age_gender.tflite +bloom_isface.tflite +hiai_object_detect_814.tflite +hiai_object_tflite_graph_8bit.tflite diff --git a/mindspore/lite/test/models_tflite_fp16.cfg b/mindspore/lite/test/models_tflite_fp16.cfg index b06d681105..2ce4e15a5d 100644 --- a/mindspore/lite/test/models_tflite_fp16.cfg +++ b/mindspore/lite/test/models_tflite_fp16.cfg @@ -209,3 +209,9 @@ Q888_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 2 Q888_face_emo_dress_mv3_orderd.tflite 2.5 Q_iMaxDN_RGB_385_p_RGB_RGB_pb2tflite.tflite 1 Q_iMaxSR_RGB_385_p_pb2tflite.tflite 5 +bloom_new_detect.tflite 3.5 +bloom_model_age_gender.tflite 0.5 +bloom_isface.tflite 0.5 +# The output values of conv layers range from -e±5 to e±5, which almost reaches the representation limit of fp16. In +# this range, the fp16 data will has big bias. And the accumulation of this bias lowers the final precision. +hiai_object_detect_814.tflite 14 diff --git a/mindspore/lite/test/models_with_multiple_inputs.cfg b/mindspore/lite/test/models_with_multiple_inputs.cfg index 58ce0c934b..865f1e23a9 100644 --- a/mindspore/lite/test/models_with_multiple_inputs.cfg +++ b/mindspore/lite/test/models_with_multiple_inputs.cfg @@ -8,6 +8,7 @@ ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2 ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 decoder.onnx;2;1,7,512:1,7 +#fasterrcnn_crop.pb is the same model as gts_object_detect_Ics.pb. fasterrcnn_crop.pb;1;420,630,3 ml_video_edit_person_divison_video;2 hdc_tb_cn_neg.tflite;3 @@ -31,4 +32,11 @@ add_uint8.tflite;2 ml_Heatmap_depth_240180;2 ml_Heatmap_depth_180240;2 hiai_nlu_model.pb;3;1,16:1,16:1,16 -gts_object_detect_lcs.pb;1;420,630,3 \ No newline at end of file +#calib data file in server incorrect +#gts_object_detect_Ics.pb;1;420,630,3 +ml_headpose_pb2tflite.tflite;3;16:1,64,64,3:16 +ml_ei_headpose_pb2tflite.tflite;3;16:1,64,64,3:16 +hiai_transformer_encoder.pb;15 +lite-model_albert_lite_base_squadv1_metadata_1.tflite;3 +lite-model_mobilebert_1_metadata_1.tflite;3 +hiai_vad.tflite;2 diff --git a/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg b/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg index 71e2ab3036..6c6333a13e 100644 --- a/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg +++ b/mindspore/lite/test/models_with_multiple_inputs_fp16.cfg @@ -26,3 +26,6 @@ ml_tts_vocoder.pb;66 53 # The outputs of two Heatmap_depth models have small value ml_Heatmap_depth_240180;2 10 16 ml_Heatmap_depth_180240;2 7 7 +ml_headpose_pb2tflite.tflite;3;16:1,64,64,3:16 1 +ml_ei_headpose_pb2tflite.tflite;3;16:1,64,64,3:16 0.5 +hiai_transformer_encoder.pb;15 4 diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc index e3477fe2ad..0ef9d85a1e 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc @@ -63,7 +63,7 @@ void ConvertTensorList(MetaGraphT *graph, uint32_t index, bool *convert_succ, st if (!tensorT->data.empty()) { int *data = reinterpret_cast(tensorT->data.data()); type = TypeId(data[0]); - if (tensorT->data.size() < 8 || (data[1] + 2) * 4 != static_cast(tensorT->data.size())) { + if (tensorT->data.size() < 8 || (data[1] != 0 && (data[1] + 2) * 4 != static_cast(tensorT->data.size()))) { MS_LOG(ERROR) << "tensorlist data length illegal"; *convert_succ = false; return; @@ -229,29 +229,36 @@ void SetDataType(MetaGraphT *graph, const std::vector &output_tensors, output_tensor->dataType = output_tensors[i]->data_type(); if (output_tensors[i]->data_type() == kObjectTypeTensorType) { auto tensor_list = reinterpret_cast(output_tensors[i]); + if (output_tensor->data.empty()) { + output_tensor->data.resize(8, 0); + } if (tensor_list->tensors_data_type() == kTypeUnknown) { - tensors_->at(node->outputIndex[i]).is_infer_ = false; + tensors_->at(node->outputIndex[i]).is_inferred_ = false; + return; } + output_tensor->data.at(0) = tensor_list->tensors_data_type(); + } else if (output_tensors[i]->data_type() == kTypeUnknown) { + tensors_->at(node->outputIndex[i]).is_inferred_ = false; + return; } + tensors_->at(node->outputIndex[i]).is_inferred_ = true; + return; } } // namespace STATUS InferShapePass::Run(MetaGraphT *graph) { - graph_ = graph; - InitSearchTensor(graph); MS_ASSERT(graph != nullptr); - for (auto idx : graph->inputIndex) { - auto input_tensor = graph->allTensors[idx].get(); + InitSearchTensor(graph); + for (auto input_idx : graph->inputIndex) { + auto input_tensor = graph->allTensors[input_idx].get(); for (auto &dim : input_tensor->dims) { if (dim == 0) { MS_LOG(WARNING) << "One dimension of the input shape is 0, which would be set to -1 as a default value."; dim = DEFAULT_DIM_VALUE; } } - } - for (auto g_input_idx : graph->inputIndex) { - auto g_input_shape = graph->allTensors.at(g_input_idx)->dims; - if (std::find(g_input_shape.begin(), g_input_shape.end(), -1) != g_input_shape.end() || fmk_type_ == FmkType_TF) { + auto input_shape = graph->allTensors.at(input_idx)->dims; + if (std::find(input_shape.begin(), input_shape.end(), -1) != input_shape.end() || fmk_type_ == FmkType_TF) { infer_interrupt_ = true; } } @@ -286,11 +293,11 @@ STATUS InferShapePass::Run(MetaGraphT *graph) { auto output_dims = output_tensors[i]->shape(); auto &output_tensor = graph->allTensors.at(node->outputIndex[i]); output_tensor->dims.swap(output_dims); - SetDataType(graph_, output_tensors, &tensors_, i, infer_node_index); + SetDataType(graph, output_tensors, &tensors_, i, infer_node_index); } } else if (status == RET_INFER_INVALID) { for (size_t i = 0; i < output_tensors.size(); i++) { - SetDataType(graph_, output_tensors, &tensors_, i, infer_node_index); + SetDataType(graph, output_tensors, &tensors_, i, infer_node_index); } infer_interrupt_ = true; } else { @@ -300,7 +307,7 @@ STATUS InferShapePass::Run(MetaGraphT *graph) { return RET_INFER_ERR; } FreeTensors(&input_tensors, &output_tensors); - AddOutputNode(infer_node_index); + AddOutputNodes(graph, infer_node_index); } return RET_OK; } @@ -313,82 +320,60 @@ void InferShapePass::InitSearchTensor(MetaGraphT *graph) { auto node_input_indexes = node->inputIndex; // init in_nodes index for (size_t j = 0; j < node_input_indexes.size(); j++) { - tensors_[node_input_indexes[j]].in_nodes_.push_back(i); + tensors_[node_input_indexes[j]].next_nodes_.push_back(i); } auto node_output_indexes = node->outputIndex; for (size_t j = 0; j < node_output_indexes.size(); j++) { - tensors_[node_output_indexes[j]].out_nodes_.push_back(i); - all_node_output_tensor_indexes.insert(all_node_output_tensor_indexes.end(), node_output_indexes.begin(), - node_output_indexes.end()); + tensors_[node_output_indexes[j]].prev_nodes_.push_back(i); + } + all_node_output_tensor_indexes.insert(all_node_output_tensor_indexes.end(), node_output_indexes.begin(), + node_output_indexes.end()); + } + for (uint32_t i = 0; i < tensors_.size(); i++) { + if (tensors_[i].prev_nodes_.empty() || IsContain(graph->inputIndex, i) || !graph->allTensors.at(i)->data.empty()) { + tensors_[i].is_inferred_ = true; } } for (size_t i = 0; i < graph->nodes.size(); i++) { - auto &node = graph->nodes[i]; + auto &node = graph->nodes.at(i); if (std::all_of(node->inputIndex.begin(), node->inputIndex.end(), - [&](uint32_t index) { return !IsContain(all_node_output_tensor_indexes, index); })) { + [&](uint32_t idx) { return tensors_[idx].is_inferred_; })) { infer_node_indexes_.push_back(i); } } - for (size_t i = 0; i < tensors_.size(); i++) { - if (tensors_[i].out_nodes_.empty()) { - tensors_[i].is_infer_ = true; - } - } } -void InferShapePass::AddOutputNode(uint32_t infer_node_index) { - auto &node = graph_->nodes[infer_node_index]; +void InferShapePass::AddOutputNodes(MetaGraphT *graph, uint32_t infer_node_index) { + auto &node = graph->nodes.at(infer_node_index); for (size_t i = 0; i < node->outputIndex.size(); i++) { - auto output_tensor_node_indexes = tensors_[node->outputIndex[i]].in_nodes_; - tensors_[node->outputIndex[i]].is_infer_ = true; - for (size_t j = 0; j < output_tensor_node_indexes.size(); j++) { - bool flag = false; - auto &output_tensor_node = graph_->nodes[output_tensor_node_indexes[j]]; - for (size_t k = 0; k < output_tensor_node->outputIndex.size(); k++) { - if (graph_->allTensors.at(output_tensor_node->outputIndex[k])->dataType != kObjectTypeTensorType) { - if (graph_->allTensors.at(output_tensor_node->outputIndex[k])->dataType == kTypeUnknown || - tensors_[output_tensor_node->outputIndex[k]].is_infer_ == false) { - flag = true; - break; - } - } else { - if (tensors_[output_tensor_node->outputIndex[k]].is_infer_ == false) { - flag = true; - break; - } - } - } - if (flag) { - AddNextInferShapeNode(output_tensor_node_indexes, j); + auto next_nodes_indexes = tensors_[node->outputIndex[i]].next_nodes_; + for (size_t j = 0; j < next_nodes_indexes.size(); j++) { + auto &next_node = graph->nodes.at(next_nodes_indexes[j]); + if (std::any_of(next_node->outputIndex.begin(), next_node->outputIndex.end(), + [&](uint32_t idx) { return !tensors_[idx].is_inferred_; })) { + AddNextInferShapeNode(graph, next_nodes_indexes, j); } } } } -void InferShapePass::AddNextInferShapeNode(std::vector output_tensor_node_indexes, size_t index) { - auto &output_tensor_node = graph_->nodes.at(output_tensor_node_indexes[index]); - if (find(infer_node_indexes_.begin(), infer_node_indexes_.end(), output_tensor_node_indexes[index]) == +void InferShapePass::AddNextInferShapeNode(MetaGraphT *graph, std::vector next_nodes_indexes, size_t index) { + auto &next_node = graph->nodes.at(next_nodes_indexes[index]); + if (find(infer_node_indexes_.begin(), infer_node_indexes_.end(), next_nodes_indexes[index]) == infer_node_indexes_.end()) { - auto output_tensor_node_type = output_tensor_node->primitive->value.type; - if (output_tensor_node_type == schema::PrimitiveType_Merge) { - if (std::all_of(output_tensor_node->inputIndex.begin(), - output_tensor_node->inputIndex.begin() + output_tensor_node->inputIndex.size() / 2, - [&](uint32_t k) { return tensors_[k].is_infer_; }) || - std::all_of(output_tensor_node->inputIndex.begin() + output_tensor_node->inputIndex.size() / 2, - output_tensor_node->inputIndex.end(), [&](uint32_t k) { return tensors_[k].is_infer_; })) { - infer_node_indexes_.push_back(output_tensor_node_indexes[index]); - } - } else { - bool flag = true; - for (size_t i = 0; i < output_tensor_node->inputIndex.size(); i++) { - if (!(tensors_[output_tensor_node->inputIndex[i]].is_infer_)) { - flag = false; - break; - } - } - if (flag) { - infer_node_indexes_.push_back(output_tensor_node_indexes[index]); + auto next_node_type = next_node->primitive->value.type; + if (next_node_type == schema::PrimitiveType_Merge) { + if (std::all_of(next_node->inputIndex.begin(), next_node->inputIndex.begin() + next_node->inputIndex.size() / 2, + [&](uint32_t i) { return tensors_[i].is_inferred_; }) || + std::all_of(next_node->inputIndex.begin() + next_node->inputIndex.size() / 2, next_node->inputIndex.end(), + [&](uint32_t i) { return tensors_[i].is_inferred_; })) { + infer_node_indexes_.push_back(next_nodes_indexes[index]); } + } else if (std::all_of(next_node->inputIndex.begin(), next_node->inputIndex.end(), + [&](uint32_t i) { return tensors_[i].is_inferred_; }) || + std::any_of(next_node->inputIndex.begin(), next_node->inputIndex.end(), + [&](uint32_t i) { return graph->allTensors.at(i)->dataType == kObjectTypeTensorType; })) { + infer_node_indexes_.push_back(next_nodes_indexes[index]); } } } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h index d5fda95ec3..491d5644c5 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h @@ -32,9 +32,9 @@ namespace mindspore { namespace lite { struct InferTensor { - std::vector in_nodes_; /* used current tensor as input */ - std::vector out_nodes_; - bool is_infer_; + std::vector next_nodes_; + std::vector prev_nodes_; + bool is_inferred_; }; class InferShapePass : public GraphPass { @@ -45,11 +45,10 @@ class InferShapePass : public GraphPass { private: void InitSearchTensor(MetaGraphT *graph); - void AddNextInferShapeNode(std::vector output_tensor_node_indexes, size_t index); - void AddOutputNode(uint32_t infer_node_index); + void AddNextInferShapeNode(MetaGraphT *graph, std::vector next_nodes_indexes, size_t index); + void AddOutputNodes(MetaGraphT *graph, uint32_t infer_node_index); lite::converter::FmkType fmk_type_ = FmkType_TF; - MetaGraphT *graph_ = nullptr; std::vector tensors_ = {}; std::vector infer_node_indexes_ = {}; bool infer_interrupt_ = false;