| @@ -77,6 +77,8 @@ TypePtr TypeIdToType(TypeId id) { | |||||
| return kInt16; | return kInt16; | ||||
| case kNumberTypeInt32: | case kNumberTypeInt32: | ||||
| return kInt32; | return kInt32; | ||||
| case kNumberTypeInt: | |||||
| return kInt32; | |||||
| case kNumberTypeInt64: | case kNumberTypeInt64: | ||||
| return kInt64; | return kInt64; | ||||
| case kNumberTypeUInt8: | case kNumberTypeUInt8: | ||||
| @@ -119,6 +121,8 @@ TypePtr TypeIdToType(TypeId id) { | |||||
| return kSlice; | return kSlice; | ||||
| case kObjectTypeKeyword: | case kObjectTypeKeyword: | ||||
| return kKeyword; | return kKeyword; | ||||
| case kObjectTypeTensorType: | |||||
| return kTensorType; | |||||
| case kTypeUnknown: | case kTypeUnknown: | ||||
| return kTypeNone; | return kTypeNone; | ||||
| default: | default: | ||||
| @@ -1194,6 +1194,7 @@ table TensorListSetItem { | |||||
| table TensorListReserve { | table TensorListReserve { | ||||
| elementDType : int; | elementDType : int; | ||||
| shapeType : int; | |||||
| } | } | ||||
| table All { | table All { | ||||
| @@ -150,6 +150,11 @@ | |||||
| #include "src/ops/unsorted_segment_sum.h" | #include "src/ops/unsorted_segment_sum.h" | ||||
| #include "src/ops/reciprocal.h" | #include "src/ops/reciprocal.h" | ||||
| #include "src/ops/constant.h" | #include "src/ops/constant.h" | ||||
| #include "src/ops/tensorlistfromtensor.h" | |||||
| #include "src/ops/tensorlistgetitem.h" | |||||
| #include "src/ops/tensorlistsetitem.h" | |||||
| #include "src/ops/tensorlistreserve.h" | |||||
| #include "src/ops/tensorliststack.h" | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| #include "src/ops/neg_grad.h" | #include "src/ops/neg_grad.h" | ||||
| @@ -906,6 +911,16 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new (std::nothrow) Reciprocal(primitive); | return new (std::nothrow) Reciprocal(primitive); | ||||
| case schema::PrimitiveType_Constant: | case schema::PrimitiveType_Constant: | ||||
| return new (std::nothrow) Constant(primitive); | return new (std::nothrow) Constant(primitive); | ||||
| case schema::PrimitiveType_TensorListFromTensor: | |||||
| return new (std::nothrow) TensorListFromTensor(primitive); | |||||
| case schema::PrimitiveType_TensorListGetItem: | |||||
| return new (std::nothrow) TensorListGetItem(primitive); | |||||
| case schema::PrimitiveType_TensorListSetItem: | |||||
| return new (std::nothrow) TensorListSetItem(primitive); | |||||
| case schema::PrimitiveType_TensorListReserve: | |||||
| return new (std::nothrow) TensorListReserve(primitive); | |||||
| case schema::PrimitiveType_TensorListStack: | |||||
| return new (std::nothrow) TensorListStack(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| @@ -52,6 +52,7 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| attr->type = schema::ActivationType_TANH; | attr->type = schema::ActivationType_TANH; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); | MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); | ||||
| return RET_ERROR; | |||||
| } | } | ||||
| primitive->value.type = schema::PrimitiveType_Activation; | primitive->value.type = schema::PrimitiveType_Activation; | ||||
| @@ -117,6 +117,22 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| } | } | ||||
| primitive->value.type = schema::PrimitiveType_LessEqual; | primitive->value.type = schema::PrimitiveType_LessEqual; | ||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| } else if (tf_op.op() == "Equal") { | |||||
| auto attr = std::make_unique<schema::EqualT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Equal; | |||||
| primitive->value.value = attr.release(); | |||||
| } else if (tf_op.op() == "NotEqual") { | |||||
| auto attr = std::make_unique<schema::NotEqualT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_NotEqual; | |||||
| primitive->value.value = attr.release(); | |||||
| } | } | ||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | *primitiveC = PrimitiveC::Create(primitive.release()); | ||||
| @@ -144,5 +160,7 @@ TFNodeRegistrar g_tfGreaterParser("Greater", new TFArithmeticParser()); | |||||
| TFNodeRegistrar g_tfGreaterEqualParser("GreaterEqual", new TFArithmeticParser()); | TFNodeRegistrar g_tfGreaterEqualParser("GreaterEqual", new TFArithmeticParser()); | ||||
| TFNodeRegistrar g_tfLessParser("Less", new TFArithmeticParser()); | TFNodeRegistrar g_tfLessParser("Less", new TFArithmeticParser()); | ||||
| TFNodeRegistrar g_tfLessEqualParser("LessEqual", new TFArithmeticParser()); | TFNodeRegistrar g_tfLessEqualParser("LessEqual", new TFArithmeticParser()); | ||||
| TFNodeRegistrar g_tfEqualParser("Equal", new TFArithmeticParser()); | |||||
| TFNodeRegistrar g_tfNotEqualParser("NotEqual", new TFArithmeticParser()); | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -41,6 +41,7 @@ STATUS TFAssertParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| MS_LOG(ERROR) << "new attr failed"; | MS_LOG(ERROR) << "new attr failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| tensorflow::AttrValue attr_value; | tensorflow::AttrValue attr_value; | ||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "summarize", &attr_value)) { | if (!TensorFlowUtils::FindAttrValue(tf_op, "summarize", &attr_value)) { | ||||
| MS_LOG(ERROR) << "The keep_dims attr should be specified"; | MS_LOG(ERROR) << "The keep_dims attr should be specified"; | ||||
| @@ -56,12 +57,14 @@ STATUS TFAssertParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| *output_size = 0; // Assert not have output | |||||
| for (int i = 0; i < tf_op.input_size(); ++i) { | |||||
| auto status = AddOpInput(tf_op, i, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| return status; | |||||
| return RET_OK; | |||||
| } | } | ||||
| TFNodeRegistrar g_tfAssertParser("Assert", new TFAssertParser()); | TFNodeRegistrar g_tfAssertParser("Assert", new TFAssertParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -15,7 +15,6 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ | #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ | ||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ | #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | #include <map> | ||||
| @@ -45,6 +45,10 @@ STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| attr->group = 1; | attr->group = 1; | ||||
| attr->format = TensorFlowUtils::ParseNodeFormat(tf_op); | attr->format = TensorFlowUtils::ParseNodeFormat(tf_op); | ||||
| if (attr->format == schema::Format_NCHW) { | |||||
| MS_LOG(ERROR) << "TF Conv2D with data_format=NCHW is not supported now"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int64_t> dilations(2); | std::vector<int64_t> dilations(2); | ||||
| auto status = ParseDilations(tf_op, attr->format, &dilations); | auto status = ParseDilations(tf_op, attr->format, &dilations); | ||||
| @@ -25,10 +25,17 @@ | |||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| #include "tools/common/protobuf_utils.h" | #include "tools/common/protobuf_utils.h" | ||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | #include "tools/converter/parser/tf/tf_node_parser_registry.h" | ||||
| #include "tools/optimizer/common/gllo_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | namespace { | ||||
| static const std::vector<schema::PrimitiveType> tensorListOutputOpList = { | |||||
| schema::PrimitiveType_TensorListFromTensor, | |||||
| schema::PrimitiveType_TensorListSetItem, | |||||
| schema::PrimitiveType_TensorListReserve, | |||||
| }; | |||||
| // subgraph node input may be a:output:0/a:z:0 | // subgraph node input may be a:output:0/a:z:0 | ||||
| std::string GetFlattenNodeName(std::string input_name) { | std::string GetFlattenNodeName(std::string input_name) { | ||||
| std::regex re("\\:+"); | std::regex re("\\:+"); | ||||
| @@ -107,7 +114,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value | |||||
| } | } | ||||
| tensor_size = shape_size * sizeof(float); | tensor_size = shape_size * sizeof(float); | ||||
| param_value->SetTensorData(tensor_data, tensor_size); | param_value->SetTensorData(tensor_data, tensor_size); | ||||
| } else if (type == kNumberTypeInt32) { | |||||
| } else if (type == kNumberTypeInt32 || type == kNumberTypeInt) { | |||||
| auto tensor_data = new (std::nothrow) int[shape_size]; | auto tensor_data = new (std::nothrow) int[shape_size]; | ||||
| if (tensor_proto.int_val_size() == 1) { | if (tensor_proto.int_val_size() == 1) { | ||||
| int value = tensor_proto.int_val(0); | int value = tensor_proto.int_val(0); | ||||
| @@ -445,9 +452,19 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C | |||||
| MS_ASSERT(op != nullptr); | MS_ASSERT(op != nullptr); | ||||
| MS_ASSERT(anf_node != nullptr); | MS_ASSERT(anf_node != nullptr); | ||||
| MS_ASSERT(anf_graph != nullptr); | MS_ASSERT(anf_graph != nullptr); | ||||
| if (output_size == 1) { | |||||
| if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node)) && output_size != 1) { | |||||
| MS_LOG(ERROR) << "tensorlist output op output_size !=1"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (output_size == 0) { | |||||
| return RET_OK; | |||||
| } else if (output_size == 1) { | |||||
| auto type = kFloat32; | |||||
| std::vector<int64_t> shape_vector; | std::vector<int64_t> shape_vector; | ||||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||||
| if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { | |||||
| type = TypeIdToType(kObjectTypeTensorType); | |||||
| } | |||||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vector)); | |||||
| anf_node_map->insert(std::pair(op.name(), anf_node)); | anf_node_map->insert(std::pair(op.name(), anf_node)); | ||||
| } else { | } else { | ||||
| AbstractBasePtrList abstractList; | AbstractBasePtrList abstractList; | ||||
| @@ -0,0 +1,87 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFTensorListFromTensorParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, | |||||
| int *output_size) { | |||||
| MS_LOG(INFO) << "TF TensorListFromTensorParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::TensorListFromTensorT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||||
| if (type == kTypeUnknown) { | |||||
| MS_LOG(ERROR) << "tensor_list_from_tensor element dtype must be known type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->elementDType = type; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "shape_type", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The shape_type attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||||
| if (type == kTypeUnknown) { | |||||
| MS_LOG(ERROR) << "tensor_list_from_tensor shape type must be known type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->shapeType = type; | |||||
| primitive->value.type = schema::PrimitiveType_TensorListFromTensor; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| status = AddOpInput(tf_op, 1, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfTensorListFromTensorParser("TensorListFromTensor", new TFTensorListFromTensorParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_FROM_TENSOR_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_FROM_TENSOR_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFTensorListFromTensorParser : public TFNodeParser { | |||||
| public: | |||||
| TFTensorListFromTensorParser() = default; | |||||
| ~TFTensorListFromTensorParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_FROM_TENSOR_PARSER_H_ | |||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_tensor_list_get_item_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFTensorListGetItemParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF TensorListGetItemParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::TensorListGetItemT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||||
| if (type == kTypeUnknown) { | |||||
| MS_LOG(ERROR) << "tensor_list_get_item element_dtype must be known type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->elementDType = type; | |||||
| primitive->value.type = schema::PrimitiveType_TensorListGetItem; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| for (int i = 0; i < 3; ++i) { | |||||
| auto status = AddOpInput(tf_op, i, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TFNodeRegistrar g_tfTensorListGetItemParser("TensorListGetItem", new TFTensorListGetItemParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_GET_ITEM_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_GET_ITEM_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFTensorListGetItemParser : public TFNodeParser { | |||||
| public: | |||||
| TFTensorListGetItemParser() = default; | |||||
| ~TFTensorListGetItemParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_GET_ITEM_PARSER_H_ | |||||
| @@ -0,0 +1,86 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_tensor_list_reserve_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFTensorListReserveParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF TensorListReserveParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::TensorListReserveT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||||
| if (type == kTypeUnknown) { | |||||
| MS_LOG(ERROR) << "tensor_list_reserve element dtype must be known type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->elementDType = type; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "shape_type", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The shape_type attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||||
| if (type == kTypeUnknown) { | |||||
| MS_LOG(ERROR) << "tensor_list_reserve shape_type must be known type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->shapeType = type; | |||||
| primitive->value.type = schema::PrimitiveType_TensorListReserve; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| status = AddOpInput(tf_op, 1, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfTensorListReserveParser("TensorListReserve", new TFTensorListReserveParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_RESERVE_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_RESERVE_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFTensorListReserveParser : public TFNodeParser { | |||||
| public: | |||||
| TFTensorListReserveParser() = default; | |||||
| ~TFTensorListReserveParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_RESERVE_PARSER_H_ | |||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_tensor_list_set_item_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFTensorListSetItemParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF TensorListSetItemParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::TensorListSetItemT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||||
| if (type == kTypeUnknown) { | |||||
| MS_LOG(ERROR) << "tensor_list_set_item element dtype must be known type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->elementDType = type; | |||||
| primitive->value.type = schema::PrimitiveType_TensorListSetItem; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| for (int i = 0; i < 3; ++i) { | |||||
| auto status = AddOpInput(tf_op, i, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| TFNodeRegistrar g_tfTensorListSetItemParser("TensorListSetItem", new TFTensorListSetItemParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_SET_ITEM_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_SET_ITEM_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFTensorListSetItemParser : public TFNodeParser { | |||||
| public: | |||||
| TFTensorListSetItemParser() = default; | |||||
| ~TFTensorListSetItemParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_SET_ITEM_PARSER_H_ | |||||
| @@ -0,0 +1,81 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/tf/tf_tensor_list_stack_parser.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS TFTensorListStackParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||||
| MS_LOG(INFO) << "TF TensorListStackParser"; | |||||
| if (primitiveC == nullptr || output_size == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr = std::make_unique<schema::TensorListStackT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new attr failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||||
| if (type == kTypeUnknown) { | |||||
| MS_LOG(ERROR) << "tensor_list_stack element_dtype must be known type"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->elementDType = type; | |||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "num_elements", &attr_value)) { | |||||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->numElements = attr_value.i(); | |||||
| primitive->value.type = schema::PrimitiveType_TensorListStack; | |||||
| primitive->value.value = attr.release(); | |||||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||||
| if (*primitiveC == nullptr) { | |||||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| *output_size = 1; | |||||
| auto status = AddOpInput(tf_op, 0, inputs); | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| status = AddOpInput(tf_op, 1, inputs); | |||||
| return status; | |||||
| } | |||||
| TFNodeRegistrar g_tfTensorListStackParser("TensorListStack", new TFTensorListStackParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_STACK_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_STACK_PARSER_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class TFTensorListStackParser : public TFNodeParser { | |||||
| public: | |||||
| TFTensorListStackParser() = default; | |||||
| ~TFTensorListStackParser() override = default; | |||||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_STACK_PARSER_H_ | |||||
| @@ -23,17 +23,24 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| static const std::unordered_map<int, mindspore::TypeId> TF_TYPE_MAP = { | static const std::unordered_map<int, mindspore::TypeId> TF_TYPE_MAP = { | ||||
| {tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, | |||||
| {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, | |||||
| {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, | |||||
| {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, | |||||
| {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64}, | |||||
| {tensorflow::DT_BOOL, mindspore::kNumberTypeBool}, {tensorflow::DT_STRING, mindspore::kObjectTypeString}}; | |||||
| {tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, | |||||
| {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, | |||||
| {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, | |||||
| {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, | |||||
| {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, | |||||
| {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, | |||||
| {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, | |||||
| {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, | |||||
| {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, | |||||
| {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64}, | |||||
| {tensorflow::DT_BOOL, mindspore::kNumberTypeBool}, | |||||
| {tensorflow::DT_STRING, mindspore::kObjectTypeString}, | |||||
| {tensorflow::DT_VARIANT, mindspore::kObjectTypeTensorType}}; | |||||
| TypeId TensorFlowUtils::GetTFDataType(const tensorflow::DataType &tf_data_type) { | TypeId TensorFlowUtils::GetTFDataType(const tensorflow::DataType &tf_data_type) { | ||||
| auto iter = TF_TYPE_MAP.find(tf_data_type); | auto iter = TF_TYPE_MAP.find(tf_data_type); | ||||
| if (iter == TF_TYPE_MAP.end()) { | if (iter == TF_TYPE_MAP.end()) { | ||||
| MS_LOG(ERROR) << "unsupported TF data type: " << tf_data_type; | |||||
| MS_LOG(WARNING) << "unsupported TF data type: " << tf_data_type; | |||||
| return kTypeUnknown; | return kTypeUnknown; | ||||
| } | } | ||||
| return iter->second; | return iter->second; | ||||