From: @YeFeng_24 Reviewed-by: @hangangqiang,@hangangqiang Signed-off-by: @hangangqiangtags/v1.2.0-rc1
| @@ -61,6 +61,8 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| status = CreateOperator<schema::LogT>(primitive, schema::PrimitiveType_Log); | status = CreateOperator<schema::LogT>(primitive, schema::PrimitiveType_Log); | ||||
| } else if (tf_op.op() == "Sqrt") { | } else if (tf_op.op() == "Sqrt") { | ||||
| status = CreateOperator<schema::SqrtT>(primitive, schema::PrimitiveType_Sqrt); | status = CreateOperator<schema::SqrtT>(primitive, schema::PrimitiveType_Sqrt); | ||||
| } else if (tf_op.op() == "Pow") { | |||||
| status = CreateOperator<schema::PowerT>(primitive, schema::PrimitiveType_Power); | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| return status; | return status; | ||||
| @@ -84,5 +86,6 @@ TFNodeRegistrar g_tfExpParser("Exp", new TFArithmeticSelfParser()); | |||||
| TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser()); | TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser()); | ||||
| TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser()); | TFNodeRegistrar g_tfLogParser("Log", new TFArithmeticSelfParser()); | ||||
| TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser()); | TFNodeRegistrar g_tfSqrtParser("Sqrt", new TFArithmeticSelfParser()); | ||||
| TFNodeRegistrar g_tfPowParser("Pow", new TFArithmeticSelfParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_one_hot_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFOneHotParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC, | |||||
| std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF OneHotParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::OneHotT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "axis", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The axis attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->axis = static_cast<int32_t>(attr_value.i()); | |||||
| primitive->value.type = schema::PrimitiveType_OneHot; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| for (int i = 0; i < tf_op.input_size(); ++i) { | |||||
| auto status = AddOpInput(tf_op, i, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TFNodeRegistrar g_tfOneHotParser("OneHot", new TFOneHotParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFOneHotParser : public TFNodeParser { | |||||
| public: | |||||
| TFOneHotParser() = default; | |||||
| ~TFOneHotParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ONE_HOT_PARSER_H_ | |||||
| @@ -56,6 +56,8 @@ STATUS TFResizeParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| attr->method = schema::ResizeMethod_LINEAR; | attr->method = schema::ResizeMethod_LINEAR; | ||||
| } else if (tf_op.op() == "ResizeNearestNeighbor") { | } else if (tf_op.op() == "ResizeNearestNeighbor") { | ||||
| attr->method = schema::ResizeMethod_NEAREST; | attr->method = schema::ResizeMethod_NEAREST; | ||||
| } else if (tf_op.op() == "ResizeBicubic") { | |||||
| attr->method = schema::ResizeMethod_CUBIC; | |||||
| } else { | } else { | ||||
| attr->method = schema::ResizeMethod_UNKNOWN; | attr->method = schema::ResizeMethod_UNKNOWN; | ||||
| } | } | ||||
| @@ -90,5 +92,6 @@ STATUS TFResizeParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| } | } | ||||
| TFNodeRegistrar g_tfResizeBilinearParser("ResizeBilinear", new TFResizeParser()); | TFNodeRegistrar g_tfResizeBilinearParser("ResizeBilinear", new TFResizeParser()); | ||||
| TFNodeRegistrar g_tfResizeNearestNeighborParser("ResizeNearestNeighbor", new TFResizeParser()); | TFNodeRegistrar g_tfResizeNearestNeighborParser("ResizeNearestNeighbor", new TFResizeParser()); | ||||
| TFNodeRegistrar g_tfResizeBicubicParser("ResizeBicubic", new TFResizeParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||