| @@ -77,6 +77,8 @@ TypePtr TypeIdToType(TypeId id) { | |||
| return kInt16; | |||
| case kNumberTypeInt32: | |||
| return kInt32; | |||
| case kNumberTypeInt: | |||
| return kInt32; | |||
| case kNumberTypeInt64: | |||
| return kInt64; | |||
| case kNumberTypeUInt8: | |||
| @@ -119,6 +121,8 @@ TypePtr TypeIdToType(TypeId id) { | |||
| return kSlice; | |||
| case kObjectTypeKeyword: | |||
| return kKeyword; | |||
| case kObjectTypeTensorType: | |||
| return kTensorType; | |||
| case kTypeUnknown: | |||
| return kTypeNone; | |||
| default: | |||
| @@ -1194,6 +1194,7 @@ table TensorListSetItem { | |||
| table TensorListReserve { | |||
| elementDType : int; | |||
| shapeType : int; | |||
| } | |||
| table All { | |||
| @@ -150,6 +150,11 @@ | |||
| #include "src/ops/unsorted_segment_sum.h" | |||
| #include "src/ops/reciprocal.h" | |||
| #include "src/ops/constant.h" | |||
| #include "src/ops/tensorlistfromtensor.h" | |||
| #include "src/ops/tensorlistgetitem.h" | |||
| #include "src/ops/tensorlistsetitem.h" | |||
| #include "src/ops/tensorlistreserve.h" | |||
| #include "src/ops/tensorliststack.h" | |||
| #ifdef SUPPORT_TRAIN | |||
| #include "src/ops/neg_grad.h" | |||
| @@ -906,6 +911,16 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new (std::nothrow) Reciprocal(primitive); | |||
| case schema::PrimitiveType_Constant: | |||
| return new (std::nothrow) Constant(primitive); | |||
| case schema::PrimitiveType_TensorListFromTensor: | |||
| return new (std::nothrow) TensorListFromTensor(primitive); | |||
| case schema::PrimitiveType_TensorListGetItem: | |||
| return new (std::nothrow) TensorListGetItem(primitive); | |||
| case schema::PrimitiveType_TensorListSetItem: | |||
| return new (std::nothrow) TensorListSetItem(primitive); | |||
| case schema::PrimitiveType_TensorListReserve: | |||
| return new (std::nothrow) TensorListReserve(primitive); | |||
| case schema::PrimitiveType_TensorListStack: | |||
| return new (std::nothrow) TensorListStack(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| @@ -52,6 +52,7 @@ STATUS TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| attr->type = schema::ActivationType_TANH; | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); | |||
| return RET_ERROR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Activation; | |||
| @@ -117,6 +117,22 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_LessEqual; | |||
| primitive->value.value = attr.release(); | |||
| } else if (tf_op.op() == "Equal") { | |||
| auto attr = std::make_unique<schema::EqualT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_Equal; | |||
| primitive->value.value = attr.release(); | |||
| } else if (tf_op.op() == "NotEqual") { | |||
| auto attr = std::make_unique<schema::NotEqualT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| primitive->value.type = schema::PrimitiveType_NotEqual; | |||
| primitive->value.value = attr.release(); | |||
| } | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| @@ -144,5 +160,7 @@ TFNodeRegistrar g_tfGreaterParser("Greater", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfGreaterEqualParser("GreaterEqual", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfLessParser("Less", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfLessEqualParser("LessEqual", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfEqualParser("Equal", new TFArithmeticParser()); | |||
| TFNodeRegistrar g_tfNotEqualParser("NotEqual", new TFArithmeticParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -41,6 +41,7 @@ STATUS TFAssertParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "summarize", &attr_value)) { | |||
| MS_LOG(ERROR) << "The keep_dims attr should be specified"; | |||
| @@ -56,12 +57,14 @@ STATUS TFAssertParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| *output_size = 0; // Assert not have output | |||
| for (int i = 0; i < tf_op.input_size(); ++i) { | |||
| auto status = AddOpInput(tf_op, i, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| } | |||
| return status; | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfAssertParser("Assert", new TFAssertParser()); | |||
| } // namespace lite | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ASSERT_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| @@ -45,6 +45,10 @@ STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| attr->group = 1; | |||
| attr->format = TensorFlowUtils::ParseNodeFormat(tf_op); | |||
| if (attr->format == schema::Format_NCHW) { | |||
| MS_LOG(ERROR) << "TF Conv2D with data_format=NCHW is not supported now"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int64_t> dilations(2); | |||
| auto status = ParseDilations(tf_op, attr->format, &dilations); | |||
| @@ -25,10 +25,17 @@ | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/protobuf_utils.h" | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace { | |||
| static const std::vector<schema::PrimitiveType> tensorListOutputOpList = { | |||
| schema::PrimitiveType_TensorListFromTensor, | |||
| schema::PrimitiveType_TensorListSetItem, | |||
| schema::PrimitiveType_TensorListReserve, | |||
| }; | |||
| // subgraph node input may be a:output:0/a:z:0 | |||
| std::string GetFlattenNodeName(std::string input_name) { | |||
| std::regex re("\\:+"); | |||
| @@ -107,7 +114,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value | |||
| } | |||
| tensor_size = shape_size * sizeof(float); | |||
| param_value->SetTensorData(tensor_data, tensor_size); | |||
| } else if (type == kNumberTypeInt32) { | |||
| } else if (type == kNumberTypeInt32 || type == kNumberTypeInt) { | |||
| auto tensor_data = new (std::nothrow) int[shape_size]; | |||
| if (tensor_proto.int_val_size() == 1) { | |||
| int value = tensor_proto.int_val(0); | |||
| @@ -445,9 +452,19 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C | |||
| MS_ASSERT(op != nullptr); | |||
| MS_ASSERT(anf_node != nullptr); | |||
| MS_ASSERT(anf_graph != nullptr); | |||
| if (output_size == 1) { | |||
| if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node)) && output_size != 1) { | |||
| MS_LOG(ERROR) << "tensorlist output op output_size !=1"; | |||
| return RET_ERROR; | |||
| } | |||
| if (output_size == 0) { | |||
| return RET_OK; | |||
| } else if (output_size == 1) { | |||
| auto type = kFloat32; | |||
| std::vector<int64_t> shape_vector; | |||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||
| if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { | |||
| type = TypeIdToType(kObjectTypeTensorType); | |||
| } | |||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vector)); | |||
| anf_node_map->insert(std::pair(op.name(), anf_node)); | |||
| } else { | |||
| AbstractBasePtrList abstractList; | |||
| @@ -0,0 +1,87 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_tensor_list_from_tensor_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFTensorListFromTensorParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, | |||
| int *output_size) { | |||
| MS_LOG(INFO) << "TF TensorListFromTensorParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::TensorListFromTensorT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { | |||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||
| if (type == kTypeUnknown) { | |||
| MS_LOG(ERROR) << "tensor_list_from_tensor element dtype must be known type"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->elementDType = type; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "shape_type", &attr_value)) { | |||
| MS_LOG(ERROR) << "The shape_type attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||
| if (type == kTypeUnknown) { | |||
| MS_LOG(ERROR) << "tensor_list_from_tensor shape type must be known type"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->shapeType = type; | |||
| primitive->value.type = schema::PrimitiveType_TensorListFromTensor; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| status = AddOpInput(tf_op, 1, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfTensorListFromTensorParser("TensorListFromTensor", new TFTensorListFromTensorParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_FROM_TENSOR_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_FROM_TENSOR_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFTensorListFromTensorParser : public TFNodeParser { | |||
| public: | |||
| TFTensorListFromTensorParser() = default; | |||
| ~TFTensorListFromTensorParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_FROM_TENSOR_PARSER_H_ | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_tensor_list_get_item_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFTensorListGetItemParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF TensorListGetItemParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::TensorListGetItemT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { | |||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||
| if (type == kTypeUnknown) { | |||
| MS_LOG(ERROR) << "tensor_list_get_item element_dtype must be known type"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->elementDType = type; | |||
| primitive->value.type = schema::PrimitiveType_TensorListGetItem; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < 3; ++i) { | |||
| auto status = AddOpInput(tf_op, i, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfTensorListGetItemParser("TensorListGetItem", new TFTensorListGetItemParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_GET_ITEM_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_GET_ITEM_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFTensorListGetItemParser : public TFNodeParser { | |||
| public: | |||
| TFTensorListGetItemParser() = default; | |||
| ~TFTensorListGetItemParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_GET_ITEM_PARSER_H_ | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_tensor_list_reserve_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFTensorListReserveParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF TensorListReserveParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::TensorListReserveT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { | |||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||
| if (type == kTypeUnknown) { | |||
| MS_LOG(ERROR) << "tensor_list_reserve element dtype must be known type"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->elementDType = type; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "shape_type", &attr_value)) { | |||
| MS_LOG(ERROR) << "The shape_type attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||
| if (type == kTypeUnknown) { | |||
| MS_LOG(ERROR) << "tensor_list_reserve shape_type must be known type"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->shapeType = type; | |||
| primitive->value.type = schema::PrimitiveType_TensorListReserve; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| status = AddOpInput(tf_op, 1, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfTensorListReserveParser("TensorListReserve", new TFTensorListReserveParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_RESERVE_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_RESERVE_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFTensorListReserveParser : public TFNodeParser { | |||
| public: | |||
| TFTensorListReserveParser() = default; | |||
| ~TFTensorListReserveParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_RESERVE_PARSER_H_ | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_tensor_list_set_item_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFTensorListSetItemParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF TensorListSetItemParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::TensorListSetItemT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { | |||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||
| if (type == kTypeUnknown) { | |||
| MS_LOG(ERROR) << "tensor_list_set_item element dtype must be known type"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->elementDType = type; | |||
| primitive->value.type = schema::PrimitiveType_TensorListSetItem; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| for (int i = 0; i < 3; ++i) { | |||
| auto status = AddOpInput(tf_op, i, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TFNodeRegistrar g_tfTensorListSetItemParser("TensorListSetItem", new TFTensorListSetItemParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_SET_ITEM_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_SET_ITEM_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFTensorListSetItemParser : public TFNodeParser { | |||
| public: | |||
| TFTensorListSetItemParser() = default; | |||
| ~TFTensorListSetItemParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_SET_ITEM_PARSER_H_ | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_tensor_list_stack_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFTensorListStackParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF TensorListStackParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::TensorListStackT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "element_dtype", &attr_value)) { | |||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| auto type = TensorFlowUtils::GetTFDataType(attr_value.type()); | |||
| if (type == kTypeUnknown) { | |||
| MS_LOG(ERROR) << "tensor_list_stack element_dtype must be known type"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->elementDType = type; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "num_elements", &attr_value)) { | |||
| MS_LOG(ERROR) << "The element_dtype attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->numElements = attr_value.i(); | |||
| primitive->value.type = schema::PrimitiveType_TensorListStack; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| status = AddOpInput(tf_op, 1, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfTensorListStackParser("TensorListStack", new TFTensorListStackParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_STACK_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_STACK_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFTensorListStackParser : public TFNodeParser { | |||
| public: | |||
| TFTensorListStackParser() = default; | |||
| ~TFTensorListStackParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_TENSOR_LIST_STACK_PARSER_H_ | |||
| @@ -23,17 +23,24 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| static const std::unordered_map<int, mindspore::TypeId> TF_TYPE_MAP = { | |||
| {tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, | |||
| {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, | |||
| {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, | |||
| {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, | |||
| {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64}, | |||
| {tensorflow::DT_BOOL, mindspore::kNumberTypeBool}, {tensorflow::DT_STRING, mindspore::kObjectTypeString}}; | |||
| {tensorflow::DT_INT8, mindspore::kNumberTypeInt8}, | |||
| {tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8}, | |||
| {tensorflow::DT_INT16, mindspore::kNumberTypeInt16}, | |||
| {tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16}, | |||
| {tensorflow::DT_INT32, mindspore::kNumberTypeInt32}, | |||
| {tensorflow::DT_INT64, mindspore::kNumberTypeInt64}, | |||
| {tensorflow::DT_HALF, mindspore::kNumberTypeFloat16}, | |||
| {tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32}, | |||
| {tensorflow::DT_DOUBLE, mindspore::kNumberTypeFloat64}, | |||
| {tensorflow::DT_COMPLEX64, mindspore::kNumberTypeComplex64}, | |||
| {tensorflow::DT_BOOL, mindspore::kNumberTypeBool}, | |||
| {tensorflow::DT_STRING, mindspore::kObjectTypeString}, | |||
| {tensorflow::DT_VARIANT, mindspore::kObjectTypeTensorType}}; | |||
| TypeId TensorFlowUtils::GetTFDataType(const tensorflow::DataType &tf_data_type) { | |||
| auto iter = TF_TYPE_MAP.find(tf_data_type); | |||
| if (iter == TF_TYPE_MAP.end()) { | |||
| MS_LOG(ERROR) << "unsupported TF data type: " << tf_data_type; | |||
| MS_LOG(WARNING) << "unsupported TF data type: " << tf_data_type; | |||
| return kTypeUnknown; | |||
| } | |||
| return iter->second; | |||