From 79c08fcd48ad4bffa29ca3ec3a704584c582cbdd Mon Sep 17 00:00:00 2001 From: wangzhe Date: Mon, 14 Dec 2020 10:18:14 +0800 Subject: [PATCH] runtime support tensorlist --- mindspore/lite/src/lite_session.cc | 45 +++-- mindspore/lite/src/ops/tensorlistsetitem.cc | 56 ++++--- .../kernel/arm/fp32/TensorListFromTensor.cc | 4 +- .../kernel/arm/fp32/TensorListSetItem.cc | 13 +- .../kernel/arm/fp32/TensorListStack.cc | 5 +- mindspore/lite/src/scheduler.cc | 5 + mindspore/lite/src/tensorlist.cc | 14 +- mindspore/lite/src/tensorlist.h | 6 +- .../lite/tools/anf_exporter/anf_exporter.cc | 29 +++- .../legacy_optimizer/graph/infershape_pass.cc | 77 ++++++--- .../converter/parser/onnx/onnx_conv_parser.cc | 3 +- .../parser/tf/tf_arithmetic_parser.cc | 2 +- .../converter/parser/tf/tf_biasadd_parser.cc | 4 + .../converter/parser/tf/tf_concat_parser.cc | 6 +- .../converter/parser/tf/tf_conv_parser.cc | 4 +- .../parser/tf/tf_expand_dims_parser.cc | 4 +- .../converter/parser/tf/tf_gather_parser.cc | 4 +- .../converter/parser/tf/tf_model_parser.cc | 157 ++++++++---------- .../converter/parser/tf/tf_model_parser.h | 14 +- .../converter/parser/tf/tf_node_parser.cc | 16 ++ .../converter/parser/tf/tf_node_parser.h | 3 + .../converter/parser/tf/tf_reduce_parser.cc | 4 +- .../parser/tf/tf_reverse_sequence_parser.cc | 70 ++++++++ .../parser/tf/tf_reverse_sequence_parser.h | 37 +++++ .../converter/parser/tf/tf_split_parser.cc | 8 +- .../parser/tf/tf_stride_slice_parser.cc | 12 +- .../converter/parser/tf/tf_tile_parser.cc | 4 +- .../parser/tf/tf_transpose_parser.cc | 5 +- .../lite/tools/converter/parser/tf/tf_util.cc | 29 ++++ .../lite/tools/converter/parser/tf/tf_util.h | 2 + .../graph/clip_convert_activation_pass.cc | 1 + .../tools/optimizer/graph/infershape_pass.cc | 19 +-- 32 files changed, 445 insertions(+), 217 deletions(-) create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.h diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 0ed533d902..81c6cbd53c 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -89,27 +89,38 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde auto data_type = src_tensor->dataType(); if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { - MS_ASSERT(dst_tensor->Size() == src_tensor->data()->size()); - if (WeightTensorNeedCopy(model, tensor_index)) { - auto dst_data = dst_tensor->MutableData(); - if (dst_data == nullptr) { - MS_LOG(ERROR) << "Data from tensor is nullptr"; - return RET_NULL_PTR; + if (src_tensor->dataType() == kObjectTypeTensorType) { + auto tensor_list = reinterpret_cast(dst_tensor); + if (src_tensor->data() == nullptr) { + MS_LOG(ERROR) << "src_tensor->data() is nullptr"; + return RET_ERROR; + } + if (tensor_list->Decode(reinterpret_cast(src_tensor->data()->data())) != RET_OK) { + return RET_ERROR; } - memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); - copyed_tensor_idxes_.emplace_back(tensor_index); } else { - int pack_size = src_tensor->data()->size(); - int org_size = dst_tensor->Size(); - if (pack_size != org_size && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16)) { - auto ret = dst_tensor->MallocData(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Malloc data for tensor failed "; - return RET_ERROR; + MS_ASSERT(dst_tensor->Size() == src_tensor->data()->size()); + if (WeightTensorNeedCopy(model, tensor_index)) { + auto dst_data = dst_tensor->MutableData(); + if (dst_data == nullptr) { + MS_LOG(ERROR) << "Data from tensor is nullptr"; + return RET_NULL_PTR; } - kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); + memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); + copyed_tensor_idxes_.emplace_back(tensor_index); } else { - dst_tensor->set_data(const_cast(src_tensor->data()->data())); + int pack_size = src_tensor->data()->size(); + int org_size = dst_tensor->Size(); + if (pack_size != org_size && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16)) { + auto ret = dst_tensor->MallocData(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Malloc data for tensor failed "; + return RET_ERROR; + } + kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); + } else { + dst_tensor->set_data(const_cast(src_tensor->data()->data())); + } } } } diff --git a/mindspore/lite/src/ops/tensorlistsetitem.cc b/mindspore/lite/src/ops/tensorlistsetitem.cc index aa4a88713d..5626a877e2 100644 --- a/mindspore/lite/src/ops/tensorlistsetitem.cc +++ b/mindspore/lite/src/ops/tensorlistsetitem.cc @@ -97,6 +97,21 @@ int TensorListSetItem::InferShape(std::vector inputs_, std::vect MS_ASSERT(input0 != nullptr); auto get_index = inputs_[1]; MS_ASSERT(get_index != nullptr); + auto value_tensor = inputs_[2]; + MS_ASSERT(value_tensor != nullptr); + auto output0 = reinterpret_cast(outputs_[0]); + MS_ASSERT(output0 != nullptr); + + output0->set_data_type(input0->data_type()); + output0->set_format(input0->format()); + + if (!infer_flag()) { + return RET_INFER_INVALID; + } + if (get_index->data_c() == nullptr || value_tensor->data_c() == nullptr) { + return RET_INFER_INVALID; + } + if (get_index->data_type() != kNumberTypeInt && get_index->data_type() != kNumberTypeInt32) { MS_LOG(ERROR) << "inputs_[1]->data_type():" << get_index->data_type() << " is not int"; return RET_ERROR; @@ -110,31 +125,34 @@ int TensorListSetItem::InferShape(std::vector inputs_, std::vect return RET_NULL_PTR; } int index = reinterpret_cast(get_index->data_c())[0]; - if (index < 0 || index > (input0->ElementsNum() - 1)) { - MS_LOG(ERROR) << "index_:" << index << "must in [0, " << input0->ElementsNum() - 1 << "]"; + if (index < 0 || (index >= static_cast(input0->tensors().size()) && index != 0)) { + MS_LOG(ERROR) << "index_:" << index << "must in [0, " << input0->tensors().size() << "]"; return RET_ERROR; } - auto value_tensor = inputs_[2]; - MS_ASSERT(value_tensor != nullptr); - auto output0 = reinterpret_cast(outputs_[0]); - MS_ASSERT(output0 != nullptr); - output0->set_element_shape(input0->element_shape()); + output0->set_max_elements_num(input0->max_elements_num()); - output0->set_shape(input0->shape()); - output0->set_data_type(input0->data_type()); + output0->set_element_shape(input0->element_shape()); + std::vector > out_shape; - for (int i = 0; i < input0->ElementsNum(); ++i) { - auto src_ptr = input0->GetTensorIndex(i); - if (src_ptr == nullptr) { - MS_LOG(ERROR) << "input0->tensors_[" << i << "] is nullptr!"; - return RET_ERROR; - } - if (src_ptr->data_type() != kTypeUnknown) { - out_shape.push_back(src_ptr->shape()); - } else { - out_shape.push_back(std::vector()); + if (index == 0 && input0->tensors().size() == 0) { // uninitialized tensorlist + out_shape.push_back(value_tensor->shape()); + output0->set_shape(std::vector{1}); + } else { + output0->set_shape(input0->shape()); + for (int i = 0; i < input0->ElementsNum(); ++i) { + auto src_ptr = input0->GetTensorIndex(i); + if (src_ptr == nullptr) { + MS_LOG(ERROR) << "input0->tensors_[" << i << "] is nullptr!"; + return RET_ERROR; + } + if (src_ptr->data_type() != kTypeUnknown) { + out_shape.push_back(src_ptr->shape()); + } else { + out_shape.push_back(std::vector()); + } } } + out_shape[index] = value_tensor->shape(); output0->MallocTensorListData(input0->tensors_data_type(), out_shape); return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListFromTensor.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListFromTensor.cc index 0ab87f0699..5351e69b4f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListFromTensor.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListFromTensor.cc @@ -28,8 +28,8 @@ using mindspore::schema::PrimitiveType_TensorListFromTensor; namespace mindspore::kernel { int TensorListFromTensorCPUKernel::IsCompatibleShape() { - if (input1_->data_type() != kNumberTypeInt) { // element_shape - MS_LOG(ERROR) << "in_tensors_[1] data type is must be \"kNumberTypeInt\", but now is:" << input1_->data_type(); + if (input1_->data_type() != kNumberTypeInt && input1_->data_type() != kNumberTypeInt32) { // element_shape + MS_LOG(ERROR) << "in_tensors_[1] data type is must be int"; return RET_ERROR; } int in1_ele_num = input1_->ElementsNum(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListSetItem.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListSetItem.cc index a842cf2b74..63a0cfadd1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListSetItem.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListSetItem.cc @@ -28,16 +28,17 @@ using mindspore::schema::PrimitiveType_TensorListSetItem; namespace mindspore::kernel { -int TensorListSetItemCPUKernel::Init() { +int TensorListSetItemCPUKernel::Init() { return RET_OK; } + +int TensorListSetItemCPUKernel::Run() { input0_ = reinterpret_cast(in_tensors_[0]); if (dtype_ != input0_->data_type()) { MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type(); return RET_ERROR; } int dim0 = input0_->ElementsNum() - 1; - if (in_tensors_[1]->data_type() != kNumberTypeInt) { - MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() - << " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt; + if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) { + MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int"; return RET_ERROR; } if (in_tensors_[1]->ElementsNum() != 1) { @@ -54,10 +55,6 @@ int TensorListSetItemCPUKernel::Init() { if (!input0_->IsCompatibleShape(input2_->shape())) { return RET_ERROR; } - return RET_OK; -} - -int TensorListSetItemCPUKernel::Run() { output0_ = reinterpret_cast(out_tensors_[0]); MS_ASSERT(output0_ != nullptr); // copy each tensor in tensors_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListStack.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListStack.cc index 742473e811..91ad3f9956 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListStack.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/TensorListStack.cc @@ -73,9 +73,8 @@ bool TensorListStackCPUKernel::IsFullyDefined(const std::vector &shape) con int TensorListStackCPUKernel::MergeElementShape() { MS_ASSERT(in_tensors_[1]); - if (in_tensors_[1]->data_type() != kNumberTypeInt) { - MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() - << " must be \"kNumberTypeInt\":" << kNumberTypeInt; + if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) { + MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int"; return RET_ERROR; } auto ele_shape_data = reinterpret_cast(in_tensors_[1]->data_c()); diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 4d2f514672..964f80ae1b 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -19,6 +19,7 @@ #include #include #include +#include "src/tensorlist.h" #include "src/ops/partial.h" #include "include/errorcode.h" #include "src/common/graph_util.h" @@ -426,6 +427,10 @@ TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector &in_ten if (dtype == kObjectTypeString) { return kNumberTypeFloat32; } + if (dtype == kObjectTypeTensorType) { + auto tensor_list = reinterpret_cast(tensor); + return tensor_list->tensors_data_type(); + } if (dtype == kNumberTypeFloat32 || dtype == kNumberTypeFloat16 || dtype == kNumberTypeInt8 || dtype == kNumberTypeInt32 || dtype == kNumberTypeBool) { return dtype; diff --git a/mindspore/lite/src/tensorlist.cc b/mindspore/lite/src/tensorlist.cc index fde03ac1cb..688016285b 100644 --- a/mindspore/lite/src/tensorlist.cc +++ b/mindspore/lite/src/tensorlist.cc @@ -204,7 +204,7 @@ int TensorList::CheckTensorListParam() { Tensor *TensorList::GetTensorIndex(int index) { // return tensor[index] ptr. With this function, you can modify tensors_[index] at will. - if (index < 0 || index > (this->ElementsNum() - 1)) { + if (index < 0 || index >= static_cast(tensors_.size())) { MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!"; return nullptr; } @@ -240,5 +240,17 @@ bool TensorList::IsCompatibleShape(const Tensor *src) { } return true; } + +STATUS TensorList::Decode(const int *data) { + if (data == nullptr) { + MS_LOG(ERROR) << "data is nullptr"; + return RET_ERROR; + } + tensors_data_type_ = TypeId(data[0]); + for (int j = 0; j < data[1]; ++j) { + element_shape_.push_back(data[2 + j]); + } + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/tensorlist.h b/mindspore/lite/src/tensorlist.h index f810ff6613..580b78bcbd 100644 --- a/mindspore/lite/src/tensorlist.h +++ b/mindspore/lite/src/tensorlist.h @@ -60,6 +60,8 @@ class TensorList : public Tensor { public: TensorList() = default; + TensorList(std::vector shape, std::vector element_shape); + ~TensorList() override; // **Note**: This is a shallow copy, src and dst tensorlist share one memory space of each tensor in tensors_ @@ -74,8 +76,6 @@ class TensorList : public Tensor { // tensorlist deep copy memory TensorList &operator=(const TensorList &tl); - TensorList(std::vector shape, std::vector element_shape); - void set_element_shape(const std::vector &shape) { element_shape_ = shape; } std::vector &element_shape() { return element_shape_; } @@ -112,6 +112,8 @@ class TensorList : public Tensor { bool IsCompatibleShape(const Tensor *src); + STATUS Decode(const int *data); + protected: // The following functions must be masked. void set_data(void *data) override { return; } diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 5d419596b0..f6b13d31fd 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -572,7 +572,12 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr(cnode->abstract())) { auto tuple = std::reinterpret_pointer_cast(cnode->abstract()); - for (size_t i = 0; i < tuple->size(); i++) { + if (tuple == nullptr) { + MS_LOG(ERROR) << "tuple is nullptr"; + return; + } + auto elements = tuple->elements(); + for (size_t i = 0; i < elements.size(); i++) { auto msTensor = new (std::nothrow) schema::TensorT(); if (msTensor == nullptr) { MS_LOG(ERROR) << "new msTensor failed"; @@ -589,7 +594,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrsize() == 1) { + if (elements.size() == 1) { node_id_map_[cnode_name] = meta_graphT->allTensors.size(); msTensor->name = cnode_name; } else { @@ -597,6 +602,18 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptrallTensors.size(); msTensor->name = name; } + + if (!utils::isa(elements[i])) { + MS_LOG(ERROR) << "abstract is not AbstractTensor"; + return; + } + auto type = kNumberTypeFloat32; + if (utils::isa(elements[i])) { + auto abstract_tensor = utils::cast(elements[i]); + auto typePtr = abstract_tensor->element()->GetTypeTrack(); + type = typePtr->type_id(); + } + msTensor->dataType = type; meta_graphT->allTensors.emplace_back(msTensor); if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || @@ -611,8 +628,14 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr(cnode->abstract())) { + auto abstract_tensor = utils::cast(cnode->abstract()); + auto typePtr = abstract_tensor->element()->GetTypeTrack(); + type = typePtr->type_id(); + } + ms_tensor->dataType = type; ms_tensor->nodeType = schema::NodeType_CNode; - ms_tensor->dataType = TypeId::kNumberTypeFloat32; ms_tensor->name = cnode_name; fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); node_id_map_[cnode_name] = meta_graphT->allTensors.size(); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc index 1a03b11a62..d1c2009a5b 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc @@ -19,6 +19,7 @@ #include "src/common/log_adapter.h" #include "include/errorcode.h" #include "src/tensor.h" +#include "src/tensorlist.h" #include "src/ops/primitive_c.h" using mindspore::lite::PrimitiveC; @@ -50,32 +51,58 @@ std::vector ConvertTensorToLiteTensor(MetaGraphT *graph, const std::ve std::vector lite_tensors; bool convert_succ = true; for (size_t i = 0; i < tensor_indexs.size(); i++) { + std::unique_ptr lite_tensor = nullptr; auto &tensorT = graph->allTensors.at(tensor_indexs[i]); - auto tensor_shape = tensorT->dims; - auto lite_tensor = std::make_unique( - TypeId(tensorT->dataType), tensor_shape, tensorT->format, - TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size())); - if (lite_tensor == nullptr) { - MS_LOG(ERROR) << "lite tensor is nullptr"; - convert_succ = false; - break; - } - auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t); - // when tensorT as param input - if (lite_tensor_size == 0) { - lite_tensors.emplace_back(lite_tensor.release()); - continue; - } - auto ret = lite_tensor->MallocData(); - if (ret != 0) { - MS_LOG(ERROR) << "Malloc tensor data failed"; - convert_succ = false; - break; - } - if (memcpy_s(lite_tensor->MutableData(), lite_tensor->Size(), tensorT->data.data(), lite_tensor_size) != EOK) { - MS_LOG(ERROR) << "memcpy_s failed"; - convert_succ = false; - break; + if (tensorT->dataType != kObjectTypeTensorType) { // convert to lite::Tensor + auto tensor_shape = tensorT->dims; + lite_tensor = std::make_unique( + TypeId(tensorT->dataType), tensor_shape, tensorT->format, + TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size())); + if (lite_tensor == nullptr) { + MS_LOG(ERROR) << "lite tensor is nullptr"; + convert_succ = false; + break; + } + auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t); + // when tensorT as param input + if (lite_tensor_size == 0) { + lite_tensors.emplace_back(lite_tensor.release()); + continue; + } + auto ret = lite_tensor->MallocData(); + if (ret != 0) { + MS_LOG(ERROR) << "Malloc tensor data failed"; + convert_succ = false; + break; + } + if (memcpy_s(lite_tensor->MutableData(), lite_tensor->Size(), tensorT->data.data(), lite_tensor_size) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + convert_succ = false; + break; + } + } else { // convert to lite::TensorList + auto tensor_shape = tensorT->dims; + TypeId type = kTypeUnknown; + std::vector element_shape; + if (!tensorT->data.empty()) { + int *data = reinterpret_cast(tensorT->data.data()); + type = TypeId(data[0]); + if (tensorT->data.size() < 8 || (data[1] + 2) * 4 != static_cast(tensorT->data.size())) { + MS_LOG(ERROR) << "tensorlist data length illegal"; + convert_succ = false; + break; + } + for (int j = 0; j < data[1]; ++j) { + element_shape.push_back(data[j + 2]); + } + } + lite_tensor = std::make_unique(tensor_shape, element_shape); + if (lite_tensor == nullptr) { + MS_LOG(ERROR) << "lite tensorlist is nullptr"; + convert_succ = false; + break; + } + reinterpret_cast(lite_tensor.get())->set_tensors_data_type(type); } lite_tensors.emplace_back(lite_tensor.release()); } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc index a81410bf15..fba10b0d98 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -20,7 +20,6 @@ #include namespace mindspore::lite { -constexpr int32_t kSingleGroup = 1; bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr &attr, schema::PrimitiveT *primitive) { MS_LOG(DEBUG) << "onnx DepthwiseConvParser"; @@ -175,7 +174,7 @@ lite::PrimitiveC *OnnxConvParser::ParseLitePrimitive(const onnx::GraphProto &onn MS_LOG(ERROR) << "new primitive failed"; return nullptr; } - if (attr->group > kSingleGroup && attr->group == attr->channelIn) { + if (attr->group == attr->channelIn && attr->channelIn == attr->channelOut) { if (!ParseGroupConvolution(attr, primitive.get())) { MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; return nullptr; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc index 10d4a1e6cb..3c2fe18261 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc @@ -37,7 +37,7 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, return RET_NULL_PTR; } - if (tf_op.op() == "Add") { + if (tf_op.op() == "Add" || tf_op.op() == "AddV2") { auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new attr failed"; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc index cd3434aec2..52cb99a407 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc @@ -54,6 +54,10 @@ STATUS TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op, *output_size = 1; auto status = AddOpInput(tf_op, 0, inputs); + if (status != RET_OK) { + return status; + } + status = AddOpInput(tf_op, 1, inputs); return status; } TFNodeRegistrar g_tfBiasAddParser("BiasAdd", new TFBiasAddParser()); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_concat_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_concat_parser.cc index 7145c545ec..0b87142e93 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_concat_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_concat_parser.cc @@ -42,11 +42,11 @@ STATUS TFConcatParser::Parse(const tensorflow::NodeDef &tf_op, return RET_NULL_PTR; } - if (tf_node_map.find(tf_op.input(tf_op.input_size() - 1)) == tf_node_map.end()) { - MS_LOG(ERROR) << "Find Concat input axis failed"; + auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(tf_op.input_size() - 1)); + if (axis_node == nullptr) { + MS_LOG(ERROR) << "get concat axis attr node failed"; return RET_ERROR; } - auto axis_node = tf_node_map.at(tf_op.input(tf_op.input_size() - 1)); tensorflow::AttrValue attr_value; if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { MS_LOG(ERROR) << "The value attr should be specified"; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc index a790cb5e93..426f004c7b 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc @@ -66,11 +66,11 @@ STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op, attr->strideH = strides[0]; attr->strideW = strides[1]; - if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1)); + if (weight_node == nullptr) { MS_LOG(ERROR) << "Find Conv2D input weights failed"; return RET_ERROR; } - auto weight_node = tf_node_map.at(tf_op.input(1)); std::vector kernels(4); status = ParseKernels(*weight_node, attr->format, &kernels); if (status != RET_OK) { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.cc index ba8288c424..e547a81ac7 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.cc @@ -42,11 +42,11 @@ STATUS TFExpandDimsParser::Parse(const tensorflow::NodeDef &tf_op, return RET_NULL_PTR; } - if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(1)); + if (axis_node == nullptr) { MS_LOG(ERROR) << "Find ExpandDims input axis failed"; return RET_ERROR; } - auto axis_node = tf_node_map.at(tf_op.input(1)); tensorflow::AttrValue attr_value; if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { MS_LOG(ERROR) << "The value attr should be specified"; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_gather_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_gather_parser.cc index 20c1c670b7..597145f468 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_gather_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_gather_parser.cc @@ -50,11 +50,11 @@ STATUS TFGatherParser::Parse(const tensorflow::NodeDef &tf_op, bool axis_is_set = false; if (tf_op.input_size() == 3) { axis_is_set = true; - if (tf_node_map.find(tf_op.input(2)) == tf_node_map.end()) { + auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(2)); + if (axis_node == nullptr) { MS_LOG(ERROR) << "Find Gather input axis failed"; return RET_ERROR; } - auto axis_node = tf_node_map.at(tf_op.input(2)); if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { MS_LOG(ERROR) << "The value attr should be specified"; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 2184e4d2cc..0bde8184b5 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -17,7 +17,6 @@ #include "tools/converter/parser/tf/tf_model_parser.h" #include -#include #include #include "src/common/log_adapter.h" #include "src/common/utils.h" @@ -25,31 +24,11 @@ #include "tools/common/graph_util.h" #include "tools/common/protobuf_utils.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h" -#include "tools/optimizer/common/gllo_utils.h" namespace mindspore { namespace lite { namespace { -static const std::vector tensorListOutputOpList = { - schema::PrimitiveType_TensorListFromTensor, - schema::PrimitiveType_TensorListSetItem, - schema::PrimitiveType_TensorListReserve, -}; -// subgraph node input may be a:output:0/a:z:0 -std::string GetFlattenNodeName(std::string input_name) { - std::regex re("\\:+"); - std::vector input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1), - std::sregex_token_iterator()); - if (input_splits.size() == 3) { - if (input_splits[2] == "0") { - input_name = input_splits[0]; - } else { - input_name = input_splits[0] + input_splits[2]; // multi output node - } - } - return input_name; -} AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map &anf_node_map) { AnfNodePtr ret = nullptr; if (anf_node_map.find(name) != anf_node_map.end()) { @@ -67,10 +46,11 @@ std::string GetOriginInputName(const tensorflow::NodeDef &node, } auto tmp_node = &node; while (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient") { - if (tf_graph_nodes.find(tmp_node->input(0)) == tf_graph_nodes.end()) { - return tmp_node->input(0); + auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(tmp_node->input(0)); + if (tf_graph_nodes.find(flatten_input_name) == tf_graph_nodes.end()) { + return flatten_input_name; } - tmp_node = tf_graph_nodes.at(tmp_node->input(0)); + tmp_node = tf_graph_nodes.at(flatten_input_name); } return tmp_node->name(); } @@ -89,6 +69,10 @@ STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_ MS_LOG(ERROR) << "Only TensorList type is supported now"; return RET_NOT_SUPPORT; } + if (variant.tensors_size() > 0) { + MS_LOG(ERROR) << "Only empty tensorlist is supported now"; + return RET_NOT_SUPPORT; + } auto descriptor = variant.GetMetadata().descriptor; auto reflection = variant.GetMetadata().reflection; if (descriptor == nullptr || reflection == nullptr) { @@ -232,6 +216,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value param_value->set_tensor_type(type); param_value->set_format(schema::Format::Format_NHWC); parameter->set_default_param(param_value); + parameter->set_name("const_" + std::to_string(anf_root_node_map.size()) + "_parameter"); return RET_OK; } @@ -263,7 +248,8 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa return status; } } else { - graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names + parameter->set_name("placeholder_" + std::to_string(anf_root_node_map.size())); + graph_input_names.emplace_back(parameter->name()); // only root graph need set graph input names } auto abstract_tensor = std::make_shared(type_ptr, shape_vector); @@ -271,12 +257,9 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa MS_LOG(ERROR) << "abstract_tensor is nullptr"; return RET_ERROR; } - parameter->set_name(node.name()); parameter->set_abstract(abstract_tensor); - (*anf_node_map)[node.name()] = parameter; (*anf_node_map)[node.name() + ":0"] = parameter; - return RET_OK; } @@ -311,48 +294,43 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - tf_root_graph_ = std::make_unique(); - if (tf_root_graph_ == nullptr) { - MS_LOG(ERROR) << "tf_root_graph_ is nullptr"; + tf_root_graph = std::make_unique(); + if (tf_root_graph == nullptr) { + MS_LOG(ERROR) << "tf_root_graph is nullptr"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } - status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get()); + status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph.get()); if (status != RET_OK) { MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } - anf_root_graph_ = std::make_shared(); - if (anf_root_graph_ == nullptr) { + anf_root_graph = std::make_shared(); + if (anf_root_graph == nullptr) { MS_LOG(ERROR) << "funGraphPtr is nullptr"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } - for (int i = 0; i < tf_root_graph_->node_size(); i++) { - auto &node_def = tf_root_graph_->node(i); - tf_root_graph_nodes_[node_def.name()] = &node_def; + for (int i = 0; i < tf_root_graph->node_size(); i++) { + auto &node_def = tf_root_graph->node(i); + tf_root_graph_nodes[node_def.name()] = &node_def; } - status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); + status = ConvertGraphInputsAndConsts(tf_root_graph_nodes, anf_root_graph, &anf_root_node_map); if (status != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - bool success_flag = true; - for (int i = 0; i < tf_root_graph_->node_size(); i++) { - auto &node_def = tf_root_graph_->node(i); - status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); - if (status != RET_OK) { - success_flag = false; + for (int i = 0; i < tf_root_graph->node_size(); i++) { + auto &node_def = tf_root_graph->node(i); + if (ConvertOps(node_def, tf_root_graph_nodes, anf_root_graph, &anf_root_node_map) != RET_OK) { + MS_LOG(ERROR) << "Convert ops failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; } } - if (!success_flag) { - MS_LOG(ERROR) << "Convert ops failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; - } status = ConvertRootGraphOutputs(); if (status != RET_OK) { MS_LOG(ERROR) << "Convert graph outputs failed."; @@ -367,25 +345,25 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin return nullptr; } - return anf_root_graph_; + return anf_root_graph; } STATUS TFModelParser::ConvertSubgraph() { - auto graph_def_liarary = tf_root_graph_->library(); + auto graph_def_liarary = tf_root_graph->library(); auto subgraph_size = graph_def_liarary.function_size(); std::map while_cond_map; std::map while_body_map; - std::vector sub_graph_inputs; for (int i = 0; i < subgraph_size; i++) { + std::vector sub_graph_inputs; auto &tf_sub_fuction = graph_def_liarary.function(i); auto &tf_sub_signature = tf_sub_fuction.signature(); auto input_arg_size = tf_sub_signature.input_arg_size(); auto &sub_graph_name = tf_sub_signature.name(); - if (!function_while_map_.count(sub_graph_name)) { + if (!function_while_map.count(sub_graph_name)) { MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name; return RET_ERROR; } - auto while_cnode = function_while_map_[sub_graph_name]->cast(); + auto while_cnode = function_while_map[sub_graph_name]->cast(); if (while_cnode == nullptr || static_cast(while_cnode->inputs().size()) != input_arg_size + 1) { MS_LOG(ERROR) << "while cnode not equal input arg size"; return RET_ERROR; @@ -426,9 +404,16 @@ STATUS TFModelParser::ConvertSubgraph() { // convert subgraph outputs std::vector sub_output_nodes; auto &subgraph_ret = tf_sub_fuction.ret(); - for (auto &t : subgraph_ret) { - MS_LOG(INFO) << "subret " << t.first << " " << t.second; - auto tf_output_name = GetFlattenNodeName(t.second); + auto &output_args = tf_sub_signature.output_arg(); + for (auto &output_arg : output_args) { + auto &signature_name = output_arg.name(); + if (subgraph_ret.find(signature_name) == subgraph_ret.end()) { + MS_LOG(ERROR) << "can't find signature_name: " << signature_name; + return RET_ERROR; + } + auto t = subgraph_ret.find(signature_name); + MS_LOG(INFO) << "subret " << t->first << " " << t->second; + auto tf_output_name = TensorFlowUtils::GetFlattenNodeName(t->second); AnfNodePtr anf_node = nullptr; if (tf_sub_node_map.find(tf_output_name) == tf_sub_node_map.end()) { anf_node = GetAnfNode(tf_output_name, anf_sub_node_map); @@ -456,7 +441,7 @@ STATUS TFModelParser::ConvertSubgraph() { } // hardcode subgraph inputs name for (size_t j = 0; j < sub_graph_inputs.size(); j++) { - sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter"); + sub_graph_inputs[j]->set_name("graph_input_" + std::to_string(j) + "parameter"); } MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; } @@ -473,9 +458,9 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map roots = {anf_root_graph_}; + std::vector roots = {anf_root_graph}; auto root_func_manager = std::make_shared(roots); - anf_root_graph_->set_manager(root_func_manager); + anf_root_graph->set_manager(root_func_manager); for (auto &kv : while_cond_map) { auto while_node = kv.first; auto &cond_sub_graph = kv.second; @@ -484,12 +469,9 @@ STATUS TFModelParser::WhileNodePostProcess(const std::mapset_manager(root_func_manager); auto cond_value_node = NewValueNode(cond_sub_graph); auto body_value_node = NewValueNode(body_sub_graph); - auto new_while_inputs = while_node->cast()->inputs(); - new_while_inputs[0] = cond_value_node; - new_while_inputs.insert(new_while_inputs.begin() + 1, body_value_node); - auto new_while_node = anf_root_graph_->NewCNode(new_while_inputs); - new_while_node->set_abstract(while_node->abstract()); - root_func_manager->Replace(while_node, new_while_node); + auto inputs = while_node->inputs(); + inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node}); + while_node->set_inputs(inputs); } return RET_OK; } @@ -510,7 +492,7 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, for (size_t j = 0; j < input_names.size(); j++) { std::string input_name = input_names[j]; // input may be produced by multi-outputs node // subgraph input name x:output:index,need flatten - auto flatten_input_name = GetFlattenNodeName(input_name); + auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(input_name); if (tf_node_map.find(flatten_input_name) != tf_node_map.end()) { auto input_node = tf_node_map.at(flatten_input_name); flatten_input_name = GetOriginInputName(*input_node, tf_node_map); @@ -531,20 +513,10 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C MS_ASSERT(op != nullptr); MS_ASSERT(anf_node != nullptr); MS_ASSERT(anf_graph != nullptr); - if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node)) && output_size != 1) { - MS_LOG(ERROR) << "tensorlist output op output_size !=1"; - return RET_ERROR; - } - if (output_size == 0) { - return RET_OK; - } else if (output_size == 1) { - auto type = kFloat32; + if (output_size == 1) { std::vector shape_vector; - if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { - type = TypeIdToType(kObjectTypeTensorType); - } - anf_node->set_abstract(std::make_shared(type, shape_vector)); - anf_node_map->insert(std::pair(op.name(), anf_node)); + anf_node->set_abstract(std::make_shared(kFloat32, shape_vector)); + anf_node_map->insert(std::pair(op.name() + ":0", anf_node)); } else { AbstractBasePtrList abstractList; for (int output_idx = 0; output_idx < output_size; output_idx++) { @@ -608,17 +580,17 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, // control_depends are not processed currently auto anf_node = func_graph_ptr->NewCNode(inputs); anf_node->set_fullname_with_scope(node_def.name()); - if (op_type == "StatelessWhile" || op_type == "while") { + if (op_type == "StatelessWhile" || op_type == "While") { MS_LOG(INFO) << "find while node:" << node_def.name(); tensorflow::AttrValue attr_value; if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) { auto body_name = attr_value.func().name(); - function_while_map_[body_name] = anf_node; + function_while_map[body_name] = anf_node; MS_LOG(DEBUG) << "parse body name:" << body_name; } if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) { auto cond_name = attr_value.func().name(); - function_while_map_[cond_name] = anf_node; + function_while_map[cond_name] = anf_node; MS_LOG(DEBUG) << "parse cond name:" << cond_name; } } @@ -634,28 +606,31 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, STATUS TFModelParser::ConvertRootGraphOutputs() { // because output of intermediate node in anf graph may also be output tensors, we search output tensors in - // tf_root_graph_nodes_ but not anf_root_node_map_ + // tf_root_graph_nodes but not anf_root_node_map std::set all_node_inputs; std::vector output_nodes; - for (auto &pair : tf_root_graph_nodes_) { + for (auto &pair : tf_root_graph_nodes) { for (int i = 0; i < pair.second->input_size(); ++i) { - all_node_inputs.insert(pair.second->input(i)); + all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i))); } } - for (auto &pair : tf_root_graph_nodes_) { + for (auto &pair : tf_root_graph_nodes) { + if (pair.second->op() == "Assert") { + continue; + } auto it = all_node_inputs.find(pair.first); if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity - auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); - auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); + auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes); + auto anf_node = GetAnfNode(origin_name, anf_root_node_map); if (anf_node == nullptr) { MS_LOG(ERROR) << "can't find anf node"; return RET_ERROR; } output_nodes.push_back(anf_node); - graph_output_names_.push_back(anf_node->fullname_with_scope()); + graph_output_names.push_back(anf_node->fullname_with_scope()); } } - auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); + auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph); if (status != RET_OK) { MS_LOG(ERROR) << "make anf graph outputs node error"; return status; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index d112967bd5..a4dbd4230d 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -71,13 +71,13 @@ class TFModelParser : public ModelParser { STATUS MakeAnfGraphOutputs(std::vector *output_nodes, const FuncGraphPtr &anf_graph); - FuncGraphPtr anf_root_graph_; - std::unique_ptr tf_root_graph_; // tf root graph def - std::map tf_root_graph_nodes_; // tf root graph node map - std::unordered_map anf_root_node_map_; - std::vector graph_input_names_; - std::vector graph_output_names_; - std::map function_while_map_; // tf function name->while_node_name + FuncGraphPtr anf_root_graph; + std::unique_ptr tf_root_graph; // tf root graph def + std::map tf_root_graph_nodes; // tf root graph node map + std::unordered_map anf_root_node_map; + std::vector graph_input_names; + std::vector graph_output_names; + std::map function_while_map; // tf function name->while_node_name }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc index 7af394d659..54ab046b9b 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc @@ -17,6 +17,8 @@ #include #include +using tensorflow::NodeDef; + namespace mindspore { namespace lite { STATUS TFNodeParser::AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector *inputs) { @@ -27,5 +29,19 @@ STATUS TFNodeParser::AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, inputs->push_back(tf_op.input(idx)); return RET_OK; } + +const NodeDef *TFNodeParser::GetConstInputNode(const std::map &tf_node_map, + const std::string &input_name) { + auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(input_name); + if (tf_node_map.find(flatten_input_name) == tf_node_map.end()) { + return nullptr; + } + auto node = tf_node_map.at(flatten_input_name); + if (node->op() != "Const") { + MS_LOG(ERROR) << "Attr node is not Const"; + return nullptr; + } + return node; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_node_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.h index d6d3abf99b..2b36a83eef 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_node_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_node_parser.h @@ -40,6 +40,9 @@ class TFNodeParser { } STATUS AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector *inputs); + + const tensorflow::NodeDef *GetConstInputNode(const std::map &tf_node_map, + const std::string &input_name); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.cc index 65f2390ff5..1776868565 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.cc @@ -69,11 +69,11 @@ STATUS TFReduceParser::Parse(const tensorflow::NodeDef &tf_op, } attr->keepDims = attr_value.b(); - if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(1)); + if (axis_node == nullptr) { MS_LOG(ERROR) << "Find Reduce input axis failed"; return RET_ERROR; } - auto axis_node = tf_node_map.at(tf_op.input(1)); if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { MS_LOG(ERROR) << "The value attr should be specified"; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc new file mode 100644 index 0000000000..fa7c4db27a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WRRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_reverse_sequence_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF ReverseSequenceParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "New PrimitiveT failed"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "batch_dim", &attr_value)) { + MS_LOG(ERROR) << "The batch_dim attr should be specified"; + return RET_ERROR; + } + attr->batchAxis = attr_value.i(); + if (!TensorFlowUtils::FindAttrValue(tf_op, "seq_dim", &attr_value)) { + MS_LOG(ERROR) << "The seq_dim attr should be specified"; + return RET_ERROR; + } + attr->seqAxis = attr_value.i(); + + primitive->value.type = schema::PrimitiveType_ReverseSequence; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + return AddOpInput(tf_op, 0, inputs); +} +TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.h new file mode 100644 index 0000000000..e7b6e13742 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_SEQUENCE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_SEQUENCE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFReverseSequenceParser : public TFNodeParser { + public: + TFReverseSequenceParser() = default; + ~TFReverseSequenceParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_SEQUENCE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc index 7f3912b87f..5dd8733d96 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc @@ -58,11 +58,11 @@ STATUS TFSplitParser::Parse(const tensorflow::NodeDef &tf_op, input_index = 0; } - if (tf_node_map.find(tf_op.input(split_dim_index)) == tf_node_map.end()) { + auto split_dim_node = GetConstInputNode(tf_node_map, tf_op.input(split_dim_index)); + if (split_dim_node == nullptr) { MS_LOG(ERROR) << "Find Split input split_dim node failed"; return RET_ERROR; } - const auto &split_dim_node = tf_node_map.at(tf_op.input(split_dim_index)); if (!TensorFlowUtils::FindAttrValue(*split_dim_node, "value", &attr_value)) { MS_LOG(ERROR) << "The attribute splitDim should be specified"; return RET_PARAM_INVALID; @@ -72,11 +72,11 @@ STATUS TFSplitParser::Parse(const tensorflow::NodeDef &tf_op, *output_size = attr->numberSplit; if (tf_op.op() == "SplitV") { - if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + auto size_splits_node = GetConstInputNode(tf_node_map, tf_op.input(1)); + if (size_splits_node == nullptr) { MS_LOG(ERROR) << "Find Split input size_splits failed"; return RET_ERROR; } - auto size_splits_node = tf_node_map.at(tf_op.input(1)); if (!TensorFlowUtils::FindAttrValue(*size_splits_node, "value", &attr_value)) { MS_LOG(ERROR) << "The attribute size splits should be specified"; return RET_PARAM_INVALID; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc index b7255da91c..6d6f31f998 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc @@ -74,11 +74,11 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, attr->shrinkAxisMask = attr_value.i(); // begin - if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1)); + if (begin_node == nullptr) { MS_LOG(ERROR) << "Find StridedSlice input begin failed"; return RET_ERROR; } - auto begin_node = tf_node_map.at(tf_op.input(1)); if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) { MS_LOG(ERROR) << "The value attr should be specified"; return RET_ERROR; @@ -97,11 +97,11 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, } // end - if (tf_node_map.find(tf_op.input(2)) == tf_node_map.end()) { + auto end_node = GetConstInputNode(tf_node_map, tf_op.input(2)); + if (end_node == nullptr) { MS_LOG(ERROR) << "Find StridedSlice input end failed"; return RET_ERROR; } - auto end_node = tf_node_map.at(tf_op.input(2)); if (!TensorFlowUtils::FindAttrValue(*end_node, "value", &attr_value)) { MS_LOG(ERROR) << "The value attr should be specified"; return RET_ERROR; @@ -120,11 +120,11 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, } // strides - if (tf_node_map.find(tf_op.input(3)) == tf_node_map.end()) { + auto stride_node = GetConstInputNode(tf_node_map, tf_op.input(3)); + if (stride_node == nullptr) { MS_LOG(ERROR) << "Find StridedSlice input strides failed"; return RET_ERROR; } - auto stride_node = tf_node_map.at(tf_op.input(3)); if (!TensorFlowUtils::FindAttrValue(*stride_node, "value", &attr_value)) { MS_LOG(ERROR) << "The value attr should be specified"; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_tile_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_tile_parser.cc index 42e6c131c5..e05bc213b3 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_tile_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_tile_parser.cc @@ -42,11 +42,11 @@ STATUS TFTileParser::Parse(const tensorflow::NodeDef &tf_op, return RET_NULL_PTR; } - if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + auto multiplies_node = GetConstInputNode(tf_node_map, tf_op.input(1)); + if (multiplies_node == nullptr) { MS_LOG(ERROR) << "Find Tile input multiplies failed"; return RET_ERROR; } - auto multiplies_node = tf_node_map.at(tf_op.input(1)); tensorflow::AttrValue attr_value; if (!TensorFlowUtils::FindAttrValue(*multiplies_node, "value", &attr_value)) { MS_LOG(ERROR) << "The value attr should be specified"; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc index 5d8e7a3ec1..b8d3f52d44 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc @@ -42,11 +42,12 @@ STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op, return RET_NULL_PTR; } - if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { + attr->conjugate = false; + auto perm_node = GetConstInputNode(tf_node_map, tf_op.input(1)); + if (perm_node == nullptr) { MS_LOG(ERROR) << "Find Transpose input perm failed"; return RET_ERROR; } - auto perm_node = tf_node_map.at(tf_op.input(1)); tensorflow::AttrValue attr_value; if (!TensorFlowUtils::FindAttrValue(*perm_node, "value", &attr_value)) { MS_LOG(ERROR) << "The value attr should be specified"; diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.cc b/mindspore/lite/tools/converter/parser/tf/tf_util.cc index a13828787d..791c2468af 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_util.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.cc @@ -16,8 +16,10 @@ #include "tools/converter/parser/tf/tf_util.h" #include +#include #include #include +#include #include "src/common/log_adapter.h" #include "schema/inner/model_generated.h" @@ -112,5 +114,32 @@ bool TensorFlowUtils::DecodeInt64(std::string_view *str_view, uint64_t *value) { return true; } } + +// convert input_arg in subgraph to node_name[:index] format +std::string TensorFlowUtils::GetFlattenNodeName(const std::string &input_name) { + std::regex re("\\:+"); + std::vector input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1), + std::sregex_token_iterator()); + std::string ret = input_name; + if (input_splits.size() == 3) { + if (input_splits[2] == "0") { + ret = input_splits[0]; + } else { + ret = input_splits[0] + input_splits[2]; // multi output node + } + } + return ret; +} + +// get referenced node name from input name +std::string TensorFlowUtils::GetNodeName(const std::string &input_name) { + std::regex re("\\:+"); + std::vector input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1), + std::sregex_token_iterator()); + if (input_splits.size() > 1) { + return input_splits[0]; + } + return input_name; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_util.h b/mindspore/lite/tools/converter/parser/tf/tf_util.h index 9ce7eaed3a..d93cdebacb 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_util.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_util.h @@ -34,6 +34,8 @@ class TensorFlowUtils { static TypeId ParseAttrDataType(const tensorflow::NodeDef &node_def, const std::string &attr_name); static schema::Format ParseNodeFormat(const tensorflow::NodeDef &node_def); static bool DecodeInt64(std::string_view *str_view, uint64_t *value); + static std::string GetFlattenNodeName(const std::string &input_name); + static std::string GetNodeName(const std::string &input_name); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc index e23b43aab4..d50523715d 100644 --- a/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc @@ -101,6 +101,7 @@ bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) { op_inputs.push_back(clip_cnode->input(1)); auto new_cnode = graph->NewCNode(op_inputs); new_cnode->set_fullname_with_scope(node->fullname_with_scope()); + new_cnode->set_abstract(clip_cnode->abstract()->Clone()); manager->Replace(node, new_cnode); } return false; diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index dfaa3f2770..ea14116151 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -121,7 +121,7 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector(cnode->input(i))) { - MS_LOG(ERROR) << "input is value node"; + MS_LOG(WARNING) << "input is value node"; continue; } @@ -178,13 +178,10 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector(param_value->tensor_addr()); - auto tensor_list = dynamic_cast(tensor.get()); - tensor_list->set_tensors_data_type(TypeId(data[0])); - std::vector shape; - for (int j = 0; j < data[1]; ++j) { - shape.push_back(data[2 + j]); + auto tensor_list = reinterpret_cast(tensor.get()); + if (tensor_list->Decode(data) != RET_OK) { + return RET_ERROR; } - tensor_list->set_element_shape(shape); } } } @@ -210,8 +207,8 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector< return RET_ERROR; } auto abstract_tensor = utils::cast(element); - auto typePtr = abstract_tensor->element()->GetTypeTrack(); - types.push_back(typePtr->type_id()); + auto type_ptr = abstract_tensor->element()->GetTypeTrack(); + types.push_back(type_ptr->type_id()); } } else { if (!utils::isa(abstract)) { @@ -219,8 +216,8 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector< return RET_ERROR; } auto abstract_tensor = utils::cast(abstract); - auto typePtr = abstract_tensor->element()->GetTypeTrack(); - types.push_back(typePtr->type_id()); + auto type_ptr = abstract_tensor->element()->GetTypeTrack(); + types.push_back(type_ptr->type_id()); } for (auto &type : types) { std::unique_ptr output_tensor = nullptr;