Browse Source

!10970 [MS_LITE] tf parser

From: @YeFeng_24
Reviewed-by: @hangangqiang
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
970c7b90bb
8 changed files with 315 additions and 0 deletions
  1. +14
    -0
      mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc
  2. +1
    -0
      mindspore/lite/tools/converter/parser/tf/tf_batchnorm_parser.cc
  3. +64
    -0
      mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.cc
  4. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.h
  5. +64
    -0
      mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.cc
  6. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.h
  7. +61
    -0
      mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.cc
  8. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.h

+ 14
- 0
mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc View File

@@ -50,6 +50,8 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
attr->type = schema::ActivationType_SIGMOID;
} else if (tf_op.op() == "Tanh") {
attr->type = schema::ActivationType_TANH;
} else if (tf_op.op() == "LeakyRelu") {
attr->type = schema::ActivationType_LEAKY_RELU;
} else {
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
return RET_ERROR;
@@ -57,6 +59,17 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,

primitive->value.type = schema::PrimitiveType_Activation;
primitive->value.value = attr.release();
if (tf_op.op() == "LeakyRelu") {
auto attr_leaky_relu = std::make_unique<schema::LeakyReLUT>();
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
MS_LOG(ERROR) << "The attribute alpha shoud be specified.";
return RET_ERROR;
}
attr_leaky_relu->negativeSlope = attr_value.f();
primitive->value.type = schema::PrimitiveType_LeakyReLU;
primitive->value.value = attr_leaky_relu.release();
}
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
@@ -71,5 +84,6 @@ TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser());
TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser());
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser());
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser());
} // namespace lite
} // namespace mindspore

+ 1
- 0
mindspore/lite/tools/converter/parser/tf/tf_batchnorm_parser.cc View File

@@ -60,5 +60,6 @@ STATUS TFBatchNormParser::Parse(const tensorflow::NodeDef &tf_op,
return RET_OK;
}
TFNodeRegistrar g_tfBatchNormParser("FusedBatchNormV3", new TFBatchNormParser());
TFNodeRegistrar g_tfFusedBatchNormParser("FusedBatchNorm", new TFBatchNormParser());
} // namespace lite
} // namespace mindspore

+ 64
- 0
mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.cc View File

@@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_dropout_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFDropoutParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF DropoutParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::DropoutT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}

tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "ratio", &attr_value)) {
MS_LOG(ERROR) << "The ratio attr should be specified";
return RET_ERROR;
}
attr->ratio = static_cast<int32_t>(attr_value.i());
primitive->value.type = schema::PrimitiveType_Dropout;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfDropoutParser("Dropout", new TFDropoutParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DROPOUT_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DROPOUT_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFDropoutParser : public TFNodeParser {
public:
TFDropoutParser() = default;
~TFDropoutParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DROPOUT_PARSER_H_

+ 64
- 0
mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.cc View File

@@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_softmax_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFSoftmaxParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF SoftmaxParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::SoftMaxT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}

tensorflow::AttrValue attr_value;
int axis = -1;
if (TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) {
axis = static_cast<int32_t>(attr_value.i());
}
attr->axis = axis;
primitive->value.type = schema::PrimitiveType_SoftMax;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfSoftmaxParser("Softmax", new TFSoftmaxParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SOFTMAX_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SOFTMAX_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFSoftmaxParser : public TFNodeParser {
public:
TFSoftmaxParser() = default;
~TFSoftmaxParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SOFTMAX_PARSER_H_

+ 61
- 0
mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.cc View File

@@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_squared_difference_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFSquaredDifferenceParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF SquaredDifferenceParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::SquaredDifferenceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}

primitive->value.type = schema::PrimitiveType_SquaredDifference;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return RET_ERROR;
}
status = AddOpInput(tf_op, 1, inputs);
return status;
}
TFNodeRegistrar g_tfSquaredDifferenceParser("SquaredDifference", new TFSquaredDifferenceParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUARED_DIFFERENCE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUARED_DIFFERENCE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFSquaredDifferenceParser : public TFNodeParser {
public:
TFSquaredDifferenceParser() = default;
~TFSquaredDifferenceParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUARED_DIFFERENCE_PARSER_H_

Loading…
Cancel
Save