From: @wangzhe128 Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -89,27 +89,38 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde | |||
| auto data_type = src_tensor->dataType(); | |||
| if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && | |||
| src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { | |||
| MS_ASSERT(dst_tensor->Size() == src_tensor->data()->size()); | |||
| if (WeightTensorNeedCopy(model, tensor_index)) { | |||
| auto dst_data = dst_tensor->MutableData(); | |||
| if (dst_data == nullptr) { | |||
| MS_LOG(ERROR) << "Data from tensor is nullptr"; | |||
| return RET_NULL_PTR; | |||
| if (src_tensor->dataType() == kObjectTypeTensorType) { | |||
| auto tensor_list = reinterpret_cast<TensorList *>(dst_tensor); | |||
| if (src_tensor->data() == nullptr) { | |||
| MS_LOG(ERROR) << "src_tensor->data() is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (tensor_list->Decode(reinterpret_cast<const int *>(src_tensor->data()->data())) != RET_OK) { | |||
| return RET_ERROR; | |||
| } | |||
| memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); | |||
| copyed_tensor_idxes_.emplace_back(tensor_index); | |||
| } else { | |||
| int pack_size = src_tensor->data()->size(); | |||
| int org_size = dst_tensor->Size(); | |||
| if (pack_size != org_size && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16)) { | |||
| auto ret = dst_tensor->MallocData(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Malloc data for tensor failed "; | |||
| return RET_ERROR; | |||
| MS_ASSERT(dst_tensor->Size() == src_tensor->data()->size()); | |||
| if (WeightTensorNeedCopy(model, tensor_index)) { | |||
| auto dst_data = dst_tensor->MutableData(); | |||
| if (dst_data == nullptr) { | |||
| MS_LOG(ERROR) << "Data from tensor is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); | |||
| memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); | |||
| copyed_tensor_idxes_.emplace_back(tensor_index); | |||
| } else { | |||
| dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); | |||
| int pack_size = src_tensor->data()->size(); | |||
| int org_size = dst_tensor->Size(); | |||
| if (pack_size != org_size && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16)) { | |||
| auto ret = dst_tensor->MallocData(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Malloc data for tensor failed "; | |||
| return RET_ERROR; | |||
| } | |||
| kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData()); | |||
| } else { | |||
| dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -97,6 +97,21 @@ int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| MS_ASSERT(input0 != nullptr); | |||
| auto get_index = inputs_[1]; | |||
| MS_ASSERT(get_index != nullptr); | |||
| auto value_tensor = inputs_[2]; | |||
| MS_ASSERT(value_tensor != nullptr); | |||
| auto output0 = reinterpret_cast<TensorList *>(outputs_[0]); | |||
| MS_ASSERT(output0 != nullptr); | |||
| output0->set_data_type(input0->data_type()); | |||
| output0->set_format(input0->format()); | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| if (get_index->data_c() == nullptr || value_tensor->data_c() == nullptr) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| if (get_index->data_type() != kNumberTypeInt && get_index->data_type() != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "inputs_[1]->data_type():" << get_index->data_type() << " is not int"; | |||
| return RET_ERROR; | |||
| @@ -110,31 +125,34 @@ int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| return RET_NULL_PTR; | |||
| } | |||
| int index = reinterpret_cast<int *>(get_index->data_c())[0]; | |||
| if (index < 0 || index > (input0->ElementsNum() - 1)) { | |||
| MS_LOG(ERROR) << "index_:" << index << "must in [0, " << input0->ElementsNum() - 1 << "]"; | |||
| if (index < 0 || (index >= static_cast<int>(input0->tensors().size()) && index != 0)) { | |||
| MS_LOG(ERROR) << "index_:" << index << "must in [0, " << input0->tensors().size() << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| auto value_tensor = inputs_[2]; | |||
| MS_ASSERT(value_tensor != nullptr); | |||
| auto output0 = reinterpret_cast<TensorList *>(outputs_[0]); | |||
| MS_ASSERT(output0 != nullptr); | |||
| output0->set_element_shape(input0->element_shape()); | |||
| output0->set_max_elements_num(input0->max_elements_num()); | |||
| output0->set_shape(input0->shape()); | |||
| output0->set_data_type(input0->data_type()); | |||
| output0->set_element_shape(input0->element_shape()); | |||
| std::vector<std::vector<int> > out_shape; | |||
| for (int i = 0; i < input0->ElementsNum(); ++i) { | |||
| auto src_ptr = input0->GetTensorIndex(i); | |||
| if (src_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "input0->tensors_[" << i << "] is nullptr!"; | |||
| return RET_ERROR; | |||
| } | |||
| if (src_ptr->data_type() != kTypeUnknown) { | |||
| out_shape.push_back(src_ptr->shape()); | |||
| } else { | |||
| out_shape.push_back(std::vector<int>()); | |||
| if (index == 0 && input0->tensors().size() == 0) { // uninitialized tensorlist | |||
| out_shape.push_back(value_tensor->shape()); | |||
| output0->set_shape(std::vector<int>{1}); | |||
| } else { | |||
| output0->set_shape(input0->shape()); | |||
| for (int i = 0; i < input0->ElementsNum(); ++i) { | |||
| auto src_ptr = input0->GetTensorIndex(i); | |||
| if (src_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "input0->tensors_[" << i << "] is nullptr!"; | |||
| return RET_ERROR; | |||
| } | |||
| if (src_ptr->data_type() != kTypeUnknown) { | |||
| out_shape.push_back(src_ptr->shape()); | |||
| } else { | |||
| out_shape.push_back(std::vector<int>()); | |||
| } | |||
| } | |||
| } | |||
| out_shape[index] = value_tensor->shape(); | |||
| output0->MallocTensorListData(input0->tensors_data_type(), out_shape); | |||
| return RET_OK; | |||
| @@ -28,8 +28,8 @@ using mindspore::schema::PrimitiveType_TensorListFromTensor; | |||
| namespace mindspore::kernel { | |||
| int TensorListFromTensorCPUKernel::IsCompatibleShape() { | |||
| if (input1_->data_type() != kNumberTypeInt) { // element_shape | |||
| MS_LOG(ERROR) << "in_tensors_[1] data type is must be \"kNumberTypeInt\", but now is:" << input1_->data_type(); | |||
| if (input1_->data_type() != kNumberTypeInt && input1_->data_type() != kNumberTypeInt32) { // element_shape | |||
| MS_LOG(ERROR) << "in_tensors_[1] data type is must be int"; | |||
| return RET_ERROR; | |||
| } | |||
| int in1_ele_num = input1_->ElementsNum(); | |||
| @@ -28,16 +28,17 @@ using mindspore::schema::PrimitiveType_TensorListSetItem; | |||
| namespace mindspore::kernel { | |||
| int TensorListSetItemCPUKernel::Init() { | |||
| int TensorListSetItemCPUKernel::Init() { return RET_OK; } | |||
| int TensorListSetItemCPUKernel::Run() { | |||
| input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]); | |||
| if (dtype_ != input0_->data_type()) { | |||
| MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| int dim0 = input0_->ElementsNum() - 1; | |||
| if (in_tensors_[1]->data_type() != kNumberTypeInt) { | |||
| MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() | |||
| << " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt; | |||
| if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int"; | |||
| return RET_ERROR; | |||
| } | |||
| if (in_tensors_[1]->ElementsNum() != 1) { | |||
| @@ -54,10 +55,6 @@ int TensorListSetItemCPUKernel::Init() { | |||
| if (!input0_->IsCompatibleShape(input2_->shape())) { | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int TensorListSetItemCPUKernel::Run() { | |||
| output0_ = reinterpret_cast<lite::TensorList *>(out_tensors_[0]); | |||
| MS_ASSERT(output0_ != nullptr); | |||
| // copy each tensor in tensors_ | |||
| @@ -73,9 +73,8 @@ bool TensorListStackCPUKernel::IsFullyDefined(const std::vector<int> &shape) con | |||
| int TensorListStackCPUKernel::MergeElementShape() { | |||
| MS_ASSERT(in_tensors_[1]); | |||
| if (in_tensors_[1]->data_type() != kNumberTypeInt) { | |||
| MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() | |||
| << " must be \"kNumberTypeInt\":" << kNumberTypeInt; | |||
| if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int"; | |||
| return RET_ERROR; | |||
| } | |||
| auto ele_shape_data = reinterpret_cast<int *>(in_tensors_[1]->data_c()); | |||
| @@ -19,6 +19,7 @@ | |||
| #include <queue> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "src/tensorlist.h" | |||
| #include "src/ops/partial.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/common/graph_util.h" | |||
| @@ -426,6 +427,10 @@ TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_ten | |||
| if (dtype == kObjectTypeString) { | |||
| return kNumberTypeFloat32; | |||
| } | |||
| if (dtype == kObjectTypeTensorType) { | |||
| auto tensor_list = reinterpret_cast<TensorList *>(tensor); | |||
| return tensor_list->tensors_data_type(); | |||
| } | |||
| if (dtype == kNumberTypeFloat32 || dtype == kNumberTypeFloat16 || dtype == kNumberTypeInt8 || | |||
| dtype == kNumberTypeInt32 || dtype == kNumberTypeBool) { | |||
| return dtype; | |||
| @@ -204,7 +204,7 @@ int TensorList::CheckTensorListParam() { | |||
| Tensor *TensorList::GetTensorIndex(int index) { | |||
| // return tensor[index] ptr. With this function, you can modify tensors_[index] at will. | |||
| if (index < 0 || index > (this->ElementsNum() - 1)) { | |||
| if (index < 0 || index >= static_cast<int>(tensors_.size())) { | |||
| MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!"; | |||
| return nullptr; | |||
| } | |||
| @@ -240,5 +240,17 @@ bool TensorList::IsCompatibleShape(const Tensor *src) { | |||
| } | |||
| return true; | |||
| } | |||
| STATUS TensorList::Decode(const int *data) { | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "data is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| tensors_data_type_ = TypeId(data[0]); | |||
| for (int j = 0; j < data[1]; ++j) { | |||
| element_shape_.push_back(data[2 + j]); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -60,6 +60,8 @@ class TensorList : public Tensor { | |||
| public: | |||
| TensorList() = default; | |||
| TensorList(std::vector<int> shape, std::vector<int> element_shape); | |||
| ~TensorList() override; | |||
| // **Note**: This is a shallow copy, src and dst tensorlist share one memory space of each tensor in tensors_ | |||
| @@ -74,8 +76,6 @@ class TensorList : public Tensor { | |||
| // tensorlist deep copy memory | |||
| TensorList &operator=(const TensorList &tl); | |||
| TensorList(std::vector<int> shape, std::vector<int> element_shape); | |||
| void set_element_shape(const std::vector<int> &shape) { element_shape_ = shape; } | |||
| std::vector<int> &element_shape() { return element_shape_; } | |||
| @@ -112,6 +112,8 @@ class TensorList : public Tensor { | |||
| bool IsCompatibleShape(const Tensor *src); | |||
| STATUS Decode(const int *data); | |||
| protected: | |||
| // The following functions must be masked. | |||
| void set_data(void *data) override { return; } | |||
| @@ -572,7 +572,12 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) { | |||
| auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract()); | |||
| for (size_t i = 0; i < tuple->size(); i++) { | |||
| if (tuple == nullptr) { | |||
| MS_LOG(ERROR) << "tuple is nullptr"; | |||
| return; | |||
| } | |||
| auto elements = tuple->elements(); | |||
| for (size_t i = 0; i < elements.size(); i++) { | |||
| auto msTensor = new (std::nothrow) schema::TensorT(); | |||
| if (msTensor == nullptr) { | |||
| MS_LOG(ERROR) << "new msTensor failed"; | |||
| @@ -589,7 +594,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam)) | |||
| break; | |||
| #else | |||
| if (tuple->size() == 1) { | |||
| if (elements.size() == 1) { | |||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | |||
| msTensor->name = cnode_name; | |||
| } else { | |||
| @@ -597,6 +602,18 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| node_id_map_[name] = meta_graphT->allTensors.size(); | |||
| msTensor->name = name; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||
| MS_LOG(ERROR) << "abstract is not AbstractTensor"; | |||
| return; | |||
| } | |||
| auto type = kNumberTypeFloat32; | |||
| if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| type = typePtr->type_id(); | |||
| } | |||
| msTensor->dataType = type; | |||
| meta_graphT->allTensors.emplace_back(msTensor); | |||
| if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || | |||
| IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || | |||
| @@ -611,8 +628,14 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s | |||
| MS_LOG(ERROR) << "new tensor failed"; | |||
| return; | |||
| } | |||
| auto type = kNumberTypeFloat32; | |||
| if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) { | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract()); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| type = typePtr->type_id(); | |||
| } | |||
| ms_tensor->dataType = type; | |||
| ms_tensor->nodeType = schema::NodeType_CNode; | |||
| ms_tensor->dataType = TypeId::kNumberTypeFloat32; | |||
| ms_tensor->name = cnode_name; | |||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | |||
| @@ -19,6 +19,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/tensor.h" | |||
| #include "src/tensorlist.h" | |||
| #include "src/ops/primitive_c.h" | |||
| using mindspore::lite::PrimitiveC; | |||
| @@ -50,32 +51,58 @@ std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::ve | |||
| std::vector<Tensor *> lite_tensors; | |||
| bool convert_succ = true; | |||
| for (size_t i = 0; i < tensor_indexs.size(); i++) { | |||
| std::unique_ptr<Tensor> lite_tensor = nullptr; | |||
| auto &tensorT = graph->allTensors.at(tensor_indexs[i]); | |||
| auto tensor_shape = tensorT->dims; | |||
| auto lite_tensor = std::make_unique<Tensor>( | |||
| TypeId(tensorT->dataType), tensor_shape, tensorT->format, | |||
| TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size())); | |||
| if (lite_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "lite tensor is nullptr"; | |||
| convert_succ = false; | |||
| break; | |||
| } | |||
| auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t); | |||
| // when tensorT as param input | |||
| if (lite_tensor_size == 0) { | |||
| lite_tensors.emplace_back(lite_tensor.release()); | |||
| continue; | |||
| } | |||
| auto ret = lite_tensor->MallocData(); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "Malloc tensor data failed"; | |||
| convert_succ = false; | |||
| break; | |||
| } | |||
| if (memcpy_s(lite_tensor->MutableData(), lite_tensor->Size(), tensorT->data.data(), lite_tensor_size) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| convert_succ = false; | |||
| break; | |||
| if (tensorT->dataType != kObjectTypeTensorType) { // convert to lite::Tensor | |||
| auto tensor_shape = tensorT->dims; | |||
| lite_tensor = std::make_unique<Tensor>( | |||
| TypeId(tensorT->dataType), tensor_shape, tensorT->format, | |||
| TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size())); | |||
| if (lite_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "lite tensor is nullptr"; | |||
| convert_succ = false; | |||
| break; | |||
| } | |||
| auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t); | |||
| // when tensorT as param input | |||
| if (lite_tensor_size == 0) { | |||
| lite_tensors.emplace_back(lite_tensor.release()); | |||
| continue; | |||
| } | |||
| auto ret = lite_tensor->MallocData(); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "Malloc tensor data failed"; | |||
| convert_succ = false; | |||
| break; | |||
| } | |||
| if (memcpy_s(lite_tensor->MutableData(), lite_tensor->Size(), tensorT->data.data(), lite_tensor_size) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| convert_succ = false; | |||
| break; | |||
| } | |||
| } else { // convert to lite::TensorList | |||
| auto tensor_shape = tensorT->dims; | |||
| TypeId type = kTypeUnknown; | |||
| std::vector<int> element_shape; | |||
| if (!tensorT->data.empty()) { | |||
| int *data = reinterpret_cast<int *>(tensorT->data.data()); | |||
| type = TypeId(data[0]); | |||
| if (tensorT->data.size() < 8 || (data[1] + 2) * 4 != static_cast<int>(tensorT->data.size())) { | |||
| MS_LOG(ERROR) << "tensorlist data length illegal"; | |||
| convert_succ = false; | |||
| break; | |||
| } | |||
| for (int j = 0; j < data[1]; ++j) { | |||
| element_shape.push_back(data[j + 2]); | |||
| } | |||
| } | |||
| lite_tensor = std::make_unique<TensorList>(tensor_shape, element_shape); | |||
| if (lite_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "lite tensorlist is nullptr"; | |||
| convert_succ = false; | |||
| break; | |||
| } | |||
| reinterpret_cast<TensorList *>(lite_tensor.get())->set_tensors_data_type(type); | |||
| } | |||
| lite_tensors.emplace_back(lite_tensor.release()); | |||
| } | |||
| @@ -20,7 +20,6 @@ | |||
| #include <vector> | |||
| namespace mindspore::lite { | |||
| constexpr int32_t kSingleGroup = 1; | |||
| bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, | |||
| schema::PrimitiveT *primitive) { | |||
| MS_LOG(DEBUG) << "onnx DepthwiseConvParser"; | |||
| @@ -175,7 +174,7 @@ lite::PrimitiveC *OnnxConvParser::ParseLitePrimitive(const onnx::GraphProto &onn | |||
| MS_LOG(ERROR) << "new primitive failed"; | |||
| return nullptr; | |||
| } | |||
| if (attr->group > kSingleGroup && attr->group == attr->channelIn) { | |||
| if (attr->group == attr->channelIn && attr->channelIn == attr->channelOut) { | |||
| if (!ParseGroupConvolution(attr, primitive.get())) { | |||
| MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; | |||
| return nullptr; | |||
| @@ -37,7 +37,7 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tf_op.op() == "Add") { | |||
| if (tf_op.op() == "Add" || tf_op.op() == "AddV2") { | |||
| auto attr = std::make_unique<schema::AddT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| @@ -54,6 +54,10 @@ STATUS TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| *output_size = 1; | |||
| auto status = AddOpInput(tf_op, 0, inputs); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| status = AddOpInput(tf_op, 1, inputs); | |||
| return status; | |||
| } | |||
| TFNodeRegistrar g_tfBiasAddParser("BiasAdd", new TFBiasAddParser()); | |||
| @@ -42,11 +42,11 @@ STATUS TFConcatParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tf_node_map.find(tf_op.input(tf_op.input_size() - 1)) == tf_node_map.end()) { | |||
| MS_LOG(ERROR) << "Find Concat input axis failed"; | |||
| auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(tf_op.input_size() - 1)); | |||
| if (axis_node == nullptr) { | |||
| MS_LOG(ERROR) << "get concat axis attr node failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto axis_node = tf_node_map.at(tf_op.input(tf_op.input_size() - 1)); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The value attr should be specified"; | |||
| @@ -66,11 +66,11 @@ STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| attr->strideH = strides[0]; | |||
| attr->strideW = strides[1]; | |||
| if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { | |||
| auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| if (weight_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find Conv2D input weights failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto weight_node = tf_node_map.at(tf_op.input(1)); | |||
| std::vector<int64_t> kernels(4); | |||
| status = ParseKernels(*weight_node, attr->format, &kernels); | |||
| if (status != RET_OK) { | |||
| @@ -42,11 +42,11 @@ STATUS TFExpandDimsParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { | |||
| auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| if (axis_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find ExpandDims input axis failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto axis_node = tf_node_map.at(tf_op.input(1)); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The value attr should be specified"; | |||
| @@ -50,11 +50,11 @@ STATUS TFGatherParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| bool axis_is_set = false; | |||
| if (tf_op.input_size() == 3) { | |||
| axis_is_set = true; | |||
| if (tf_node_map.find(tf_op.input(2)) == tf_node_map.end()) { | |||
| auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(2)); | |||
| if (axis_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find Gather input axis failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto axis_node = tf_node_map.at(tf_op.input(2)); | |||
| if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The value attr should be specified"; | |||
| return RET_ERROR; | |||
| @@ -17,7 +17,6 @@ | |||
| #include "tools/converter/parser/tf/tf_model_parser.h" | |||
| #include <functional> | |||
| #include <regex> | |||
| #include <set> | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| @@ -25,31 +24,11 @@ | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/protobuf_utils.h" | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace { | |||
| static const std::vector<schema::PrimitiveType> tensorListOutputOpList = { | |||
| schema::PrimitiveType_TensorListFromTensor, | |||
| schema::PrimitiveType_TensorListSetItem, | |||
| schema::PrimitiveType_TensorListReserve, | |||
| }; | |||
| // subgraph node input may be a:output:0/a:z:0 | |||
| std::string GetFlattenNodeName(std::string input_name) { | |||
| std::regex re("\\:+"); | |||
| std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1), | |||
| std::sregex_token_iterator()); | |||
| if (input_splits.size() == 3) { | |||
| if (input_splits[2] == "0") { | |||
| input_name = input_splits[0]; | |||
| } else { | |||
| input_name = input_splits[0] + input_splits[2]; // multi output node | |||
| } | |||
| } | |||
| return input_name; | |||
| } | |||
| AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) { | |||
| AnfNodePtr ret = nullptr; | |||
| if (anf_node_map.find(name) != anf_node_map.end()) { | |||
| @@ -67,10 +46,11 @@ std::string GetOriginInputName(const tensorflow::NodeDef &node, | |||
| } | |||
| auto tmp_node = &node; | |||
| while (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient") { | |||
| if (tf_graph_nodes.find(tmp_node->input(0)) == tf_graph_nodes.end()) { | |||
| return tmp_node->input(0); | |||
| auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(tmp_node->input(0)); | |||
| if (tf_graph_nodes.find(flatten_input_name) == tf_graph_nodes.end()) { | |||
| return flatten_input_name; | |||
| } | |||
| tmp_node = tf_graph_nodes.at(tmp_node->input(0)); | |||
| tmp_node = tf_graph_nodes.at(flatten_input_name); | |||
| } | |||
| return tmp_node->name(); | |||
| } | |||
| @@ -89,6 +69,10 @@ STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_ | |||
| MS_LOG(ERROR) << "Only TensorList type is supported now"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| if (variant.tensors_size() > 0) { | |||
| MS_LOG(ERROR) << "Only empty tensorlist is supported now"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| auto descriptor = variant.GetMetadata().descriptor; | |||
| auto reflection = variant.GetMetadata().reflection; | |||
| if (descriptor == nullptr || reflection == nullptr) { | |||
| @@ -232,6 +216,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value | |||
| param_value->set_tensor_type(type); | |||
| param_value->set_format(schema::Format::Format_NHWC); | |||
| parameter->set_default_param(param_value); | |||
| parameter->set_name("const_" + std::to_string(anf_root_node_map.size()) + "_parameter"); | |||
| return RET_OK; | |||
| } | |||
| @@ -263,7 +248,8 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa | |||
| return status; | |||
| } | |||
| } else { | |||
| graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names | |||
| parameter->set_name("placeholder_" + std::to_string(anf_root_node_map.size())); | |||
| graph_input_names.emplace_back(parameter->name()); // only root graph need set graph input names | |||
| } | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector); | |||
| @@ -271,12 +257,9 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa | |||
| MS_LOG(ERROR) << "abstract_tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| parameter->set_name(node.name()); | |||
| parameter->set_abstract(abstract_tensor); | |||
| (*anf_node_map)[node.name()] = parameter; | |||
| (*anf_node_map)[node.name() + ":0"] = parameter; | |||
| return RET_OK; | |||
| } | |||
| @@ -311,48 +294,43 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| tf_root_graph_ = std::make_unique<tensorflow::GraphDef>(); | |||
| if (tf_root_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "tf_root_graph_ is nullptr"; | |||
| tf_root_graph = std::make_unique<tensorflow::GraphDef>(); | |||
| if (tf_root_graph == nullptr) { | |||
| MS_LOG(ERROR) << "tf_root_graph is nullptr"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get()); | |||
| status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| anf_root_graph_ = std::make_shared<FuncGraph>(); | |||
| if (anf_root_graph_ == nullptr) { | |||
| anf_root_graph = std::make_shared<FuncGraph>(); | |||
| if (anf_root_graph == nullptr) { | |||
| MS_LOG(ERROR) << "funGraphPtr is nullptr"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| for (int i = 0; i < tf_root_graph_->node_size(); i++) { | |||
| auto &node_def = tf_root_graph_->node(i); | |||
| tf_root_graph_nodes_[node_def.name()] = &node_def; | |||
| for (int i = 0; i < tf_root_graph->node_size(); i++) { | |||
| auto &node_def = tf_root_graph->node(i); | |||
| tf_root_graph_nodes[node_def.name()] = &node_def; | |||
| } | |||
| status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); | |||
| status = ConvertGraphInputsAndConsts(tf_root_graph_nodes, anf_root_graph, &anf_root_node_map); | |||
| if (status != RET_OK) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| bool success_flag = true; | |||
| for (int i = 0; i < tf_root_graph_->node_size(); i++) { | |||
| auto &node_def = tf_root_graph_->node(i); | |||
| status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); | |||
| if (status != RET_OK) { | |||
| success_flag = false; | |||
| for (int i = 0; i < tf_root_graph->node_size(); i++) { | |||
| auto &node_def = tf_root_graph->node(i); | |||
| if (ConvertOps(node_def, tf_root_graph_nodes, anf_root_graph, &anf_root_node_map) != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert ops failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| } | |||
| if (!success_flag) { | |||
| MS_LOG(ERROR) << "Convert ops failed."; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| } | |||
| status = ConvertRootGraphOutputs(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert graph outputs failed."; | |||
| @@ -367,25 +345,25 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin | |||
| return nullptr; | |||
| } | |||
| return anf_root_graph_; | |||
| return anf_root_graph; | |||
| } | |||
| STATUS TFModelParser::ConvertSubgraph() { | |||
| auto graph_def_liarary = tf_root_graph_->library(); | |||
| auto graph_def_liarary = tf_root_graph->library(); | |||
| auto subgraph_size = graph_def_liarary.function_size(); | |||
| std::map<CNodePtr, FuncGraphPtr> while_cond_map; | |||
| std::map<CNodePtr, FuncGraphPtr> while_body_map; | |||
| std::vector<ParameterPtr> sub_graph_inputs; | |||
| for (int i = 0; i < subgraph_size; i++) { | |||
| std::vector<ParameterPtr> sub_graph_inputs; | |||
| auto &tf_sub_fuction = graph_def_liarary.function(i); | |||
| auto &tf_sub_signature = tf_sub_fuction.signature(); | |||
| auto input_arg_size = tf_sub_signature.input_arg_size(); | |||
| auto &sub_graph_name = tf_sub_signature.name(); | |||
| if (!function_while_map_.count(sub_graph_name)) { | |||
| if (!function_while_map.count(sub_graph_name)) { | |||
| MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name; | |||
| return RET_ERROR; | |||
| } | |||
| auto while_cnode = function_while_map_[sub_graph_name]->cast<CNodePtr>(); | |||
| auto while_cnode = function_while_map[sub_graph_name]->cast<CNodePtr>(); | |||
| if (while_cnode == nullptr || static_cast<int>(while_cnode->inputs().size()) != input_arg_size + 1) { | |||
| MS_LOG(ERROR) << "while cnode not equal input arg size"; | |||
| return RET_ERROR; | |||
| @@ -426,9 +404,16 @@ STATUS TFModelParser::ConvertSubgraph() { | |||
| // convert subgraph outputs | |||
| std::vector<AnfNodePtr> sub_output_nodes; | |||
| auto &subgraph_ret = tf_sub_fuction.ret(); | |||
| for (auto &t : subgraph_ret) { | |||
| MS_LOG(INFO) << "subret " << t.first << " " << t.second; | |||
| auto tf_output_name = GetFlattenNodeName(t.second); | |||
| auto &output_args = tf_sub_signature.output_arg(); | |||
| for (auto &output_arg : output_args) { | |||
| auto &signature_name = output_arg.name(); | |||
| if (subgraph_ret.find(signature_name) == subgraph_ret.end()) { | |||
| MS_LOG(ERROR) << "can't find signature_name: " << signature_name; | |||
| return RET_ERROR; | |||
| } | |||
| auto t = subgraph_ret.find(signature_name); | |||
| MS_LOG(INFO) << "subret " << t->first << " " << t->second; | |||
| auto tf_output_name = TensorFlowUtils::GetFlattenNodeName(t->second); | |||
| AnfNodePtr anf_node = nullptr; | |||
| if (tf_sub_node_map.find(tf_output_name) == tf_sub_node_map.end()) { | |||
| anf_node = GetAnfNode(tf_output_name, anf_sub_node_map); | |||
| @@ -456,7 +441,7 @@ STATUS TFModelParser::ConvertSubgraph() { | |||
| } | |||
| // hardcode subgraph inputs name | |||
| for (size_t j = 0; j < sub_graph_inputs.size(); j++) { | |||
| sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter"); | |||
| sub_graph_inputs[j]->set_name("graph_input_" + std::to_string(j) + "parameter"); | |||
| } | |||
| MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; | |||
| } | |||
| @@ -473,9 +458,9 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr | |||
| MS_LOG(ERROR) << "while cond body size error"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<FuncGraphPtr> roots = {anf_root_graph_}; | |||
| std::vector<FuncGraphPtr> roots = {anf_root_graph}; | |||
| auto root_func_manager = std::make_shared<FuncGraphManager>(roots); | |||
| anf_root_graph_->set_manager(root_func_manager); | |||
| anf_root_graph->set_manager(root_func_manager); | |||
| for (auto &kv : while_cond_map) { | |||
| auto while_node = kv.first; | |||
| auto &cond_sub_graph = kv.second; | |||
| @@ -484,12 +469,9 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr | |||
| body_sub_graph->set_manager(root_func_manager); | |||
| auto cond_value_node = NewValueNode(cond_sub_graph); | |||
| auto body_value_node = NewValueNode(body_sub_graph); | |||
| auto new_while_inputs = while_node->cast<CNodePtr>()->inputs(); | |||
| new_while_inputs[0] = cond_value_node; | |||
| new_while_inputs.insert(new_while_inputs.begin() + 1, body_value_node); | |||
| auto new_while_node = anf_root_graph_->NewCNode(new_while_inputs); | |||
| new_while_node->set_abstract(while_node->abstract()); | |||
| root_func_manager->Replace(while_node, new_while_node); | |||
| auto inputs = while_node->inputs(); | |||
| inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node}); | |||
| while_node->set_inputs(inputs); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -510,7 +492,7 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, | |||
| for (size_t j = 0; j < input_names.size(); j++) { | |||
| std::string input_name = input_names[j]; // input may be produced by multi-outputs node | |||
| // subgraph input name x:output:index,need flatten | |||
| auto flatten_input_name = GetFlattenNodeName(input_name); | |||
| auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(input_name); | |||
| if (tf_node_map.find(flatten_input_name) != tf_node_map.end()) { | |||
| auto input_node = tf_node_map.at(flatten_input_name); | |||
| flatten_input_name = GetOriginInputName(*input_node, tf_node_map); | |||
| @@ -531,20 +513,10 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C | |||
| MS_ASSERT(op != nullptr); | |||
| MS_ASSERT(anf_node != nullptr); | |||
| MS_ASSERT(anf_graph != nullptr); | |||
| if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node)) && output_size != 1) { | |||
| MS_LOG(ERROR) << "tensorlist output op output_size !=1"; | |||
| return RET_ERROR; | |||
| } | |||
| if (output_size == 0) { | |||
| return RET_OK; | |||
| } else if (output_size == 1) { | |||
| auto type = kFloat32; | |||
| if (output_size == 1) { | |||
| std::vector<int64_t> shape_vector; | |||
| if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) { | |||
| type = TypeIdToType(kObjectTypeTensorType); | |||
| } | |||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vector)); | |||
| anf_node_map->insert(std::pair(op.name(), anf_node)); | |||
| anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector)); | |||
| anf_node_map->insert(std::pair(op.name() + ":0", anf_node)); | |||
| } else { | |||
| AbstractBasePtrList abstractList; | |||
| for (int output_idx = 0; output_idx < output_size; output_idx++) { | |||
| @@ -608,17 +580,17 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, | |||
| // control_depends are not processed currently | |||
| auto anf_node = func_graph_ptr->NewCNode(inputs); | |||
| anf_node->set_fullname_with_scope(node_def.name()); | |||
| if (op_type == "StatelessWhile" || op_type == "while") { | |||
| if (op_type == "StatelessWhile" || op_type == "While") { | |||
| MS_LOG(INFO) << "find while node:" << node_def.name(); | |||
| tensorflow::AttrValue attr_value; | |||
| if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) { | |||
| auto body_name = attr_value.func().name(); | |||
| function_while_map_[body_name] = anf_node; | |||
| function_while_map[body_name] = anf_node; | |||
| MS_LOG(DEBUG) << "parse body name:" << body_name; | |||
| } | |||
| if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) { | |||
| auto cond_name = attr_value.func().name(); | |||
| function_while_map_[cond_name] = anf_node; | |||
| function_while_map[cond_name] = anf_node; | |||
| MS_LOG(DEBUG) << "parse cond name:" << cond_name; | |||
| } | |||
| } | |||
| @@ -634,28 +606,31 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, | |||
| STATUS TFModelParser::ConvertRootGraphOutputs() { | |||
| // because output of intermediate node in anf graph may also be output tensors, we search output tensors in | |||
| // tf_root_graph_nodes_ but not anf_root_node_map_ | |||
| // tf_root_graph_nodes but not anf_root_node_map | |||
| std::set<std::string> all_node_inputs; | |||
| std::vector<AnfNodePtr> output_nodes; | |||
| for (auto &pair : tf_root_graph_nodes_) { | |||
| for (auto &pair : tf_root_graph_nodes) { | |||
| for (int i = 0; i < pair.second->input_size(); ++i) { | |||
| all_node_inputs.insert(pair.second->input(i)); | |||
| all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i))); | |||
| } | |||
| } | |||
| for (auto &pair : tf_root_graph_nodes_) { | |||
| for (auto &pair : tf_root_graph_nodes) { | |||
| if (pair.second->op() == "Assert") { | |||
| continue; | |||
| } | |||
| auto it = all_node_inputs.find(pair.first); | |||
| if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity | |||
| auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); | |||
| auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); | |||
| auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes); | |||
| auto anf_node = GetAnfNode(origin_name, anf_root_node_map); | |||
| if (anf_node == nullptr) { | |||
| MS_LOG(ERROR) << "can't find anf node"; | |||
| return RET_ERROR; | |||
| } | |||
| output_nodes.push_back(anf_node); | |||
| graph_output_names_.push_back(anf_node->fullname_with_scope()); | |||
| graph_output_names.push_back(anf_node->fullname_with_scope()); | |||
| } | |||
| } | |||
| auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); | |||
| auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "make anf graph outputs node error"; | |||
| return status; | |||
| @@ -71,13 +71,13 @@ class TFModelParser : public ModelParser { | |||
| STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph); | |||
| FuncGraphPtr anf_root_graph_; | |||
| std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def | |||
| std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map | |||
| std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_; | |||
| std::vector<std::string> graph_input_names_; | |||
| std::vector<std::string> graph_output_names_; | |||
| std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name | |||
| FuncGraphPtr anf_root_graph; | |||
| std::unique_ptr<tensorflow::GraphDef> tf_root_graph; // tf root graph def | |||
| std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes; // tf root graph node map | |||
| std::unordered_map<std::string, AnfNodePtr> anf_root_node_map; | |||
| std::vector<std::string> graph_input_names; | |||
| std::vector<std::string> graph_output_names; | |||
| std::map<std::string, AnfNodePtr> function_while_map; // tf function name->while_node_name | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -17,6 +17,8 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| using tensorflow::NodeDef; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFNodeParser::AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector<std::string> *inputs) { | |||
| @@ -27,5 +29,19 @@ STATUS TFNodeParser::AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, | |||
| inputs->push_back(tf_op.input(idx)); | |||
| return RET_OK; | |||
| } | |||
| const NodeDef *TFNodeParser::GetConstInputNode(const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| const std::string &input_name) { | |||
| auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(input_name); | |||
| if (tf_node_map.find(flatten_input_name) == tf_node_map.end()) { | |||
| return nullptr; | |||
| } | |||
| auto node = tf_node_map.at(flatten_input_name); | |||
| if (node->op() != "Const") { | |||
| MS_LOG(ERROR) << "Attr node is not Const"; | |||
| return nullptr; | |||
| } | |||
| return node; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -40,6 +40,9 @@ class TFNodeParser { | |||
| } | |||
| STATUS AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector<std::string> *inputs); | |||
| const tensorflow::NodeDef *GetConstInputNode(const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| const std::string &input_name); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -69,11 +69,11 @@ STATUS TFReduceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| attr->keepDims = attr_value.b(); | |||
| if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { | |||
| auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| if (axis_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find Reduce input axis failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto axis_node = tf_node_map.at(tf_op.input(1)); | |||
| if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The value attr should be specified"; | |||
| return RET_ERROR; | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WRRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/tf/tf_reverse_sequence_parser.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) { | |||
| MS_LOG(INFO) << "TF ReverseSequenceParser"; | |||
| if (primitiveC == nullptr || output_size == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (primitive == nullptr) { | |||
| MS_LOG(ERROR) << "New PrimitiveT failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto attr = std::make_unique<schema::ReverseSequenceT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new attr failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "batch_dim", &attr_value)) { | |||
| MS_LOG(ERROR) << "The batch_dim attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->batchAxis = attr_value.i(); | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "seq_dim", &attr_value)) { | |||
| MS_LOG(ERROR) << "The seq_dim attr should be specified"; | |||
| return RET_ERROR; | |||
| } | |||
| attr->seqAxis = attr_value.i(); | |||
| primitive->value.type = schema::PrimitiveType_ReverseSequence; | |||
| primitive->value.value = attr.release(); | |||
| *primitiveC = PrimitiveC::Create(primitive.release()); | |||
| if (*primitiveC == nullptr) { | |||
| MS_LOG(ERROR) << "primitiveC is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| *output_size = 1; | |||
| return AddOpInput(tf_op, 0, inputs); | |||
| } | |||
| TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_SEQUENCE_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_SEQUENCE_PARSER_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <vector> | |||
| #include "tools/converter/parser/tf/tf_node_parser.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TFReverseSequenceParser : public TFNodeParser { | |||
| public: | |||
| TFReverseSequenceParser() = default; | |||
| ~TFReverseSequenceParser() override = default; | |||
| STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map, | |||
| PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_SEQUENCE_PARSER_H_ | |||
| @@ -58,11 +58,11 @@ STATUS TFSplitParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| input_index = 0; | |||
| } | |||
| if (tf_node_map.find(tf_op.input(split_dim_index)) == tf_node_map.end()) { | |||
| auto split_dim_node = GetConstInputNode(tf_node_map, tf_op.input(split_dim_index)); | |||
| if (split_dim_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find Split input split_dim node failed"; | |||
| return RET_ERROR; | |||
| } | |||
| const auto &split_dim_node = tf_node_map.at(tf_op.input(split_dim_index)); | |||
| if (!TensorFlowUtils::FindAttrValue(*split_dim_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The attribute splitDim should be specified"; | |||
| return RET_PARAM_INVALID; | |||
| @@ -72,11 +72,11 @@ STATUS TFSplitParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| *output_size = attr->numberSplit; | |||
| if (tf_op.op() == "SplitV") { | |||
| if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { | |||
| auto size_splits_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| if (size_splits_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find Split input size_splits failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto size_splits_node = tf_node_map.at(tf_op.input(1)); | |||
| if (!TensorFlowUtils::FindAttrValue(*size_splits_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The attribute size splits should be specified"; | |||
| return RET_PARAM_INVALID; | |||
| @@ -74,11 +74,11 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| attr->shrinkAxisMask = attr_value.i(); | |||
| // begin | |||
| if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { | |||
| auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| if (begin_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find StridedSlice input begin failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto begin_node = tf_node_map.at(tf_op.input(1)); | |||
| if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The value attr should be specified"; | |||
| return RET_ERROR; | |||
| @@ -97,11 +97,11 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| // end | |||
| if (tf_node_map.find(tf_op.input(2)) == tf_node_map.end()) { | |||
| auto end_node = GetConstInputNode(tf_node_map, tf_op.input(2)); | |||
| if (end_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find StridedSlice input end failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto end_node = tf_node_map.at(tf_op.input(2)); | |||
| if (!TensorFlowUtils::FindAttrValue(*end_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The value attr should be specified"; | |||
| return RET_ERROR; | |||
| @@ -120,11 +120,11 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| } | |||
| // strides | |||
| if (tf_node_map.find(tf_op.input(3)) == tf_node_map.end()) { | |||
| auto stride_node = GetConstInputNode(tf_node_map, tf_op.input(3)); | |||
| if (stride_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find StridedSlice input strides failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto stride_node = tf_node_map.at(tf_op.input(3)); | |||
| if (!TensorFlowUtils::FindAttrValue(*stride_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The value attr should be specified"; | |||
| return RET_ERROR; | |||
| @@ -42,11 +42,11 @@ STATUS TFTileParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { | |||
| auto multiplies_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| if (multiplies_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find Tile input multiplies failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto multiplies_node = tf_node_map.at(tf_op.input(1)); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(*multiplies_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The value attr should be specified"; | |||
| @@ -42,11 +42,12 @@ STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) { | |||
| attr->conjugate = false; | |||
| auto perm_node = GetConstInputNode(tf_node_map, tf_op.input(1)); | |||
| if (perm_node == nullptr) { | |||
| MS_LOG(ERROR) << "Find Transpose input perm failed"; | |||
| return RET_ERROR; | |||
| } | |||
| auto perm_node = tf_node_map.at(tf_op.input(1)); | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(*perm_node, "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "The value attr should be specified"; | |||
| @@ -16,8 +16,10 @@ | |||
| #include "tools/converter/parser/tf/tf_util.h" | |||
| #include <string> | |||
| #include <vector> | |||
| #include <string_view> | |||
| #include <unordered_map> | |||
| #include <regex> | |||
| #include "src/common/log_adapter.h" | |||
| #include "schema/inner/model_generated.h" | |||
| @@ -112,5 +114,32 @@ bool TensorFlowUtils::DecodeInt64(std::string_view *str_view, uint64_t *value) { | |||
| return true; | |||
| } | |||
| } | |||
| // convert input_arg in subgraph to node_name[:index] format | |||
| std::string TensorFlowUtils::GetFlattenNodeName(const std::string &input_name) { | |||
| std::regex re("\\:+"); | |||
| std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1), | |||
| std::sregex_token_iterator()); | |||
| std::string ret = input_name; | |||
| if (input_splits.size() == 3) { | |||
| if (input_splits[2] == "0") { | |||
| ret = input_splits[0]; | |||
| } else { | |||
| ret = input_splits[0] + input_splits[2]; // multi output node | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| // get referenced node name from input name | |||
| std::string TensorFlowUtils::GetNodeName(const std::string &input_name) { | |||
| std::regex re("\\:+"); | |||
| std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1), | |||
| std::sregex_token_iterator()); | |||
| if (input_splits.size() > 1) { | |||
| return input_splits[0]; | |||
| } | |||
| return input_name; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -34,6 +34,8 @@ class TensorFlowUtils { | |||
| static TypeId ParseAttrDataType(const tensorflow::NodeDef &node_def, const std::string &attr_name); | |||
| static schema::Format ParseNodeFormat(const tensorflow::NodeDef &node_def); | |||
| static bool DecodeInt64(std::string_view *str_view, uint64_t *value); | |||
| static std::string GetFlattenNodeName(const std::string &input_name); | |||
| static std::string GetNodeName(const std::string &input_name); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -101,6 +101,7 @@ bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) { | |||
| op_inputs.push_back(clip_cnode->input(1)); | |||
| auto new_cnode = graph->NewCNode(op_inputs); | |||
| new_cnode->set_fullname_with_scope(node->fullname_with_scope()); | |||
| new_cnode->set_abstract(clip_cnode->abstract()->Clone()); | |||
| manager->Replace(node, new_cnode); | |||
| } | |||
| return false; | |||
| @@ -121,7 +121,7 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l | |||
| } | |||
| if (utils::isa<ValueNodePtr>(cnode->input(i))) { | |||
| MS_LOG(ERROR) << "input is value node"; | |||
| MS_LOG(WARNING) << "input is value node"; | |||
| continue; | |||
| } | |||
| @@ -178,13 +178,10 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l | |||
| } | |||
| } else { | |||
| int *data = reinterpret_cast<int *>(param_value->tensor_addr()); | |||
| auto tensor_list = dynamic_cast<lite::TensorList *>(tensor.get()); | |||
| tensor_list->set_tensors_data_type(TypeId(data[0])); | |||
| std::vector<int> shape; | |||
| for (int j = 0; j < data[1]; ++j) { | |||
| shape.push_back(data[2 + j]); | |||
| auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor.get()); | |||
| if (tensor_list->Decode(data) != RET_OK) { | |||
| return RET_ERROR; | |||
| } | |||
| tensor_list->set_element_shape(shape); | |||
| } | |||
| } | |||
| } | |||
| @@ -210,8 +207,8 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector< | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(element); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| types.push_back(typePtr->type_id()); | |||
| auto type_ptr = abstract_tensor->element()->GetTypeTrack(); | |||
| types.push_back(type_ptr->type_id()); | |||
| } | |||
| } else { | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | |||
| @@ -219,8 +216,8 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector< | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| types.push_back(typePtr->type_id()); | |||
| auto type_ptr = abstract_tensor->element()->GetTypeTrack(); | |||
| types.push_back(type_ptr->type_id()); | |||
| } | |||
| for (auto &type : types) { | |||
| std::unique_ptr<lite::Tensor> output_tensor = nullptr; | |||