From: @wangzhe128 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -114,7 +114,7 @@ endif () | |||
| file(GLOB PROTO_FILE "" | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/*.proto | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/proto/*.proto | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto) | |||
| ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE}) | |||
| add_library(proto_mid OBJECT ${PROTO_SRCS}) | |||
| @@ -28,6 +28,7 @@ | |||
| #include "parser/caffe/caffe_converter.h" | |||
| #include "parser/tflite/tflite_converter.h" | |||
| #include "parser/onnx/onnx_converter.h" | |||
| #include "parser/tf/tf_converter.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "tools/anf_importer/import_from_protobuf.h" | |||
| #include "proto/onnx.pb.h" | |||
| @@ -149,6 +150,10 @@ int RunConverter(int argc, const char **argv) { | |||
| OnnxConverter onnxConverter; | |||
| fb_graph = onnxConverter.Convert(flags.get()); | |||
| } break; | |||
| case FmkType::FmkType_TF: { | |||
| TFConverter tfConverter; | |||
| fb_graph = tfConverter.Convert(flags.get()); | |||
| } break; | |||
| default: { | |||
| MS_LOG(ERROR) << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " " | |||
| << GetErrorInfo(RET_INPUT_PARAM_INVALID); | |||
| @@ -126,8 +126,10 @@ int Flags::Init(int argc, const char **argv) { | |||
| this->fmk = FmkType_TFLITE; | |||
| } else if (this->fmkIn == "ONNX") { | |||
| this->fmk = FmkType_ONNX; | |||
| } else if (this->fmkIn == "TF") { | |||
| this->fmk = FmkType_TF; | |||
| } else { | |||
| std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MINDIR|ONNX"; | |||
| std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX"; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| @@ -44,6 +44,7 @@ class ModelParser { | |||
| return func_graph; | |||
| } | |||
| protected: | |||
| virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type = QuantType_QUANT_NONE) = 0; | |||
| @@ -34,10 +34,10 @@ class CaffeModelParser : public ModelParser { | |||
| virtual ~CaffeModelParser(); | |||
| private: | |||
| schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type = QuantType_QUANT_NONE) override; | |||
| private: | |||
| STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); | |||
| STATUS SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache); | |||
| @@ -45,12 +45,12 @@ class OnnxModelParser : public ModelParser { | |||
| int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph, | |||
| const QuantType &quantType); | |||
| schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type = QuantType_QUANT_NONE) override; | |||
| static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | |||
| private: | |||
| schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file, | |||
| const QuantType &quant_type = QuantType_QUANT_NONE) override; | |||
| std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); | |||
| STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph); | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_activation_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF ActivationParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::ActivationT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tf_op.op() == "Relu") { | |||
| attr->type = schema::ActivationType_RELU; | |||
| } else if (tf_op.op() == "Relu6") { | |||
| attr->type = schema::ActivationType_RELU6; | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Activation; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser()); | |||
| TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFActivationParser : public TFNodeParser { | |||
| public: | |||
| TFActivationParser() = default; | |||
| ~TFActivationParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_ | |||
| @@ -0,0 +1,93 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_arithmetic_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF ArithmeticParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tf_op.op() == "Add") { | |||
| auto attr = std::make_unique<schema::AddT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Add; | |||
| primitive->value.value = attr.release(); | |||
| } else if (tf_op.op() == "Sub") { | |||
| auto attr = std::make_unique<schema::SubT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Sub; | |||
| primitive->value.value = attr.release(); | |||
| } else if (tf_op.op() == "Mul") { | |||
| auto attr = std::make_unique<schema::MulT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Mul; | |||
| primitive->value.value = attr.release(); | |||
| } else if (tf_op.op() == "Div") { | |||
| auto attr = std::make_unique<schema::DivT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Div; | |||
| primitive->value.value = attr.release(); | |||
| } | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| status = AddOpInput(tf_op, 1, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfAddParser("Add", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFArithmeticParser : public TFNodeParser { | |||
| public: | |||
| TFArithmeticParser() = default; | |||
| ~TFArithmeticParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_ | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_biasadd_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF BiasAddParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::BiasAddT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->axis = {1}; | |||
| primitive->value.type = schema::PrimitiveType_Add; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfBiasAddParser("BiasAdd", new TFBiasAddParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFBiasAddParser : public TFNodeParser { | |||
| public: | |||
| TFBiasAddParser() = default; | |||
| ~TFBiasAddParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_ | |||
| @@ -0,0 +1,22 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_converter.h" | |||
| #include "tools/converter/parser/tf/tf_model_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| TFConverter::TFConverter() { modelParser = new TFModelParser(); } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -13,22 +13,20 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_add_parser.h" | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/converter/converter.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFAddParser::Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||
| PrimitiveC *primitiveC, int *output_size) { | |||
| auto attr = std::make_unique<schema::PrimitiveT>(); | |||
| attr->value.type = schema::PrimitiveType_Add; | |||
| primitiveC = PrimitiveC::Create(attr.release()); | |||
| MS_LOG(INFO) << "primitive name" << primitiveC->type_name(); | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfAddParser("Add", new TFAddParser()); | |||
| class TFConverter : public Converter { | |||
| public: | |||
| TFConverter(); | |||
| ~TFConverter() = default; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_ | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_matmul_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFMatMulParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF MatMulParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "primitive is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::MatMulT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_a", &attr_value)) { | |||
| attr->transposeA = attr_value.b(); | |||
| } | |||
| if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_b", &attr_value)) { | |||
| attr->transposeB = attr_value.b(); | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_MatMul; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| status = AddOpInput(tf_op, 1, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfMatMulParser("MatMul", new TFMatMulParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFMatMulParser : public TFNodeParser { | |||
| public: | |||
| TFMatMulParser() = default; | |||
| ~TFMatMulParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_ | |||
| @@ -16,36 +16,236 @@ | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_model_parser.h" | |||
| #include <map> | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <set> | |||
| #include "src/common/utils.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/parser/tf/tf_util.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "tools/common/protobuf_utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| static const std::unordered_map<int, mindspore::TypeId> TF_TYPE_MAP = { | |||
| {tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, | |||
| {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, | |||
| {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, | |||
| {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, | |||
| {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64}, | |||
| {tensorflow::DT_BOOL, mindspore::kNumberTypeBool}}; | |||
| TypeId GetTFDataType(const tensorflow::DataType &tf_data_type) { | |||
| auto iter = TF_TYPE_MAP.find(tf_data_type); | |||
| if (iter == TF_TYPE_MAP.end()) { | |||
| MS_LOG(ERROR) << "unsupported TF data type: " << tf_data_type; | |||
| return kTypeUnknown; | |||
| } | |||
| return iter->second; | |||
| } | |||
| AnfNodePtr TFModelParser::GetAnfNode(const std::string &name) { | |||
| AnfNodePtr ret = nullptr; | |||
| if (anf_node_map.find(name) != anf_node_map.end()) { | |||
| ret = anf_node_map[name]; | |||
| } else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { | |||
| ret = anf_node_map[name + ":0"]; | |||
| } | |||
| return ret; | |||
| } | |||
| std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) { | |||
| if (node.op() != "Identity" && node.op() != "StopGradient") { | |||
| return node.name(); | |||
| } | |||
| auto tmp_node = &node; | |||
| while (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient") { | |||
| tmp_node = tf_node_map[tmp_node->input(0)]; | |||
| } | |||
| return tmp_node->name(); | |||
| } | |||
| STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, | |||
| const ParameterPtr ¶meter, std::vector<int64_t> *shape_vector) { | |||
| MS_ASSERT(parameter != nullptr); | |||
| MS_ASSERT(shape_vector != nullptr); | |||
| const tensorflow::TensorProto &tensor_proto = attr_value.tensor(); | |||
| const tensorflow::TensorShapeProto &tensor_shape = tensor_proto.tensor_shape(); | |||
| int shape_size = 1; | |||
| shape_vector->clear(); | |||
| for (int i = 0; i < tensor_shape.dim_size(); i++) { | |||
| shape_vector->push_back(tensor_shape.dim(i).size()); | |||
| shape_size *= tensor_shape.dim(i).size(); | |||
| } | |||
| int tensor_size; | |||
| auto param_value = std::make_shared<ParamValueLite>(); | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "param_value is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) { | |||
| auto tensor_data = new (std::nothrow) float[shape_size]; | |||
| if (tensor_proto.float_val_size() == 1) { | |||
| float value = tensor_proto.float_val(0); | |||
| for (int i = 0; i < shape_size; i++) { | |||
| tensor_data[i] = value; | |||
| } | |||
| } | |||
| if (tensor_proto.tensor_content().size() == shape_size * sizeof(float)) { | |||
| const auto addr = reinterpret_cast<const float *>(tensor_proto.tensor_content().data()); | |||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| param_value->set_tensor_addr(tensor_data); | |||
| tensor_size = shape_size * sizeof(float); | |||
| } else if (type == kNumberTypeInt32) { | |||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||
| if (tensor_proto.int_val_size() == 1) { | |||
| int value = tensor_proto.int_val(0); | |||
| for (int i = 0; i < shape_size; i++) { | |||
| tensor_data[i] = value; | |||
| } | |||
| } | |||
| if (tensor_proto.tensor_content().size() == shape_size * sizeof(int32_t)) { | |||
| const auto addr = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data()); | |||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| param_value->set_tensor_addr(tensor_data); | |||
| tensor_size = shape_size * sizeof(int); | |||
| } else if (type == kNumberTypeBool) { | |||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||
| if (tensor_proto.bool_val_size() == 1) { | |||
| int value = tensor_proto.bool_val(0); | |||
| for (int i = 0; i < shape_size; i++) { | |||
| tensor_data[i] = value; | |||
| } | |||
| } | |||
| param_value->set_tensor_addr(tensor_data); | |||
| tensor_size = shape_size * sizeof(int); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport dataType: " << type; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> param_shape(shape_vector->begin(), shape_vector->end()); | |||
| param_value->set_tensor_shape(param_shape); | |||
| param_value->set_tensor_type(type); | |||
| param_value->set_tensor_size(tensor_size); | |||
| param_value->set_format(schema::Format::Format_NHWC); | |||
| parameter->set_default_param(param_value); | |||
| parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(parameter != nullptr); | |||
| tensorflow::AttrValue attr_value; | |||
| TypeId type = kNumberTypeFloat32; | |||
| if (TensorFlowUtils::FindAttrValue(node, "dtype", &attr_value)) { | |||
| type = GetTFDataType(attr_value.type()); | |||
| } | |||
| auto type_ptr = TypeIdToType(type); | |||
| std::vector<int> shape; | |||
| if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) { | |||
| auto &shape_attr = attr_value.shape(); | |||
| for (int i = 0; i < shape_attr.dim_size(); ++i) { | |||
| shape.push_back(shape_attr.dim(i).size()); | |||
| } | |||
| } | |||
| std::vector<int64_t> shape_vector(shape.begin(), shape.end()); | |||
| if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { | |||
| MS_LOG(INFO) << "Found value attr, means it has default value"; | |||
| auto status = ConvertConstTensor(attr_value, type, parameter, &shape_vector); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| } else { | |||
| parameter->set_name("placeholder_" + std::to_string(anf_node_map.size())); | |||
| graph_input_names.emplace_back(parameter->name()); | |||
| } | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| if (abstract_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "abstract_tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| parameter->set_abstract(abstract_tensor); | |||
| anf_node_map[node.name()] = parameter; | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertGraphInputsAndConsts() { | |||
| for (auto &pair : tf_node_map) { | |||
| bool have_data_depend = false; | |||
| for (int i = 0; i < pair.second->input_size(); ++i) { | |||
| auto name = pair.second->input(i); | |||
| if (!name.empty() && name[0] != '^') { // control_depend input start with "^" | |||
| have_data_depend = true; | |||
| break; | |||
| } | |||
| } | |||
| if (!have_data_depend) { | |||
| auto parameter = funcGraphPtr->add_parameter(); | |||
| if (ConvertParameter(*pair.second, parameter) != RET_OK) { | |||
| MS_LOG(ERROR) << "convert Parameter Node failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| auto status = ValidateFileStr(modelFile, ".prototxt"); | |||
| auto status = ValidateFileStr(modelFile, ".pb"); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; | |||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| if (!TensorFlowUtils::TfReadProtoFromBinary(modelFile.c_str(), tf_graph_def.get())) { | |||
| tf_graph_def = std::make_unique<tensorflow::GraphDef>(); | |||
| if (tf_graph_def == nullptr) { | |||
| MS_LOG(ERROR) << "tf_graph_def is nullptr"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_graph_def.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| funcGraphPtr = std::make_shared<FuncGraph>(); | |||
| status = ConvertGraphInputs(); | |||
| if (funcGraphPtr == nullptr) { | |||
| MS_LOG(ERROR) << "funGraphPtr is nullptr"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | |||
| auto &node_def = tf_graph_def->node(i); | |||
| tf_node_map[node_def.name()] = &node_def; | |||
| } | |||
| status = ConvertGraphInputsAndConsts(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert graph inputs failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| status = ConvertOps(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert ops failed."; | |||
| @@ -61,103 +261,36 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||
| } | |||
| return funcGraphPtr; | |||
| } | |||
| STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef *node, ParameterPtr parameter) { | |||
| tensorflow::AttrValue attr_value; | |||
| if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { | |||
| tensorflow::AttrValue data_type; | |||
| tensorflow::DataType type = tensorflow::DT_FLOAT; | |||
| // datatype | |||
| if (TensorFlowUtils::FindAttrValue(node, "dtype", &data_type)) { | |||
| type = data_type.type(); | |||
| } | |||
| const tensorflow::TensorProto &tensorProto = attr_value.tensor(); | |||
| const tensorflow::TensorShapeProto &tensorShape = tensorProto.tensor_shape(); | |||
| parameter = funcGraphPtr->add_parameter(); | |||
| std::vector<int64_t> shape_vector; | |||
| int shape_size = 1; | |||
| shape_vector.resize(tensorShape.dim_size()); | |||
| for (int i = 0; i < tensorShape.dim_size(); i++) { | |||
| shape_vector[i] = tensorShape.dim(i).size(); | |||
| shape_size *= shape_vector[i]; | |||
| schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| MS_LOG(ERROR) << "TF Model Parser not return MetaGraph, use TFModelParser::Parse instead"; | |||
| return nullptr; | |||
| } | |||
| STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, | |||
| const std::vector<std::string> &input_names, std::vector<AnfNodePtr> *inputs) { | |||
| // parse inputs | |||
| for (size_t j = 0; j < input_names.size(); j++) { | |||
| std::string input_name = input_names[j]; // input may be produced by multi-outputs node | |||
| if (tf_node_map.find(input_name) != tf_node_map.end()) { | |||
| auto input_node = tf_node_map[input_name]; | |||
| input_name = GetOriginInputName(*input_node); | |||
| } | |||
| // convert const to paramter | |||
| TypePtr ms_data_ype; | |||
| auto paramValue = std::make_shared<ParamValueLite>(); | |||
| if (type == tensorflow::DT_FLOAT) { | |||
| ms_data_ype = kFloat32; | |||
| auto tensor_data = new (std::nothrow) float[shape_size]; | |||
| if (tensorProto.float_val_size() == 1) { | |||
| float value = tensorProto.float_val(0); | |||
| for (int i = 0; i < shape_size; i++) { | |||
| tensor_data[i] = value; | |||
| } | |||
| } | |||
| if (tensorProto.tensor_content().size() == shape_size * sizeof(float)) { | |||
| const auto addr = reinterpret_cast<const float *>(tensorProto.tensor_content().data()); | |||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| paramValue->set_tensor_addr(tensor_data); | |||
| paramValue->set_tensor_size(shape_size * sizeof(float)); | |||
| } else if (type == tensorflow::DT_INT32) { | |||
| ms_data_ype = kInt32; | |||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||
| if (tensorProto.int_val_size() == 1) { | |||
| int value = tensorProto.int_val(0); | |||
| for (int i = 0; i < shape_size; i++) { | |||
| tensor_data[i] = value; | |||
| } | |||
| } | |||
| if (tensorProto.tensor_content().size() == shape_size * sizeof(int32_t)) { | |||
| const auto addr = reinterpret_cast<const int32_t *>(tensorProto.tensor_content().data()); | |||
| auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| paramValue->set_tensor_addr(tensor_data); | |||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||
| } else if (type == tensorflow::DT_BOOL) { | |||
| ms_data_ype = kFloat32; | |||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||
| if (tensorProto.bool_val_size() == 1) { | |||
| int value = tensorProto.bool_val(0); | |||
| for (int i = 0; i < shape_size; i++) { | |||
| tensor_data[i] = value; | |||
| } | |||
| } | |||
| paramValue->set_tensor_addr(tensor_data); | |||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport dataType," << node->name(); | |||
| auto input = GetAnfNode(input_name); | |||
| if (input == nullptr) { | |||
| MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes"; | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(ms_data_ype, shape_vector); | |||
| parameter->set_abstract(abstract_tensor); | |||
| parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); | |||
| std::vector<int> param_shape; | |||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(param_shape), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| MS_ASSERT(paramValue != nullptr); | |||
| paramValue->set_tensor_shape(param_shape); | |||
| paramValue->set_tensor_type(ms_data_ype->type_id()); | |||
| paramValue->set_format(schema::Format::Format_NHWC); | |||
| paramValue->set_tensor_size(shape_size * sizeof(int)); | |||
| parameter->set_default_param(paramValue); | |||
| inputs->emplace_back(input); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size) { | |||
| STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size) { | |||
| if (output_size == 1) { | |||
| std::vector<int64_t> shape_vector; | |||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||
| anf_node_map.insert(std::pair(op->name(), anf_node)); | |||
| anf_node_map.insert(std::pair(op.name(), anf_node)); | |||
| } else { | |||
| AbstractBasePtrList abstractList; | |||
| for (int output_idx = 0; output_idx < output_size; output_idx++) { | |||
| @@ -174,113 +307,126 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const C | |||
| CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs); | |||
| std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); | |||
| getItemCNode->set_fullname_with_scope(output_item_name); | |||
| anf_node_map.insert(std::pair(output_item_name, getItemCNode)); | |||
| anf_node_map.insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); | |||
| } | |||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertOps() { | |||
| NoSupportOp::GetInstance()->SetFmkType("TENSORFLOW"); | |||
| NoSupportOp::GetInstance()->SetFmkType("TF"); | |||
| STATUS status = RET_OK; | |||
| // redirect identity to it's input0 | |||
| ClipIdentityAndStopGradient(); | |||
| int op_idx = 0; | |||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | |||
| auto node_def = tf_graph_def->mutable_node(i); | |||
| tf_node_map[node_def->name()] = node_def; | |||
| auto tf_op_type = node_def->op(); | |||
| if (tf_op_type == "Placeholder" || tf_op_type == "Const") { | |||
| auto &node_def = tf_graph_def->node(i); | |||
| const auto &op_type = node_def.op(); | |||
| if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") { | |||
| continue; | |||
| } | |||
| auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(tf_op_type); | |||
| auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type); | |||
| if (node_parser == nullptr) { | |||
| NoSupportOp::GetInstance()->InsertOp(tf_op_type); | |||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||
| status = (status == RET_OK ? RET_NOT_FIND_OP : status); | |||
| MS_LOG(ERROR) << "cannot find node parser:" << tf_op_type; | |||
| MS_LOG(ERROR) << "cannot find node parser:" << op_type; | |||
| continue; | |||
| } | |||
| if (status != RET_OK) { | |||
| continue; | |||
| } | |||
| PrimitiveC *primitiveC = nullptr; | |||
| if (status == RET_OK) { | |||
| int output_size = 1; | |||
| status = node_parser->Parse(node_def, tf_graph_def, primitiveC, &output_size); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "node " << tf_op_type.c_str() << " parser failed"; | |||
| continue; | |||
| } | |||
| std::vector<AnfNodePtr> opInputs = {NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC))}; | |||
| // parse inputs | |||
| for (int j = 0; j < node_def->input_size(); j++) { | |||
| auto input_node = tf_node_map[node_def->input(i)]; | |||
| // last node output | |||
| if (anf_node_map.find(input_node->name()) != anf_node_map.end()) { | |||
| opInputs.emplace_back(anf_node_map[input_node->name()]); | |||
| continue; | |||
| } | |||
| // const tensor | |||
| if (input_node->op() == "Const") { | |||
| ParameterPtr parameter; | |||
| if (ConvertConstTensor(input_node, parameter) != RET_OK) { | |||
| MS_LOG(ERROR) << "convert const tensor failed," << input_node->name(); | |||
| return RET_ERROR; | |||
| } | |||
| opInputs.emplace_back(parameter); | |||
| anf_node_map[parameter->fullname_with_scope()] = parameter; | |||
| continue; | |||
| } | |||
| MS_LOG(ERROR) << "node" << node_def->name() << "has inputs neither a node output nor a weight tensor."; | |||
| return RET_ERROR; | |||
| } | |||
| auto anf_node = funcGraphPtr->NewCNode(opInputs); | |||
| anf_node->set_fullname_with_scope(tf_op_type + "-" + std::to_string(op_idx++)); | |||
| int output_size; | |||
| std::vector<std::string> input_names; | |||
| status = node_parser->Parse(node_def, tf_node_map, &primitiveC, &input_names, &output_size); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "node " << op_type << " parser failed"; | |||
| continue; | |||
| } | |||
| // parse outputs | |||
| status = ConvertOutputTensor(node_def, anf_node, output_size); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return status; | |||
| } | |||
| auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC)); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(ERROR) << "value_node is nullptr"; | |||
| status = RET_ERROR; | |||
| continue; | |||
| } | |||
| std::vector<AnfNodePtr> inputs = {value_node}; | |||
| status = ConvertInputNodes(node_def, input_names, &inputs); | |||
| if (status != RET_OK) { | |||
| continue; | |||
| } | |||
| // control_depends are not processed currently | |||
| auto anf_node = funcGraphPtr->NewCNode(inputs); | |||
| anf_node->set_fullname_with_scope(op_type + "-" + std::to_string(op_idx++)); | |||
| status = ConvertOutputTensor(node_def, anf_node, output_size); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| continue; | |||
| } | |||
| // redirect identity to it's input0 | |||
| ClipIdentityAndStopGradient(); | |||
| } | |||
| return RET_OK; | |||
| return status; | |||
| } | |||
| STATUS TFModelParser::ConvertGraphInputs() { | |||
| for (int i = 0; i < tf_graph_def->node_size(); i++) { | |||
| auto node_def = tf_graph_def->mutable_node(i); | |||
| tf_node_map[node_def->name()] = node_def; | |||
| if (node_def->op() == "Placeholder") { | |||
| auto parameter = funcGraphPtr->add_parameter(); | |||
| if (ConvertConstTensor(node_def, parameter) != RET_OK) { | |||
| MS_LOG(ERROR) << "convert const tensor failed"; | |||
| STATUS TFModelParser::ConvertGraphOutputs() { | |||
| // because output of intermediate node in anf graph may also be output tensors, we search output tensors in | |||
| // tf_node_map but not anf_node_map | |||
| std::set<std::string> all_node_inputs; | |||
| std::vector<AnfNodePtr> output_nodes; | |||
| for (auto &pair : tf_node_map) { | |||
| for (int i = 0; i < pair.second->input_size(); ++i) { | |||
| all_node_inputs.insert(pair.second->input(i)); | |||
| } | |||
| } | |||
| for (auto &pair : tf_node_map) { | |||
| auto it = all_node_inputs.find(pair.first); | |||
| if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity | |||
| auto origin_name = GetOriginInputName(*(pair.second)); | |||
| auto anf_node = GetAnfNode(origin_name); | |||
| if (anf_node == nullptr) { | |||
| MS_LOG(ERROR) << "can't find anf node"; | |||
| return RET_ERROR; | |||
| } | |||
| anf_node_map[node_def->name()] = parameter; | |||
| graph_input_names.emplace_back(node_def->name()); | |||
| output_nodes.push_back(anf_node); | |||
| graph_output_names.push_back(anf_node->fullname_with_scope()); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TFModelParser::ConvertGraphOutputs() { return RET_OK; } | |||
| std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) { | |||
| if (node.op() != "Identity" && node.op() != "StopGradient") { | |||
| return node.name(); | |||
| } | |||
| auto tmpNode = node; | |||
| while (tmpNode.op() == "Identity" || tmpNode.op() == "StopGradient") { | |||
| tmpNode = *tf_node_map[tmpNode.input(0)]; | |||
| } | |||
| return tmpNode.name(); | |||
| } | |||
| if (output_nodes.size() > 1) { | |||
| std::vector<AnfNodePtr> &make_tuple_inputs = output_nodes; | |||
| auto make_tuple_prim_ptr = GetMakeTuplePrim(); | |||
| if (make_tuple_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); | |||
| make_tuple_inputs.insert(output_nodes.begin(), make_tuple_prim); | |||
| auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs); | |||
| make_tuple_cnode->set_fullname_with_scope("return tuple"); | |||
| void TFModelParser::ClipIdentityAndStopGradient() { | |||
| for (auto &pair : tf_node_map) { | |||
| pair.second = tf_node_map[GetOriginInputName(*pair.second)]; | |||
| auto return_prim_ptr = GetReturnPrim(); | |||
| if (return_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto value_node = NewValueNode(return_prim_ptr); | |||
| std::vector<AnfNodePtr> op_inputs = {value_node, make_tuple_cnode}; | |||
| auto cnode = funcGraphPtr->NewCNode(op_inputs); | |||
| cnode->set_fullname_with_scope("return"); | |||
| funcGraphPtr->set_return(cnode); | |||
| } else { | |||
| auto return_prim_ptr = GetReturnPrim(); | |||
| if (return_prim_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "GetReturnPrim return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto value_node = NewValueNode(return_prim_ptr); | |||
| std::vector<AnfNodePtr> op_inputs{value_node, output_nodes.front()}; | |||
| auto return_cnode = funcGraphPtr->NewCNode(op_inputs); | |||
| return_cnode->set_fullname_with_scope("return"); | |||
| funcGraphPtr->set_return(return_cnode); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -31,29 +31,36 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFModelParser { | |||
| class TFModelParser : public ModelParser { | |||
| public: | |||
| TFModelParser() = default; | |||
| ~TFModelParser() = default; | |||
| FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); | |||
| protected: | |||
| schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||
| private: | |||
| STATUS ConvertConstTensor(const tensorflow::NodeDef *op, ParameterPtr parameter); | |||
| STATUS ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size); | |||
| AnfNodePtr GetAnfNode(const std::string &name); | |||
| std::string GetOriginInputName(const tensorflow::NodeDef &node); | |||
| STATUS ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr ¶meter, | |||
| std::vector<int64_t> *shape_vector); | |||
| STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter); | |||
| STATUS ConvertGraphInputsAndConsts(); | |||
| STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector<std::string> &input_names, | |||
| std::vector<AnfNodePtr> *inputs); | |||
| STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size); | |||
| STATUS ConvertOps(); | |||
| STATUS ConvertGraphInputs(); | |||
| STATUS ConvertGraphOutputs(); | |||
| std::string GetOriginInputName(const tensorflow::NodeDef &node); | |||
| void ClipIdentityAndStopGradient(); | |||
| FuncGraphPtr funcGraphPtr; | |||
| std::unique_ptr<tensorflow::GraphDef> tf_graph_def; | |||
| std::map<std::string, const tensorflow::NodeDef *> tf_node_map; | |||
| std::unordered_map<std::string, AnfNodePtr> anf_node_map; | |||
| std::vector<std::string> graph_input_names, graphOutputNames; | |||
| std::vector<std::string> graph_input_names; | |||
| std::vector<std::string> graph_output_names; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -13,23 +13,21 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||
| #include <memory> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFAddParser : public TFNodeParser { | |||
| public: | |||
| TFAddParser() = default; | |||
| ~TFAddParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||
| PrimitiveC *primitiveC, int *output_size) override; | |||
| }; | |||
| STATUS TFNodeParser::AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector<std::string> *inputs) { | |||
| if (tf_op.input_size() <= idx) { | |||
| MS_LOG(ERROR) << "input idx is greater than op input size"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| inputs->push_back(tf_op.input(idx)); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | |||
| #include <string> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <memory> | |||
| #include "tools/converter/parser/tf/tf_util.h" | |||
| @@ -32,12 +33,14 @@ class TFNodeParser { | |||
| virtual ~TFNodeParser() = default; | |||
| virtual STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model, | |||
| PrimitiveC *primitiveC, int *output_size) { | |||
| virtual STATUS Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| return RET_OK; | |||
| } | |||
| STATUS AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector<std::string> *inputs); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H | |||
| @@ -0,0 +1,109 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_split_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFSplitParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||
| std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF SplitParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::SplitT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "num_split", &attr_value)) { | |||
| MS_LOG(ERROR) << "The attribute num_split should be specified"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| attr->numberSplit = (int32_t)(attr_value.i()); | |||
| int split_dim_index; | |||
| int input_index; | |||
| if (tf_op.op() == "Split") { | |||
| split_dim_index = 0; | |||
| input_index = 1; | |||
| } else { | |||
| split_dim_index = 2; | |||
| input_index = 0; | |||
| } | |||
| if (tf_node_map.find(tf_op.input(split_dim_index)) == tf_node_map.end()) { | |||
| MS_LOG(ERROR) << "Find Split input split_dim node failed"; | |||
| return RET_ERROR; | |||
| } | |||
| const auto &split_dim_node = tf_node_map.at(tf_op.input(split_dim_index)); | |||
| if (!TensorFlowUtils::FindAttrValue(*split_dim_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The attribute splitDim should be specified"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto split_dim_tensor = attr_value.tensor(); | |||
| attr->splitDim = split_dim_tensor.int_val(0); | |||
| *output_size = attr->numberSplit; | |||
| if (tf_op.op() == "SplitV") { | |||
| if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { | |||
| MS_LOG(ERROR) << "Find Split input size_splits failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto size_splits_node = tf_node_map.at(tf_op.input(1)); | |||
| if (!TensorFlowUtils::FindAttrValue(*size_splits_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The attribute size splits should be specified"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto size_splits_tensor = attr_value.tensor(); | |||
| auto size = size_splits_tensor.tensor_content().size() / sizeof(int32_t); | |||
| attr->sizeSplits.resize(size); | |||
| auto ret = memcpy_s(attr->sizeSplits.data(), size * sizeof(int32_t), size_splits_tensor.tensor_content().data(), | |||
| size * sizeof(int32_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Split; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto status = AddOpInput(tf_op, input_index, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfSplitParser("Split", new TFSplitParser()); | |||
| TFNodeRegistrar g_tfSplitVParser("SplitV", new TFSplitParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFSplitParser : public TFNodeParser { | |||
| public: | |||
| TFSplitParser() = default; | |||
| ~TFSplitParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_ | |||
| @@ -22,9 +22,9 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, | |||
| bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef &nodeDef, const std::string &attr_name, | |||
| tensorflow::AttrValue *attr_value) { | |||
| const google::protobuf::Map<std::string, tensorflow::AttrValue> &attr = nodeDef->attr(); | |||
| const google::protobuf::Map<std::string, tensorflow::AttrValue> &attr = nodeDef.attr(); | |||
| const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(attr_name); | |||
| if (it != attr.end()) { | |||
| *attr_value = it->second; | |||
| @@ -32,24 +32,5 @@ bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::str | |||
| } | |||
| return false; | |||
| } | |||
| bool TensorFlowUtils::TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message) { | |||
| std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary); | |||
| if (!fs.is_open()) { | |||
| fprintf(stderr, "open failed %s\n", filepath); | |||
| return false; | |||
| } | |||
| google::protobuf::io::IstreamInputStream input(&fs); | |||
| google::protobuf::io::CodedInputStream codedstr(&input); | |||
| codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2); | |||
| bool success = message->ParseFromCodedStream(&codedstr); | |||
| fs.close(); | |||
| return success; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -26,10 +26,8 @@ namespace mindspore { | |||
| namespace lite { | |||
| class TensorFlowUtils { | |||
| public: | |||
| static bool FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name, | |||
| static bool FindAttrValue(const tensorflow::NodeDef &nodeDef, const std::string &attr_name, | |||
| tensorflow::AttrValue *attr_value); | |||
| static bool TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||