| @@ -67,7 +67,7 @@ ml_video_edit_v10_best_model_nomean_20200723 8 | |||||
| #hdc_ocr_detect.onnx 30 #too many subgraphs | #hdc_ocr_detect.onnx 30 #too many subgraphs | ||||
| ml_edu_kit_hand_detection.onnx 1 | ml_edu_kit_hand_detection.onnx 1 | ||||
| ml_edu_kit_hand_key_position.onnx 2 | ml_edu_kit_hand_key_position.onnx 2 | ||||
| #ml_video_edit_oneclick_adaptis.pb 2 3 | |||||
| ml_video_edit_oneclick_adaptis.pb 2 3 | |||||
| densenet.tflite 3 | densenet.tflite 3 | ||||
| resnet_v2_101_299.tflite 1 | resnet_v2_101_299.tflite 1 | ||||
| ml_video_edit_enhance.pb 2 | ml_video_edit_enhance.pb 2 | ||||
| @@ -1,9 +1,9 @@ | |||||
| ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 11 | ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite.tflite;2 11 | ||||
| ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 12.3 | |||||
| ml_video_edit_video_segment_gauss_adaptis_part2.pb;2 11 | |||||
| ml_video_edit_img_segment_adaptise.pb;2 40 | ml_video_edit_img_segment_adaptise.pb;2 40 | ||||
| ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2 0.5 | ml_video_edit_img_segment_adaptise_pb2tflite.tflite;2 0.5 | ||||
| ml_video_edit_person_divison_video;2 38 | ml_video_edit_person_divison_video;2 38 | ||||
| ml_video_edit_oneclick_adaptis.pb;3 6.1 | |||||
| ml_video_edit_oneclick_adaptis.pb;3 6 | |||||
| hdc_tb_cn_neg.tflite;3 281 | hdc_tb_cn_neg.tflite;3 281 | ||||
| decoder_step_201217.pb;5 187 | decoder_step_201217.pb;5 187 | ||||
| ml_video_edit_art_transfer.onnx;3 3 | ml_video_edit_art_transfer.onnx;3 3 | ||||
| @@ -79,8 +79,13 @@ std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) { | |||||
| } | } | ||||
| return cnodes; | return cnodes; | ||||
| } | } | ||||
| ShapeVector GetShapeVectorFromTensorInfo(const tensor::TensorPtr &tensor_info, size_t *offset) { | |||||
| ShapeVector shape_vector; | |||||
| STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) { | |||||
| auto data_type = tensor_info->data_type(); | |||||
| if (data_type != kObjectTypeString) { | |||||
| MS_LOG(ERROR) << "This function only used for string tensor."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| shape_vector->clear(); | |||||
| auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | ||||
| std::string shape_str; | std::string shape_str; | ||||
| std::string shape_size_str; | std::string shape_size_str; | ||||
| @@ -93,11 +98,15 @@ ShapeVector GetShapeVectorFromTensorInfo(const tensor::TensorPtr &tensor_info, s | |||||
| } | } | ||||
| shape_size_str.push_back(tensor_data[*offset]); | shape_size_str.push_back(tensor_data[*offset]); | ||||
| } | } | ||||
| if (*offset == 0) { | |||||
| MS_LOG(ERROR) << "string tensor's dim size not found."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t shape_size = std::stoi(shape_size_str); | size_t shape_size = std::stoi(shape_size_str); | ||||
| for (; *offset < tensor_info->Size(); (*offset)++) { | for (; *offset < tensor_info->Size(); (*offset)++) { | ||||
| if (tensor_data[*offset] == ',') { | if (tensor_data[*offset] == ',') { | ||||
| cnt++; | cnt++; | ||||
| shape_vector.push_back(std::stoi(shape_str)); | |||||
| shape_vector->push_back(std::stoi(shape_str)); | |||||
| shape_str.clear(); | shape_str.clear(); | ||||
| } else { | } else { | ||||
| shape_str.push_back(tensor_data[*offset]); | shape_str.push_back(tensor_data[*offset]); | ||||
| @@ -107,8 +116,11 @@ ShapeVector GetShapeVectorFromTensorInfo(const tensor::TensorPtr &tensor_info, s | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| return shape_vector; | |||||
| if (shape_vector->empty()) { | |||||
| MS_LOG(ERROR) << "string tensor's shape shouldn't be empty."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | } | ||||
| schema::Format GetFormatByFmk(int32_t fmk_type) { | schema::Format GetFormatByFmk(int32_t fmk_type) { | ||||
| switch (fmk_type) { | switch (fmk_type) { | ||||
| @@ -124,6 +136,28 @@ schema::Format GetFormatByFmk(int32_t fmk_type) { | |||||
| return static_cast<schema::Format>(fmk_type); | return static_cast<schema::Format>(fmk_type); | ||||
| } | } | ||||
| } | } | ||||
| STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape_vector) { | |||||
| auto abstract_base = param_node->abstract(); | |||||
| if (abstract_base == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||||
| MS_ASSERT(typePtr != nullptr); | |||||
| *data_type = typePtr->type_id(); | |||||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| *shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | ||||
| @@ -500,7 +534,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee | |||||
| meta_graphT->fmkType = GetValue<int>(fmk); | meta_graphT->fmkType = GetValue<int>(fmk); | ||||
| int ret = ExportSubgraph(func_graph, meta_graphT, keep_graph, copy_primitive); | int ret = ExportSubgraph(func_graph, meta_graphT, keep_graph, copy_primitive); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "ExportSubgraph failed."; | |||||
| MS_LOG(ERROR) << "Export subgraph failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -613,28 +647,22 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_ano | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| schema_tensor->name = param_node->name(); | schema_tensor->name = param_node->name(); | ||||
| auto abstract_base = param_node->abstract(); | |||||
| if (abstract_base == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||||
| MS_ASSERT(typePtr != nullptr); | |||||
| schema_tensor->dataType = typePtr->type_id(); | |||||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); | |||||
| return RET_PARAM_INVALID; | |||||
| ShapeVector shape_vector; | |||||
| TypeId data_type; | |||||
| auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "get data type and shape from param node failed."; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| schema_tensor->dataType = data_type; | |||||
| auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param()); | auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param()); | ||||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||||
| size_t offset = 0; | size_t offset = 0; | ||||
| if (!shape_vector.empty() && schema_tensor->dataType == kObjectTypeString) { | if (!shape_vector.empty() && schema_tensor->dataType == kObjectTypeString) { | ||||
| shape_vector = GetShapeVectorFromTensorInfo(tensor_info, &offset); | |||||
| status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "get shape vector from string tensor failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | } | ||||
| std::vector<int32_t> dims; | std::vector<int32_t> dims; | ||||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), | (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), | ||||
| @@ -32,9 +32,6 @@ STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_t | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<int64_t> shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end()); | std::vector<int64_t> shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end()); | ||||
| std::vector<int> shape; | |||||
| std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), | |||||
| [](const int64_t &val) { return static_cast<int32_t>(val); }); | |||||
| auto tensor_info = std::make_shared<tensor::Tensor>(data_type, shape_vector); | auto tensor_info = std::make_shared<tensor::Tensor>(data_type, shape_vector); | ||||
| if (tensor_info == nullptr) { | if (tensor_info == nullptr) { | ||||
| MS_LOG(ERROR) << "new a paramValueLite failed."; | MS_LOG(ERROR) << "new a paramValueLite failed."; | ||||
| @@ -50,6 +50,29 @@ bool IsSpecialType(const CNodePtr &cnode) { | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr &cnode, size_t index) { | |||||
| AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, index); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | |||||
| MS_LOG(DEBUG) << "Abstract of parameter should be abstract tensor"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); | |||||
| if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape | |||||
| MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *tensor_info = utils::cast<tensor::TensorPtr>(abstract_tensor->GetValueTrack()); | |||||
| if (*tensor_info == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { | abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { | ||||
| @@ -93,6 +116,50 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li | |||||
| return new_abstract; | return new_abstract; | ||||
| } | } | ||||
| STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { | |||||
| MS_ASSERT(parameter != nullptr); | |||||
| auto old_abstract = parameter->abstract(); | |||||
| if (old_abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << parameter->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(old_abstract)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << parameter->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(old_abstract); | |||||
| auto type_ptr = abstract_tensor->element()->GetTypeTrack(); | |||||
| if (type_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "type_ptr is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << parameter->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||||
| std::vector<int32_t> shape; | |||||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), | |||||
| [](const int64_t &value) { return static_cast<int32_t>(value); }); | |||||
| auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| auto new_tensor_info = std::make_shared<tensor::Tensor>(type_ptr->type_id(), shape_vector); | |||||
| if (parameter->has_default()) { | |||||
| auto old_tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param()); | |||||
| new_tensor_info = lite::CreateTensorInfo(old_tensor_info->data_c(), old_tensor_info->Size(), | |||||
| old_tensor_info->shape(), old_tensor_info->data_type()); | |||||
| if (new_tensor_info == nullptr) { | |||||
| MS_LOG(ERROR) << "create tensor info failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| new_abstract->set_value(new_tensor_info); | |||||
| parameter->set_abstract(new_abstract); | |||||
| return RET_OK; | |||||
| } | |||||
| void InferShapePass::FreeTensors(std::vector<lite::Tensor *> *tensors) { | void InferShapePass::FreeTensors(std::vector<lite::Tensor *> *tensors) { | ||||
| for (auto tensor : *tensors) { | for (auto tensor : *tensors) { | ||||
| delete tensor; | delete tensor; | ||||
| @@ -104,6 +171,12 @@ void InferShapePass::FreeTensors(std::vector<lite::Tensor *> *tensors) { | |||||
| STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors) { | STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors) { | ||||
| MS_ASSERT(cnode != nullptr); | MS_ASSERT(cnode != nullptr); | ||||
| MS_ASSERT(input_tensors != nullptr); | MS_ASSERT(input_tensors != nullptr); | ||||
| auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is nullptr: " << cnode->fullname_with_scope(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| const int WEIGHT_INDEX = 2; | |||||
| auto inputs = cnode->inputs(); | auto inputs = cnode->inputs(); | ||||
| for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
| auto input = inputs[i]; | auto input = inputs[i]; | ||||
| @@ -117,28 +190,14 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l | |||||
| continue; | continue; | ||||
| } | } | ||||
| AbstractBasePtr abstract = GetCNodeInputAbstract(cnode, i); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of CNode: " << cnode->fullname_with_scope() << " is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | |||||
| MS_LOG(DEBUG) << "Abstract of parameter should be abstract tensor"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); | |||||
| if (!utils::isa<tensor::TensorPtr>(abstract_tensor->GetValueTrack())) { // input node not complete infershape | |||||
| MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto param_value_lite = utils::cast<tensor::TensorPtr>(abstract_tensor->GetValueTrack()); | |||||
| if (param_value_lite == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor::Tensor of abstract is nullptr"; | |||||
| tensor::TensorPtr tensor_info; | |||||
| auto status = GetTensorInfoFromAbstract(&tensor_info, cnode, i); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "get tensor info failed."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::unique_ptr<lite::Tensor> tensor = nullptr; | std::unique_ptr<lite::Tensor> tensor = nullptr; | ||||
| if (param_value_lite->data_type() != kObjectTypeTensorType) { | |||||
| if (tensor_info->data_type() != kObjectTypeTensorType) { | |||||
| tensor = std::make_unique<lite::Tensor>(); | tensor = std::make_unique<lite::Tensor>(); | ||||
| } else { | } else { | ||||
| tensor = std::make_unique<lite::TensorList>(); | tensor = std::make_unique<lite::TensorList>(); | ||||
| @@ -149,30 +208,36 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l | |||||
| } | } | ||||
| std::vector<int> shape; | std::vector<int> shape; | ||||
| std::transform(param_value_lite->shape().begin(), param_value_lite->shape().end(), std::back_inserter(shape), | |||||
| std::transform(tensor_info->shape().begin(), tensor_info->shape().end(), std::back_inserter(shape), | |||||
| [](const int64_t &value) { return static_cast<int32_t>(value); }); | [](const int64_t &value) { return static_cast<int32_t>(value); }); | ||||
| if (param_value_lite->data_type() != kObjectTypeTensorType) { | |||||
| if (tensor_info->data_type() != kObjectTypeTensorType) { | |||||
| tensor->set_shape(shape); | tensor->set_shape(shape); | ||||
| tensor->set_data_type(param_value_lite->data_type()); | |||||
| tensor->set_data_type(tensor_info->data_type()); | |||||
| if (primitive->GetAttr(opt::kWeightFormat) != nullptr && i == WEIGHT_INDEX) { | |||||
| tensor->set_format(static_cast<schema::Format>(GetValue<int64_t>(primitive->GetAttr(opt::kWeightFormat)))); | |||||
| } else { | |||||
| tensor->set_format(schema::Format::Format_NHWC); | |||||
| } | |||||
| } | } | ||||
| if (utils::isa<ParameterPtr>(input)) { | if (utils::isa<ParameterPtr>(input)) { | ||||
| auto parameter = input->cast<ParameterPtr>(); | auto parameter = input->cast<ParameterPtr>(); | ||||
| if (parameter->has_default()) { | if (parameter->has_default()) { | ||||
| auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param()); | |||||
| if (param_value_lite->data_type() != kObjectTypeTensorType) { | |||||
| auto default_tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param()); | |||||
| if (tensor_info->data_type() != kObjectTypeTensorType) { | |||||
| auto ret = tensor->MallocData(); | auto ret = tensor->MallocData(); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "Malloc tensor data failed"; | MS_LOG(ERROR) << "Malloc tensor data failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| ret = memcpy_s(tensor->MutableData(), tensor->Size(), tensor_info->data_c(), tensor_info->Size()); | |||||
| ret = | |||||
| memcpy_s(tensor->MutableData(), tensor->Size(), default_tensor_info->data_c(), default_tensor_info->Size()); | |||||
| if (tensor->Size() != 0 && ret != EOK) { | if (tensor->Size() != 0 && ret != EOK) { | ||||
| MS_LOG(ERROR) << "memcpy error: " << ret; | MS_LOG(ERROR) << "memcpy error: " << ret; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else { | } else { | ||||
| int *data = reinterpret_cast<int *>(tensor_info->data_c()); | |||||
| int *data = reinterpret_cast<int *>(default_tensor_info->data_c()); | |||||
| auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor.get()); | auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor.get()); | ||||
| if (tensor_list->Decode(data) != RET_OK) { | if (tensor_list->Decode(data) != RET_OK) { | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -301,6 +366,10 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { | |||||
| auto node_list = TopoSort(func_graph->get_return()); | auto node_list = TopoSort(func_graph->get_return()); | ||||
| for (auto &node : node_list) { | for (auto &node : node_list) { | ||||
| if (utils::isa<ParameterPtr>(node)) { | if (utils::isa<ParameterPtr>(node)) { | ||||
| int status = SetParameterAbstract(node->cast<ParameterPtr>()); | |||||
| if (status != RET_OK) { | |||||
| return false; | |||||
| } | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (!utils::isa<CNodePtr>(node)) { | if (!utils::isa<CNodePtr>(node)) { | ||||
| @@ -39,6 +39,7 @@ class InferShapePass : public Pass { | |||||
| abstract::AbstractTensorPtr ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor); | abstract::AbstractTensorPtr ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor); | ||||
| STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors); | STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors); | ||||
| STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors); | STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors); | ||||
| STATUS SetParameterAbstract(const ParameterPtr ¶meter); | |||||
| STATUS SetCNodeAbstract(const std::vector<lite::Tensor *> &output_tensors, const std::shared_ptr<CNode> &cnode); | STATUS SetCNodeAbstract(const std::vector<lite::Tensor *> &output_tensors, const std::shared_ptr<CNode> &cnode); | ||||
| int StrIsContain(const std::vector<std::string> &total, const std::string &aim); | int StrIsContain(const std::vector<std::string> &total, const std::string &aim); | ||||
| int SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph); | int SetSubGraphInputsAbstract(const CNodePtr &cnode, const FuncGraphPtr &func_graph); | ||||