Browse Source

!8915 adjust tf converter & add some tf parsers

From: @wangzhe128
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
999cd513c6
34 changed files with 991 additions and 245 deletions
  1. +1
    -1
      mindspore/lite/tools/converter/CMakeLists.txt
  2. +5
    -0
      mindspore/lite/tools/converter/converter.cc
  3. +3
    -1
      mindspore/lite/tools/converter/converter_flags.cc
  4. +1
    -0
      mindspore/lite/tools/converter/model_parser.h
  5. +1
    -1
      mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h
  6. +3
    -3
      mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h
  7. +0
    -0
      mindspore/lite/tools/converter/parser/tf/proto/attr_value.proto
  8. +0
    -0
      mindspore/lite/tools/converter/parser/tf/proto/function.proto
  9. +0
    -0
      mindspore/lite/tools/converter/parser/tf/proto/graph.proto
  10. +0
    -0
      mindspore/lite/tools/converter/parser/tf/proto/node_def.proto
  11. +0
    -0
      mindspore/lite/tools/converter/parser/tf/proto/op_def.proto
  12. +0
    -0
      mindspore/lite/tools/converter/parser/tf/proto/resource_handle.proto
  13. +0
    -0
      mindspore/lite/tools/converter/parser/tf/proto/tensor.proto
  14. +0
    -0
      mindspore/lite/tools/converter/parser/tf/proto/tensor_shape.proto
  15. +0
    -0
      mindspore/lite/tools/converter/parser/tf/proto/types.proto
  16. +0
    -0
      mindspore/lite/tools/converter/parser/tf/proto/versions.proto
  17. +68
    -0
      mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc
  18. +38
    -0
      mindspore/lite/tools/converter/parser/tf/tf_activation_parser.h
  19. +93
    -0
      mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc
  20. +36
    -0
      mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.h
  21. +61
    -0
      mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc
  22. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.h
  23. +22
    -0
      mindspore/lite/tools/converter/parser/tf/tf_converter.cc
  24. +11
    -13
      mindspore/lite/tools/converter/parser/tf/tf_converter.h
  25. +70
    -0
      mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.cc
  26. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.h
  27. +322
    -176
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  28. +16
    -9
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.h
  29. +12
    -14
      mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc
  30. +6
    -3
      mindspore/lite/tools/converter/parser/tf/tf_node_parser.h
  31. +109
    -0
      mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc
  32. +36
    -0
      mindspore/lite/tools/converter/parser/tf/tf_split_parser.h
  33. +2
    -21
      mindspore/lite/tools/converter/parser/tf/tf_util.cc
  34. +1
    -3
      mindspore/lite/tools/converter/parser/tf/tf_util.h

+ 1
- 1
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -114,7 +114,7 @@ endif ()

file(GLOB PROTO_FILE ""
${CMAKE_CURRENT_SOURCE_DIR}/parser/caffe/caffe.proto
${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/*.proto
${CMAKE_CURRENT_SOURCE_DIR}/parser/tf/proto/*.proto
${CMAKE_CURRENT_SOURCE_DIR}/parser/onnx/onnx.proto)
ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_FILE})
add_library(proto_mid OBJECT ${PROTO_SRCS})


+ 5
- 0
mindspore/lite/tools/converter/converter.cc View File

@@ -28,6 +28,7 @@
#include "parser/caffe/caffe_converter.h"
#include "parser/tflite/tflite_converter.h"
#include "parser/onnx/onnx_converter.h"
#include "parser/tf/tf_converter.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/anf_importer/import_from_protobuf.h"
#include "proto/onnx.pb.h"
@@ -149,6 +150,10 @@ int RunConverter(int argc, const char **argv) {
OnnxConverter onnxConverter;
fb_graph = onnxConverter.Convert(flags.get());
} break;
case FmkType::FmkType_TF: {
TFConverter tfConverter;
fb_graph = tfConverter.Convert(flags.get());
} break;
default: {
MS_LOG(ERROR) << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " "
<< GetErrorInfo(RET_INPUT_PARAM_INVALID);


+ 3
- 1
mindspore/lite/tools/converter/converter_flags.cc View File

@@ -126,8 +126,10 @@ int Flags::Init(int argc, const char **argv) {
this->fmk = FmkType_TFLITE;
} else if (this->fmkIn == "ONNX") {
this->fmk = FmkType_ONNX;
} else if (this->fmkIn == "TF") {
this->fmk = FmkType_TF;
} else {
std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MINDIR|ONNX";
std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX";
return RET_INPUT_PARAM_INVALID;
}



+ 1
- 0
mindspore/lite/tools/converter/model_parser.h View File

@@ -44,6 +44,7 @@ class ModelParser {
return func_graph;
}

protected:
virtual schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) = 0;



+ 1
- 1
mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h View File

@@ -34,10 +34,10 @@ class CaffeModelParser : public ModelParser {

virtual ~CaffeModelParser();

private:
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) override;

private:
STATUS SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache);

STATUS SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache);


+ 3
- 3
mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h View File

@@ -45,12 +45,12 @@ class OnnxModelParser : public ModelParser {
int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph,
const QuantType &quantType);

schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) override;

static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);

private:
schema::MetaGraphT *ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type = QuantType_QUANT_NONE) override;

std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);

STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph);


mindspore/lite/tools/converter/parser/tf/attr_value.proto → mindspore/lite/tools/converter/parser/tf/proto/attr_value.proto View File


mindspore/lite/tools/converter/parser/tf/function.proto → mindspore/lite/tools/converter/parser/tf/proto/function.proto View File


mindspore/lite/tools/converter/parser/tf/graph.proto → mindspore/lite/tools/converter/parser/tf/proto/graph.proto View File


mindspore/lite/tools/converter/parser/tf/node_def.proto → mindspore/lite/tools/converter/parser/tf/proto/node_def.proto View File


mindspore/lite/tools/converter/parser/tf/op_def.proto → mindspore/lite/tools/converter/parser/tf/proto/op_def.proto View File


mindspore/lite/tools/converter/parser/tf/resource_handle.proto → mindspore/lite/tools/converter/parser/tf/proto/resource_handle.proto View File


mindspore/lite/tools/converter/parser/tf/tensor.proto → mindspore/lite/tools/converter/parser/tf/proto/tensor.proto View File


mindspore/lite/tools/converter/parser/tf/tensor_shape.proto → mindspore/lite/tools/converter/parser/tf/proto/tensor_shape.proto View File


mindspore/lite/tools/converter/parser/tf/types.proto → mindspore/lite/tools/converter/parser/tf/proto/types.proto View File


mindspore/lite/tools/converter/parser/tf/versions.proto → mindspore/lite/tools/converter/parser/tf/proto/versions.proto View File


+ 68
- 0
mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc View File

@@ -0,0 +1,68 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_activation_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ActivationParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}

if (tf_op.op() == "Relu") {
attr->type = schema::ActivationType_RELU;
} else if (tf_op.op() == "Relu6") {
attr->type = schema::ActivationType_RELU6;
} else {
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
}

primitive->value.type = schema::PrimitiveType_Activation;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser());
TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
} // namespace lite
} // namespace mindspore

+ 38
- 0
mindspore/lite/tools/converter/parser/tf/tf_activation_parser.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_

#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFActivationParser : public TFNodeParser {
public:
TFActivationParser() = default;
~TFActivationParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ACTIVATION_PARSER_H_

+ 93
- 0
mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc View File

@@ -0,0 +1,93 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_arithmetic_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ArithmeticParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}

if (tf_op.op() == "Add") {
auto attr = std::make_unique<schema::AddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Add;
primitive->value.value = attr.release();
} else if (tf_op.op() == "Sub") {
auto attr = std::make_unique<schema::SubT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Sub;
primitive->value.value = attr.release();
} else if (tf_op.op() == "Mul") {
auto attr = std::make_unique<schema::MulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Mul;
primitive->value.value = attr.release();
} else if (tf_op.op() == "Div") {
auto attr = std::make_unique<schema::DivT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Div;
primitive->value.value = attr.release();
}

*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
status = AddOpInput(tf_op, 1, inputs);
return status;
}
TFNodeRegistrar g_tfAddParser("Add", new TFArithmeticParser());
TFNodeRegistrar g_tfSubParser("Sub", new TFArithmeticParser());
TFNodeRegistrar g_tfMulParser("Mul", new TFArithmeticParser());
TFNodeRegistrar g_tfDivParser("Div", new TFArithmeticParser());
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFArithmeticParser : public TFNodeParser {
public:
TFArithmeticParser() = default;
~TFArithmeticParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ARITHMETIC_PARSER_H_

+ 61
- 0
mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc View File

@@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_biasadd_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF BiasAddParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::BiasAddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}

attr->axis = {1};

primitive->value.type = schema::PrimitiveType_Add;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfBiasAddParser("BiasAdd", new TFBiasAddParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_

#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFBiasAddParser : public TFNodeParser {
public:
TFBiasAddParser() = default;
~TFBiasAddParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BIASSADD_PARSER_H_

+ 22
- 0
mindspore/lite/tools/converter/parser/tf/tf_converter.cc View File

@@ -0,0 +1,22 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_converter.h"
#include "tools/converter/parser/tf/tf_model_parser.h"
namespace mindspore {
namespace lite {
TFConverter::TFConverter() { modelParser = new TFModelParser(); }
} // namespace lite
} // namespace mindspore

mindspore/lite/tools/converter/parser/tf/tf_add_parser.cc → mindspore/lite/tools/converter/parser/tf/tf_converter.h View File

@@ -13,22 +13,20 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_add_parser.h"
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_
#include <string>
#include <memory>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

#include "tools/converter/converter.h"
namespace mindspore {
namespace lite {
STATUS TFAddParser::Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model,
PrimitiveC *primitiveC, int *output_size) {
auto attr = std::make_unique<schema::PrimitiveT>();
attr->value.type = schema::PrimitiveType_Add;
primitiveC = PrimitiveC::Create(attr.release());
MS_LOG(INFO) << "primitive name" << primitiveC->type_name();
return RET_OK;
}
TFNodeRegistrar g_tfAddParser("Add", new TFAddParser());
class TFConverter : public Converter {
public:
TFConverter();

~TFConverter() = default;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_

+ 70
- 0
mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.cc View File

@@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_matmul_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF MatMulParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::MatMulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_a", &attr_value)) {
attr->transposeA = attr_value.b();
}
if (TensorFlowUtils::FindAttrValue(tf_op, "transpose_b", &attr_value)) {
attr->transposeB = attr_value.b();
}

primitive->value.type = schema::PrimitiveType_MatMul;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
status = AddOpInput(tf_op, 1, inputs);
return status;
}
TFNodeRegistrar g_tfMatMulParser("MatMul", new TFMatMulParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_matmul_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_

#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFMatMulParser : public TFNodeParser {
public:
TFMatMulParser() = default;
~TFMatMulParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_MATMUL_PARSER_H_

+ 322
- 176
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -16,36 +16,236 @@
*/

#include "tools/converter/parser/tf/tf_model_parser.h"
#include <map>
#include <algorithm>
#include <functional>
#include <set>
#include "src/common/utils.h"
#include "src/common/log_adapter.h"
#include "tools/converter/parser/tf/tf_util.h"
#include "tools/common/graph_util.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "src/param_value_lite.h"
#include "tools/common/protobuf_utils.h"

namespace mindspore {
namespace lite {
static const std::unordered_map<int, mindspore::TypeId> TF_TYPE_MAP = {
{tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8},
{tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16},
{tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, {tensorflow::DT_INT64, mindspore::kNumberTypeInt64},
{tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32},
{tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64},
{tensorflow::DT_BOOL, mindspore::kNumberTypeBool}};

TypeId GetTFDataType(const tensorflow::DataType &tf_data_type) {
auto iter = TF_TYPE_MAP.find(tf_data_type);
if (iter == TF_TYPE_MAP.end()) {
MS_LOG(ERROR) << "unsupported TF data type: " << tf_data_type;
return kTypeUnknown;
}
return iter->second;
}

AnfNodePtr TFModelParser::GetAnfNode(const std::string &name) {
AnfNodePtr ret = nullptr;
if (anf_node_map.find(name) != anf_node_map.end()) {
ret = anf_node_map[name];
} else if (anf_node_map.find(name + ":0") != anf_node_map.end()) {
ret = anf_node_map[name + ":0"];
}
return ret;
}

std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) {
if (node.op() != "Identity" && node.op() != "StopGradient") {
return node.name();
}
auto tmp_node = &node;
while (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient") {
tmp_node = tf_node_map[tmp_node->input(0)];
}
return tmp_node->name();
}

STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type,
const ParameterPtr &parameter, std::vector<int64_t> *shape_vector) {
MS_ASSERT(parameter != nullptr);
MS_ASSERT(shape_vector != nullptr);
const tensorflow::TensorProto &tensor_proto = attr_value.tensor();
const tensorflow::TensorShapeProto &tensor_shape = tensor_proto.tensor_shape();
int shape_size = 1;
shape_vector->clear();
for (int i = 0; i < tensor_shape.dim_size(); i++) {
shape_vector->push_back(tensor_shape.dim(i).size());
shape_size *= tensor_shape.dim(i).size();
}

int tensor_size;
auto param_value = std::make_shared<ParamValueLite>();
if (param_value == nullptr) {
MS_LOG(ERROR) << "param_value is nullptr";
return RET_ERROR;
}
if (type == kNumberTypeFloat32 || type == kNumberTypeFloat) {
auto tensor_data = new (std::nothrow) float[shape_size];
if (tensor_proto.float_val_size() == 1) {
float value = tensor_proto.float_val(0);
for (int i = 0; i < shape_size; i++) {
tensor_data[i] = value;
}
}
if (tensor_proto.tensor_content().size() == shape_size * sizeof(float)) {
const auto addr = reinterpret_cast<const float *>(tensor_proto.tensor_content().data());
auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
}
param_value->set_tensor_addr(tensor_data);
tensor_size = shape_size * sizeof(float);
} else if (type == kNumberTypeInt32) {
auto tensor_data = new (std::nothrow) int[shape_size];
if (tensor_proto.int_val_size() == 1) {
int value = tensor_proto.int_val(0);
for (int i = 0; i < shape_size; i++) {
tensor_data[i] = value;
}
}
if (tensor_proto.tensor_content().size() == shape_size * sizeof(int32_t)) {
const auto addr = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
}
param_value->set_tensor_addr(tensor_data);
tensor_size = shape_size * sizeof(int);
} else if (type == kNumberTypeBool) {
auto tensor_data = new (std::nothrow) int[shape_size];
if (tensor_proto.bool_val_size() == 1) {
int value = tensor_proto.bool_val(0);
for (int i = 0; i < shape_size; i++) {
tensor_data[i] = value;
}
}
param_value->set_tensor_addr(tensor_data);
tensor_size = shape_size * sizeof(int);
} else {
MS_LOG(ERROR) << "Unsupport dataType: " << type;
return RET_ERROR;
}

std::vector<int> param_shape(shape_vector->begin(), shape_vector->end());
param_value->set_tensor_shape(param_shape);
param_value->set_tensor_type(type);
param_value->set_tensor_size(tensor_size);
param_value->set_format(schema::Format::Format_NHWC);
parameter->set_default_param(param_value);
parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter");
return RET_OK;
}

STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr &parameter) {
MS_ASSERT(node != nullptr);
MS_ASSERT(parameter != nullptr);

tensorflow::AttrValue attr_value;
TypeId type = kNumberTypeFloat32;
if (TensorFlowUtils::FindAttrValue(node, "dtype", &attr_value)) {
type = GetTFDataType(attr_value.type());
}
auto type_ptr = TypeIdToType(type);

std::vector<int> shape;
if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) {
auto &shape_attr = attr_value.shape();
for (int i = 0; i < shape_attr.dim_size(); ++i) {
shape.push_back(shape_attr.dim(i).size());
}
}
std::vector<int64_t> shape_vector(shape.begin(), shape.end());

if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) {
MS_LOG(INFO) << "Found value attr, means it has default value";
auto status = ConvertConstTensor(attr_value, type, parameter, &shape_vector);
if (status != RET_OK) {
return status;
}
} else {
parameter->set_name("placeholder_" + std::to_string(anf_node_map.size()));
graph_input_names.emplace_back(parameter->name());
}

auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "abstract_tensor is nullptr";
return RET_ERROR;
}
parameter->set_abstract(abstract_tensor);

anf_node_map[node.name()] = parameter;
return RET_OK;
}

STATUS TFModelParser::ConvertGraphInputsAndConsts() {
for (auto &pair : tf_node_map) {
bool have_data_depend = false;
for (int i = 0; i < pair.second->input_size(); ++i) {
auto name = pair.second->input(i);
if (!name.empty() && name[0] != '^') { // control_depend input start with "^"
have_data_depend = true;
break;
}
}
if (!have_data_depend) {
auto parameter = funcGraphPtr->add_parameter();
if (ConvertParameter(*pair.second, parameter) != RET_OK) {
MS_LOG(ERROR) << "convert Parameter Node failed";
return RET_ERROR;
}
}
}
return RET_OK;
}

FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
auto status = ValidateFileStr(modelFile, ".prototxt");
auto status = ValidateFileStr(modelFile, ".pb");
if (status != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt";
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
if (!TensorFlowUtils::TfReadProtoFromBinary(modelFile.c_str(), tf_graph_def.get())) {
tf_graph_def = std::make_unique<tensorflow::GraphDef>();
if (tf_graph_def == nullptr) {
MS_LOG(ERROR) << "tf_graph_def is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_graph_def.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
funcGraphPtr = std::make_shared<FuncGraph>();
status = ConvertGraphInputs();
if (funcGraphPtr == nullptr) {
MS_LOG(ERROR) << "funGraphPtr is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}

for (int i = 0; i < tf_graph_def->node_size(); i++) {
auto &node_def = tf_graph_def->node(i);
tf_node_map[node_def.name()] = &node_def;
}

status = ConvertGraphInputsAndConsts();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph inputs failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}

status = ConvertOps();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert ops failed.";
@@ -61,103 +261,36 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
}
return funcGraphPtr;
}
STATUS TFModelParser::ConvertConstTensor(const tensorflow::NodeDef *node, ParameterPtr parameter) {
tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) {
tensorflow::AttrValue data_type;
tensorflow::DataType type = tensorflow::DT_FLOAT;
// datatype
if (TensorFlowUtils::FindAttrValue(node, "dtype", &data_type)) {
type = data_type.type();
}
const tensorflow::TensorProto &tensorProto = attr_value.tensor();
const tensorflow::TensorShapeProto &tensorShape = tensorProto.tensor_shape();
parameter = funcGraphPtr->add_parameter();
std::vector<int64_t> shape_vector;
int shape_size = 1;
shape_vector.resize(tensorShape.dim_size());
for (int i = 0; i < tensorShape.dim_size(); i++) {
shape_vector[i] = tensorShape.dim(i).size();
shape_size *= shape_vector[i];
schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
MS_LOG(ERROR) << "TF Model Parser not return MetaGraph, use TFModelParser::Parse instead";
return nullptr;
}

STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def,
const std::vector<std::string> &input_names, std::vector<AnfNodePtr> *inputs) {
// parse inputs
for (size_t j = 0; j < input_names.size(); j++) {
std::string input_name = input_names[j]; // input may be produced by multi-outputs node
if (tf_node_map.find(input_name) != tf_node_map.end()) {
auto input_node = tf_node_map[input_name];
input_name = GetOriginInputName(*input_node);
}
// convert const to paramter
TypePtr ms_data_ype;
auto paramValue = std::make_shared<ParamValueLite>();
if (type == tensorflow::DT_FLOAT) {
ms_data_ype = kFloat32;
auto tensor_data = new (std::nothrow) float[shape_size];
if (tensorProto.float_val_size() == 1) {
float value = tensorProto.float_val(0);
for (int i = 0; i < shape_size; i++) {
tensor_data[i] = value;
}
}
if (tensorProto.tensor_content().size() == shape_size * sizeof(float)) {
const auto addr = reinterpret_cast<const float *>(tensorProto.tensor_content().data());
auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(float), addr, shape_size * sizeof(float));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
}
paramValue->set_tensor_addr(tensor_data);
paramValue->set_tensor_size(shape_size * sizeof(float));
} else if (type == tensorflow::DT_INT32) {
ms_data_ype = kInt32;
auto tensor_data = new (std::nothrow) int[shape_size];
if (tensorProto.int_val_size() == 1) {
int value = tensorProto.int_val(0);
for (int i = 0; i < shape_size; i++) {
tensor_data[i] = value;
}
}
if (tensorProto.tensor_content().size() == shape_size * sizeof(int32_t)) {
const auto addr = reinterpret_cast<const int32_t *>(tensorProto.tensor_content().data());
auto ret = ::memcpy_s(tensor_data, shape_size * sizeof(int32_t), addr, shape_size * sizeof(int32_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
}
paramValue->set_tensor_addr(tensor_data);
paramValue->set_tensor_size(shape_size * sizeof(int));
} else if (type == tensorflow::DT_BOOL) {
ms_data_ype = kFloat32;
auto tensor_data = new (std::nothrow) int[shape_size];
if (tensorProto.bool_val_size() == 1) {
int value = tensorProto.bool_val(0);
for (int i = 0; i < shape_size; i++) {
tensor_data[i] = value;
}
}
paramValue->set_tensor_addr(tensor_data);
paramValue->set_tensor_size(shape_size * sizeof(int));
} else {
MS_LOG(ERROR) << "Unsupport dataType," << node->name();
auto input = GetAnfNode(input_name);
if (input == nullptr) {
MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes";
return RET_ERROR;
}
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(ms_data_ype, shape_vector);
parameter->set_abstract(abstract_tensor);
parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter");

std::vector<int> param_shape;
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(param_shape),
[](const int64_t &value) { return static_cast<int>(value); });

MS_ASSERT(paramValue != nullptr);
paramValue->set_tensor_shape(param_shape);
paramValue->set_tensor_type(ms_data_ype->type_id());
paramValue->set_format(schema::Format::Format_NHWC);
paramValue->set_tensor_size(shape_size * sizeof(int));
parameter->set_default_param(paramValue);
inputs->emplace_back(input);
}
return RET_OK;
}
STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size) {

STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size) {
if (output_size == 1) {
std::vector<int64_t> shape_vector;
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
anf_node_map.insert(std::pair(op->name(), anf_node));
anf_node_map.insert(std::pair(op.name(), anf_node));
} else {
AbstractBasePtrList abstractList;
for (int output_idx = 0; output_idx < output_size; output_idx++) {
@@ -174,113 +307,126 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef *op, const C
CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs);
std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
getItemCNode->set_fullname_with_scope(output_item_name);
anf_node_map.insert(std::pair(output_item_name, getItemCNode));
anf_node_map.insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode));
}
anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList));
}
return RET_OK;
}

STATUS TFModelParser::ConvertOps() {
NoSupportOp::GetInstance()->SetFmkType("TENSORFLOW");
NoSupportOp::GetInstance()->SetFmkType("TF");
STATUS status = RET_OK;

// redirect identity to it's input0
ClipIdentityAndStopGradient();
int op_idx = 0;
for (int i = 0; i < tf_graph_def->node_size(); i++) {
auto node_def = tf_graph_def->mutable_node(i);
tf_node_map[node_def->name()] = node_def;
auto tf_op_type = node_def->op();
if (tf_op_type == "Placeholder" || tf_op_type == "Const") {
auto &node_def = tf_graph_def->node(i);
const auto &op_type = node_def.op();
if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") {
continue;
}
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(tf_op_type);
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(tf_op_type);
NoSupportOp::GetInstance()->InsertOp(op_type);
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
MS_LOG(ERROR) << "cannot find node parser:" << tf_op_type;
MS_LOG(ERROR) << "cannot find node parser:" << op_type;
continue;
}
if (status != RET_OK) {
continue;
}
PrimitiveC *primitiveC = nullptr;
if (status == RET_OK) {
int output_size = 1;
status = node_parser->Parse(node_def, tf_graph_def, primitiveC, &output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "node " << tf_op_type.c_str() << " parser failed";
continue;
}
std::vector<AnfNodePtr> opInputs = {NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC))};
// parse inputs
for (int j = 0; j < node_def->input_size(); j++) {
auto input_node = tf_node_map[node_def->input(i)];
// last node output
if (anf_node_map.find(input_node->name()) != anf_node_map.end()) {
opInputs.emplace_back(anf_node_map[input_node->name()]);
continue;
}
// const tensor
if (input_node->op() == "Const") {
ParameterPtr parameter;
if (ConvertConstTensor(input_node, parameter) != RET_OK) {
MS_LOG(ERROR) << "convert const tensor failed," << input_node->name();
return RET_ERROR;
}
opInputs.emplace_back(parameter);
anf_node_map[parameter->fullname_with_scope()] = parameter;
continue;
}
MS_LOG(ERROR) << "node" << node_def->name() << "has inputs neither a node output nor a weight tensor.";
return RET_ERROR;
}
auto anf_node = funcGraphPtr->NewCNode(opInputs);
anf_node->set_fullname_with_scope(tf_op_type + "-" + std::to_string(op_idx++));
int output_size;
std::vector<std::string> input_names;
status = node_parser->Parse(node_def, tf_node_map, &primitiveC, &input_names, &output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "node " << op_type << " parser failed";
continue;
}

// parse outputs
status = ConvertOutputTensor(node_def, anf_node, output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return status;
}
auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitiveC));
if (value_node == nullptr) {
MS_LOG(ERROR) << "value_node is nullptr";
status = RET_ERROR;
continue;
}
std::vector<AnfNodePtr> inputs = {value_node};
status = ConvertInputNodes(node_def, input_names, &inputs);
if (status != RET_OK) {
continue;
}
// control_depends are not processed currently
auto anf_node = funcGraphPtr->NewCNode(inputs);
anf_node->set_fullname_with_scope(op_type + "-" + std::to_string(op_idx++));

status = ConvertOutputTensor(node_def, anf_node, output_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
continue;
}
// redirect identity to it's input0
ClipIdentityAndStopGradient();
}
return RET_OK;
return status;
}
STATUS TFModelParser::ConvertGraphInputs() {
for (int i = 0; i < tf_graph_def->node_size(); i++) {
auto node_def = tf_graph_def->mutable_node(i);
tf_node_map[node_def->name()] = node_def;
if (node_def->op() == "Placeholder") {
auto parameter = funcGraphPtr->add_parameter();
if (ConvertConstTensor(node_def, parameter) != RET_OK) {
MS_LOG(ERROR) << "convert const tensor failed";

STATUS TFModelParser::ConvertGraphOutputs() {
// because output of intermediate node in anf graph may also be output tensors, we search output tensors in
// tf_node_map but not anf_node_map
std::set<std::string> all_node_inputs;
std::vector<AnfNodePtr> output_nodes;
for (auto &pair : tf_node_map) {
for (int i = 0; i < pair.second->input_size(); ++i) {
all_node_inputs.insert(pair.second->input(i));
}
}
for (auto &pair : tf_node_map) {
auto it = all_node_inputs.find(pair.first);
if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity
auto origin_name = GetOriginInputName(*(pair.second));
auto anf_node = GetAnfNode(origin_name);
if (anf_node == nullptr) {
MS_LOG(ERROR) << "can't find anf node";
return RET_ERROR;
}
anf_node_map[node_def->name()] = parameter;
graph_input_names.emplace_back(node_def->name());
output_nodes.push_back(anf_node);
graph_output_names.push_back(anf_node->fullname_with_scope());
}
}
return RET_OK;
}
STATUS TFModelParser::ConvertGraphOutputs() { return RET_OK; }

std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) {
if (node.op() != "Identity" && node.op() != "StopGradient") {
return node.name();
}
auto tmpNode = node;
while (tmpNode.op() == "Identity" || tmpNode.op() == "StopGradient") {
tmpNode = *tf_node_map[tmpNode.input(0)];
}
return tmpNode.name();
}
if (output_nodes.size() > 1) {
std::vector<AnfNodePtr> &make_tuple_inputs = output_nodes;
auto make_tuple_prim_ptr = GetMakeTuplePrim();
if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
return RET_NULL_PTR;
}
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
make_tuple_inputs.insert(output_nodes.begin(), make_tuple_prim);
auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs);
make_tuple_cnode->set_fullname_with_scope("return tuple");

void TFModelParser::ClipIdentityAndStopGradient() {
for (auto &pair : tf_node_map) {
pair.second = tf_node_map[GetOriginInputName(*pair.second)];
auto return_prim_ptr = GetReturnPrim();
if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR;
}
auto value_node = NewValueNode(return_prim_ptr);
std::vector<AnfNodePtr> op_inputs = {value_node, make_tuple_cnode};
auto cnode = funcGraphPtr->NewCNode(op_inputs);
cnode->set_fullname_with_scope("return");
funcGraphPtr->set_return(cnode);
} else {
auto return_prim_ptr = GetReturnPrim();
if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR;
}
auto value_node = NewValueNode(return_prim_ptr);
std::vector<AnfNodePtr> op_inputs{value_node, output_nodes.front()};
auto return_cnode = funcGraphPtr->NewCNode(op_inputs);
return_cnode->set_fullname_with_scope("return");
funcGraphPtr->set_return(return_cnode);
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 16
- 9
mindspore/lite/tools/converter/parser/tf/tf_model_parser.h View File

@@ -31,29 +31,36 @@

namespace mindspore {
namespace lite {
class TFModelParser {
class TFModelParser : public ModelParser {
public:
TFModelParser() = default;
~TFModelParser() = default;

FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType);

protected:
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;

private:
STATUS ConvertConstTensor(const tensorflow::NodeDef *op, ParameterPtr parameter);
STATUS ConvertOutputTensor(const tensorflow::NodeDef *op, const CNodePtr &anf_node, int output_size);
AnfNodePtr GetAnfNode(const std::string &name);
std::string GetOriginInputName(const tensorflow::NodeDef &node);
STATUS ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr &parameter,
std::vector<int64_t> *shape_vector);
STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr &parameter);
STATUS ConvertGraphInputsAndConsts();
STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector<std::string> &input_names,
std::vector<AnfNodePtr> *inputs);
STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size);
STATUS ConvertOps();
STATUS ConvertGraphInputs();
STATUS ConvertGraphOutputs();

std::string GetOriginInputName(const tensorflow::NodeDef &node);

void ClipIdentityAndStopGradient();

FuncGraphPtr funcGraphPtr;
std::unique_ptr<tensorflow::GraphDef> tf_graph_def;
std::map<std::string, const tensorflow::NodeDef *> tf_node_map;
std::unordered_map<std::string, AnfNodePtr> anf_node_map;
std::vector<std::string> graph_input_names, graphOutputNames;
std::vector<std::string> graph_input_names;
std::vector<std::string> graph_output_names;
};
} // namespace lite
} // namespace mindspore


mindspore/lite/tools/converter/parser/tf/tf_add_parser.h → mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc View File

@@ -13,23 +13,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H

#include <memory>
#include "tools/converter/parser/tf/tf_node_parser.h"
#include <string>
#include <memory>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
class TFAddParser : public TFNodeParser {
public:
TFAddParser() = default;
~TFAddParser() override = default;
STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model,
PrimitiveC *primitiveC, int *output_size) override;
};
STATUS TFNodeParser::AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector<std::string> *inputs) {
if (tf_op.input_size() <= idx) {
MS_LOG(ERROR) << "input idx is greater than op input size";
return RET_PARAM_INVALID;
}
inputs->push_back(tf_op.input(idx));
return RET_OK;
}
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_ADD_PARSER_H

+ 6
- 3
mindspore/lite/tools/converter/parser/tf/tf_node_parser.h View File

@@ -18,6 +18,7 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H

#include <string>
#include <vector>
#include <map>
#include <memory>
#include "tools/converter/parser/tf/tf_util.h"
@@ -32,12 +33,14 @@ class TFNodeParser {

virtual ~TFNodeParser() = default;

virtual STATUS Parse(const tensorflow::NodeDef *tf_op, const std::unique_ptr<tensorflow::GraphDef> &tf_model,
PrimitiveC *primitiveC, int *output_size) {
virtual STATUS Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
return RET_OK;
}

STATUS AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector<std::string> *inputs);
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_NODE_PARSER_H

+ 109
- 0
mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc View File

@@ -0,0 +1,109 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_split_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFSplitParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF SplitParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::SplitT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}

tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "num_split", &attr_value)) {
MS_LOG(ERROR) << "The attribute num_split should be specified";
return RET_PARAM_INVALID;
}
attr->numberSplit = (int32_t)(attr_value.i());

int split_dim_index;
int input_index;
if (tf_op.op() == "Split") {
split_dim_index = 0;
input_index = 1;
} else {
split_dim_index = 2;
input_index = 0;
}

if (tf_node_map.find(tf_op.input(split_dim_index)) == tf_node_map.end()) {
MS_LOG(ERROR) << "Find Split input split_dim node failed";
return RET_ERROR;
}
const auto &split_dim_node = tf_node_map.at(tf_op.input(split_dim_index));
if (!TensorFlowUtils::FindAttrValue(*split_dim_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The attribute splitDim should be specified";
return RET_PARAM_INVALID;
}
auto split_dim_tensor = attr_value.tensor();
attr->splitDim = split_dim_tensor.int_val(0);
*output_size = attr->numberSplit;

if (tf_op.op() == "SplitV") {
if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
MS_LOG(ERROR) << "Find Split input size_splits failed";
return RET_ERROR;
}
auto size_splits_node = tf_node_map.at(tf_op.input(1));
if (!TensorFlowUtils::FindAttrValue(*size_splits_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The attribute size splits should be specified";
return RET_PARAM_INVALID;
}
auto size_splits_tensor = attr_value.tensor();
auto size = size_splits_tensor.tensor_content().size() / sizeof(int32_t);
attr->sizeSplits.resize(size);
auto ret = memcpy_s(attr->sizeSplits.data(), size * sizeof(int32_t), size_splits_tensor.tensor_content().data(),
size * sizeof(int32_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
}

primitive->value.type = schema::PrimitiveType_Split;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

auto status = AddOpInput(tf_op, input_index, inputs);
return status;
}
TFNodeRegistrar g_tfSplitParser("Split", new TFSplitParser());
TFNodeRegistrar g_tfSplitVParser("SplitV", new TFSplitParser());
} // namespace lite
} // namespace mindspore

+ 36
- 0
mindspore/lite/tools/converter/parser/tf/tf_split_parser.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFSplitParser : public TFNodeParser {
public:
TFSplitParser() = default;
~TFSplitParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SPLIT_PARSER_H_

+ 2
- 21
mindspore/lite/tools/converter/parser/tf/tf_util.cc View File

@@ -22,9 +22,9 @@

namespace mindspore {
namespace lite {
bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name,
bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef &nodeDef, const std::string &attr_name,
tensorflow::AttrValue *attr_value) {
const google::protobuf::Map<std::string, tensorflow::AttrValue> &attr = nodeDef->attr();
const google::protobuf::Map<std::string, tensorflow::AttrValue> &attr = nodeDef.attr();
const google::protobuf::Map<std::string, tensorflow::AttrValue>::const_iterator it = attr.find(attr_name);
if (it != attr.end()) {
*attr_value = it->second;
@@ -32,24 +32,5 @@ bool TensorFlowUtils::FindAttrValue(const tensorflow::NodeDef *nodeDef, std::str
}
return false;
}

bool TensorFlowUtils::TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message) {
std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
if (!fs.is_open()) {
fprintf(stderr, "open failed %s\n", filepath);
return false;
}

google::protobuf::io::IstreamInputStream input(&fs);
google::protobuf::io::CodedInputStream codedstr(&input);

codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);

bool success = message->ParseFromCodedStream(&codedstr);

fs.close();

return success;
}
} // namespace lite
} // namespace mindspore

+ 1
- 3
mindspore/lite/tools/converter/parser/tf/tf_util.h View File

@@ -26,10 +26,8 @@ namespace mindspore {
namespace lite {
class TensorFlowUtils {
public:
static bool FindAttrValue(const tensorflow::NodeDef *nodeDef, std::string attr_name,
static bool FindAttrValue(const tensorflow::NodeDef &nodeDef, const std::string &attr_name,
tensorflow::AttrValue *attr_value);

static bool TfReadProtoFromBinary(const char *filepath, google::protobuf::Message *message);
};
} // namespace lite
} // namespace mindspore


Loading…
Cancel
Save