From: @YeFeng_24 Reviewed-by: @hangangqiang Signed-off-by:tags/v1.2.0-rc1
| @@ -50,6 +50,8 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| attr->type = schema::ActivationType_SIGMOID; | attr->type = schema::ActivationType_SIGMOID; | ||||
| } else if (tf_op.op() == "Tanh") { | } else if (tf_op.op() == "Tanh") { | ||||
| attr->type = schema::ActivationType_TANH; | attr->type = schema::ActivationType_TANH; | ||||
| } else if (tf_op.op() == "LeakyRelu") { | |||||
| attr->type = schema::ActivationType_LEAKY_RELU; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); | MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -57,6 +59,17 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| primitive->value.type = schema::PrimitiveType_Activation; | primitive->value.type = schema::PrimitiveType_Activation; | ||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| if (tf_op.op() == "LeakyRelu") { | |||||
| auto attr_leaky_relu = std::make_unique<schema::LeakyReLUT>(); | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The attribute alpha shoud be specified."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr_leaky_relu->negativeSlope = attr_value.f(); | |||||
| primitive->value.type = schema::PrimitiveType_LeakyReLU; | |||||
| primitive->value.value = attr_leaky_relu.release(); | |||||
| } | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | *primitiveC = PrimitiveC::Create(primitive.release()); | ||||
| if (*primitiveC == nullptr) { | if (*primitiveC == nullptr) { | ||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | MS_LOG(ERROR) << "primitiveC is nullptr"; | ||||
| @@ -71,5 +84,6 @@ TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser()); | |||||
| TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); | TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); | ||||
| TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser()); | TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser()); | ||||
| TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser()); | TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser()); | ||||
| TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -60,5 +60,6 @@ STATUS TFBatchNormParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TFNodeRegistrar g_tfBatchNormParser("FusedBatchNormV3", new TFBatchNormParser()); | TFNodeRegistrar g_tfBatchNormParser("FusedBatchNormV3", new TFBatchNormParser()); | ||||
| TFNodeRegistrar g_tfFusedBatchNormParser("FusedBatchNorm", new TFBatchNormParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,64 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_dropout_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFDropoutParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF DropoutParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::DropoutT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "ratio", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The ratio attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->ratio = static_cast<int32_t>(attr_value.i()); | |||||
| primitive->value.type = schema::PrimitiveType_Dropout; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfDropoutParser("Dropout", new TFDropoutParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DROPOUT_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DROPOUT_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFDropoutParser : public TFNodeParser { | |||||
| public: | |||||
| TFDropoutParser() = default; | |||||
| ~TFDropoutParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DROPOUT_PARSER_H_ | |||||
| @@ -0,0 +1,64 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_softmax_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFSoftmaxParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF SoftmaxParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::SoftMaxT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| int axis = -1; | |||||
| if (TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) { | |||||
| axis = static_cast<int32_t>(attr_value.i()); | |||||
| } | |||||
| attr->axis = axis; | |||||
| primitive->value.type = schema::PrimitiveType_SoftMax; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfSoftmaxParser("Softmax", new TFSoftmaxParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SOFTMAX_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SOFTMAX_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFSoftmaxParser : public TFNodeParser { | |||||
| public: | |||||
| TFSoftmaxParser() = default; | |||||
| ~TFSoftmaxParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SOFTMAX_PARSER_H_ | |||||
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_squared_difference_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFSquaredDifferenceParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF SquaredDifferenceParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::SquaredDifferenceT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_SquaredDifference; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| if (status != RET_OK) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| status = AddOpInput(tf_op, 1, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfSquaredDifferenceParser("SquaredDifference", new TFSquaredDifferenceParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUARED_DIFFERENCE_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUARED_DIFFERENCE_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFSquaredDifferenceParser : public TFNodeParser { | |||||
| public: | |||||
| TFSquaredDifferenceParser() = default; | |||||
| ~TFSquaredDifferenceParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_SQUARED_DIFFERENCE_PARSER_H_ | |||||