| @@ -38,9 +38,12 @@ int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o | |||||
| if (!parameter->infer_flag_) { | if (!parameter->infer_flag_) { | ||||
| return NNACL_INFER_INVALID; | return NNACL_INFER_INVALID; | ||||
| } | } | ||||
| const TensorC *input_k_tensor = inputs[1]; | |||||
| if (input_k_tensor->data_ == NULL) { | |||||
| return NNACL_INFER_INVALID; | |||||
| } | |||||
| TopkParameter *param = (TopkParameter *)parameter; | TopkParameter *param = (TopkParameter *)parameter; | ||||
| const TensorC *input_k_tensor = inputs[1]; | |||||
| param->k_ = ((int32_t *)input_k_tensor->data_)[0]; | param->k_ = ((int32_t *)input_k_tensor->data_)[0]; | ||||
| int out_shape[MAX_SHAPE_SIZE]; | int out_shape[MAX_SHAPE_SIZE]; | ||||
| @@ -75,12 +75,14 @@ class LiteModel : public Model { | |||||
| } else { | } else { | ||||
| node->name_ = c_node->name()->c_str(); | node->name_ = c_node->name()->c_str(); | ||||
| } | } | ||||
| auto count = c_node->inputIndex()->size(); | |||||
| for (uint32_t j = 0; j < count; ++j) { | |||||
| node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j))); | |||||
| if (c_node->inputIndex() != nullptr) { | |||||
| auto count = c_node->inputIndex()->size(); | |||||
| for (uint32_t j = 0; j < count; ++j) { | |||||
| node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j))); | |||||
| } | |||||
| } | } | ||||
| if (c_node->outputIndex() != nullptr) { | if (c_node->outputIndex() != nullptr) { | ||||
| count = c_node->outputIndex()->size(); | |||||
| auto count = c_node->outputIndex()->size(); | |||||
| for (uint32_t j = 0; j < count; ++j) { | for (uint32_t j = 0; j < count; ++j) { | ||||
| node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j))); | node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j))); | ||||
| } | } | ||||
| @@ -247,6 +247,7 @@ if(ENABLE_CONVERTER) | |||||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/mul_add_fusion.cc | |||||
| ${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc | ||||
| ${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc | ${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc | ||||
| @@ -73,6 +73,40 @@ tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std | |||||
| return tensor_info; | return tensor_info; | ||||
| } | } | ||||
| AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type) { | |||||
| auto tensor_info = CreateTensorInfo(nullptr, 0, shape, data_type); | |||||
| if (tensor_info == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor info failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto abstract = tensor_info->ToAbstract(); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return nullptr; | |||||
| } | |||||
| return abstract; | |||||
| } | |||||
| int SetParameterAbstractAndParam(const ParameterPtr ¶meter, const void *data, size_t data_size, | |||||
| const std::vector<int64_t> &shape, TypeId data_type) { | |||||
| if (parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "Input parameter is nullptr"; | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| auto tensor_info = CreateTensorInfo(data, data_size, shape, data_type); | |||||
| if (tensor_info == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor info failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract = tensor_info->ToAbstract(); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| parameter->set_abstract(abstract); | |||||
| return RET_OK; | |||||
| } | |||||
| int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size) { | int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size) { | ||||
| if (tensor_info == nullptr) { | if (tensor_info == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor info is nullptr."; | MS_LOG(ERROR) << "tensor info is nullptr."; | ||||
| @@ -46,6 +46,11 @@ std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> | |||||
| tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector<int64_t> &shape, | tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector<int64_t> &shape, | ||||
| TypeId data_type); | TypeId data_type); | ||||
| AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type); | |||||
| int SetParameterAbstractAndParam(const ParameterPtr ¶meter, const void *data, size_t data_size, | |||||
| const std::vector<int64_t> &shape, TypeId data_type); | |||||
| int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size); | int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size); | ||||
| std::unique_ptr<schema::TensorT> CreateTensorTFromTensorInfo(const tensor::TensorPtr &tensor_info, | std::unique_ptr<schema::TensorT> CreateTensorTFromTensorInfo(const tensor::TensorPtr &tensor_info, | ||||
| @@ -54,6 +54,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/fusion/tf_bidirection_gru_fusion.cc | ../optimizer/fusion/tf_bidirection_gru_fusion.cc | ||||
| ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc | ||||
| ../optimizer/fusion/matmul_add_fusion.cc | ../optimizer/fusion/matmul_add_fusion.cc | ||||
| ../optimizer/fusion/mul_add_fusion.cc | |||||
| ../optimizer/fusion/gelu_fusion.cc | ../optimizer/fusion/gelu_fusion.cc | ||||
| ../optimizer/fusion/tf_gelu_fusion.cc | ../optimizer/fusion/tf_gelu_fusion.cc | ||||
| ../optimizer/fusion/onnx_gelu_fusion.cc | ../optimizer/fusion/onnx_gelu_fusion.cc | ||||
| @@ -70,7 +70,7 @@ MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag) { | |||||
| } | } | ||||
| MS_LOG(INFO) << "Run anfTransform success"; | MS_LOG(INFO) << "Run anfTransform success"; | ||||
| // protobuf -> flatbuf | |||||
| // protobuf -> flatbuffer | |||||
| auto meta_graph = Export(graph, false, false, flag->trainModel); | auto meta_graph = Export(graph, false, false, flag->trainModel); | ||||
| if (meta_graph == nullptr) { | if (meta_graph == nullptr) { | ||||
| MS_LOG(ERROR) << "Export to meta graph return nullptr"; | MS_LOG(ERROR) << "Export to meta graph return nullptr"; | ||||
| @@ -39,7 +39,6 @@ | |||||
| using std::string; | using std::string; | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() { | std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() { | ||||
| std::vector<schema::CNodeT *> old_nodes{}; | std::vector<schema::CNodeT *> old_nodes{}; | ||||
| old_nodes.resize(graph_defT_->nodes.size()); | old_nodes.resize(graph_defT_->nodes.size()); | ||||
| @@ -71,54 +70,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| } | } | ||||
| // generate and infer quant parameters | |||||
| { | |||||
| Optimizer infer_quant_param_pass; | |||||
| infer_quant_param_pass.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| infer_quant_param_pass.AddPass(new (std::nothrow) InferQuantParamPass()); | |||||
| status = infer_quant_param_pass.Run(graph_defT_); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run infer_quant_param_pass graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| { | |||||
| // format transform | |||||
| // init old node indices | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer format_trans_optimizer; | |||||
| format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| format_trans_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| if (ctx.fmk != converter::FmkType_TF) { | |||||
| auto infer_shape_pass = new (std::nothrow) InferShapePass(); | |||||
| if (infer_shape_pass == nullptr) { | |||||
| MS_LOG(ERROR) << "new InferShapePass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| infer_shape_pass->set_fmk_type(ctx.fmk); | |||||
| format_trans_optimizer.AddPass(infer_shape_pass); | |||||
| } | |||||
| status = format_trans_optimizer.Run(graph_defT_); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||||
| MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| { | |||||
| // init old node indices | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer format_trans_optimizer; | |||||
| format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||||
| format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| status = format_trans_optimizer.Run(graph_defT_); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||||
| MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| // format transpose global optimize | |||||
| { | { | ||||
| // init old node indices | // init old node indices | ||||
| auto old_nodes = GetGraphNodes(); | auto old_nodes = GetGraphNodes(); | ||||
| @@ -134,20 +86,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| } | } | ||||
| // postconvert pass | |||||
| { | |||||
| // node replace | |||||
| if (!ctx.trainModel) { | |||||
| // init old node indices | // init old node indices | ||||
| auto old_nodes = GetGraphNodes(); | auto old_nodes = GetGraphNodes(); | ||||
| Optimizer replace_optimizer; | Optimizer replace_optimizer; | ||||
| if (!ctx.trainModel) { | |||||
| auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); | |||||
| if (batch_norm_scale_pass == nullptr) { | |||||
| MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| batch_norm_scale_pass->SetFmk(ctx.fmk); | |||||
| replace_optimizer.AddPass(batch_norm_scale_pass); | |||||
| } | |||||
| replace_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); | |||||
| replace_optimizer.AddPass(new (std::nothrow) BatchNormConvertScalePass(ctx.fmk)); | |||||
| replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | ||||
| replace_optimizer.AddPass(new SubgraphNodePass(old_nodes)); | replace_optimizer.AddPass(new SubgraphNodePass(old_nodes)); | ||||
| status = replace_optimizer.Run(graph_defT_); | status = replace_optimizer.Run(graph_defT_); | ||||
| @@ -157,6 +102,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| } | } | ||||
| // node fusion | |||||
| { | { | ||||
| // init old node indices | // init old node indices | ||||
| auto old_nodes = GetGraphNodes(); | auto old_nodes = GetGraphNodes(); | ||||
| @@ -171,19 +117,14 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| } | } | ||||
| // do quantization | |||||
| // quantization | |||||
| if (ctx.fmk != converter::FmkType_TF) { | if (ctx.fmk != converter::FmkType_TF) { | ||||
| // init old node indices | // init old node indices | ||||
| auto old_nodes = GetGraphNodes(); | auto old_nodes = GetGraphNodes(); | ||||
| Optimizer tensor_quant_optimizer; | Optimizer tensor_quant_optimizer; | ||||
| tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | ||||
| auto infer_shape_pass = new (std::nothrow) InferShapePass(); | |||||
| if (infer_shape_pass == nullptr) { | |||||
| MS_LOG(ERROR) << "new InferShapePass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| infer_shape_pass->set_fmk_type(ctx.fmk); | |||||
| tensor_quant_optimizer.AddPass(infer_shape_pass); | |||||
| tensor_quant_optimizer.AddPass(new (std::nothrow) InferQuantParamPass()); | |||||
| tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); | |||||
| tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass()); | tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass()); | ||||
| tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | ||||
| status = tensor_quant_optimizer.Run(graph_defT_); | status = tensor_quant_optimizer.Run(graph_defT_); | ||||
| @@ -193,38 +134,17 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| } | } | ||||
| // insert quantNode and deQuantNode | |||||
| // quantization | |||||
| if (ctx.fmk != converter::FmkType_TF) { | if (ctx.fmk != converter::FmkType_TF) { | ||||
| // init old node indices | // init old node indices | ||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer quant_node_optimizer; | Optimizer quant_node_optimizer; | ||||
| quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | ||||
| auto infer_shape_pass = new (std::nothrow) InferShapePass(); | |||||
| if (infer_shape_pass == nullptr) { | |||||
| MS_LOG(ERROR) << "new InferShapePass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| infer_shape_pass->set_fmk_type(ctx.fmk); | |||||
| quant_node_optimizer.AddPass(infer_shape_pass); | |||||
| status = quant_node_optimizer.Run(graph_defT_); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| auto old_nodes2 = GetGraphNodes(); | |||||
| quant_node_optimizer.AddPass(new (std::nothrow) InferQuantParamPass()); | |||||
| auto dtype_trans_pass = new (std::nothrow) DTypeTransPass(); | |||||
| if (dtype_trans_pass == nullptr) { | |||||
| MS_LOG(ERROR) << "new dtype_trans_pass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| dtype_trans_pass->set_input_data_dtype(ctx.inputDataType); | |||||
| dtype_trans_pass->set_output_data_dtype(ctx.outputDataType); | |||||
| quant_node_optimizer.AddPass(dtype_trans_pass); | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); | |||||
| quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(ctx.inputDataType, ctx.outputDataType)); | |||||
| quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | ||||
| quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | ||||
| quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); | |||||
| quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| status = quant_node_optimizer.Run(graph_defT_); | status = quant_node_optimizer.Run(graph_defT_); | ||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed"; | MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed"; | ||||
| @@ -232,7 +152,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| } | } | ||||
| // switch pass | |||||
| // controlflow pass | |||||
| { | { | ||||
| // init old node indices | // init old node indices | ||||
| auto old_nodes = GetGraphNodes(); | auto old_nodes = GetGraphNodes(); | ||||
| @@ -240,6 +160,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| switch_optimizer.AddPass(new (std::nothrow) SwitchPass()); | switch_optimizer.AddPass(new (std::nothrow) SwitchPass()); | ||||
| switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | ||||
| switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | ||||
| switch_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass()); | |||||
| status = switch_optimizer.Run(graph_defT_); | status = switch_optimizer.Run(graph_defT_); | ||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| MS_LOG(ERROR) << "Run switch_optimizer Failed"; | MS_LOG(ERROR) << "Run switch_optimizer Failed"; | ||||
| @@ -247,34 +168,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| } | } | ||||
| // subgraph tensor pass | |||||
| { | |||||
| Optimizer subgraph_tensor_optimizer; | |||||
| subgraph_tensor_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass()); | |||||
| status = subgraph_tensor_optimizer.Run(graph_defT_); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run subgraph tensor pass Failed"; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| // tensor name | |||||
| { | |||||
| // init old node indices | |||||
| auto old_nodes = GetGraphNodes(); | |||||
| Optimizer name_optimizer; | |||||
| name_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| name_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| name_optimizer.AddPass(new (std::nothrow) TensorNamePass()); | |||||
| status = name_optimizer.Run(graph_defT_); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run name_optimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| { | { | ||||
| Optimizer nested_loop_optimizer; | Optimizer nested_loop_optimizer; | ||||
| auto old_nodes = GetGraphNodes(); | |||||
| nested_loop_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||||
| nested_loop_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass()); | nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass()); | ||||
| status = nested_loop_optimizer.Run(graph_defT_); | status = nested_loop_optimizer.Run(graph_defT_); | ||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | if (status != RET_OK && status != RET_NO_CHANGE) { | ||||
| @@ -284,30 +182,16 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| { | { | ||||
| Optimizer quant_param_optimizer; | |||||
| quant_param_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); | |||||
| status = quant_param_optimizer.Run(graph_defT_); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "Run quant_param_optimizer graphPasses Failed"; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| { | |||||
| Optimizer infer_shape_optimizer; | |||||
| auto infer_shape_pass = new (std::nothrow) InferShapePass(); | |||||
| if (infer_shape_pass == nullptr) { | |||||
| MS_LOG(ERROR) << "new InferShapePass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| infer_shape_pass->set_fmk_type(ctx.fmk); | |||||
| infer_shape_optimizer.AddPass(infer_shape_pass); | |||||
| status = infer_shape_optimizer.Run(graph_defT_); | |||||
| Optimizer forming_model_optimizer; | |||||
| forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); | |||||
| forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); | |||||
| forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass()); | |||||
| status = forming_model_optimizer.Run(graph_defT_); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed."; | MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed."; | ||||
| return status; | return status; | ||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } // namespace mindspore::lite | |||||
| } | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -36,14 +36,12 @@ struct BNWeightTensors { | |||||
| }; | }; | ||||
| class BatchNormConvertScalePass : public GraphPass { | class BatchNormConvertScalePass : public GraphPass { | ||||
| public: | public: | ||||
| BatchNormConvertScalePass() = default; | |||||
| explicit BatchNormConvertScalePass(converter::FmkType fmk) : fmkType(fmk) {} | |||||
| ~BatchNormConvertScalePass() = default; | ~BatchNormConvertScalePass() = default; | ||||
| STATUS Run(MetaGraphT *graph) override; | STATUS Run(MetaGraphT *graph) override; | ||||
| void SetFmk(converter::FmkType fmk) { this->fmkType = fmk; } | |||||
| protected: | protected: | ||||
| STATUS GetTransParam(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode); | STATUS GetTransParam(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode); | ||||
| @@ -276,10 +276,5 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte | |||||
| return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num, | return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num, | ||||
| castOpCopyer); | castOpCopyer); | ||||
| } | } | ||||
| void DTypeTransPass::set_input_data_dtype(TypeId data_type) { this->input_data_dtype = data_type; } | |||||
| void DTypeTransPass::set_output_data_dtype(TypeId data_type) { this->output_data_dtype = data_type; } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -30,16 +30,13 @@ enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 } | |||||
| class DTypeTransPass : public GraphPass { | class DTypeTransPass : public GraphPass { | ||||
| public: | public: | ||||
| DTypeTransPass() : id_(0) {} | |||||
| DTypeTransPass(TypeId model_input_data_type, TypeId model_output_data_type) | |||||
| : id_(0), input_data_dtype(model_input_data_type), output_data_dtype(model_output_data_type) {} | |||||
| ~DTypeTransPass() override = default; | ~DTypeTransPass() override = default; | ||||
| STATUS Run(schema::MetaGraphT *graph) override; | STATUS Run(schema::MetaGraphT *graph) override; | ||||
| void set_input_data_dtype(TypeId data_type); | |||||
| void set_output_data_dtype(TypeId dataType); | |||||
| private: | private: | ||||
| STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph); | STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph); | ||||
| @@ -39,14 +39,10 @@ struct InferTensor { | |||||
| class InferShapePass : public GraphPass { | class InferShapePass : public GraphPass { | ||||
| public: | public: | ||||
| InferShapePass() = default; | |||||
| ~InferShapePass() = default; | |||||
| explicit InferShapePass(converter::FmkType fmk_type) : fmk_type_(fmk_type) {} | |||||
| ~InferShapePass() override = default; | |||||
| STATUS Run(MetaGraphT *graph) override; | STATUS Run(MetaGraphT *graph) override; | ||||
| void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; } | |||||
| private: | private: | ||||
| void InitSearchTensor(MetaGraphT *graph); | void InitSearchTensor(MetaGraphT *graph); | ||||
| void AddNextInferShapeNode(std::vector<uint32_t> output_tensor_node_indexes, size_t index); | void AddNextInferShapeNode(std::vector<uint32_t> output_tensor_node_indexes, size_t index); | ||||
| @@ -34,8 +34,28 @@ class ModelParser { | |||||
| virtual ~ModelParser() = default; | virtual ~ModelParser() = default; | ||||
| virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) = 0; | |||||
| FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) { | |||||
| auto ret = ParseToFuncGraph(model_file, weight_file, quant_type); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Parse to func graph failed : " << ret; | |||||
| return nullptr; | |||||
| } | |||||
| ret = PostAdjust(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Adjust func graph failed : " << ret; | |||||
| return nullptr; | |||||
| } | |||||
| return this->res_graph_; | |||||
| } | |||||
| protected: | |||||
| virtual int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) = 0; | |||||
| virtual int PostAdjust() = 0; | |||||
| protected: | |||||
| FuncGraphPtr res_graph_ = nullptr; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include <vector> | #include <vector> | ||||
| #include "tools/common/tensor_util.h" | |||||
| #include "tools/converter/ops/while.h" | #include "tools/converter/ops/while.h" | ||||
| #include "utils/check_convert_utils.h" | #include "utils/check_convert_utils.h" | ||||
| #include "abstract/primitive_infer_map.h" | #include "abstract/primitive_infer_map.h" | ||||
| @@ -55,7 +56,9 @@ AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP | |||||
| AbstractBasePtrList output; | AbstractBasePtrList output; | ||||
| for (int64_t i = 0; i < (int64_t)input_args.size(); i++) { | for (int64_t i = 0; i < (int64_t)input_args.size(); i++) { | ||||
| auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape]; | auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape]; | ||||
| output.push_back(std::make_shared<abstract::AbstractTensor>(input_args[i]->BuildType(), shape)); | |||||
| auto abstract_tensor = lite::CreateTensorAbstract(shape, input_args[i]->BuildType()->type_id()); | |||||
| MS_EXCEPTION_IF_NULL(abstract_tensor); | |||||
| output.push_back(abstract_tensor); | |||||
| } | } | ||||
| return std::make_shared<abstract::AbstractTuple>(output); | return std::make_shared<abstract::AbstractTuple>(output); | ||||
| } | } | ||||
| @@ -41,34 +41,34 @@ CaffeModelParser::CaffeModelParser() = default; | |||||
| CaffeModelParser::~CaffeModelParser() = default; | CaffeModelParser::~CaffeModelParser() = default; | ||||
| FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| STATUS status = InitOriginModel(model_file, weight_file); | STATUS status = InitOriginModel(model_file, weight_file); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| func_graph_ptr_ = std::make_shared<FuncGraph>(); | |||||
| res_graph_ = std::make_shared<FuncGraph>(); | |||||
| status = ConvertGraphInputs(); | status = ConvertGraphInputs(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| status = ConvertLayers(); | status = ConvertLayers(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| status = ConvertGraphOutputs(); | status = ConvertGraphOutputs(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| func_graph_ptr_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE))); | |||||
| return func_graph_ptr_; | |||||
| res_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE))); | |||||
| return RET_OK; | |||||
| } | } | ||||
| STATUS CaffeModelParser::ConvertLayers() { | STATUS CaffeModelParser::ConvertLayers() { | ||||
| @@ -134,7 +134,7 @@ STATUS CaffeModelParser::ConvertLayers() { | |||||
| std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c))}; | std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c))}; | ||||
| op_inputs.insert(op_inputs.end(), input_nodes.begin(), input_nodes.end()); | op_inputs.insert(op_inputs.end(), input_nodes.begin(), input_nodes.end()); | ||||
| op_inputs.insert(op_inputs.end(), const_parameters.begin(), const_parameters.end()); | op_inputs.insert(op_inputs.end(), const_parameters.begin(), const_parameters.end()); | ||||
| auto new_cnode = func_graph_ptr_->NewCNode(op_inputs); | |||||
| auto new_cnode = res_graph_->NewCNode(op_inputs); | |||||
| new_cnode->set_fullname_with_scope(layer.name()); | new_cnode->set_fullname_with_scope(layer.name()); | ||||
| // convert outputs | // convert outputs | ||||
| @@ -194,14 +194,17 @@ STATUS CaffeModelParser::ConvertGraphInputs() { | |||||
| for (int i = 0; i < caffe_model_.layer_size(); i++) { | for (int i = 0; i < caffe_model_.layer_size(); i++) { | ||||
| auto layer = caffe_model_.layer(i); | auto layer = caffe_model_.layer(i); | ||||
| if (layer.type() == "Input") { | if (layer.type() == "Input") { | ||||
| auto parameter = func_graph_ptr_->add_parameter(); | |||||
| auto parameter = res_graph_->add_parameter(); | |||||
| std::vector<int64_t> shape; | std::vector<int64_t> shape; | ||||
| for (int j = 0; j < layer.input_param().shape(0).dim_size(); j++) { | for (int j = 0; j < layer.input_param().shape(0).dim_size(); j++) { | ||||
| shape.push_back(layer.input_param().shape(0).dim(j)); | shape.push_back(layer.input_param().shape(0).dim(j)); | ||||
| } | } | ||||
| auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| auto abstract = CreateTensorAbstract(shape, kNumberTypeFloat32); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| parameter->set_abstract(abstract); | |||||
| parameter->set_name("graph_input-" + std::to_string(i)); | parameter->set_name("graph_input-" + std::to_string(i)); | ||||
| nodes_.insert(std::pair(layer.top(0), parameter)); | nodes_.insert(std::pair(layer.top(0), parameter)); | ||||
| } | } | ||||
| @@ -220,10 +223,13 @@ STATUS CaffeModelParser::ConvertGraphInputs() { | |||||
| shape.push_back(caffe_model_.input_dim(j)); | shape.push_back(caffe_model_.input_dim(j)); | ||||
| } | } | ||||
| } | } | ||||
| auto parameter = func_graph_ptr_->add_parameter(); | |||||
| auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| auto parameter = res_graph_->add_parameter(); | |||||
| auto abstract = CreateTensorAbstract(shape, kNumberTypeFloat32); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| parameter->set_abstract(abstract); | |||||
| parameter->set_name("graph_input-" + caffe_model_.input(i)); | parameter->set_name("graph_input-" + caffe_model_.input(i)); | ||||
| nodes_.insert(std::pair(caffe_model_.input(i), parameter)); | nodes_.insert(std::pair(caffe_model_.input(i), parameter)); | ||||
| } | } | ||||
| @@ -234,10 +240,18 @@ STATUS CaffeModelParser::ConvertGraphInputs() { | |||||
| for (int j = 0; j < shape.dim_size(); j++) { | for (int j = 0; j < shape.dim_size(); j++) { | ||||
| shape_vector.push_back(shape.dim(j)); | shape_vector.push_back(shape.dim(j)); | ||||
| } | } | ||||
| auto parameter = func_graph_ptr_->add_parameter(); | |||||
| auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| auto parameter = res_graph_->add_parameter(); | |||||
| auto tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeFloat32); | |||||
| if (tensor_info == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor info failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract = tensor_info->ToAbstract(); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| parameter->set_abstract(abstract); | |||||
| parameter->set_name("graph_input-" + caffe_model_.input(i)); | parameter->set_name("graph_input-" + caffe_model_.input(i)); | ||||
| nodes_.insert(std::pair(caffe_model_.input(i), parameter)); | nodes_.insert(std::pair(caffe_model_.input(i), parameter)); | ||||
| } | } | ||||
| @@ -265,7 +279,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() { | |||||
| auto cnode = nodes_.find(output_node)->second; | auto cnode = nodes_.find(output_node)->second; | ||||
| make_tuple_inputs.emplace_back(cnode); | make_tuple_inputs.emplace_back(cnode); | ||||
| } | } | ||||
| auto make_tuple_cnode = func_graph_ptr_->NewCNode(make_tuple_inputs); | |||||
| auto make_tuple_cnode = res_graph_->NewCNode(make_tuple_inputs); | |||||
| make_tuple_cnode->set_fullname_with_scope("return tuple"); | make_tuple_cnode->set_fullname_with_scope("return tuple"); | ||||
| std::vector<AnfNodePtr> op_inputs; | std::vector<AnfNodePtr> op_inputs; | ||||
| @@ -277,9 +291,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() { | |||||
| auto value_node = NewValueNode(return_prim_ptr); | auto value_node = NewValueNode(return_prim_ptr); | ||||
| op_inputs.emplace_back(value_node); | op_inputs.emplace_back(value_node); | ||||
| op_inputs.emplace_back(make_tuple_cnode); | op_inputs.emplace_back(make_tuple_cnode); | ||||
| auto cnode = func_graph_ptr_->NewCNode(op_inputs); | |||||
| auto cnode = res_graph_->NewCNode(op_inputs); | |||||
| cnode->set_fullname_with_scope("Return"); | cnode->set_fullname_with_scope("Return"); | ||||
| func_graph_ptr_->set_return(cnode); | |||||
| res_graph_->set_return(cnode); | |||||
| } else { | } else { | ||||
| auto returnPrim = std::make_shared<ops::Return>(); | auto returnPrim = std::make_shared<ops::Return>(); | ||||
| if (returnPrim == nullptr) { | if (returnPrim == nullptr) { | ||||
| @@ -298,9 +312,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() { | |||||
| return RET_NOT_FIND_OP; | return RET_NOT_FIND_OP; | ||||
| } | } | ||||
| opInputs.emplace_back(cnode); | opInputs.emplace_back(cnode); | ||||
| auto returnCnode = func_graph_ptr_->NewCNode(opInputs); | |||||
| auto returnCnode = res_graph_->NewCNode(opInputs); | |||||
| returnCnode->set_fullname_with_scope("Return"); | returnCnode->set_fullname_with_scope("Return"); | ||||
| func_graph_ptr_->set_return(returnCnode); | |||||
| res_graph_->set_return(returnCnode); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -333,7 +347,7 @@ STATUS CaffeModelParser::ConvertBlobs(const caffe::LayerParameter &layer, std::v | |||||
| ConvertShape(layer.blobs(i), &shape); | ConvertShape(layer.blobs(i), &shape); | ||||
| // cal Weight num | // cal Weight num | ||||
| auto parameter = func_graph_ptr_->add_parameter(); | |||||
| auto parameter = res_graph_->add_parameter(); | |||||
| auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); | auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); | ||||
| std::vector<int64_t> shape_vector; | std::vector<int64_t> shape_vector; | ||||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | ||||
| @@ -402,17 +416,25 @@ STATUS CaffeModelParser::ConvertBottom(const caffe::LayerParameter &layer, std:: | |||||
| } | } | ||||
| STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CNodePtr &cnode) { | STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CNodePtr &cnode) { | ||||
| auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); | |||||
| std::vector<int64_t> shape_vector; | |||||
| if (layer.top_size() == 1) { | if (layer.top_size() == 1) { | ||||
| cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||||
| auto abstract = CreateTensorAbstract({}, kNumberTypeFloat32); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| cnode->set_abstract(abstract); | |||||
| nodes_[layer.top(0)] = cnode; | nodes_[layer.top(0)] = cnode; | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| AbstractBasePtrList abstract_list; | AbstractBasePtrList abstract_list; | ||||
| for (int i = 0; i < layer.top_size(); i++) { | for (int i = 0; i < layer.top_size(); i++) { | ||||
| abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||||
| auto abstract = CreateTensorAbstract({}, kNumberTypeFloat32); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| abstract_list.emplace_back(abstract); | |||||
| auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); | auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); | ||||
| if (tuple_get_item_prim_ptr == nullptr) { | if (tuple_get_item_prim_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << "new TupleGetItem failed"; | MS_LOG(ERROR) << "new TupleGetItem failed"; | ||||
| @@ -421,7 +443,7 @@ STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CN | |||||
| auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | ||||
| auto get_item_value = NewValueNode(MakeValue<int>(i)); | auto get_item_value = NewValueNode(MakeValue<int>(i)); | ||||
| std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value}; | std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value}; | ||||
| CNodePtr get_item_cnode = func_graph_ptr_->NewCNode(inputs); | |||||
| CNodePtr get_item_cnode = res_graph_->NewCNode(inputs); | |||||
| get_item_cnode->set_fullname_with_scope(layer.top(i)); | get_item_cnode->set_fullname_with_scope(layer.top(i)); | ||||
| nodes_[layer.top(i)] = get_item_cnode; | nodes_[layer.top(i)] = get_item_cnode; | ||||
| } | } | ||||
| @@ -446,4 +468,6 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name) | |||||
| } | } | ||||
| return layer.name(); | return layer.name(); | ||||
| } | } | ||||
| int CaffeModelParser::PostAdjust() { return RET_OK; } | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -32,8 +32,10 @@ class CaffeModelParser : public ModelParser { | |||||
| ~CaffeModelParser() override; | ~CaffeModelParser() override; | ||||
| FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) override; | |||||
| int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) override; | |||||
| int PostAdjust() override; | |||||
| private: | private: | ||||
| STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file); | STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file); | ||||
| @@ -59,7 +61,6 @@ class CaffeModelParser : public ModelParser { | |||||
| caffe::NetParameter caffe_weight_; | caffe::NetParameter caffe_weight_; | ||||
| std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_; | std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_; | ||||
| std::unordered_map<std::string, AnfNodePtr> nodes_; | std::unordered_map<std::string, AnfNodePtr> nodes_; | ||||
| FuncGraphPtr func_graph_ptr_; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -45,31 +45,31 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = { | |||||
| {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, | {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, | ||||
| {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; | {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; | ||||
| FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| NotSupportOp::GetInstance()->set_fmk_type("ONNX"); | NotSupportOp::GetInstance()->set_fmk_type("ONNX"); | ||||
| anf_root_graph_ = std::make_shared<FuncGraph>(); | |||||
| res_graph_ = std::make_shared<FuncGraph>(); | |||||
| auto status = InitOriginModel(model_file); | auto status = InitOriginModel(model_file); | ||||
| if (RET_OK != status) { | if (RET_OK != status) { | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| MS_LOG(ERROR) << "init origin model failed."; | MS_LOG(ERROR) << "init origin model failed."; | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| status = ConvertOnnxGraph(onnx_root_graph_, anf_root_graph_, &anf_nodes_map_, {}, "root_node"); | |||||
| status = ConvertOnnxGraph(onnx_root_graph_, res_graph_, &anf_nodes_map_, {}, "root_node"); | |||||
| if (RET_OK != status) { | if (RET_OK != status) { | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| MS_LOG(ERROR) << "convert onnx graph failed."; | MS_LOG(ERROR) << "convert onnx graph failed."; | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| static auto root_func_manager = Manage(anf_root_graph_); | |||||
| static auto root_func_manager = Manage(res_graph_); | |||||
| for (auto &subgraph : all_subgraphs_) { | for (auto &subgraph : all_subgraphs_) { | ||||
| subgraph->set_manager(root_func_manager); | subgraph->set_manager(root_func_manager); | ||||
| subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); | subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); | ||||
| } | } | ||||
| anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); | |||||
| return anf_root_graph_; | |||||
| res_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); | |||||
| return RET_OK; | |||||
| } | } | ||||
| STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) { | STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) { | ||||
| @@ -88,9 +88,9 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) { | |||||
| OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version()); | OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version()); | ||||
| onnx_root_graph_ = onnx_model_.graph(); | onnx_root_graph_ = onnx_model_.graph(); | ||||
| if (OnnxNodeParser::opset_version() > 15) { | if (OnnxNodeParser::opset_version() > 15) { | ||||
| anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); | |||||
| res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX))); | |||||
| } else { | } else { | ||||
| anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION))); | |||||
| res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION))); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -170,13 +170,16 @@ STATUS OnnxModelParser::ConvertGraphInputs(const onnx::GraphProto &onnx_graph, c | |||||
| << static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type()); | << static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type()); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto type_ptr = TypeIdToType(data_type); | |||||
| std::vector<int64_t> shape_vector; | std::vector<int64_t> shape_vector; | ||||
| auto onnx_shape = input_value.type().tensor_type().shape().dim(); | auto onnx_shape = input_value.type().tensor_type().shape().dim(); | ||||
| std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector), | std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector), | ||||
| [](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); }); | [](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); }); | ||||
| std::replace(shape_vector.begin(), shape_vector.end(), 0, -1); | std::replace(shape_vector.begin(), shape_vector.end(), 0, -1); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| parameter->set_abstract(abstract_tensor); | parameter->set_abstract(abstract_tensor); | ||||
| parameter->set_name(input_value.name()); | parameter->set_name(input_value.name()); | ||||
| anf_nodes_map->emplace(input_value.name(), parameter); | anf_nodes_map->emplace(input_value.name(), parameter); | ||||
| @@ -490,17 +493,23 @@ STATUS OnnxModelParser::BuildOpOutputs(const onnx::NodeProto &onnx_node, const F | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (onnx_node.output_size() == 1) { | if (onnx_node.output_size() == 1) { | ||||
| auto type_ptr = TypeIdToType(kNumberTypeFloat32); | |||||
| std::vector<int64_t> shape_vector; | |||||
| cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||||
| auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| cnode->set_abstract(abstract_tensor); | |||||
| anf_nodes_map->emplace(onnx_node.output(0), cnode); | anf_nodes_map->emplace(onnx_node.output(0), cnode); | ||||
| } else { | } else { | ||||
| AbstractBasePtrList abstract_list; | AbstractBasePtrList abstract_list; | ||||
| int op_idx = 0; | int op_idx = 0; | ||||
| for (const auto &output_name : onnx_node.output()) { | for (const auto &output_name : onnx_node.output()) { | ||||
| std::vector<int64_t> shape_vector; | |||||
| auto type_ptr = TypeIdToType(kNumberTypeFloat32); | |||||
| abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||||
| auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| abstract_list.emplace_back(abstract_tensor); | |||||
| auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); | auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); | ||||
| if (tuple_get_item_prim_ptr == nullptr) { | if (tuple_get_item_prim_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << "new TupleGetItem failed"; | MS_LOG(ERROR) << "new TupleGetItem failed"; | ||||
| @@ -687,7 +696,11 @@ ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto const_node = anf_graph->add_parameter(); | auto const_node = anf_graph->add_parameter(); | ||||
| auto const_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>()); | |||||
| auto const_abstract = CreateTensorAbstract({}, kNumberTypeInt32); | |||||
| if (const_abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return nullptr; | |||||
| } | |||||
| const_node->set_abstract(const_abstract); | const_node->set_abstract(const_abstract); | ||||
| int *tensor_data = new (std::nothrow) int[1]; | int *tensor_data = new (std::nothrow) int[1]; | ||||
| if (tensor_data == nullptr) { | if (tensor_data == nullptr) { | ||||
| @@ -834,9 +847,16 @@ STATUS OnnxModelParser::AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::v | |||||
| for (int i = 0; i < act_output_num; i++) { | for (int i = 0; i < act_output_num; i++) { | ||||
| // tensor_array need as root while input | // tensor_array need as root while input | ||||
| auto while_tensor_array_input = anf_root_graph->add_parameter(); | auto while_tensor_array_input = anf_root_graph->add_parameter(); | ||||
| std::vector<int64_t> shape_vector; | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kTensorType, shape_vector); | |||||
| auto tensor_info = std::make_shared<tensor::Tensor>(kObjectTypeTensorType, shape_vector); | |||||
| auto tensor_info = CreateTensorInfo(nullptr, 0, {}, kObjectTypeTensorType); | |||||
| if (tensor_info == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor info failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract_tensor = tensor_info->ToAbstract(); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| while_tensor_array_input->set_abstract(abstract_tensor); | while_tensor_array_input->set_abstract(abstract_tensor); | ||||
| while_tensor_array_input->set_default_param(tensor_info); | while_tensor_array_input->set_default_param(tensor_info); | ||||
| while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray"); | while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray"); | ||||
| @@ -975,7 +995,11 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf | |||||
| auto input_paramter = cond_graph->add_parameter(); | auto input_paramter = cond_graph->add_parameter(); | ||||
| input_paramter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter"); | input_paramter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter"); | ||||
| auto root_while_inputs = root_while_node->cast<CNodePtr>()->inputs(); | auto root_while_inputs = root_while_node->cast<CNodePtr>()->inputs(); | ||||
| auto input_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>()); | |||||
| auto input_abstract = CreateTensorAbstract({}, kNumberTypeInt32); | |||||
| if (input_abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| input_paramter->set_abstract(input_abstract); | input_paramter->set_abstract(input_abstract); | ||||
| if (i == 0) { | if (i == 0) { | ||||
| auto zero_parameter = CreateConstParamter(cond_graph, 0); | auto zero_parameter = CreateConstParamter(cond_graph, 0); | ||||
| @@ -987,7 +1011,11 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf | |||||
| MS_LOG(ERROR) << "new cnode error"; | MS_LOG(ERROR) << "new cnode error"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto less_abstract = std::make_shared<abstract::AbstractTensor>(kBool, std::vector<int64_t>()); | |||||
| auto less_abstract = CreateTensorAbstract({}, kNumberTypeBool); | |||||
| if (less_abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| less_cnode->set_abstract(less_abstract); | less_cnode->set_abstract(less_abstract); | ||||
| less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode"); | less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode"); | ||||
| } | } | ||||
| @@ -1020,12 +1048,11 @@ STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const | |||||
| MS_LOG(ERROR) << "quant param type don't support."; | MS_LOG(ERROR) << "quant param type don't support."; | ||||
| return RET_NOT_SUPPORT; | return RET_NOT_SUPPORT; | ||||
| } | } | ||||
| std::vector<int64_t> shape_vector; | |||||
| auto parameter_node = anf_root_graph_->add_parameter(); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape_vector); | |||||
| auto parameter_node = res_graph_->add_parameter(); | |||||
| auto abstract_tensor = CreateTensorAbstract({}, type); | |||||
| if (abstract_tensor == nullptr) { | if (abstract_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "new abstract_tensor failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| parameter_node->set_abstract(abstract_tensor); | parameter_node->set_abstract(abstract_tensor); | ||||
| parameter_node->set_name(name); | parameter_node->set_name(name); | ||||
| @@ -1051,9 +1078,12 @@ STATUS OnnxModelParser::BuildParameterNode(const ParameterPtr ¶meter_node, c | |||||
| MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(tensor.data_type()); | MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(tensor.data_type()); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto type_ptr = TypeIdToType(data_type); | |||||
| std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end()); | std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end()); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| parameter_node->set_abstract(abstract_tensor); | parameter_node->set_abstract(abstract_tensor); | ||||
| parameter_node->set_name(tensor.name()); | parameter_node->set_name(tensor.name()); | ||||
| @@ -1142,5 +1172,7 @@ TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type | |||||
| } | } | ||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| int OnnxModelParser::PostAdjust() { return 0; } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -40,14 +40,17 @@ class OnnxModelParser : public ModelParser { | |||||
| ~OnnxModelParser() override = default; | ~OnnxModelParser() override = default; | ||||
| FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) override; | |||||
| int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) override; | |||||
| int PostAdjust() override; | |||||
| static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | ||||
| static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor, | static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor, | ||||
| const tensor::TensorPtr ¶m_value_lite); | const tensor::TensorPtr ¶m_value_lite); | ||||
| STATUS InitOriginModel(const std::string &model_file); | |||||
| private: | private: | ||||
| STATUS InitOriginModel(const std::string &model_file); | |||||
| STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr, | STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr, | ||||
| std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs, | std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs, | ||||
| const std::string &root_node_name); | const std::string &root_node_name); | ||||
| @@ -94,7 +97,6 @@ class OnnxModelParser : public ModelParser { | |||||
| std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_; | std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_; | ||||
| std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_; | std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_; | ||||
| std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node | std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node | ||||
| FuncGraphPtr anf_root_graph_ = nullptr; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -417,18 +417,17 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa | |||||
| type = TensorFlowUtils::GetTFDataType(attr_value.type()); | type = TensorFlowUtils::GetTFDataType(attr_value.type()); | ||||
| } | } | ||||
| std::vector<int> shape; | |||||
| std::vector<int64_t> shape; | |||||
| if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) { | if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) { | ||||
| auto &shape_attr = attr_value.shape(); | auto &shape_attr = attr_value.shape(); | ||||
| for (int i = 0; i < shape_attr.dim_size(); ++i) { | for (int i = 0; i < shape_attr.dim_size(); ++i) { | ||||
| shape.push_back(shape_attr.dim(i).size()); | shape.push_back(shape_attr.dim(i).size()); | ||||
| } | } | ||||
| } | } | ||||
| std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | |||||
| if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { | if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { | ||||
| MS_LOG(INFO) << "Found value attr, means it has default value"; | MS_LOG(INFO) << "Found value attr, means it has default value"; | ||||
| auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape_vector); | |||||
| auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "convert const tensor failed."; | MS_LOG(ERROR) << "convert const tensor failed."; | ||||
| return status; | return status; | ||||
| @@ -437,10 +436,10 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa | |||||
| graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names | graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names | ||||
| } | } | ||||
| auto type_ptr = TypeIdToType(type == kNumberTypeInt64 ? kNumberTypeInt32 : type); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| type = (type == kNumberTypeInt64) ? kNumberTypeInt32 : type; | |||||
| auto abstract_tensor = CreateTensorAbstract(shape, type); | |||||
| if (abstract_tensor == nullptr) { | if (abstract_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "abstract_tensor is nullptr"; | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| parameter->set_name(node.name()); | parameter->set_name(node.name()); | ||||
| @@ -473,51 +472,51 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts( | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| FuncGraphPtr paserTfFuction() { return nullptr; } | |||||
| FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType) { | |||||
| int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType) { | |||||
| NotSupportOp::GetInstance()->set_fmk_type("TF"); | NotSupportOp::GetInstance()->set_fmk_type("TF"); | ||||
| auto status = ValidateFileStr(modelFile, ".pb"); | auto status = ValidateFileStr(modelFile, ".pb"); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb"; | MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| tf_root_graph_ = std::make_unique<tensorflow::GraphDef>(); | tf_root_graph_ = std::make_unique<tensorflow::GraphDef>(); | ||||
| if (tf_root_graph_ == nullptr) { | if (tf_root_graph_ == nullptr) { | ||||
| MS_LOG(ERROR) << "tf_root_graph_ is nullptr"; | MS_LOG(ERROR) << "tf_root_graph_ is nullptr"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get()); | status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get()); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; | MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| anf_root_graph_ = std::make_shared<FuncGraph>(); | |||||
| if (anf_root_graph_ == nullptr) { | |||||
| res_graph_ = std::make_shared<FuncGraph>(); | |||||
| if (res_graph_ == nullptr) { | |||||
| MS_LOG(ERROR) << "funGraphPtr is nullptr"; | MS_LOG(ERROR) << "funGraphPtr is nullptr"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF))); | |||||
| res_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF))); | |||||
| for (int i = 0; i < tf_root_graph_->node_size(); i++) { | for (int i = 0; i < tf_root_graph_->node_size(); i++) { | ||||
| auto &node_def = tf_root_graph_->node(i); | auto &node_def = tf_root_graph_->node(i); | ||||
| tf_root_graph_nodes_[node_def.name()] = &node_def; | tf_root_graph_nodes_[node_def.name()] = &node_def; | ||||
| } | } | ||||
| status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); | |||||
| status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, res_graph_, &anf_root_node_map_); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| bool success_flag = true; | bool success_flag = true; | ||||
| for (int i = 0; i < tf_root_graph_->node_size(); i++) { | for (int i = 0; i < tf_root_graph_->node_size(); i++) { | ||||
| auto &node_def = tf_root_graph_->node(i); | auto &node_def = tf_root_graph_->node(i); | ||||
| status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); | |||||
| status = ConvertOps(node_def, tf_root_graph_nodes_, res_graph_, &anf_root_node_map_); | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| success_flag = false; | success_flag = false; | ||||
| @@ -525,7 +524,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||||
| } | } | ||||
| if (!success_flag) { | if (!success_flag) { | ||||
| MS_LOG(ERROR) << "Convert ops failed."; | MS_LOG(ERROR) << "Convert ops failed."; | ||||
| return nullptr; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| if (!nodes_with_null_input_.empty()) { | if (!nodes_with_null_input_.empty()) { | ||||
| @@ -533,7 +532,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Connect null inputs failed."; | MS_LOG(ERROR) << "Connect null inputs failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| } | } | ||||
| @@ -541,17 +540,17 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Convert graph outputs failed."; | MS_LOG(ERROR) << "Convert graph outputs failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| status = ConvertSubgraph(); | status = ConvertSubgraph(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Convert subgraph failed."; | MS_LOG(ERROR) << "Convert subgraph failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| return anf_root_graph_; | |||||
| return RET_OK; | |||||
| } | } | ||||
| STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map, | STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map, | ||||
| @@ -745,7 +744,7 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGr | |||||
| MS_LOG(ERROR) << "while cond body size error"; | MS_LOG(ERROR) << "while cond body size error"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| static auto root_func_manager = Manage(anf_root_graph_); | |||||
| static auto root_func_manager = Manage(res_graph_); | |||||
| for (auto &kv : first_func_map) { | for (auto &kv : first_func_map) { | ||||
| auto control_flow_node = kv.first; | auto control_flow_node = kv.first; | ||||
| @@ -757,7 +756,7 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGr | |||||
| auto second_value_node = NewValueNode(second_sub_graph); | auto second_value_node = NewValueNode(second_sub_graph); | ||||
| auto inputs = control_flow_node->inputs(); | auto inputs = control_flow_node->inputs(); | ||||
| inputs.insert(inputs.begin() + 1, {first_value_node, second_value_node}); | inputs.insert(inputs.begin() + 1, {first_value_node, second_value_node}); | ||||
| auto new_node = anf_root_graph_->NewCNode(inputs); // must create new node, otherwise node_users won't update | |||||
| auto new_node = res_graph_->NewCNode(inputs); // must create new node, otherwise node_users won't update | |||||
| if (new_node == nullptr) { | if (new_node == nullptr) { | ||||
| MS_LOG(ERROR) << "new node failed"; | MS_LOG(ERROR) << "new node failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -811,43 +810,46 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C | |||||
| if (output_size == 0) { | if (output_size == 0) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } else if (output_size == 1) { | } else if (output_size == 1) { | ||||
| auto type = kFloat32; | |||||
| std::vector<int64_t> shape_vector; | |||||
| auto type = kNumberTypeFloat32; | |||||
| if (IsTensorListOp(anf_node)) { | if (IsTensorListOp(anf_node)) { | ||||
| type = TypeIdToType(kObjectTypeTensorType); | |||||
| type = kObjectTypeTensorType; | |||||
| } | } | ||||
| auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape_vector); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "create AbstractTensor failed"; | |||||
| auto abstract_tensor = CreateTensorAbstract({}, type); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| anf_node->set_abstract(abstract); | |||||
| anf_node->set_abstract(abstract_tensor); | |||||
| anf_node_map->insert(std::pair(op.name(), anf_node)); | anf_node_map->insert(std::pair(op.name(), anf_node)); | ||||
| } else { | } else { | ||||
| AbstractBasePtrList abstractList; | |||||
| AbstractBasePtrList abstract_list; | |||||
| for (int output_idx = 0; output_idx < output_size; output_idx++) { | for (int output_idx = 0; output_idx < output_size; output_idx++) { | ||||
| std::vector<int64_t> shape_vector; | |||||
| abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||||
| auto tupleGetItemPrimPtr = std::make_shared<ops::TupleGetItem>(); | |||||
| if (tupleGetItemPrimPtr == nullptr) { | |||||
| auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| abstract_list.emplace_back(abstract_tensor); | |||||
| auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); | |||||
| if (tuple_get_item_prim_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "new TupleGetItem failed"; | MS_LOG(ERROR) << "new TupleGetItem failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr); | |||||
| auto getItemValue = NewValueNode(MakeValue<int>(output_idx)); | |||||
| std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue}; | |||||
| CNodePtr getItemCNode = anf_graph->NewCNode(inputs); | |||||
| auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | |||||
| auto get_item_value = NewValueNode(MakeValue<int>(output_idx)); | |||||
| std::vector<AnfNodePtr> inputs{tuple_get_item_prim, anf_node, get_item_value}; | |||||
| CNodePtr get_item_cnode = anf_graph->NewCNode(inputs); | |||||
| std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | ||||
| auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "create AbstractTensor failed"; | |||||
| auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32); | |||||
| if (get_item_abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| getItemCNode->set_abstract(abstract); | |||||
| getItemCNode->set_fullname_with_scope(output_item_name); | |||||
| anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); | |||||
| get_item_cnode->set_abstract(get_item_abstract); | |||||
| get_item_cnode->set_fullname_with_scope(output_item_name); | |||||
| anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), get_item_cnode)); | |||||
| } | } | ||||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList)); | |||||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -1003,7 +1005,7 @@ STATUS TFModelParser::ConvertRootGraphOutputs() { | |||||
| graph_output_names_.push_back(anf_node->fullname_with_scope()); | graph_output_names_.push_back(anf_node->fullname_with_scope()); | ||||
| } | } | ||||
| } | } | ||||
| auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); | |||||
| auto status = MakeAnfGraphOutputs(&output_nodes, res_graph_); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "make anf graph outputs node error"; | MS_LOG(ERROR) << "make anf graph outputs node error"; | ||||
| return status; | return status; | ||||
| @@ -1051,5 +1053,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int TFModelParser::PostAdjust() { return 0; } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,9 +36,11 @@ namespace lite { | |||||
| class TFModelParser : public ModelParser { | class TFModelParser : public ModelParser { | ||||
| public: | public: | ||||
| TFModelParser() = default; | TFModelParser() = default; | ||||
| ~TFModelParser() = default; | |||||
| ~TFModelParser() override = default; | |||||
| FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); | |||||
| int ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); | |||||
| int PostAdjust() override; | |||||
| private: | private: | ||||
| static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info); | static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info); | ||||
| @@ -84,7 +86,6 @@ class TFModelParser : public ModelParser { | |||||
| STATUS ConnectNullInput(); | STATUS ConnectNullInput(); | ||||
| FuncGraphPtr anf_root_graph_; | |||||
| std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def | std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def | ||||
| std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map | std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map | ||||
| std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_; | std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_; | ||||
| @@ -43,46 +43,46 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *m | |||||
| return tflite::UnPackModel(tflite_model_buf_); | return tflite::UnPackModel(tflite_model_buf_); | ||||
| } | } | ||||
| FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) { | |||||
| // load graph | // load graph | ||||
| tflite_model_ = ReadTfliteModel(model_file.c_str()); | tflite_model_ = ReadTfliteModel(model_file.c_str()); | ||||
| if (tflite_model_ == nullptr) { | if (tflite_model_ == nullptr) { | ||||
| MS_LOG(ERROR) << "read tflite model failed"; | MS_LOG(ERROR) << "read tflite model failed"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | ||||
| return nullptr; | |||||
| return RET_GRAPH_FILE_ERR; | |||||
| } | } | ||||
| if (tflite_model_->subgraphs.size() != 1) { | if (tflite_model_->subgraphs.size() != 1) { | ||||
| MS_LOG(ERROR) << "read tflite model subgraphs failed"; | MS_LOG(ERROR) << "read tflite model subgraphs failed"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); | ||||
| return nullptr; | |||||
| return RET_GRAPH_FILE_ERR; | |||||
| } | } | ||||
| func_graph_ = std::make_shared<FuncGraph>(); | |||||
| func_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE))); | |||||
| res_graph_ = std::make_shared<FuncGraph>(); | |||||
| res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE))); | |||||
| auto status = ConvertGraphInputs(); | auto status = ConvertGraphInputs(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Convert graph inputs failed."; | MS_LOG(ERROR) << "Convert graph inputs failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| status = ConvertOps(); | status = ConvertOps(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Convert ops failed."; | MS_LOG(ERROR) << "Convert ops failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| status = ConvertGraphOutputs(); | status = ConvertGraphOutputs(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Convert graph outputs failed."; | MS_LOG(ERROR) << "Convert graph outputs failed."; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | |||||
| return status; | |||||
| } | } | ||||
| func_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| return func_graph_; | |||||
| res_graph_->set_attr("graph_name", MakeValue("main_graph")); | |||||
| return RET_OK; | |||||
| } | } | ||||
| std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) { | std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) { | ||||
| @@ -158,7 +158,7 @@ STATUS TfliteModelParser::ConvertOps() { | |||||
| } else { | } else { | ||||
| tensor_name = GetTensorName(i, tflite_op_type, op_name); | tensor_name = GetTensorName(i, tflite_op_type, op_name); | ||||
| } | } | ||||
| auto parameter = func_graph_->add_parameter(); | |||||
| auto parameter = res_graph_->add_parameter(); | |||||
| status = ConvertConstTensor(input_tensor.get(), parameter, tensor_name); | status = ConvertConstTensor(input_tensor.get(), parameter, tensor_name); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; | MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; | ||||
| @@ -168,7 +168,7 @@ STATUS TfliteModelParser::ConvertOps() { | |||||
| op_inputs.emplace_back(parameter); | op_inputs.emplace_back(parameter); | ||||
| nodes_.insert(std::pair(input_idx, parameter)); | nodes_.insert(std::pair(input_idx, parameter)); | ||||
| } | } | ||||
| auto new_cnode = func_graph_->NewCNode(op_inputs); | |||||
| auto new_cnode = res_graph_->NewCNode(op_inputs); | |||||
| new_cnode->set_fullname_with_scope(op_name); | new_cnode->set_fullname_with_scope(op_name); | ||||
| // parse outputs | // parse outputs | ||||
| @@ -284,13 +284,16 @@ STATUS TfliteModelParser::ConvertGraphInputs() { | |||||
| if (tflite_graph_input < 0) { | if (tflite_graph_input < 0) { | ||||
| tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size(); | tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size(); | ||||
| } | } | ||||
| auto parameter = func_graph_->add_parameter(); | |||||
| auto parameter = res_graph_->add_parameter(); | |||||
| const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input); | const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input); | ||||
| std::vector<int64_t> shape_vector; | std::vector<int64_t> shape_vector; | ||||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | ||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | [](const int32_t &value) { return static_cast<int64_t>(value); }); | ||||
| auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type)); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| parameter->set_abstract(abstract_tensor); | parameter->set_abstract(abstract_tensor); | ||||
| parameter->set_name("graph_input-" + std::to_string(tflite_graph_input)); | parameter->set_name("graph_input-" + std::to_string(tflite_graph_input)); | ||||
| nodes_.insert(std::pair(tflite_graph_input, parameter)); | nodes_.insert(std::pair(tflite_graph_input, parameter)); | ||||
| @@ -318,7 +321,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { | |||||
| } | } | ||||
| make_tuple_inputs.emplace_back(cnode); | make_tuple_inputs.emplace_back(cnode); | ||||
| } | } | ||||
| auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); | |||||
| auto make_tuple_cnode = res_graph_->NewCNode(make_tuple_inputs); | |||||
| make_tuple_cnode->set_fullname_with_scope("return tuple"); | make_tuple_cnode->set_fullname_with_scope("return tuple"); | ||||
| std::vector<AnfNodePtr> op_inputs; | std::vector<AnfNodePtr> op_inputs; | ||||
| @@ -330,9 +333,9 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { | |||||
| auto value_node = NewValueNode(return_prim_ptr); | auto value_node = NewValueNode(return_prim_ptr); | ||||
| op_inputs.emplace_back(value_node); | op_inputs.emplace_back(value_node); | ||||
| op_inputs.emplace_back(make_tuple_cnode); | op_inputs.emplace_back(make_tuple_cnode); | ||||
| auto cnode = func_graph_->NewCNode(op_inputs); | |||||
| auto cnode = res_graph_->NewCNode(op_inputs); | |||||
| cnode->set_fullname_with_scope("Return"); | cnode->set_fullname_with_scope("Return"); | ||||
| func_graph_->set_return(cnode); | |||||
| res_graph_->set_return(cnode); | |||||
| } else { | } else { | ||||
| auto returnPrim = std::make_shared<ops::Return>(); | auto returnPrim = std::make_shared<ops::Return>(); | ||||
| if (returnPrim == nullptr) { | if (returnPrim == nullptr) { | ||||
| @@ -350,9 +353,9 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { | |||||
| return RET_NOT_FIND_OP; | return RET_NOT_FIND_OP; | ||||
| } | } | ||||
| op_inputs.emplace_back(cnode); | op_inputs.emplace_back(cnode); | ||||
| auto returnCnode = func_graph_->NewCNode(op_inputs); | |||||
| auto returnCnode = res_graph_->NewCNode(op_inputs); | |||||
| returnCnode->set_fullname_with_scope("Return"); | returnCnode->set_fullname_with_scope("Return"); | ||||
| func_graph_->set_return(returnCnode); | |||||
| res_graph_->set_return(returnCnode); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -436,8 +439,12 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const | |||||
| std::vector<int64_t> shape_vector; | std::vector<int64_t> shape_vector; | ||||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | ||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | [](const int32_t &value) { return static_cast<int64_t>(value); }); | ||||
| auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); | |||||
| dst_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||||
| auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type)); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| dst_cnode->set_abstract(abstract_tensor); | |||||
| nodes_.insert(std::pair(op->outputs.front(), dst_cnode)); | nodes_.insert(std::pair(op->outputs.front(), dst_cnode)); | ||||
| } else { | } else { | ||||
| AbstractBasePtrList abstract_list; | AbstractBasePtrList abstract_list; | ||||
| @@ -450,8 +457,12 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const | |||||
| std::vector<int64_t> shape_vector; | std::vector<int64_t> shape_vector; | ||||
| (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), | ||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | [](const int32_t &value) { return static_cast<int64_t>(value); }); | ||||
| auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); | |||||
| abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector)); | |||||
| auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type)); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| abstract_list.emplace_back(abstract_tensor); | |||||
| auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); | auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); | ||||
| if (tuple_get_item_prim_ptr == nullptr) { | if (tuple_get_item_prim_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << "new TupleGetItem failed"; | MS_LOG(ERROR) << "new TupleGetItem failed"; | ||||
| @@ -460,7 +471,7 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const | |||||
| auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); | ||||
| auto get_item_value = NewValueNode(MakeValue<int>(op_idx)); | auto get_item_value = NewValueNode(MakeValue<int>(op_idx)); | ||||
| std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value}; | std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value}; | ||||
| CNodePtr get_item_cnode = func_graph_->NewCNode(inputs); | |||||
| CNodePtr get_item_cnode = res_graph_->NewCNode(inputs); | |||||
| get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx)); | get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx)); | ||||
| nodes_.insert(std::pair(output_idx, get_item_cnode)); | nodes_.insert(std::pair(output_idx, get_item_cnode)); | ||||
| op_idx++; | op_idx++; | ||||
| @@ -469,4 +480,6 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int TfliteModelParser::PostAdjust() { return 0; } | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -32,13 +32,14 @@ class TfliteModelParser : public ModelParser { | |||||
| ~TfliteModelParser() override = default; | ~TfliteModelParser() override = default; | ||||
| FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) override; | |||||
| int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type) override; | |||||
| int PostAdjust() override; | |||||
| private: | private: | ||||
| std::unordered_map<int, AnfNodePtr> nodes_; | std::unordered_map<int, AnfNodePtr> nodes_; | ||||
| std::unique_ptr<tflite::ModelT> tflite_model_; | std::unique_ptr<tflite::ModelT> tflite_model_; | ||||
| FuncGraphPtr func_graph_; | |||||
| char *tflite_model_buf_ = nullptr; | char *tflite_model_buf_ = nullptr; | ||||
| std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path); | std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path); | ||||
| STATUS ConvertConstTensor(const tflite::TensorT *tensor, const ParameterPtr ¶meter, | STATUS ConvertConstTensor(const tflite::TensorT *tensor, const ParameterPtr ¶meter, | ||||
| @@ -399,6 +399,24 @@ int CheckIfCNodeIsNull(const CNodePtr &node) { | |||||
| return lite::RET_OK; | return lite::RET_OK; | ||||
| } | } | ||||
| int CheckIfParameterIsNull(const ParameterPtr &node) { | |||||
| if (node == nullptr) { | |||||
| MS_LOG(ERROR) << "The Parameter is null."; | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||||
| return lite::RET_NULL_PTR; | |||||
| } | |||||
| return lite::RET_OK; | |||||
| } | |||||
| int CheckIfValueNodeIsNull(const ValueNodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| MS_LOG(ERROR) << "The ValueNode is null."; | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||||
| return lite::RET_NULL_PTR; | |||||
| } | |||||
| return lite::RET_OK; | |||||
| } | |||||
| int CheckIfVarIsNull(const VarPtr &var) { | int CheckIfVarIsNull(const VarPtr &var) { | ||||
| if (var == nullptr) { | if (var == nullptr) { | ||||
| MS_LOG(ERROR) << "The Var is null."; | MS_LOG(ERROR) << "The Var is null."; | ||||
| @@ -57,6 +57,10 @@ int CheckIfAnfNodeIsNull(const AnfNodePtr &node); | |||||
| int CheckIfCNodeIsNull(const CNodePtr &node); | int CheckIfCNodeIsNull(const CNodePtr &node); | ||||
| int CheckIfParameterIsNull(const ParameterPtr &node); | |||||
| int CheckIfValueNodeIsNull(const ValueNodePtr &node); | |||||
| int CheckIfVarIsNull(const VarPtr &var); | int CheckIfVarIsNull(const VarPtr &var); | ||||
| int CheckInputSize(const CNodePtr &node, int size); | int CheckInputSize(const CNodePtr &node, int size); | ||||
| @@ -0,0 +1,294 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/optimizer/fusion/mul_add_fusion.h" | |||||
| #include <memory> | |||||
| #include "ops/fusion/mul_fusion.h" | |||||
| #include "ops/fusion/add_fusion.h" | |||||
| #include "ops/fusion/scale_fusion.h" | |||||
| #include "ops/op_utils.h" | |||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| namespace mindspore::opt { | |||||
| namespace { | |||||
| constexpr size_t kMulInputsLength = 3; | |||||
| constexpr size_t kAddInputsLength = 3; | |||||
| } // namespace | |||||
| const BaseRef MulAddFusion::DefinePattern() const { | |||||
| auto mul_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>); | |||||
| auto add_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>); | |||||
| return VectorRef({add_var, mul_var}); | |||||
| } | |||||
| bool MulAddFusion::ScaleInputShapeValid() const { | |||||
| MS_ASSERT(scale_tensor_ != nullptr); | |||||
| MS_ASSERT(bias_tensor_ != nullptr); | |||||
| auto scale_shape = scale_tensor_->shape_c(); | |||||
| auto offset_shape = bias_tensor_->shape_c(); | |||||
| if (mul_input_shape_.size() < scale_shape.size() || scale_shape.size() == 0) { | |||||
| return false; | |||||
| } | |||||
| size_t rank_diff = mul_input_shape_.size() - scale_shape.size(); | |||||
| for (size_t i = 0; i < scale_shape.size(); ++i) { | |||||
| if (mul_input_shape_[i + rank_diff] != scale_shape[i]) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (scale_shape != offset_shape) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool MulAddFusion::CheckMulNode(const FuncGraphPtr &func_graph) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| if (mul_anode_ == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (IsMultiOutputTensors(func_graph, mul_anode_)) { | |||||
| MS_LOG(DEBUG) << "Mul op has multi-output"; | |||||
| return false; | |||||
| } | |||||
| auto mul_node = mul_anode_->cast<CNodePtr>(); | |||||
| if (!CheckPrimitiveType(mul_node, prim::kPrimMulFusion)) { | |||||
| MS_LOG(DEBUG) << "Mul add fusion pass match only mul or add"; | |||||
| return false; | |||||
| } | |||||
| auto mul_primitive = GetValueNode<std::shared_ptr<ops::MulFusion>>(mul_node->input(0)); | |||||
| MS_ASSERT(mul_primitive != nullptr); | |||||
| auto mul_act_type = mul_primitive->get_activation_type(); | |||||
| if (mul_act_type != ActivationType::NO_ACTIVATION) { | |||||
| MS_LOG(DEBUG) << "Only support mul node with no activation"; | |||||
| return false; | |||||
| } | |||||
| if (CheckIfCNodeIsNull(mul_node) != lite::RET_OK || CheckInputSize(mul_node, kMulInputsLength) != lite::RET_OK) { | |||||
| MS_LOG(DEBUG) << "Mul op is null or has error input size"; | |||||
| return false; | |||||
| } | |||||
| // find mul's const input and mul input | |||||
| AnfNodePtr mul_pre_input_node = nullptr; | |||||
| AnfNodePtr mul_pre_const_node = nullptr; | |||||
| auto mul_pre_node_1 = mul_node->input(1); | |||||
| if (CheckIfAnfNodeIsNull(mul_pre_node_1) != lite::RET_OK) { | |||||
| MS_LOG(DEBUG) << "Pre-node of mul op is nullptr"; | |||||
| return false; | |||||
| } | |||||
| auto mul_pre_node_2 = mul_node->input(2); | |||||
| if (CheckIfAnfNodeIsNull(mul_pre_node_2) != lite::RET_OK) { | |||||
| MS_LOG(DEBUG) << "Pre-node of mul op is nullptr"; | |||||
| return false; | |||||
| } | |||||
| if (utils::isa<CNodePtr>(mul_pre_node_1) && !utils::isa<CNodePtr>(mul_pre_node_2)) { | |||||
| mul_pre_input_node = mul_pre_node_1; | |||||
| mul_pre_const_node = mul_pre_node_2; | |||||
| } else if (!utils::isa<CNodePtr>(mul_pre_node_1) && utils::isa<CNodePtr>(mul_pre_node_2)) { | |||||
| mul_pre_input_node = mul_pre_node_1; | |||||
| mul_pre_const_node = mul_pre_node_2; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Mul op should has a cnode input and a const input"; | |||||
| return false; | |||||
| } | |||||
| // check mul's const input | |||||
| tensor::TensorPtr mul_tensor = nullptr; | |||||
| if (utils::isa<ParameterPtr>(mul_pre_const_node)) { | |||||
| auto mul_bias_node = mul_pre_const_node->cast<ParameterPtr>(); | |||||
| MS_ASSERT(mul_bias_node != nullptr); | |||||
| if (!mul_bias_node->has_default()) { | |||||
| MS_LOG(DEBUG) << "Const input of mul op should has data"; | |||||
| return false; | |||||
| } | |||||
| mul_tensor = mul_bias_node->default_param()->cast<tensor::TensorPtr>(); | |||||
| } else if (utils::isa<ValueNodePtr>(mul_pre_const_node)) { | |||||
| auto mul_bias_node = mul_pre_const_node->cast<ValueNodePtr>(); | |||||
| MS_ASSERT(mul_bias_node != nullptr); | |||||
| if (mul_bias_node->value() == nullptr) { | |||||
| MS_LOG(DEBUG) << "Const input of mul op should has data"; | |||||
| return false; | |||||
| } | |||||
| mul_tensor = mul_bias_node->value()->cast<tensor::TensorPtr>(); | |||||
| } else { | |||||
| MS_ASSERT(false); | |||||
| } | |||||
| if (mul_tensor == nullptr) { | |||||
| MS_LOG(DEBUG) << "Const input of add op should has data"; | |||||
| return false; | |||||
| } | |||||
| mul_input_anode_ = mul_pre_input_node; | |||||
| mul_const_anode_ = mul_pre_const_node; | |||||
| scale_tensor_ = mul_tensor; | |||||
| return true; | |||||
| } | |||||
| bool MulAddFusion::CheckAddNode() const { | |||||
| if (add_anode_ == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto add_cnode = add_anode_->cast<CNodePtr>(); | |||||
| if (CheckIfCNodeIsNull(add_cnode) != lite::RET_OK || CheckInputSize(add_cnode, kAddInputsLength) != lite::RET_OK) { | |||||
| MS_LOG(DEBUG) << "Add op is null or has error input size"; | |||||
| return false; | |||||
| } | |||||
| if (!CheckPrimitiveType(add_cnode, prim::kPrimAddFusion)) { | |||||
| MS_LOG(DEBUG) << "Mul add fusion pass match only mul or add"; | |||||
| return false; | |||||
| } | |||||
| auto add_primitive = GetValueNode<std::shared_ptr<ops::AddFusion>>(add_cnode->input(0)); | |||||
| MS_ASSERT(add_primitive != nullptr); | |||||
| auto add_act_type = add_primitive->get_activation_type(); | |||||
| if (add_act_type != ActivationType::RELU && add_act_type != ActivationType::RELU6 && | |||||
| add_act_type != ActivationType::NO_ACTIVATION) { | |||||
| MS_LOG(DEBUG) << "Only support add node with relu or relu6 or no activation"; | |||||
| return false; | |||||
| } | |||||
| scale_act_type_ = add_act_type; | |||||
| // find add's const input and mul input | |||||
| AnfNodePtr add_pre_input_node = nullptr; | |||||
| AnfNodePtr add_pre_const_node = nullptr; | |||||
| auto add_pre_node_1 = add_cnode->input(1); | |||||
| if (CheckIfAnfNodeIsNull(add_pre_node_1) != lite::RET_OK) { | |||||
| MS_LOG(DEBUG) << "Pre-node of add op is nullptr"; | |||||
| return false; | |||||
| } | |||||
| auto add_pre_node_2 = add_cnode->input(2); | |||||
| if (CheckIfAnfNodeIsNull(add_pre_node_2) != lite::RET_OK) { | |||||
| MS_LOG(DEBUG) << "Pre-node of add op is nullptr"; | |||||
| return false; | |||||
| } | |||||
| if (utils::isa<CNodePtr>(add_pre_node_1) && !utils::isa<CNodePtr>(add_pre_node_2)) { | |||||
| add_pre_input_node = add_pre_node_1; | |||||
| add_pre_const_node = add_pre_node_2; | |||||
| } else if (!utils::isa<CNodePtr>(add_pre_node_1) && utils::isa<CNodePtr>(add_pre_node_2)) { | |||||
| add_pre_input_node = add_pre_node_2; | |||||
| add_pre_const_node = add_pre_node_1; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Add op should has a cnode input and a const input"; | |||||
| return false; | |||||
| } | |||||
| // check add's const input | |||||
| tensor::TensorPtr add_tensor = nullptr; | |||||
| if (utils::isa<ParameterPtr>(add_pre_const_node)) { | |||||
| auto add_bias_node = add_pre_const_node->cast<ParameterPtr>(); | |||||
| MS_ASSERT(add_bias_node != nullptr); | |||||
| if (!add_bias_node->has_default()) { | |||||
| MS_LOG(DEBUG) << "Const input of add op should has data"; | |||||
| return false; | |||||
| } | |||||
| add_tensor = add_bias_node->default_param()->cast<tensor::TensorPtr>(); | |||||
| } else if (utils::isa<ValueNodePtr>(add_pre_const_node)) { | |||||
| auto add_bias_node = add_pre_const_node->cast<ValueNodePtr>(); | |||||
| MS_ASSERT(add_bias_node != nullptr); | |||||
| if (add_bias_node->value() == nullptr) { | |||||
| MS_LOG(DEBUG) << "Const input of add op should has data"; | |||||
| return false; | |||||
| } | |||||
| add_tensor = add_bias_node->value()->cast<tensor::TensorPtr>(); | |||||
| } else { | |||||
| MS_ASSERT(false); | |||||
| } | |||||
| if (add_tensor == nullptr) { | |||||
| MS_LOG(DEBUG) << "Const input of add op should has data"; | |||||
| return false; | |||||
| } | |||||
| mul_anode_ = add_pre_input_node; | |||||
| add_const_anode_ = add_pre_const_node; | |||||
| bias_tensor_ = add_tensor; | |||||
| return true; | |||||
| } | |||||
| bool MulAddFusion::GetMulInputShape() const { | |||||
| MS_ASSERT(mul_input_anode_ != nullptr); | |||||
| ShapeVector mul_input_shape; | |||||
| AbstractBasePtr mul_input_abstract = nullptr; | |||||
| if (utils::isa<ParameterPtr>(mul_input_anode_)) { | |||||
| auto mul_input_node = mul_input_anode_->cast<ParameterPtr>(); | |||||
| MS_ASSERT(mul_bias_node != nullptr); | |||||
| mul_input_abstract = mul_input_node->abstract(); | |||||
| } else if (utils::isa<ValueNodePtr>(mul_input_anode_)) { | |||||
| auto mul_input_node = mul_input_anode_->cast<ValueNodePtr>(); | |||||
| MS_ASSERT(mul_input_node != nullptr); | |||||
| mul_input_abstract = mul_input_node->abstract(); | |||||
| } else if (utils::isa<CNodePtr>(mul_input_anode_)) { | |||||
| auto mul_input_node = mul_input_anode_->cast<CNodePtr>(); | |||||
| MS_ASSERT(mul_input_node != nullptr); | |||||
| mul_input_abstract = mul_input_node->abstract(); | |||||
| } else { | |||||
| MS_ASSERT(false); | |||||
| } | |||||
| if (mul_input_abstract == nullptr) { | |||||
| MS_LOG(DEBUG) << "Mul input node has no abstract"; | |||||
| return false; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(mul_input_abstract)) { | |||||
| MS_LOG(DEBUG) << "Abstract of mul input node should be AbstractTensor"; | |||||
| return false; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(mul_input_abstract); | |||||
| MS_ASSERT(abstract_tensor != nullptr); | |||||
| MS_ASSERT(abstract_tensor->BuildShape() != nullptr); | |||||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||||
| MS_LOG(DEBUG) << "BuildShape of abstract of mul input node should be ShapePtr"; | |||||
| return false; | |||||
| } | |||||
| mul_input_shape_ = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||||
| return true; | |||||
| } | |||||
| const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| MS_ASSERT(node != nullptr); | |||||
| if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { | |||||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); | |||||
| return nullptr; | |||||
| } | |||||
| add_anode_ = node; | |||||
| if (!CheckAddNode()) { | |||||
| MS_LOG(DEBUG) << "Add op is not suit for mul-add-fusion: " << node->fullname_with_scope(); | |||||
| return nullptr; | |||||
| } | |||||
| MS_ASSERT(mul_anode_ != nullptr); | |||||
| MS_ASSERT(bias_tensor_ != nullptr); | |||||
| MS_ASSERT(add_const_anode_ != nullptr); | |||||
| if (!CheckMulNode(func_graph)) { | |||||
| MS_LOG(DEBUG) << "Mul op is not suit for mul-add-fusion: " << mul_anode_->fullname_with_scope(); | |||||
| return nullptr; | |||||
| } | |||||
| MS_ASSERT(mul_input_anode_ != nullptr); | |||||
| MS_ASSERT(scale_tensor_ != nullptr); | |||||
| MS_ASSERT(mul_const_anode_ != nullptr); | |||||
| if (!GetMulInputShape()) { | |||||
| MS_LOG(DEBUG) << "Get input shape of mul op failed"; | |||||
| return nullptr; | |||||
| } | |||||
| // scale requires scale shape tail sub of input shape, scale shape same as bias shape | |||||
| if (!ScaleInputShapeValid()) { | |||||
| MS_LOG(DEBUG) << "Check input shape, scale shape and bias shape failed"; | |||||
| return nullptr; | |||||
| } | |||||
| // create scale primitive | |||||
| auto scale_primitive = new (std::nothrow) mindspore::ops::ScaleFusion(); | |||||
| if (scale_primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new scale primitive failed"; | |||||
| return nullptr; | |||||
| } | |||||
| scale_primitive->set_activation_type(scale_act_type_); | |||||
| scale_primitive->set_axis(0 - bias_tensor_->shape_c().size()); | |||||
| // create scale op | |||||
| auto scale_node = func_graph->NewCNode(std::shared_ptr<ops::PrimitiveC>(scale_primitive), | |||||
| {mul_input_anode_, mul_const_anode_, add_const_anode_}); | |||||
| return scale_node; | |||||
| } | |||||
| } // namespace mindspore::opt | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MUL_ADD_FUSION_H_ | |||||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_MUL_ADD_FUSION_H_ | |||||
| #include <string> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "utils/check_convert_utils.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class MulAddFusion : public PatternProcessPass { | |||||
| public: | |||||
| explicit MulAddFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion") | |||||
| : PatternProcessPass(name, multigraph) {} | |||||
| ~MulAddFusion() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| bool CheckMulNode(const FuncGraphPtr &func_graph) const; | |||||
| bool CheckAddNode() const; | |||||
| bool GetMulInputShape() const; | |||||
| bool ScaleInputShapeValid() const; | |||||
| private: | |||||
| mutable AnfNodePtr mul_anode_ = nullptr; | |||||
| mutable AnfNodePtr mul_input_anode_ = nullptr; | |||||
| mutable AnfNodePtr mul_const_anode_ = nullptr; | |||||
| mutable ShapeVector mul_input_shape_; | |||||
| mutable AnfNodePtr add_anode_ = nullptr; | |||||
| mutable AnfNodePtr add_const_anode_ = nullptr; | |||||
| mutable tensor::TensorPtr scale_tensor_ = nullptr; | |||||
| mutable tensor::TensorPtr bias_tensor_ = nullptr; | |||||
| mutable ActivationType scale_act_type_ = ActivationType::NO_ACTIVATION; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_ | |||||
| @@ -256,11 +256,12 @@ ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &fun | |||||
| auto parameter = func_graph->add_parameter(); | auto parameter = func_graph->add_parameter(); | ||||
| parameter->set_name(name); | parameter->set_name(name); | ||||
| std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape_vector); | |||||
| if (abstract_tensor == nullptr) { | |||||
| auto abstract = lite::CreateTensorAbstract(shape_vector, type); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| parameter->set_abstract(abstract_tensor); | |||||
| parameter->set_abstract(abstract); | |||||
| auto gate_weight_default = std::make_shared<tensor::Tensor>(type, shape_vector); | auto gate_weight_default = std::make_shared<tensor::Tensor>(type, shape_vector); | ||||
| if (gate_weight_default == nullptr) { | if (gate_weight_default == nullptr) { | ||||
| @@ -502,13 +502,12 @@ CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_grap | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim, {node, get_item_value}); | CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim, {node, get_item_value}); | ||||
| std::vector<int64_t> shape_vector; | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "create abstract_tensor failed"; | |||||
| auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| get_item_cnode->set_abstract(abstract_tensor); | |||||
| get_item_cnode->set_abstract(abstract); | |||||
| get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" + | get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" + | ||||
| std::to_string(item_index)); | std::to_string(item_index)); | ||||
| return get_item_cnode; | return get_item_cnode; | ||||
| @@ -581,13 +580,12 @@ STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int o | |||||
| MS_ASSERT(cnode != nullptr); | MS_ASSERT(cnode != nullptr); | ||||
| AbstractBasePtrList abstract_list; | AbstractBasePtrList abstract_list; | ||||
| for (int i = 0; i < output_num; ++i) { | for (int i = 0; i < output_num; ++i) { | ||||
| std::vector<int64_t> shape_vector; | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "create abstract_tensor failed"; | |||||
| auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| abstract_list.emplace_back(abstract_tensor); | |||||
| abstract_list.emplace_back(abstract); | |||||
| } | } | ||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | ||||
| if (abstract_tuple == nullptr) { | if (abstract_tuple == nullptr) { | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "ops/return.h" | #include "ops/return.h" | ||||
| #include "ops/tuple_get_item.h" | #include "ops/tuple_get_item.h" | ||||
| #include "tools/converter/ops/while.h" | #include "tools/converter/ops/while.h" | ||||
| #include "tools/common/tensor_util.h" | |||||
| namespace { | namespace { | ||||
| mindspore::ValueNodePtr GetWhileAnfPrim() { | mindspore::ValueNodePtr GetWhileAnfPrim() { | ||||
| @@ -207,9 +208,13 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() { | |||||
| auto node_users = manager->node_users()[node]; | auto node_users = manager->node_users()[node]; | ||||
| for (auto &node_user : node_users) { | for (auto &node_user : node_users) { | ||||
| // new getitem | // new getitem | ||||
| AbstractBasePtrList abstractList; | |||||
| std::vector<int64_t> shape_vector; | |||||
| abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||||
| AbstractBasePtrList abstract_list; | |||||
| auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| abstract_list.emplace_back(abstract); | |||||
| auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); | auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); | ||||
| if (tuple_get_item_prim_ptr == nullptr) { | if (tuple_get_item_prim_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; | MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; | ||||
| @@ -225,12 +230,12 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() { | |||||
| std::vector<AnfNodePtr> inputs{tuple_get_item_prim, while_node_, getItemValue}; | std::vector<AnfNodePtr> inputs{tuple_get_item_prim, while_node_, getItemValue}; | ||||
| CNodePtr get_item_node = fg_->NewCNode(inputs); | CNodePtr get_item_node = fg_->NewCNode(inputs); | ||||
| std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | ||||
| auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "create AbstractTensor failed"; | |||||
| return RET_NULL_PTR; | |||||
| auto get_item_node_abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32); | |||||
| if (get_item_node_abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create get_item_node_abstract failed"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| get_item_node->set_abstract(abstract); | |||||
| get_item_node->set_abstract(get_item_node_abstract); | |||||
| get_item_node->set_fullname_with_scope(output_item_name); | get_item_node->set_fullname_with_scope(output_item_name); | ||||
| // set | // set | ||||
| if (fg_->nodes().contains(node_user.first)) { | if (fg_->nodes().contains(node_user.first)) { | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "tools/converter/quantizer/quant_cast.h" | #include "tools/converter/quantizer/quant_cast.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "tools/common/tensor_util.h" | |||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| @@ -101,13 +102,16 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto type_id = static_cast<TypeId>(weight_value->data_type()); | auto type_id = static_cast<TypeId>(weight_value->data_type()); | ||||
| auto type_ptr = TypeIdToType(type_id); | |||||
| auto shape = weight_value->shape(); | auto shape = weight_value->shape(); | ||||
| std::vector<int64_t> shape_vector; | std::vector<int64_t> shape_vector; | ||||
| (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | ||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | [](const int32_t &value) { return static_cast<int64_t>(value); }); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| weight_node->set_abstract(abstract_tensor); | |||||
| auto abstract = lite::CreateTensorAbstract(shape_vector, type_id); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weight_node->set_abstract(abstract); | |||||
| } | } | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include "tools/common/node_util.h" | #include "tools/common/node_util.h" | ||||
| #include "tools/common/tensor_util.h" | #include "tools/common/tensor_util.h" | ||||
| #include "src/common/common.h" | #include "src/common/common.h" | ||||
| #include "src/common/tensor_util.h" | |||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "src/ops/ops_utils.h" | #include "src/ops/ops_utils.h" | ||||
| #include "src/runtime/infer_manager.h" | #include "src/runtime/infer_manager.h" | ||||
| @@ -28,19 +29,6 @@ | |||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| namespace { | namespace { | ||||
| constexpr size_t INITIAL_SIZE = 1024; | constexpr size_t INITIAL_SIZE = 1024; | ||||
| tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) { | |||||
| std::vector<int> shape(tensor->shape()); | |||||
| std::vector<int64_t> shape_vector; | |||||
| std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), | |||||
| [](const int32_t &value) { return static_cast<int64_t>(value); }); | |||||
| auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector); | |||||
| if (tensor_info == nullptr) { | |||||
| MS_LOG(ERROR) << "new tensor::Tensor failed"; | |||||
| return nullptr; | |||||
| } | |||||
| return tensor_info; | |||||
| } | |||||
| bool IsSpecialType(const CNodePtr &cnode) { | bool IsSpecialType(const CNodePtr &cnode) { | ||||
| if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || | if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || | ||||
| CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) || | CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) || | ||||
| @@ -75,21 +63,14 @@ STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { | |||||
| abstract::AbstractBasePtr InferShapePass::ConvertLiteTensorToAbstract(lite::Tensor *tensor) { | |||||
| MS_ASSERT(nullptr != tensor); | MS_ASSERT(nullptr != tensor); | ||||
| std::vector<int> shape(tensor->shape()); | |||||
| auto shape = tensor->shape(); | |||||
| auto type_id = static_cast<TypeId>(tensor->data_type()); | auto type_id = static_cast<TypeId>(tensor->data_type()); | ||||
| auto type_ptr = TypeIdToType(type_id); | |||||
| std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | ||||
| auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| if (new_abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "new AbstractTensor failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_info = NewTensorInfo(tensor); | |||||
| auto tensor_info = lite::CreateTensorInfo(nullptr, 0, shape_vector, type_id); | |||||
| if (tensor_info == nullptr) { | if (tensor_info == nullptr) { | ||||
| MS_LOG(ERROR) << "new tensor::Tensor failed"; | |||||
| MS_LOG(DEBUG) << "Create tensor info failed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -112,8 +93,12 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } | } | ||||
| new_abstract->set_value(tensor_info); | |||||
| return new_abstract; | |||||
| auto abstract = tensor_info->ToAbstract(); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(DEBUG) << "Create tensor abstarct failed"; | |||||
| return nullptr; | |||||
| } | |||||
| return abstract; | |||||
| } | } | ||||
| STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { | STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { | ||||
| @@ -143,8 +128,6 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { | |||||
| std::vector<int32_t> shape; | std::vector<int32_t> shape; | ||||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), | (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), | ||||
| [](const int64_t &value) { return static_cast<int32_t>(value); }); | [](const int64_t &value) { return static_cast<int32_t>(value); }); | ||||
| auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| auto new_tensor_info = std::make_shared<tensor::Tensor>(type_ptr->type_id(), shape_vector); | auto new_tensor_info = std::make_shared<tensor::Tensor>(type_ptr->type_id(), shape_vector); | ||||
| if (parameter->has_default()) { | if (parameter->has_default()) { | ||||
| auto old_tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param()); | auto old_tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param()); | ||||
| @@ -155,7 +138,11 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| new_abstract->set_value(new_tensor_info); | |||||
| auto new_abstract = new_tensor_info->ToAbstract(); | |||||
| if (new_abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| parameter->set_abstract(new_abstract); | parameter->set_abstract(new_abstract); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -304,7 +291,7 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu | |||||
| } | } | ||||
| if (output_tensors.size() == 1) { | if (output_tensors.size() == 1) { | ||||
| auto tensor = output_tensors.front(); | auto tensor = output_tensors.front(); | ||||
| auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor); | |||||
| auto new_abstract = ConvertLiteTensorToAbstract(tensor); | |||||
| if (new_abstract == nullptr) { | if (new_abstract == nullptr) { | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -313,7 +300,7 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu | |||||
| AbstractBasePtrList abstract_list; | AbstractBasePtrList abstract_list; | ||||
| for (size_t i = 0; i < output_tensors.size(); i++) { | for (size_t i = 0; i < output_tensors.size(); i++) { | ||||
| auto tensor = output_tensors.front(); | auto tensor = output_tensors.front(); | ||||
| auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor); | |||||
| auto new_abstract = ConvertLiteTensorToAbstract(tensor); | |||||
| if (new_abstract == nullptr) { | if (new_abstract == nullptr) { | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -36,7 +36,7 @@ class InferShapePass : public Pass { | |||||
| private: | private: | ||||
| void FreeTensors(std::vector<lite::Tensor *> *tensors); | void FreeTensors(std::vector<lite::Tensor *> *tensors); | ||||
| abstract::AbstractTensorPtr ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor); | |||||
| abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor); | |||||
| STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors); | STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors); | ||||
| STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors); | STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors); | ||||
| STATUS SetParameterAbstract(const ParameterPtr ¶meter); | STATUS SetParameterAbstract(const ParameterPtr ¶meter); | ||||
| @@ -179,23 +179,23 @@ int MindirAdjustPass::ValueNodeInt64Convert(AnfNodePtr anf_node) { | |||||
| if (!utils::isa<ValueNodePtr>(anf_node)) { | if (!utils::isa<ValueNodePtr>(anf_node)) { | ||||
| return lite::RET_NO_CHANGE; | return lite::RET_NO_CHANGE; | ||||
| } | } | ||||
| auto valueNode = anf_node->cast<ValueNodePtr>(); | |||||
| if (valueNode->abstract() == nullptr) { | |||||
| auto value_node = anf_node->cast<ValueNodePtr>(); | |||||
| if (value_node->abstract() == nullptr) { | |||||
| return lite::RET_NO_CHANGE; | return lite::RET_NO_CHANGE; | ||||
| } | } | ||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueNode->abstract()); | |||||
| if (abstractTensor == nullptr) { | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_node->abstract()); | |||||
| if (abstract_tensor == nullptr) { | |||||
| return lite::RET_NO_CHANGE; | return lite::RET_NO_CHANGE; | ||||
| } | } | ||||
| auto value = abstractTensor->GetValueTrack(); | |||||
| auto value = abstract_tensor->GetValueTrack(); | |||||
| if (value != nullptr && value->isa<tensor::Tensor>()) { | if (value != nullptr && value->isa<tensor::Tensor>()) { | ||||
| if (abstractTensor->element() == nullptr) { | |||||
| if (abstract_tensor->element() == nullptr) { | |||||
| MS_LOG(ERROR) << "abstractTensor->element() is nullptr."; | MS_LOG(ERROR) << "abstractTensor->element() is nullptr."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto typePtr = abstractTensor->element()->GetTypeTrack(); | |||||
| if (typePtr->type_id() == kNumberTypeInt64) { | |||||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | |||||
| auto type_ptr = abstract_tensor->element()->GetTypeTrack(); | |||||
| if (type_ptr->type_id() == kNumberTypeInt64) { | |||||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||||
| auto dest_tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, shape_vector); | auto dest_tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, shape_vector); | ||||
| auto *dest_data_buf = reinterpret_cast<int32_t *>(dest_tensor_info->data_c()); | auto *dest_data_buf = reinterpret_cast<int32_t *>(dest_tensor_info->data_c()); | ||||
| auto src_tensor_info = value->cast<tensor::TensorPtr>(); | auto src_tensor_info = value->cast<tensor::TensorPtr>(); | ||||
| @@ -204,10 +204,10 @@ int MindirAdjustPass::ValueNodeInt64Convert(AnfNodePtr anf_node) { | |||||
| for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) { | for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) { | ||||
| dest_data_buf[i] = src_data_buf[i]; | dest_data_buf[i] = src_data_buf[i]; | ||||
| } | } | ||||
| abstractTensor->set_value(dest_tensor_info); | |||||
| abstractTensor->set_type(TypeIdToType(kNumberTypeInt32)); | |||||
| abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt32)); | |||||
| valueNode->set_value(dest_tensor_info); | |||||
| abstract_tensor->set_value(dest_tensor_info); | |||||
| abstract_tensor->set_type(TypeIdToType(kNumberTypeInt32)); | |||||
| abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt32)); | |||||
| value_node->set_value(dest_tensor_info); | |||||
| } | } | ||||
| } | } | ||||
| return lite::RET_NO_CHANGE; | return lite::RET_NO_CHANGE; | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "ops/transpose.h" | #include "ops/transpose.h" | ||||
| #include "tools/optimizer/common/gllo_utils.h" | #include "tools/optimizer/common/gllo_utils.h" | ||||
| #include "tools/common/tensor_util.h" | |||||
| using mindspore::lite::converter::FmkType_CAFFE; | using mindspore::lite::converter::FmkType_CAFFE; | ||||
| using mindspore::lite::converter::FmkType_MS; | using mindspore::lite::converter::FmkType_MS; | ||||
| @@ -92,9 +93,20 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu | |||||
| auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm"); | auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm"); | ||||
| auto prim = std::make_shared<ops::Transpose>(); | auto prim = std::make_shared<ops::Transpose>(); | ||||
| auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node}); | auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node}); | ||||
| auto type_ptr = TypeIdToType(kTypeUnknown); | |||||
| std::vector<int64_t> shape_vector; | |||||
| auto abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| if (!weight_node->has_default()) { | |||||
| MS_LOG(DEBUG) << "Weight parameter should has default parameter."; | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| auto weight_tensor = weight_node->default_param()->cast<tensor::TensorPtr>(); | |||||
| if (weight_tensor == nullptr) { | |||||
| MS_LOG(DEBUG) << "Default parameter of weight parameter should be a tensor."; | |||||
| return lite::RET_ERROR; | |||||
| } | |||||
| auto abstract = lite::CreateTensorAbstract(weight_tensor->shape_c(), weight_tensor->data_type()); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| transpose_node->set_abstract(abstract); | transpose_node->set_abstract(abstract); | ||||
| transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_post"); | transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_post"); | ||||
| for (auto &adjust_node : adjust_nodes) { | for (auto &adjust_node : adjust_nodes) { | ||||
| @@ -177,11 +189,14 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto type_id = static_cast<TypeId>(weight_value->data_type()); | auto type_id = static_cast<TypeId>(weight_value->data_type()); | ||||
| auto type_ptr = TypeIdToType(type_id); | |||||
| auto shape = weight_value->shape(); | auto shape = weight_value->shape(); | ||||
| std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| weight_node->set_abstract(abstract_tensor); | |||||
| auto abstract = lite::CreateTensorAbstract(shape_vector, type_id); | |||||
| if (abstract == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor abstarct failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| weight_node->set_abstract(abstract); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||