From 07bf8680fc2718fde4ecb97586bd2ea8d1265a67 Mon Sep 17 00:00:00 2001 From: yefeng Date: Tue, 5 Jan 2021 17:32:09 +0800 Subject: [PATCH] 032-tf_parser-6 --- .../parser/tf/tf_activation_parser.cc | 14 ++++ .../parser/tf/tf_batchnorm_parser.cc | 1 + .../converter/parser/tf/tf_dropout_parser.cc | 64 +++++++++++++++++++ .../converter/parser/tf/tf_dropout_parser.h | 37 +++++++++++ .../converter/parser/tf/tf_softmax_parser.cc | 64 +++++++++++++++++++ .../converter/parser/tf/tf_softmax_parser.h | 37 +++++++++++ .../parser/tf/tf_squared_difference_parser.cc | 61 ++++++++++++++++++ .../parser/tf/tf_squared_difference_parser.h | 37 +++++++++++ 8 files changed, 315 insertions(+) create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.h diff --git a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc index 1f29d0f3da..2151fe987e 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc @@ -50,6 +50,8 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, attr->type = schema::ActivationType_SIGMOID; } else if (tf_op.op() == "Tanh") { attr->type = schema::ActivationType_TANH; + } else if (tf_op.op() == "LeakyRelu") { + attr->type = schema::ActivationType_LEAKY_RELU; } else { MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); return RET_ERROR; @@ -57,6 +59,17 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, primitive->value.type = schema::PrimitiveType_Activation; primitive->value.value = attr.release(); + if (tf_op.op() == "LeakyRelu") { + auto attr_leaky_relu = std::make_unique(); + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) { + MS_LOG(ERROR) << "The attribute alpha shoud be specified."; + return RET_ERROR; + } + attr_leaky_relu->negativeSlope = attr_value.f(); + primitive->value.type = schema::PrimitiveType_LeakyReLU; + primitive->value.value = attr_leaky_relu.release(); + } *primitiveC = PrimitiveC::Create(primitive.release()); if (*primitiveC == nullptr) { MS_LOG(ERROR) << "primitiveC is nullptr"; @@ -71,5 +84,6 @@ TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser()); TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser()); TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser()); +TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_batchnorm_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_batchnorm_parser.cc index a228ff7228..5d0cb25ca2 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_batchnorm_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_batchnorm_parser.cc @@ -60,5 +60,6 @@ STATUS TFBatchNormParser::Parse(const tensorflow::NodeDef &tf_op, return RET_OK; } TFNodeRegistrar g_tfBatchNormParser("FusedBatchNormV3", new TFBatchNormParser()); +TFNodeRegistrar g_tfFusedBatchNormParser("FusedBatchNorm", new TFBatchNormParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.cc new file mode 100644 index 0000000000..5e7474fc30 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_dropout_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFDropoutParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF DropoutParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "ratio", &attr_value)) { + MS_LOG(ERROR) << "The ratio attr should be specified"; + return RET_ERROR; + } + attr->ratio = static_cast(attr_value.i()); + primitive->value.type = schema::PrimitiveType_Dropout; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfDropoutParser("Dropout", new TFDropoutParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.h new file mode 100644 index 0000000000..62e992ac89 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_dropout_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DROPOUT_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DROPOUT_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFDropoutParser : public TFNodeParser { + public: + TFDropoutParser() = default; + ~TFDropoutParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DROPOUT_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.cc new file mode 100644 index 0000000000..ace962b1a0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_softmax_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFSoftmaxParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF SoftmaxParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + int axis = -1; + if (TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) { + axis = static_cast(attr_value.i()); + } + attr->axis = axis; + primitive->value.type = schema::PrimitiveType_SoftMax; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + return status; +} +TFNodeRegistrar g_tfSoftmaxParser("Softmax", new TFSoftmaxParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.h new file mode 100644 index 0000000000..ec7d91aa25 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_softmax_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SOFTMAX_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SOFTMAX_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFSoftmaxParser : public TFNodeParser { + public: + TFSoftmaxParser() = default; + ~TFSoftmaxParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SOFTMAX_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.cc new file mode 100644 index 0000000000..bbda1360a8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_squared_difference_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFSquaredDifferenceParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF SquaredDifferenceParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_SquaredDifference; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + *output_size = 1; + auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return RET_ERROR; + } + status = AddOpInput(tf_op, 1, inputs); + return status; +} +TFNodeRegistrar g_tfSquaredDifferenceParser("SquaredDifference", new TFSquaredDifferenceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.h new file mode 100644 index 0000000000..2b557bf615 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_squared_difference_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUARED_DIFFERENCE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUARED_DIFFERENCE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFSquaredDifferenceParser : public TFNodeParser { + public: + TFSquaredDifferenceParser() = default; + ~TFSquaredDifferenceParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUARED_DIFFERENCE_PARSER_H_