From: @wangzhe128 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -114,7 +114,7 @@ endif () | |||||
| file(GLOB PROTO_FILE "" | file(GLOB PROTO_FILE "" | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto | ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/*.proto | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/proto/*.proto | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto) | ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto) | ||||
| ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) | ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) | ||||
| add_library(proto_mid OBJECT ${PROTO_SRCS}) | add_library(proto_mid OBJECT ${PROTO_SRCS}) | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "parser/caffe/caffe_converter.h" | #include "parser/caffe/caffe_converter.h" | ||||
| #include "parser/tflite/tflite_converter.h" | #include "parser/tflite/tflite_converter.h" | ||||
| #include "parser/onnx/onnx_converter.h" | #include "parser/onnx/onnx_converter.h" | ||||
| #include "parser/tf/tf_converter.h" | |||||
| #include "tools/anf_exporter/anf_exporter.h" | #include "tools/anf_exporter/anf_exporter.h" | ||||
| #include "tools/anf_importer/import_from_protobuf.h" | #include "tools/anf_importer/import_from_protobuf.h" | ||||
| #include "proto/onnx.pb.h" | #include "proto/onnx.pb.h" | ||||
| @@ -149,6 +150,10 @@ int RunConverter(int argc, const char **argv) { | |||||
| OnnxConverter onnxConverter; | OnnxConverter onnxConverter; | ||||
| fb_graph = onnxConverter.Convert(flags.get()); | fb_graph = onnxConverter.Convert(flags.get()); | ||||
| } break; | } break; | ||||
| case FmkType::FmkType_TF: { | |||||
| TFConverter tfConverter; | |||||
| fb_graph = tfConverter.Convert(flags.get()); | |||||
| } break; | |||||
| default: { | default: { | ||||
| MS_LOG(ERROR) << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " " | MS_LOG(ERROR) << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " " | ||||
| << GetErrorInfo(RET_INPUT_PARAM_INVALID); | << GetErrorInfo(RET_INPUT_PARAM_INVALID); | ||||
| @@ -126,8 +126,10 @@ int Flags::Init(int argc, const char **argv) { | |||||
| this->fmk = FmkType_TFLITE; | this->fmk = FmkType_TFLITE; | ||||
| } else if (this->fmkIn == "ONNX") { | } else if (this->fmkIn == "ONNX") { | ||||
| this->fmk = FmkType_ONNX; | this->fmk = FmkType_ONNX; | ||||
| } else if (this->fmkIn == "TF") { | |||||
| this->fmk = FmkType_TF; | |||||
| } else { | } else { | ||||
| std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MINDIR|ONNX"; | |||||
| std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX"; | |||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -44,6 +44,7 @@ class ModelParser { | |||||
| return func_graph; | return func_graph; | ||||
| } | } | ||||
| protected: | |||||
| virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | ||||
| const QuantType &quant_type = QuantType_QUANT_NONE) = 0; | const QuantType &quant_type = QuantType_QUANT_NONE) = 0; | ||||
| @@ -34,10 +34,10 @@ class CaffeModelParser : public ModelParser { | |||||
| virtual ~CaffeModelParser(); | virtual ~CaffeModelParser(); | ||||
| private: | |||||
| schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | ||||
| const QuantType &quant_type = QuantType_QUANT_NONE) override; | const QuantType &quant_type = QuantType_QUANT_NONE) override; | ||||
| private: | |||||
| STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); | STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); | ||||
| STATUS SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); | STATUS SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); | ||||
| @@ -45,12 +45,12 @@ class OnnxModelParser : public ModelParser { | |||||
| int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph, | int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph, | ||||
| const QuantType &quantType); | const QuantType &quantType); | ||||
| schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type = QuantType_QUANT_NONE) override; | |||||
| static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | ||||
| private: | private: | ||||
| schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||||
| const QuantType &quant_type = QuantType_QUANT_NONE) override; | |||||
| std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); | std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); | ||||
| STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph); | STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph); | ||||
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_activation_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF ActivationParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::ActivationT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (tf_op.op() == "Relu") { | |||||
| attr->type = schema::ActivationType_RELU; | |||||
| } else if (tf_op.op() == "Relu6") { | |||||
| attr->type = schema::ActivationType_RELU6; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Activation; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser()); | |||||
| TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFActivationParser : public TFNodeParser { | |||||
| public: | |||||
| TFActivationParser() = default; | |||||
| ~TFActivationParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_ | |||||
| @@ -0,0 +1,93 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_arithmetic_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF ArithmeticParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (tf_op.op() == "Add") { | |||||
| auto attr = std::make_unique<schema::AddT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Add; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tf_op.op() == "Sub") { | |||||
| auto attr = std::make_unique<schema::SubT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Sub; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tf_op.op() == "Mul") { | |||||
| auto attr = std::make_unique<schema::MulT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Mul; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tf_op.op() == "Div") { | |||||
| auto attr = std::make_unique<schema::DivT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Div; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| status = AddOpInput(tf_op, 1, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfAddParser("Add", new TFArithmeticParser()); | |||||
| TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser()); | |||||
| TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser()); | |||||
| TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFArithmeticParser : public TFNodeParser { | |||||
| public: | |||||
| TFArithmeticParser() = default; | |||||
| ~TFArithmeticParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_ | |||||
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_biasadd_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF BiasAddParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::BiasAddT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->axis = {1}; | |||||
| primitive->value.type = schema::PrimitiveType_Add; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfBiasAddParser("BiasAdd", new TFBiasAddParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFBiasAddParser : public TFNodeParser { | |||||
| public: | |||||
| TFBiasAddParser() = default; | |||||
| ~TFBiasAddParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_ | |||||
| @@ -0,0 +1,22 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_converter.h" | |||||
| #include "tools/converter/parser/tf/tf_model_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| TFConverter::TFConverter() { modelParser = new TFModelParser(); } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -13,22 +13,20 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "tools/converter/parser/tf/tf_add_parser.h" | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| #include "tools/converter/converter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TFAddParser::Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||||
| PrimitiveC *primitiveC, int *output_size) { | |||||
| auto attr = std::make_unique<schema::PrimitiveT>(); | |||||
| attr->value.type = schema::PrimitiveType_Add; | |||||
| primitiveC = PrimitiveC::Create(attr.release()); | |||||
| MS_LOG(INFO) << "primitive name" << primitiveC->type_name(); | |||||
| return RET_OK; | |||||
| } | |||||
| TFNodeRegistrar g_tfAddParser("Add", new TFAddParser()); | |||||
| class TFConverter : public Converter { | |||||
| public: | |||||
| TFConverter(); | |||||
| ~TFConverter() = default; | |||||
| }; | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_matmul_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFMatMulParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF MatMulParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::MatMulT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_a", &attr_value)) { | |||||
| attr->transposeA = attr_value.b(); | |||||
| } | |||||
| if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_b", &attr_value)) { | |||||
| attr->transposeB = attr_value.b(); | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_MatMul; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| status = AddOpInput(tf_op, 1, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfMatMulParser("MatMul", new TFMatMulParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFMatMulParser : public TFNodeParser { | |||||
| public: | |||||
| TFMatMulParser() = default; | |||||
| ~TFMatMulParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_ | |||||
| @@ -16,36 +16,236 @@ | |||||
| */ | */ | ||||
| #include "tools/converter/parser/tf/tf_model_parser.h" | #include "tools/converter/parser/tf/tf_model_parser.h" | ||||
| #include <map> | |||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include <set> | |||||
| #include "src/common/utils.h" | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "tools/converter/parser/tf/tf_util.h" | |||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | #include "tools/converter/parser/tf/tf_node_parser_registry.h" | ||||
| #include "src/param_value_lite.h" | #include "src/param_value_lite.h" | ||||
| #include "tools/common/protobuf_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| static const std::unordered_map<int, mindspore::TypeId> TF_TYPE_MAP = { | |||||
| {tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, | |||||
| {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, | |||||
| {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, | |||||
| {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, | |||||
| {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64}, | |||||
| {tensorflow::DT_BOOL, mindspore::kNumberTypeBool}}; | |||||
| TypeId GetTFDataType(const tensorflow::DataType &tf_data_type) { | |||||
| auto iter = TF_TYPE_MAP.find(tf_data_type); | |||||
| if (iter == TF_TYPE_MAP.end()) { | |||||
| MS_LOG(ERROR) << "unsupported TF data type: " << tf_data_type; | |||||
| return kTypeUnknown; | |||||
| } | |||||
| return iter->second; | |||||
| } | |||||
| AnfNodePtr TFModelParser::GetAnfNode(const std::string &name) { | |||||
| AnfNodePtr ret = nullptr; | |||||
| if (anf_node_map.find(name) != anf_node_map.end()) { | |||||
| ret = anf_node_map[name]; | |||||
| } else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { | |||||
| ret = anf_node_map[name + ":0"]; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) { | |||||
| if (node.op() != "Identity" && node.op() != "StopGradient") { | |||||
| return node.name(); | |||||
| } | |||||
| auto tmp_node = &node; | |||||
| while (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient") { | |||||
| tmp_node = tf_node_map[tmp_node->input(0)]; | |||||
| } | |||||
| return tmp_node->name(); | |||||
| } | |||||
| STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, | |||||
| const ParameterPtr ¶meter, std::vector<int64_t> *shape_vector) { | |||||
| MS_ASSERT(parameter != nullptr); | |||||
| MS_ASSERT(shape_vector != nullptr); | |||||
| const tensorflow::TensorProto &tensor_proto = attr_value.tensor(); | |||||
| const tensorflow::TensorShapeProto &tensor_shape = tensor_proto.tensor_shape(); | |||||
| int shape_size = 1; | |||||
| shape_vector->clear(); | |||||
| for (int i = 0; i < tensor_shape.dim_size(); i++) { | |||||
| shape_vector->push_back(tensor_shape.dim(i).size()); | |||||
| shape_size *= tensor_shape.dim(i).size(); | |||||
| } | |||||
| int tensor_size; | |||||
| auto param_value = std::make_shared<ParamValueLite>(); | |||||
| if (param_value == nullptr) { | |||||
| MS_LOG(ERROR) << "param_value is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) { | |||||
| auto tensor_data = new (std::nothrow) float[shape_size]; | |||||
| if (tensor_proto.float_val_size() == 1) { | |||||
| float value = tensor_proto.float_val(0); | |||||
| for (int i = 0; i < shape_size; i++) { | |||||
| tensor_data[i] = value; | |||||
| } | |||||
| } | |||||
| if (tensor_proto.tensor_content().size() == shape_size * sizeof(float)) { | |||||
| const auto addr = reinterpret_cast<const float *>(tensor_proto.tensor_content().data()); | |||||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| param_value->set_tensor_addr(tensor_data); | |||||
| tensor_size = shape_size * sizeof(float); | |||||
| } else if (type == kNumberTypeInt32) { | |||||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||||
| if (tensor_proto.int_val_size() == 1) { | |||||
| int value = tensor_proto.int_val(0); | |||||
| for (int i = 0; i < shape_size; i++) { | |||||
| tensor_data[i] = value; | |||||
| } | |||||
| } | |||||
| if (tensor_proto.tensor_content().size() == shape_size * sizeof(int32_t)) { | |||||
| const auto addr = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data()); | |||||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| param_value->set_tensor_addr(tensor_data); | |||||
| tensor_size = shape_size * sizeof(int); | |||||
| } else if (type == kNumberTypeBool) { | |||||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||||
| if (tensor_proto.bool_val_size() == 1) { | |||||
| int value = tensor_proto.bool_val(0); | |||||
| for (int i = 0; i < shape_size; i++) { | |||||
| tensor_data[i] = value; | |||||
| } | |||||
| } | |||||
| param_value->set_tensor_addr(tensor_data); | |||||
| tensor_size = shape_size * sizeof(int); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupport dataType: " << type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int> param_shape(shape_vector->begin(), shape_vector->end()); | |||||
| param_value->set_tensor_shape(param_shape); | |||||
| param_value->set_tensor_type(type); | |||||
| param_value->set_tensor_size(tensor_size); | |||||
| param_value->set_format(schema::Format::Format_NHWC); | |||||
| parameter->set_default_param(param_value); | |||||
| parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter) { | |||||
| MS_ASSERT(node != nullptr); | |||||
| MS_ASSERT(parameter != nullptr); | |||||
| tensorflow::AttrValue attr_value; | |||||
| TypeId type = kNumberTypeFloat32; | |||||
| if (TensorFlowUtils::FindAttrValue(node, "dtype", &attr_value)) { | |||||
| type = GetTFDataType(attr_value.type()); | |||||
| } | |||||
| auto type_ptr = TypeIdToType(type); | |||||
| std::vector<int> shape; | |||||
| if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) { | |||||
| auto &shape_attr = attr_value.shape(); | |||||
| for (int i = 0; i < shape_attr.dim_size(); ++i) { | |||||
| shape.push_back(shape_attr.dim(i).size()); | |||||
| } | |||||
| } | |||||
| std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | |||||
| if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { | |||||
| MS_LOG(INFO) << "Found value attr, means it has default value"; | |||||
| auto status = ConvertConstTensor(attr_value, type, parameter, &shape_vector); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| } else { | |||||
| parameter->set_name("placeholder_" + std::to_string(anf_node_map.size())); | |||||
| graph_input_names.emplace_back(parameter->name()); | |||||
| } | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||||
| if (abstract_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "abstract_tensor is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| anf_node_map[node.name()] = parameter; | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TFModelParser::ConvertGraphInputsAndConsts() { | |||||
| for (auto &pair : tf_node_map) { | |||||
| bool have_data_depend = false; | |||||
| for (int i = 0; i < pair.second->input_size(); ++i) { | |||||
| auto name = pair.second->input(i); | |||||
| if (!name.empty() && name[0] != '^') { // control_depend input start with "^" | |||||
| have_data_depend = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!have_data_depend) { | |||||
| auto parameter = funcGraphPtr->add_parameter(); | |||||
| if (ConvertParameter(*pair.second, parameter) != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert Parameter Node failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, | FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, | ||||
| const QuantType &quantType) { | const QuantType &quantType) { | ||||
| auto status = ValidateFileStr(modelFile, ".prototxt"); | |||||
| auto status = ValidateFileStr(modelFile, ".pb"); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; | |||||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (!TensorFlowUtils::TfReadProtoFromBinary(modelFile.c_str(), tf_graph_def.get())) { | |||||
| tf_graph_def = std::make_unique<tensorflow::GraphDef>(); | |||||
| if (tf_graph_def == nullptr) { | |||||
| MS_LOG(ERROR) << "tf_graph_def is nullptr"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||||
| return nullptr; | |||||
| } | |||||
| status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_graph_def.get()); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; | MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| funcGraphPtr = std::make_shared<FuncGraph>(); | funcGraphPtr = std::make_shared<FuncGraph>(); | ||||
| status = ConvertGraphInputs(); | |||||
| if (funcGraphPtr == nullptr) { | |||||
| MS_LOG(ERROR) << "funGraphPtr is nullptr"; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||||
| return nullptr; | |||||
| } | |||||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | |||||
| auto &node_def = tf_graph_def->node(i); | |||||
| tf_node_map[node_def.name()] = &node_def; | |||||
| } | |||||
| status = ConvertGraphInputsAndConsts(); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Convert graph inputs failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| status = ConvertOps(); | status = ConvertOps(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "Convert ops failed."; | MS_LOG(ERROR) << "Convert ops failed."; | ||||
| @@ -61,103 +261,36 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||||
| } | } | ||||
| return funcGraphPtr; | return funcGraphPtr; | ||||
| } | } | ||||
| STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef *node, ParameterPtr parameter) { | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { | |||||
| tensorflow::AttrValue data_type; | |||||
| tensorflow::DataType type = tensorflow::DT_FLOAT; | |||||
| // datatype | |||||
| if (TensorFlowUtils::FindAttrValue(node, "dtype", &data_type)) { | |||||
| type = data_type.type(); | |||||
| } | |||||
| const tensorflow::TensorProto &tensorProto = attr_value.tensor(); | |||||
| const tensorflow::TensorShapeProto &tensorShape = tensorProto.tensor_shape(); | |||||
| parameter = funcGraphPtr->add_parameter(); | |||||
| std::vector<int64_t> shape_vector; | |||||
| int shape_size = 1; | |||||
| shape_vector.resize(tensorShape.dim_size()); | |||||
| for (int i = 0; i < tensorShape.dim_size(); i++) { | |||||
| shape_vector[i] = tensorShape.dim(i).size(); | |||||
| shape_size *= shape_vector[i]; | |||||
| schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType) { | |||||
| MS_LOG(ERROR) << "TF Model Parser not return MetaGraph, use TFModelParser::Parse instead"; | |||||
| return nullptr; | |||||
| } | |||||
| STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, | |||||
| const std::vector<std::string> &input_names, std::vector<AnfNodePtr> *inputs) { | |||||
| // parse inputs | |||||
| for (size_t j = 0; j < input_names.size(); j++) { | |||||
| std::string input_name = input_names[j]; // input may be produced by multi-outputs node | |||||
| if (tf_node_map.find(input_name) != tf_node_map.end()) { | |||||
| auto input_node = tf_node_map[input_name]; | |||||
| input_name = GetOriginInputName(*input_node); | |||||
| } | } | ||||
| // convert const to paramter | |||||
| TypePtr ms_data_ype; | |||||
| auto paramValue = std::make_shared<ParamValueLite>(); | |||||
| if (type == tensorflow::DT_FLOAT) { | |||||
| ms_data_ype = kFloat32; | |||||
| auto tensor_data = new (std::nothrow) float[shape_size]; | |||||
| if (tensorProto.float_val_size() == 1) { | |||||
| float value = tensorProto.float_val(0); | |||||
| for (int i = 0; i < shape_size; i++) { | |||||
| tensor_data[i] = value; | |||||
| } | |||||
| } | |||||
| if (tensorProto.tensor_content().size() == shape_size * sizeof(float)) { | |||||
| const auto addr = reinterpret_cast<const float *>(tensorProto.tensor_content().data()); | |||||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| paramValue->set_tensor_addr(tensor_data); | |||||
| paramValue->set_tensor_size(shape_size * sizeof(float)); | |||||
| } else if (type == tensorflow::DT_INT32) { | |||||
| ms_data_ype = kInt32; | |||||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||||
| if (tensorProto.int_val_size() == 1) { | |||||
| int value = tensorProto.int_val(0); | |||||
| for (int i = 0; i < shape_size; i++) { | |||||
| tensor_data[i] = value; | |||||
| } | |||||
| } | |||||
| if (tensorProto.tensor_content().size() == shape_size * sizeof(int32_t)) { | |||||
| const auto addr = reinterpret_cast<const int32_t *>(tensorProto.tensor_content().data()); | |||||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| paramValue->set_tensor_addr(tensor_data); | |||||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||||
| } else if (type == tensorflow::DT_BOOL) { | |||||
| ms_data_ype = kFloat32; | |||||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||||
| if (tensorProto.bool_val_size() == 1) { | |||||
| int value = tensorProto.bool_val(0); | |||||
| for (int i = 0; i < shape_size; i++) { | |||||
| tensor_data[i] = value; | |||||
| } | |||||
| } | |||||
| paramValue->set_tensor_addr(tensor_data); | |||||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupport dataType," << node->name(); | |||||
| auto input = GetAnfNode(input_name); | |||||
| if (input == nullptr) { | |||||
| MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(ms_data_ype, shape_vector); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); | |||||
| std::vector<int> param_shape; | |||||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(param_shape), | |||||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||||
| MS_ASSERT(paramValue != nullptr); | |||||
| paramValue->set_tensor_shape(param_shape); | |||||
| paramValue->set_tensor_type(ms_data_ype->type_id()); | |||||
| paramValue->set_format(schema::Format::Format_NHWC); | |||||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||||
| parameter->set_default_param(paramValue); | |||||
| inputs->emplace_back(input); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size) { | |||||
| STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size) { | |||||
| if (output_size == 1) { | if (output_size == 1) { | ||||
| std::vector<int64_t> shape_vector; | std::vector<int64_t> shape_vector; | ||||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | ||||
| anf_node_map.insert(std::pair(op->name(), anf_node)); | |||||
| anf_node_map.insert(std::pair(op.name(), anf_node)); | |||||
| } else { | } else { | ||||
| AbstractBasePtrList abstractList; | AbstractBasePtrList abstractList; | ||||
| for (int output_idx = 0; output_idx < output_size; output_idx++) { | for (int output_idx = 0; output_idx < output_size; output_idx++) { | ||||
| @@ -174,113 +307,126 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const C | |||||
| CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs); | CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs); | ||||
| std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | ||||
| getItemCNode->set_fullname_with_scope(output_item_name); | getItemCNode->set_fullname_with_scope(output_item_name); | ||||
| anf_node_map.insert(std::pair(output_item_name, getItemCNode)); | |||||
| anf_node_map.insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); | |||||
| } | } | ||||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList)); | anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList)); | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TFModelParser::ConvertOps() { | STATUS TFModelParser::ConvertOps() { | ||||
| NoSupportOp::GetInstance()->SetFmkType("TENSORFLOW"); | |||||
| NoSupportOp::GetInstance()->SetFmkType("TF"); | |||||
| STATUS status = RET_OK; | STATUS status = RET_OK; | ||||
| // redirect identity to it's input0 | |||||
| ClipIdentityAndStopGradient(); | |||||
| int op_idx = 0; | int op_idx = 0; | ||||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | for (int i = 0; i < tf_graph_def->node_size(); i++) { | ||||
| auto node_def = tf_graph_def->mutable_node(i); | |||||
| tf_node_map[node_def->name()] = node_def; | |||||
| auto tf_op_type = node_def->op(); | |||||
| if (tf_op_type == "Placeholder" || tf_op_type == "Const") { | |||||
| auto &node_def = tf_graph_def->node(i); | |||||
| const auto &op_type = node_def.op(); | |||||
| if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(tf_op_type); | |||||
| auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type); | |||||
| if (node_parser == nullptr) { | if (node_parser == nullptr) { | ||||
| NoSupportOp::GetInstance()->InsertOp(tf_op_type); | |||||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | status = (status == RET_OK ? RET_NOT_FIND_OP : status); | ||||
| MS_LOG(ERROR) << "cannot find node parser:" << tf_op_type; | |||||
| MS_LOG(ERROR) << "cannot find node parser:" << op_type; | |||||
| continue; | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| PrimitiveC *primitiveC = nullptr; | PrimitiveC *primitiveC = nullptr; | ||||
| if (status == RET_OK) { | |||||
| int output_size = 1; | |||||
| status = node_parser->Parse(node_def, tf_graph_def, primitiveC, &output_size); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "node " << tf_op_type.c_str() << " parser failed"; | |||||
| continue; | |||||
| } | |||||
| std::vector<AnfNodePtr> opInputs = {NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC))}; | |||||
| // parse inputs | |||||
| for (int j = 0; j < node_def->input_size(); j++) { | |||||
| auto input_node = tf_node_map[node_def->input(i)]; | |||||
| // last node output | |||||
| if (anf_node_map.find(input_node->name()) != anf_node_map.end()) { | |||||
| opInputs.emplace_back(anf_node_map[input_node->name()]); | |||||
| continue; | |||||
| } | |||||
| // const tensor | |||||
| if (input_node->op() == "Const") { | |||||
| ParameterPtr parameter; | |||||
| if (ConvertConstTensor(input_node, parameter) != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert const tensor failed," << input_node->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| opInputs.emplace_back(parameter); | |||||
| anf_node_map[parameter->fullname_with_scope()] = parameter; | |||||
| continue; | |||||
| } | |||||
| MS_LOG(ERROR) << "node" << node_def->name() << "has inputs neither a node output nor a weight tensor."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto anf_node = funcGraphPtr->NewCNode(opInputs); | |||||
| anf_node->set_fullname_with_scope(tf_op_type + "-" + std::to_string(op_idx++)); | |||||
| int output_size; | |||||
| std::vector<std::string> input_names; | |||||
| status = node_parser->Parse(node_def, tf_node_map, &primitiveC, &input_names, &output_size); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "node " << op_type << " parser failed"; | |||||
| continue; | |||||
| } | |||||
| // parse outputs | |||||
| status = ConvertOutputTensor(node_def, anf_node, output_size); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| return status; | |||||
| } | |||||
| auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC)); | |||||
| if (value_node == nullptr) { | |||||
| MS_LOG(ERROR) << "value_node is nullptr"; | |||||
| status = RET_ERROR; | |||||
| continue; | |||||
| } | |||||
| std::vector<AnfNodePtr> inputs = {value_node}; | |||||
| status = ConvertInputNodes(node_def, input_names, &inputs); | |||||
| if (status != RET_OK) { | |||||
| continue; | |||||
| } | |||||
| // control_depends are not processed currently | |||||
| auto anf_node = funcGraphPtr->NewCNode(inputs); | |||||
| anf_node->set_fullname_with_scope(op_type + "-" + std::to_string(op_idx++)); | |||||
| status = ConvertOutputTensor(node_def, anf_node, output_size); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||||
| continue; | |||||
| } | } | ||||
| // redirect identity to it's input0 | |||||
| ClipIdentityAndStopGradient(); | |||||
| } | } | ||||
| return RET_OK; | |||||
| return status; | |||||
| } | } | ||||
| STATUS TFModelParser::ConvertGraphInputs() { | |||||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | |||||
| auto node_def = tf_graph_def->mutable_node(i); | |||||
| tf_node_map[node_def->name()] = node_def; | |||||
| if (node_def->op() == "Placeholder") { | |||||
| auto parameter = funcGraphPtr->add_parameter(); | |||||
| if (ConvertConstTensor(node_def, parameter) != RET_OK) { | |||||
| MS_LOG(ERROR) << "convert const tensor failed"; | |||||
| STATUS TFModelParser::ConvertGraphOutputs() { | |||||
| // because output of intermediate node in anf graph may also be output tensors, we search output tensors in | |||||
| // tf_node_map but not anf_node_map | |||||
| std::set<std::string> all_node_inputs; | |||||
| std::vector<AnfNodePtr> output_nodes; | |||||
| for (auto &pair : tf_node_map) { | |||||
| for (int i = 0; i < pair.second->input_size(); ++i) { | |||||
| all_node_inputs.insert(pair.second->input(i)); | |||||
| } | |||||
| } | |||||
| for (auto &pair : tf_node_map) { | |||||
| auto it = all_node_inputs.find(pair.first); | |||||
| if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity | |||||
| auto origin_name = GetOriginInputName(*(pair.second)); | |||||
| auto anf_node = GetAnfNode(origin_name); | |||||
| if (anf_node == nullptr) { | |||||
| MS_LOG(ERROR) << "can't find anf node"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| anf_node_map[node_def->name()] = parameter; | |||||
| graph_input_names.emplace_back(node_def->name()); | |||||
| output_nodes.push_back(anf_node); | |||||
| graph_output_names.push_back(anf_node->fullname_with_scope()); | |||||
| } | } | ||||
| } | } | ||||
| return RET_OK; | |||||
| } | |||||
| STATUS TFModelParser::ConvertGraphOutputs() { return RET_OK; } | |||||
| std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) { | |||||
| if (node.op() != "Identity" && node.op() != "StopGradient") { | |||||
| return node.name(); | |||||
| } | |||||
| auto tmpNode = node; | |||||
| while (tmpNode.op() == "Identity" || tmpNode.op() == "StopGradient") { | |||||
| tmpNode = *tf_node_map[tmpNode.input(0)]; | |||||
| } | |||||
| return tmpNode.name(); | |||||
| } | |||||
| if (output_nodes.size() > 1) { | |||||
| std::vector<AnfNodePtr> &make_tuple_inputs = output_nodes; | |||||
| auto make_tuple_prim_ptr = GetMakeTuplePrim(); | |||||
| if (make_tuple_prim_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); | |||||
| make_tuple_inputs.insert(output_nodes.begin(), make_tuple_prim); | |||||
| auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs); | |||||
| make_tuple_cnode->set_fullname_with_scope("return tuple"); | |||||
| void TFModelParser::ClipIdentityAndStopGradient() { | |||||
| for (auto &pair : tf_node_map) { | |||||
| pair.second = tf_node_map[GetOriginInputName(*pair.second)]; | |||||
| auto return_prim_ptr = GetReturnPrim(); | |||||
| if (return_prim_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto value_node = NewValueNode(return_prim_ptr); | |||||
| std::vector<AnfNodePtr> op_inputs = {value_node, make_tuple_cnode}; | |||||
| auto cnode = funcGraphPtr->NewCNode(op_inputs); | |||||
| cnode->set_fullname_with_scope("return"); | |||||
| funcGraphPtr->set_return(cnode); | |||||
| } else { | |||||
| auto return_prim_ptr = GetReturnPrim(); | |||||
| if (return_prim_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto value_node = NewValueNode(return_prim_ptr); | |||||
| std::vector<AnfNodePtr> op_inputs{value_node, output_nodes.front()}; | |||||
| auto return_cnode = funcGraphPtr->NewCNode(op_inputs); | |||||
| return_cnode->set_fullname_with_scope("return"); | |||||
| funcGraphPtr->set_return(return_cnode); | |||||
| } | } | ||||
| return RET_OK; | |||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -31,29 +31,36 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class TFModelParser { | |||||
| class TFModelParser : public ModelParser { | |||||
| public: | public: | ||||
| TFModelParser() = default; | TFModelParser() = default; | ||||
| ~TFModelParser() = default; | ~TFModelParser() = default; | ||||
| FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); | FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); | ||||
| protected: | |||||
| schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||||
| private: | private: | ||||
| STATUS ConvertConstTensor(const tensorflow::NodeDef *op, ParameterPtr parameter); | |||||
| STATUS ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size); | |||||
| AnfNodePtr GetAnfNode(const std::string &name); | |||||
| std::string GetOriginInputName(const tensorflow::NodeDef &node); | |||||
| STATUS ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr ¶meter, | |||||
| std::vector<int64_t> *shape_vector); | |||||
| STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter); | |||||
| STATUS ConvertGraphInputsAndConsts(); | |||||
| STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector<std::string> &input_names, | |||||
| std::vector<AnfNodePtr> *inputs); | |||||
| STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size); | |||||
| STATUS ConvertOps(); | STATUS ConvertOps(); | ||||
| STATUS ConvertGraphInputs(); | |||||
| STATUS ConvertGraphOutputs(); | STATUS ConvertGraphOutputs(); | ||||
| std::string GetOriginInputName(const tensorflow::NodeDef &node); | |||||
| void ClipIdentityAndStopGradient(); | |||||
| FuncGraphPtr funcGraphPtr; | FuncGraphPtr funcGraphPtr; | ||||
| std::unique_ptr<tensorflow::GraphDef> tf_graph_def; | std::unique_ptr<tensorflow::GraphDef> tf_graph_def; | ||||
| std::map<std::string, const tensorflow::NodeDef *> tf_node_map; | std::map<std::string, const tensorflow::NodeDef *> tf_node_map; | ||||
| std::unordered_map<std::string, AnfNodePtr> anf_node_map; | std::unordered_map<std::string, AnfNodePtr> anf_node_map; | ||||
| std::vector<std::string> graph_input_names, graphOutputNames; | |||||
| std::vector<std::string> graph_input_names; | |||||
| std::vector<std::string> graph_output_names; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -13,23 +13,21 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||||
| #include <memory> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | #include "tools/converter/parser/tf/tf_node_parser.h" | ||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class TFAddParser : public TFNodeParser { | |||||
| public: | |||||
| TFAddParser() = default; | |||||
| ~TFAddParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||||
| PrimitiveC *primitiveC, int *output_size) override; | |||||
| }; | |||||
| STATUS TFNodeParser::AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector<std::string> *inputs) { | |||||
| if (tf_op.input_size() <= idx) { | |||||
| MS_LOG(ERROR) << "input idx is greater than op input size"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| inputs->push_back(tf_op.input(idx)); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||||
| @@ -18,6 +18,7 @@ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| #include "tools/converter/parser/tf/tf_util.h" | #include "tools/converter/parser/tf/tf_util.h" | ||||
| @@ -32,12 +33,14 @@ class TFNodeParser { | |||||
| virtual ~TFNodeParser() = default; | virtual ~TFNodeParser() = default; | ||||
| virtual STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||||
| PrimitiveC *primitiveC, int *output_size) { | |||||
| virtual STATUS Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector<std::string> *inputs); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | ||||
| @@ -0,0 +1,109 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_split_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFSplitParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF SplitParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::SplitT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "num_split", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The attribute num_split should be specified"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| attr->numberSplit = (int32_t)(attr_value.i()); | |||||
| int split_dim_index; | |||||
| int input_index; | |||||
| if (tf_op.op() == "Split") { | |||||
| split_dim_index = 0; | |||||
| input_index = 1; | |||||
| } else { | |||||
| split_dim_index = 2; | |||||
| input_index = 0; | |||||
| } | |||||
| if (tf_node_map.find(tf_op.input(split_dim_index)) == tf_node_map.end()) { | |||||
| MS_LOG(ERROR) << "Find Split input split_dim node failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| const auto &split_dim_node = tf_node_map.at(tf_op.input(split_dim_index)); | |||||
| if (!TensorFlowUtils::FindAttrValue(*split_dim_node, "value", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The attribute splitDim should be specified"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| auto split_dim_tensor = attr_value.tensor(); | |||||
| attr->splitDim = split_dim_tensor.int_val(0); | |||||
| *output_size = attr->numberSplit; | |||||
| if (tf_op.op() == "SplitV") { | |||||
| if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { | |||||
| MS_LOG(ERROR) << "Find Split input size_splits failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto size_splits_node = tf_node_map.at(tf_op.input(1)); | |||||
| if (!TensorFlowUtils::FindAttrValue(*size_splits_node, "value", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The attribute size splits should be specified"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| auto size_splits_tensor = attr_value.tensor(); | |||||
| auto size = size_splits_tensor.tensor_content().size() / sizeof(int32_t); | |||||
| attr->sizeSplits.resize(size); | |||||
| auto ret = memcpy_s(attr->sizeSplits.data(), size * sizeof(int32_t), size_splits_tensor.tensor_content().data(), | |||||
| size * sizeof(int32_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Split; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto status = AddOpInput(tf_op, input_index, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfSplitParser("Split", new TFSplitParser()); | |||||
| TFNodeRegistrar g_tfSplitVParser("SplitV", new TFSplitParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFSplitParser : public TFNodeParser { | |||||
| public: | |||||
| TFSplitParser() = default; | |||||
| ~TFSplitParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_ | |||||
| @@ -22,9 +22,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, | |||||
| bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef &nodeDef, const std::string &attr_name, | |||||
| tensorflow::AttrValue *attr_value) { | tensorflow::AttrValue *attr_value) { | ||||
| const google::protobuf::Map<std::string, tensorflow::AttrValue> &attr = nodeDef->attr(); | |||||
| const google::protobuf::Map<std::string, tensorflow::AttrValue> &attr = nodeDef.attr(); | |||||
| const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(attr_name); | const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(attr_name); | ||||
| if (it != attr.end()) { | if (it != attr.end()) { | ||||
| *attr_value = it->second; | *attr_value = it->second; | ||||
| @@ -32,24 +32,5 @@ bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::str | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| bool TensorFlowUtils::TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message) { | |||||
| std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary); | |||||
| if (!fs.is_open()) { | |||||
| fprintf(stderr, "open failed %s\n", filepath); | |||||
| return false; | |||||
| } | |||||
| google::protobuf::io::IstreamInputStream input(&fs); | |||||
| google::protobuf::io::CodedInputStream codedstr(&input); | |||||
| codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2); | |||||
| bool success = message->ParseFromCodedStream(&codedstr); | |||||
| fs.close(); | |||||
| return success; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -26,10 +26,8 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| class TensorFlowUtils { | class TensorFlowUtils { | ||||
| public: | public: | ||||
| static bool FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, | |||||
| static bool FindAttrValue(const tensorflow::NodeDef &nodeDef, const std::string &attr_name, | |||||
| tensorflow::AttrValue *attr_value); | tensorflow::AttrValue *attr_value); | ||||
| static bool TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||