Browse Source

!9907 runtime support tensorlist

From: @wangzhe128
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e54fed511f
32 changed files with 445 additions and 217 deletions
  1. +28
    -17
      mindspore/lite/src/lite_session.cc
  2. +37
    -19
      mindspore/lite/src/ops/tensorlistsetitem.cc
  3. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/TensorListFromTensor.cc
  4. +5
    -8
      mindspore/lite/src/runtime/kernel/arm/fp32/TensorListSetItem.cc
  5. +2
    -3
      mindspore/lite/src/runtime/kernel/arm/fp32/TensorListStack.cc
  6. +5
    -0
      mindspore/lite/src/scheduler.cc
  7. +13
    -1
      mindspore/lite/src/tensorlist.cc
  8. +4
    -2
      mindspore/lite/src/tensorlist.h
  9. +26
    -3
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  10. +52
    -25
      mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc
  11. +1
    -2
      mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc
  12. +1
    -1
      mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc
  13. +4
    -0
      mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc
  14. +3
    -3
      mindspore/lite/tools/converter/parser/tf/tf_concat_parser.cc
  15. +2
    -2
      mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc
  16. +2
    -2
      mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.cc
  17. +2
    -2
      mindspore/lite/tools/converter/parser/tf/tf_gather_parser.cc
  18. +66
    -91
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  19. +7
    -7
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.h
  20. +16
    -0
      mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc
  21. +3
    -0
      mindspore/lite/tools/converter/parser/tf/tf_node_parser.h
  22. +2
    -2
      mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.cc
  23. +70
    -0
      mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc
  24. +37
    -0
      mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.h
  25. +4
    -4
      mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc
  26. +6
    -6
      mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc
  27. +2
    -2
      mindspore/lite/tools/converter/parser/tf/tf_tile_parser.cc
  28. +3
    -2
      mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc
  29. +29
    -0
      mindspore/lite/tools/converter/parser/tf/tf_util.cc
  30. +2
    -0
      mindspore/lite/tools/converter/parser/tf/tf_util.h
  31. +1
    -0
      mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc
  32. +8
    -11
      mindspore/lite/tools/optimizer/graph/infershape_pass.cc

+ 28
- 17
mindspore/lite/src/lite_session.cc View File

@@ -89,27 +89,38 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
auto data_type = src_tensor->dataType();
if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) &&
src_tensor->data() != nullptr && src_tensor->data()->size() > 0) {
MS_ASSERT(dst_tensor->Size() == src_tensor->data()->size());
if (WeightTensorNeedCopy(model, tensor_index)) {
auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) {
MS_LOG(ERROR) << "Data from tensor is nullptr";
return RET_NULL_PTR;
if (src_tensor->dataType() == kObjectTypeTensorType) {
auto tensor_list = reinterpret_cast<TensorList *>(dst_tensor);
if (src_tensor->data() == nullptr) {
MS_LOG(ERROR) << "src_tensor->data() is nullptr";
return RET_ERROR;
}
if (tensor_list->Decode(reinterpret_cast<const int *>(src_tensor->data()->data())) != RET_OK) {
return RET_ERROR;
}
memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size());
copyed_tensor_idxes_.emplace_back(tensor_index);
} else {
int pack_size = src_tensor->data()->size();
int org_size = dst_tensor->Size();
if (pack_size != org_size && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16)) {
auto ret = dst_tensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc data for tensor failed ";
return RET_ERROR;
MS_ASSERT(dst_tensor->Size() == src_tensor->data()->size());
if (WeightTensorNeedCopy(model, tensor_index)) {
auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) {
MS_LOG(ERROR) << "Data from tensor is nullptr";
return RET_NULL_PTR;
}
kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData());
memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size());
copyed_tensor_idxes_.emplace_back(tensor_index);
} else {
dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data()));
int pack_size = src_tensor->data()->size();
int org_size = dst_tensor->Size();
if (pack_size != org_size && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16)) {
auto ret = dst_tensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc data for tensor failed ";
return RET_ERROR;
}
kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData());
} else {
dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data()));
}
}
}
}


+ 37
- 19
mindspore/lite/src/ops/tensorlistsetitem.cc View File

@@ -97,6 +97,21 @@ int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
MS_ASSERT(input0 != nullptr);
auto get_index = inputs_[1];
MS_ASSERT(get_index != nullptr);
auto value_tensor = inputs_[2];
MS_ASSERT(value_tensor != nullptr);
auto output0 = reinterpret_cast<TensorList *>(outputs_[0]);
MS_ASSERT(output0 != nullptr);

output0->set_data_type(input0->data_type());
output0->set_format(input0->format());

if (!infer_flag()) {
return RET_INFER_INVALID;
}
if (get_index->data_c() == nullptr || value_tensor->data_c() == nullptr) {
return RET_INFER_INVALID;
}

if (get_index->data_type() != kNumberTypeInt && get_index->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "inputs_[1]->data_type():" << get_index->data_type() << " is not int";
return RET_ERROR;
@@ -110,31 +125,34 @@ int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
return RET_NULL_PTR;
}
int index = reinterpret_cast<int *>(get_index->data_c())[0];
if (index < 0 || index > (input0->ElementsNum() - 1)) {
MS_LOG(ERROR) << "index_:" << index << "must in [0, " << input0->ElementsNum() - 1 << "]";
if (index < 0 || (index >= static_cast<int>(input0->tensors().size()) && index != 0)) {
MS_LOG(ERROR) << "index_:" << index << "must in [0, " << input0->tensors().size() << "]";
return RET_ERROR;
}
auto value_tensor = inputs_[2];
MS_ASSERT(value_tensor != nullptr);
auto output0 = reinterpret_cast<TensorList *>(outputs_[0]);
MS_ASSERT(output0 != nullptr);
output0->set_element_shape(input0->element_shape());

output0->set_max_elements_num(input0->max_elements_num());
output0->set_shape(input0->shape());
output0->set_data_type(input0->data_type());
output0->set_element_shape(input0->element_shape());

std::vector<std::vector<int> > out_shape;
for (int i = 0; i < input0->ElementsNum(); ++i) {
auto src_ptr = input0->GetTensorIndex(i);
if (src_ptr == nullptr) {
MS_LOG(ERROR) << "input0->tensors_[" << i << "] is nullptr!";
return RET_ERROR;
}
if (src_ptr->data_type() != kTypeUnknown) {
out_shape.push_back(src_ptr->shape());
} else {
out_shape.push_back(std::vector<int>());
if (index == 0 && input0->tensors().size() == 0) { // uninitialized tensorlist
out_shape.push_back(value_tensor->shape());
output0->set_shape(std::vector<int>{1});
} else {
output0->set_shape(input0->shape());
for (int i = 0; i < input0->ElementsNum(); ++i) {
auto src_ptr = input0->GetTensorIndex(i);
if (src_ptr == nullptr) {
MS_LOG(ERROR) << "input0->tensors_[" << i << "] is nullptr!";
return RET_ERROR;
}
if (src_ptr->data_type() != kTypeUnknown) {
out_shape.push_back(src_ptr->shape());
} else {
out_shape.push_back(std::vector<int>());
}
}
}

out_shape[index] = value_tensor->shape();
output0->MallocTensorListData(input0->tensors_data_type(), out_shape);
return RET_OK;


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/TensorListFromTensor.cc View File

@@ -28,8 +28,8 @@ using mindspore::schema::PrimitiveType_TensorListFromTensor;
namespace mindspore::kernel {

int TensorListFromTensorCPUKernel::IsCompatibleShape() {
if (input1_->data_type() != kNumberTypeInt) { // element_shape
MS_LOG(ERROR) << "in_tensors_[1] data type is must be \"kNumberTypeInt\", but now is:" << input1_->data_type();
if (input1_->data_type() != kNumberTypeInt && input1_->data_type() != kNumberTypeInt32) { // element_shape
MS_LOG(ERROR) << "in_tensors_[1] data type is must be int";
return RET_ERROR;
}
int in1_ele_num = input1_->ElementsNum();


+ 5
- 8
mindspore/lite/src/runtime/kernel/arm/fp32/TensorListSetItem.cc View File

@@ -28,16 +28,17 @@ using mindspore::schema::PrimitiveType_TensorListSetItem;

namespace mindspore::kernel {

int TensorListSetItemCPUKernel::Init() {
int TensorListSetItemCPUKernel::Init() { return RET_OK; }

int TensorListSetItemCPUKernel::Run() {
input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
if (dtype_ != input0_->data_type()) {
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type();
return RET_ERROR;
}
int dim0 = input0_->ElementsNum() - 1;
if (in_tensors_[1]->data_type() != kNumberTypeInt) {
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type()
<< " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt;
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int";
return RET_ERROR;
}
if (in_tensors_[1]->ElementsNum() != 1) {
@@ -54,10 +55,6 @@ int TensorListSetItemCPUKernel::Init() {
if (!input0_->IsCompatibleShape(input2_->shape())) {
return RET_ERROR;
}
return RET_OK;
}

int TensorListSetItemCPUKernel::Run() {
output0_ = reinterpret_cast<lite::TensorList *>(out_tensors_[0]);
MS_ASSERT(output0_ != nullptr);
// copy each tensor in tensors_


+ 2
- 3
mindspore/lite/src/runtime/kernel/arm/fp32/TensorListStack.cc View File

@@ -73,9 +73,8 @@ bool TensorListStackCPUKernel::IsFullyDefined(const std::vector<int> &shape) con

int TensorListStackCPUKernel::MergeElementShape() {
MS_ASSERT(in_tensors_[1]);
if (in_tensors_[1]->data_type() != kNumberTypeInt) {
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type()
<< " must be \"kNumberTypeInt\":" << kNumberTypeInt;
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int";
return RET_ERROR;
}
auto ele_shape_data = reinterpret_cast<int *>(in_tensors_[1]->data_c());


+ 5
- 0
mindspore/lite/src/scheduler.cc View File

@@ -19,6 +19,7 @@
#include <queue>
#include <string>
#include <vector>
#include "src/tensorlist.h"
#include "src/ops/partial.h"
#include "include/errorcode.h"
#include "src/common/graph_util.h"
@@ -426,6 +427,10 @@ TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_ten
if (dtype == kObjectTypeString) {
return kNumberTypeFloat32;
}
if (dtype == kObjectTypeTensorType) {
auto tensor_list = reinterpret_cast<TensorList *>(tensor);
return tensor_list->tensors_data_type();
}
if (dtype == kNumberTypeFloat32 || dtype == kNumberTypeFloat16 || dtype == kNumberTypeInt8 ||
dtype == kNumberTypeInt32 || dtype == kNumberTypeBool) {
return dtype;


+ 13
- 1
mindspore/lite/src/tensorlist.cc View File

@@ -204,7 +204,7 @@ int TensorList::CheckTensorListParam() {

Tensor *TensorList::GetTensorIndex(int index) {
// return tensor[index] ptr. With this function, you can modify tensors_[index] at will.
if (index < 0 || index > (this->ElementsNum() - 1)) {
if (index < 0 || index >= static_cast<int>(tensors_.size())) {
MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
return nullptr;
}
@@ -240,5 +240,17 @@ bool TensorList::IsCompatibleShape(const Tensor *src) {
}
return true;
}

STATUS TensorList::Decode(const int *data) {
if (data == nullptr) {
MS_LOG(ERROR) << "data is nullptr";
return RET_ERROR;
}
tensors_data_type_ = TypeId(data[0]);
for (int j = 0; j < data[1]; ++j) {
element_shape_.push_back(data[2 + j]);
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 4
- 2
mindspore/lite/src/tensorlist.h View File

@@ -60,6 +60,8 @@ class TensorList : public Tensor {
public:
TensorList() = default;

TensorList(std::vector<int> shape, std::vector<int> element_shape);

~TensorList() override;

// **Note**: This is a shallow copy, src and dst tensorlist share one memory space of each tensor in tensors_
@@ -74,8 +76,6 @@ class TensorList : public Tensor {
// tensorlist deep copy memory
TensorList &operator=(const TensorList &tl);

TensorList(std::vector<int> shape, std::vector<int> element_shape);

void set_element_shape(const std::vector<int> &shape) { element_shape_ = shape; }

std::vector<int> &element_shape() { return element_shape_; }
@@ -112,6 +112,8 @@ class TensorList : public Tensor {

bool IsCompatibleShape(const Tensor *src);

STATUS Decode(const int *data);

protected:
// The following functions must be masked.
void set_data(void *data) override { return; }


+ 26
- 3
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -572,7 +572,12 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s

if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
for (size_t i = 0; i < tuple->size(); i++) {
if (tuple == nullptr) {
MS_LOG(ERROR) << "tuple is nullptr";
return;
}
auto elements = tuple->elements();
for (size_t i = 0; i < elements.size(); i++) {
auto msTensor = new (std::nothrow) schema::TensorT();
if (msTensor == nullptr) {
MS_LOG(ERROR) << "new msTensor failed";
@@ -589,7 +594,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam))
break;
#else
if (tuple->size() == 1) {
if (elements.size() == 1) {
node_id_map_[cnode_name] = meta_graphT->allTensors.size();
msTensor->name = cnode_name;
} else {
@@ -597,6 +602,18 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
node_id_map_[name] = meta_graphT->allTensors.size();
msTensor->name = name;
}

if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
MS_LOG(ERROR) << "abstract is not AbstractTensor";
return;
}
auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
type = typePtr->type_id();
}
msTensor->dataType = type;
meta_graphT->allTensors.emplace_back(msTensor);
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
@@ -611,8 +628,14 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
MS_LOG(ERROR) << "new tensor failed";
return;
}
auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
auto typePtr = abstract_tensor->element()->GetTypeTrack();
type = typePtr->type_id();
}
ms_tensor->dataType = type;
ms_tensor->nodeType = schema::NodeType_CNode;
ms_tensor->dataType = TypeId::kNumberTypeFloat32;
ms_tensor->name = cnode_name;
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
node_id_map_[cnode_name] = meta_graphT->allTensors.size();


+ 52
- 25
mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc View File

@@ -19,6 +19,7 @@
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
#include "src/tensor.h"
#include "src/tensorlist.h"
#include "src/ops/primitive_c.h"

using mindspore::lite::PrimitiveC;
@@ -50,32 +51,58 @@ std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::ve
std::vector<Tensor *> lite_tensors;
bool convert_succ = true;
for (size_t i = 0; i < tensor_indexs.size(); i++) {
std::unique_ptr<Tensor> lite_tensor = nullptr;
auto &tensorT = graph->allTensors.at(tensor_indexs[i]);
auto tensor_shape = tensorT->dims;
auto lite_tensor = std::make_unique<Tensor>(
TypeId(tensorT->dataType), tensor_shape, tensorT->format,
TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size()));
if (lite_tensor == nullptr) {
MS_LOG(ERROR) << "lite tensor is nullptr";
convert_succ = false;
break;
}
auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
// when tensorT as param input
if (lite_tensor_size == 0) {
lite_tensors.emplace_back(lite_tensor.release());
continue;
}
auto ret = lite_tensor->MallocData();
if (ret != 0) {
MS_LOG(ERROR) << "Malloc tensor data failed";
convert_succ = false;
break;
}
if (memcpy_s(lite_tensor->MutableData(), lite_tensor->Size(), tensorT->data.data(), lite_tensor_size) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
convert_succ = false;
break;
if (tensorT->dataType != kObjectTypeTensorType) { // convert to lite::Tensor
auto tensor_shape = tensorT->dims;
lite_tensor = std::make_unique<Tensor>(
TypeId(tensorT->dataType), tensor_shape, tensorT->format,
TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size()));
if (lite_tensor == nullptr) {
MS_LOG(ERROR) << "lite tensor is nullptr";
convert_succ = false;
break;
}
auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
// when tensorT as param input
if (lite_tensor_size == 0) {
lite_tensors.emplace_back(lite_tensor.release());
continue;
}
auto ret = lite_tensor->MallocData();
if (ret != 0) {
MS_LOG(ERROR) << "Malloc tensor data failed";
convert_succ = false;
break;
}
if (memcpy_s(lite_tensor->MutableData(), lite_tensor->Size(), tensorT->data.data(), lite_tensor_size) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed";
convert_succ = false;
break;
}
} else { // convert to lite::TensorList
auto tensor_shape = tensorT->dims;
TypeId type = kTypeUnknown;
std::vector<int> element_shape;
if (!tensorT->data.empty()) {
int *data = reinterpret_cast<int *>(tensorT->data.data());
type = TypeId(data[0]);
if (tensorT->data.size() < 8 || (data[1] + 2) * 4 != static_cast<int>(tensorT->data.size())) {
MS_LOG(ERROR) << "tensorlist data length illegal";
convert_succ = false;
break;
}
for (int j = 0; j < data[1]; ++j) {
element_shape.push_back(data[j + 2]);
}
}
lite_tensor = std::make_unique<TensorList>(tensor_shape, element_shape);
if (lite_tensor == nullptr) {
MS_LOG(ERROR) << "lite tensorlist is nullptr";
convert_succ = false;
break;
}
reinterpret_cast<TensorList *>(lite_tensor.get())->set_tensors_data_type(type);
}
lite_tensors.emplace_back(lite_tensor.release());
}


+ 1
- 2
mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc View File

@@ -20,7 +20,6 @@
#include <vector>

namespace mindspore::lite {
constexpr int32_t kSingleGroup = 1;
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr,
schema::PrimitiveT *primitive) {
MS_LOG(DEBUG) << "onnx DepthwiseConvParser";
@@ -175,7 +174,7 @@ lite::PrimitiveC *OnnxConvParser::ParseLitePrimitive(const onnx::GraphProto &onn
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
if (attr->group > kSingleGroup && attr->group == attr->channelIn) {
if (attr->group == attr->channelIn && attr->channelIn == attr->channelOut) {
if (!ParseGroupConvolution(attr, primitive.get())) {
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
return nullptr;


+ 1
- 1
mindspore/lite/tools/converter/parser/tf/tf_arithmetic_parser.cc View File

@@ -37,7 +37,7 @@ STATUS TFArithmeticParser::Parse(const tensorflow::NodeDef &tf_op,
return RET_NULL_PTR;
}

if (tf_op.op() == "Add") {
if (tf_op.op() == "Add" || tf_op.op() == "AddV2") {
auto attr = std::make_unique<schema::AddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";


+ 4
- 0
mindspore/lite/tools/converter/parser/tf/tf_biasadd_parser.cc View File

@@ -54,6 +54,10 @@ STATUS TFBiasAddParser::Parse(const tensorflow::NodeDef &tf_op,

*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
status = AddOpInput(tf_op, 1, inputs);
return status;
}
TFNodeRegistrar g_tfBiasAddParser("BiasAdd", new TFBiasAddParser());


+ 3
- 3
mindspore/lite/tools/converter/parser/tf/tf_concat_parser.cc View File

@@ -42,11 +42,11 @@ STATUS TFConcatParser::Parse(const tensorflow::NodeDef &tf_op,
return RET_NULL_PTR;
}

if (tf_node_map.find(tf_op.input(tf_op.input_size() - 1)) == tf_node_map.end()) {
MS_LOG(ERROR) << "Find Concat input axis failed";
auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(tf_op.input_size() - 1));
if (axis_node == nullptr) {
MS_LOG(ERROR) << "get concat axis attr node failed";
return RET_ERROR;
}
auto axis_node = tf_node_map.at(tf_op.input(tf_op.input_size() - 1));
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";


+ 2
- 2
mindspore/lite/tools/converter/parser/tf/tf_conv_parser.cc View File

@@ -66,11 +66,11 @@ STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
attr->strideH = strides[0];
attr->strideW = strides[1];

if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (weight_node == nullptr) {
MS_LOG(ERROR) << "Find Conv2D input weights failed";
return RET_ERROR;
}
auto weight_node = tf_node_map.at(tf_op.input(1));
std::vector<int64_t> kernels(4);
status = ParseKernels(*weight_node, attr->format, &kernels);
if (status != RET_OK) {


+ 2
- 2
mindspore/lite/tools/converter/parser/tf/tf_expand_dims_parser.cc View File

@@ -42,11 +42,11 @@ STATUS TFExpandDimsParser::Parse(const tensorflow::NodeDef &tf_op,
return RET_NULL_PTR;
}

if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (axis_node == nullptr) {
MS_LOG(ERROR) << "Find ExpandDims input axis failed";
return RET_ERROR;
}
auto axis_node = tf_node_map.at(tf_op.input(1));
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";


+ 2
- 2
mindspore/lite/tools/converter/parser/tf/tf_gather_parser.cc View File

@@ -50,11 +50,11 @@ STATUS TFGatherParser::Parse(const tensorflow::NodeDef &tf_op,
bool axis_is_set = false;
if (tf_op.input_size() == 3) {
axis_is_set = true;
if (tf_node_map.find(tf_op.input(2)) == tf_node_map.end()) {
auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(2));
if (axis_node == nullptr) {
MS_LOG(ERROR) << "Find Gather input axis failed";
return RET_ERROR;
}
auto axis_node = tf_node_map.at(tf_op.input(2));
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;


+ 66
- 91
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -17,7 +17,6 @@

#include "tools/converter/parser/tf/tf_model_parser.h"
#include <functional>
#include <regex>
#include <set>
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
@@ -25,31 +24,11 @@
#include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/optimizer/common/gllo_utils.h"

namespace mindspore {
namespace lite {
namespace {
static const std::vector<schema::PrimitiveType> tensorListOutputOpList = {
schema::PrimitiveType_TensorListFromTensor,
schema::PrimitiveType_TensorListSetItem,
schema::PrimitiveType_TensorListReserve,
};

// subgraph node input may be a:output:0/a:z:0
std::string GetFlattenNodeName(std::string input_name) {
std::regex re("\\:+");
std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1),
std::sregex_token_iterator());
if (input_splits.size() == 3) {
if (input_splits[2] == "0") {
input_name = input_splits[0];
} else {
input_name = input_splits[0] + input_splits[2]; // multi output node
}
}
return input_name;
}
AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map<std::string, AnfNodePtr> &anf_node_map) {
AnfNodePtr ret = nullptr;
if (anf_node_map.find(name) != anf_node_map.end()) {
@@ -67,10 +46,11 @@ std::string GetOriginInputName(const tensorflow::NodeDef &node,
}
auto tmp_node = &node;
while (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient") {
if (tf_graph_nodes.find(tmp_node->input(0)) == tf_graph_nodes.end()) {
return tmp_node->input(0);
auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(tmp_node->input(0));
if (tf_graph_nodes.find(flatten_input_name) == tf_graph_nodes.end()) {
return flatten_input_name;
}
tmp_node = tf_graph_nodes.at(tmp_node->input(0));
tmp_node = tf_graph_nodes.at(flatten_input_name);
}
return tmp_node->name();
}
@@ -89,6 +69,10 @@ STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_
MS_LOG(ERROR) << "Only TensorList type is supported now";
return RET_NOT_SUPPORT;
}
if (variant.tensors_size() > 0) {
MS_LOG(ERROR) << "Only empty tensorlist is supported now";
return RET_NOT_SUPPORT;
}
auto descriptor = variant.GetMetadata().descriptor;
auto reflection = variant.GetMetadata().reflection;
if (descriptor == nullptr || reflection == nullptr) {
@@ -232,6 +216,7 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value
param_value->set_tensor_type(type);
param_value->set_format(schema::Format::Format_NHWC);
parameter->set_default_param(param_value);
parameter->set_name("const_" + std::to_string(anf_root_node_map.size()) + "_parameter");
return RET_OK;
}

@@ -263,7 +248,8 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
return status;
}
} else {
graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names
parameter->set_name("placeholder_" + std::to_string(anf_root_node_map.size()));
graph_input_names.emplace_back(parameter->name()); // only root graph need set graph input names
}

auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
@@ -271,12 +257,9 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
MS_LOG(ERROR) << "abstract_tensor is nullptr";
return RET_ERROR;
}
parameter->set_name(node.name());
parameter->set_abstract(abstract_tensor);

(*anf_node_map)[node.name()] = parameter;
(*anf_node_map)[node.name() + ":0"] = parameter;

return RET_OK;
}

@@ -311,48 +294,43 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
tf_root_graph_ = std::make_unique<tensorflow::GraphDef>();
if (tf_root_graph_ == nullptr) {
MS_LOG(ERROR) << "tf_root_graph_ is nullptr";
tf_root_graph = std::make_unique<tensorflow::GraphDef>();
if (tf_root_graph == nullptr) {
MS_LOG(ERROR) << "tf_root_graph is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get());
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
anf_root_graph_ = std::make_shared<FuncGraph>();
if (anf_root_graph_ == nullptr) {
anf_root_graph = std::make_shared<FuncGraph>();
if (anf_root_graph == nullptr) {
MS_LOG(ERROR) << "funGraphPtr is nullptr";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}

for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
tf_root_graph_nodes_[node_def.name()] = &node_def;
for (int i = 0; i < tf_root_graph->node_size(); i++) {
auto &node_def = tf_root_graph->node(i);
tf_root_graph_nodes[node_def.name()] = &node_def;
}

status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes, anf_root_graph, &anf_root_node_map);
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
bool success_flag = true;
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
if (status != RET_OK) {
success_flag = false;
for (int i = 0; i < tf_root_graph->node_size(); i++) {
auto &node_def = tf_root_graph->node(i);
if (ConvertOps(node_def, tf_root_graph_nodes, anf_root_graph, &anf_root_node_map) != RET_OK) {
MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
}
if (!success_flag) {
MS_LOG(ERROR) << "Convert ops failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = ConvertRootGraphOutputs();
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert graph outputs failed.";
@@ -367,25 +345,25 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
return nullptr;
}

return anf_root_graph_;
return anf_root_graph;
}
STATUS TFModelParser::ConvertSubgraph() {
auto graph_def_liarary = tf_root_graph_->library();
auto graph_def_liarary = tf_root_graph->library();
auto subgraph_size = graph_def_liarary.function_size();
std::map<CNodePtr, FuncGraphPtr> while_cond_map;
std::map<CNodePtr, FuncGraphPtr> while_body_map;
std::vector<ParameterPtr> sub_graph_inputs;
for (int i = 0; i < subgraph_size; i++) {
std::vector<ParameterPtr> sub_graph_inputs;
auto &tf_sub_fuction = graph_def_liarary.function(i);
auto &tf_sub_signature = tf_sub_fuction.signature();
auto input_arg_size = tf_sub_signature.input_arg_size();

auto &sub_graph_name = tf_sub_signature.name();
if (!function_while_map_.count(sub_graph_name)) {
if (!function_while_map.count(sub_graph_name)) {
MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name;
return RET_ERROR;
}
auto while_cnode = function_while_map_[sub_graph_name]->cast<CNodePtr>();
auto while_cnode = function_while_map[sub_graph_name]->cast<CNodePtr>();
if (while_cnode == nullptr || static_cast<int>(while_cnode->inputs().size()) != input_arg_size + 1) {
MS_LOG(ERROR) << "while cnode not equal input arg size";
return RET_ERROR;
@@ -426,9 +404,16 @@ STATUS TFModelParser::ConvertSubgraph() {
// convert subgraph outputs
std::vector<AnfNodePtr> sub_output_nodes;
auto &subgraph_ret = tf_sub_fuction.ret();
for (auto &t : subgraph_ret) {
MS_LOG(INFO) << "subret " << t.first << " " << t.second;
auto tf_output_name = GetFlattenNodeName(t.second);
auto &output_args = tf_sub_signature.output_arg();
for (auto &output_arg : output_args) {
auto &signature_name = output_arg.name();
if (subgraph_ret.find(signature_name) == subgraph_ret.end()) {
MS_LOG(ERROR) << "can't find signature_name: " << signature_name;
return RET_ERROR;
}
auto t = subgraph_ret.find(signature_name);
MS_LOG(INFO) << "subret " << t->first << " " << t->second;
auto tf_output_name = TensorFlowUtils::GetFlattenNodeName(t->second);
AnfNodePtr anf_node = nullptr;
if (tf_sub_node_map.find(tf_output_name) == tf_sub_node_map.end()) {
anf_node = GetAnfNode(tf_output_name, anf_sub_node_map);
@@ -456,7 +441,7 @@ STATUS TFModelParser::ConvertSubgraph() {
}
// hardcode subgraph inputs name
for (size_t j = 0; j < sub_graph_inputs.size(); j++) {
sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter");
sub_graph_inputs[j]->set_name("graph_input_" + std::to_string(j) + "parameter");
}
MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name;
}
@@ -473,9 +458,9 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr
MS_LOG(ERROR) << "while cond body size error";
return RET_ERROR;
}
std::vector<FuncGraphPtr> roots = {anf_root_graph_};
std::vector<FuncGraphPtr> roots = {anf_root_graph};
auto root_func_manager = std::make_shared<FuncGraphManager>(roots);
anf_root_graph_->set_manager(root_func_manager);
anf_root_graph->set_manager(root_func_manager);
for (auto &kv : while_cond_map) {
auto while_node = kv.first;
auto &cond_sub_graph = kv.second;
@@ -484,12 +469,9 @@ STATUS TFModelParser::WhileNodePostProcess(const std::map<CNodePtr, FuncGraphPtr
body_sub_graph->set_manager(root_func_manager);
auto cond_value_node = NewValueNode(cond_sub_graph);
auto body_value_node = NewValueNode(body_sub_graph);
auto new_while_inputs = while_node->cast<CNodePtr>()->inputs();
new_while_inputs[0] = cond_value_node;
new_while_inputs.insert(new_while_inputs.begin() + 1, body_value_node);
auto new_while_node = anf_root_graph_->NewCNode(new_while_inputs);
new_while_node->set_abstract(while_node->abstract());
root_func_manager->Replace(while_node, new_while_node);
auto inputs = while_node->inputs();
inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node});
while_node->set_inputs(inputs);
}
return RET_OK;
}
@@ -510,7 +492,7 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def,
for (size_t j = 0; j < input_names.size(); j++) {
std::string input_name = input_names[j]; // input may be produced by multi-outputs node
// subgraph input name x:output:index,need flatten
auto flatten_input_name = GetFlattenNodeName(input_name);
auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(input_name);
if (tf_node_map.find(flatten_input_name) != tf_node_map.end()) {
auto input_node = tf_node_map.at(flatten_input_name);
flatten_input_name = GetOriginInputName(*input_node, tf_node_map);
@@ -531,20 +513,10 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
MS_ASSERT(op != nullptr);
MS_ASSERT(anf_node != nullptr);
MS_ASSERT(anf_graph != nullptr);
if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node)) && output_size != 1) {
MS_LOG(ERROR) << "tensorlist output op output_size !=1";
return RET_ERROR;
}
if (output_size == 0) {
return RET_OK;
} else if (output_size == 1) {
auto type = kFloat32;
if (output_size == 1) {
std::vector<int64_t> shape_vector;
if (IsContain(tensorListOutputOpList, opt::GetCNodeType(anf_node))) {
type = TypeIdToType(kObjectTypeTensorType);
}
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vector));
anf_node_map->insert(std::pair(op.name(), anf_node));
anf_node->set_abstract(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
anf_node_map->insert(std::pair(op.name() + ":0", anf_node));
} else {
AbstractBasePtrList abstractList;
for (int output_idx = 0; output_idx < output_size; output_idx++) {
@@ -608,17 +580,17 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
// control_depends are not processed currently
auto anf_node = func_graph_ptr->NewCNode(inputs);
anf_node->set_fullname_with_scope(node_def.name());
if (op_type == "StatelessWhile" || op_type == "while") {
if (op_type == "StatelessWhile" || op_type == "While") {
MS_LOG(INFO) << "find while node:" << node_def.name();
tensorflow::AttrValue attr_value;
if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) {
auto body_name = attr_value.func().name();
function_while_map_[body_name] = anf_node;
function_while_map[body_name] = anf_node;
MS_LOG(DEBUG) << "parse body name:" << body_name;
}
if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) {
auto cond_name = attr_value.func().name();
function_while_map_[cond_name] = anf_node;
function_while_map[cond_name] = anf_node;
MS_LOG(DEBUG) << "parse cond name:" << cond_name;
}
}
@@ -634,28 +606,31 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,

STATUS TFModelParser::ConvertRootGraphOutputs() {
// because output of intermediate node in anf graph may also be output tensors, we search output tensors in
// tf_root_graph_nodes_ but not anf_root_node_map_
// tf_root_graph_nodes but not anf_root_node_map
std::set<std::string> all_node_inputs;
std::vector<AnfNodePtr> output_nodes;
for (auto &pair : tf_root_graph_nodes_) {
for (auto &pair : tf_root_graph_nodes) {
for (int i = 0; i < pair.second->input_size(); ++i) {
all_node_inputs.insert(pair.second->input(i));
all_node_inputs.insert(TensorFlowUtils::GetNodeName(pair.second->input(i)));
}
}
for (auto &pair : tf_root_graph_nodes_) {
for (auto &pair : tf_root_graph_nodes) {
if (pair.second->op() == "Assert") {
continue;
}
auto it = all_node_inputs.find(pair.first);
if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_);
auto anf_node = GetAnfNode(origin_name, anf_root_node_map_);
auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes);
auto anf_node = GetAnfNode(origin_name, anf_root_node_map);
if (anf_node == nullptr) {
MS_LOG(ERROR) << "can't find anf node";
return RET_ERROR;
}
output_nodes.push_back(anf_node);
graph_output_names_.push_back(anf_node->fullname_with_scope());
graph_output_names.push_back(anf_node->fullname_with_scope());
}
}
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_);
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "make anf graph outputs node error";
return status;


+ 7
- 7
mindspore/lite/tools/converter/parser/tf/tf_model_parser.h View File

@@ -71,13 +71,13 @@ class TFModelParser : public ModelParser {

STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph);

FuncGraphPtr anf_root_graph_;
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_;
std::vector<std::string> graph_input_names_;
std::vector<std::string> graph_output_names_;
std::map<std::string, AnfNodePtr> function_while_map_; // tf function name->while_node_name
FuncGraphPtr anf_root_graph;
std::unique_ptr<tensorflow::GraphDef> tf_root_graph; // tf root graph def
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes; // tf root graph node map
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map;
std::vector<std::string> graph_input_names;
std::vector<std::string> graph_output_names;
std::map<std::string, AnfNodePtr> function_while_map; // tf function name->while_node_name
};
} // namespace lite
} // namespace mindspore


+ 16
- 0
mindspore/lite/tools/converter/parser/tf/tf_node_parser.cc View File

@@ -17,6 +17,8 @@
#include <string>
#include <vector>

using tensorflow::NodeDef;

namespace mindspore {
namespace lite {
STATUS TFNodeParser::AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector<std::string> *inputs) {
@@ -27,5 +29,19 @@ STATUS TFNodeParser::AddOpInput(const tensorflow::NodeDef &tf_op, const int idx,
inputs->push_back(tf_op.input(idx));
return RET_OK;
}

const NodeDef *TFNodeParser::GetConstInputNode(const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
const std::string &input_name) {
auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(input_name);
if (tf_node_map.find(flatten_input_name) == tf_node_map.end()) {
return nullptr;
}
auto node = tf_node_map.at(flatten_input_name);
if (node->op() != "Const") {
MS_LOG(ERROR) << "Attr node is not Const";
return nullptr;
}
return node;
}
} // namespace lite
} // namespace mindspore

+ 3
- 0
mindspore/lite/tools/converter/parser/tf/tf_node_parser.h View File

@@ -40,6 +40,9 @@ class TFNodeParser {
}

STATUS AddOpInput(const tensorflow::NodeDef &tf_op, const int idx, std::vector<std::string> *inputs);

const tensorflow::NodeDef *GetConstInputNode(const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
const std::string &input_name);
};
} // namespace lite
} // namespace mindspore


+ 2
- 2
mindspore/lite/tools/converter/parser/tf/tf_reduce_parser.cc View File

@@ -69,11 +69,11 @@ STATUS TFReduceParser::Parse(const tensorflow::NodeDef &tf_op,
}
attr->keepDims = attr_value.b();

if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
auto axis_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (axis_node == nullptr) {
MS_LOG(ERROR) << "Find Reduce input axis failed";
return RET_ERROR;
}
auto axis_node = tf_node_map.at(tf_op.input(1));
if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;


+ 70
- 0
mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.cc View File

@@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WRRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_reverse_sequence_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"

namespace mindspore {
namespace lite {
STATUS TFReverseSequenceParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF ReverseSequenceParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}

auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::ReverseSequenceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}

tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "batch_dim", &attr_value)) {
MS_LOG(ERROR) << "The batch_dim attr should be specified";
return RET_ERROR;
}
attr->batchAxis = attr_value.i();
if (!TensorFlowUtils::FindAttrValue(tf_op, "seq_dim", &attr_value)) {
MS_LOG(ERROR) << "The seq_dim attr should be specified";
return RET_ERROR;
}
attr->seqAxis = attr_value.i();

primitive->value.type = schema::PrimitiveType_ReverseSequence;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}

*output_size = 1;
return AddOpInput(tf_op, 0, inputs);
}
TFNodeRegistrar g_tfReverseSequenceParser("ReverseSequence", new TFReverseSequenceParser());
} // namespace lite
} // namespace mindspore

+ 37
- 0
mindspore/lite/tools/converter/parser/tf/tf_reverse_sequence_parser.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_SEQUENCE_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_SEQUENCE_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"

namespace mindspore {
namespace lite {
class TFReverseSequenceParser : public TFNodeParser {
public:
TFReverseSequenceParser() = default;
~TFReverseSequenceParser() override = default;

STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_REVERSE_SEQUENCE_PARSER_H_

+ 4
- 4
mindspore/lite/tools/converter/parser/tf/tf_split_parser.cc View File

@@ -58,11 +58,11 @@ STATUS TFSplitParser::Parse(const tensorflow::NodeDef &tf_op,
input_index = 0;
}

if (tf_node_map.find(tf_op.input(split_dim_index)) == tf_node_map.end()) {
auto split_dim_node = GetConstInputNode(tf_node_map, tf_op.input(split_dim_index));
if (split_dim_node == nullptr) {
MS_LOG(ERROR) << "Find Split input split_dim node failed";
return RET_ERROR;
}
const auto &split_dim_node = tf_node_map.at(tf_op.input(split_dim_index));
if (!TensorFlowUtils::FindAttrValue(*split_dim_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The attribute splitDim should be specified";
return RET_PARAM_INVALID;
@@ -72,11 +72,11 @@ STATUS TFSplitParser::Parse(const tensorflow::NodeDef &tf_op,
*output_size = attr->numberSplit;

if (tf_op.op() == "SplitV") {
if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
auto size_splits_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (size_splits_node == nullptr) {
MS_LOG(ERROR) << "Find Split input size_splits failed";
return RET_ERROR;
}
auto size_splits_node = tf_node_map.at(tf_op.input(1));
if (!TensorFlowUtils::FindAttrValue(*size_splits_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The attribute size splits should be specified";
return RET_PARAM_INVALID;


+ 6
- 6
mindspore/lite/tools/converter/parser/tf/tf_stride_slice_parser.cc View File

@@ -74,11 +74,11 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
attr->shrinkAxisMask = attr_value.i();

// begin
if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
auto begin_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (begin_node == nullptr) {
MS_LOG(ERROR) << "Find StridedSlice input begin failed";
return RET_ERROR;
}
auto begin_node = tf_node_map.at(tf_op.input(1));
if (!TensorFlowUtils::FindAttrValue(*begin_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
@@ -97,11 +97,11 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
}

// end
if (tf_node_map.find(tf_op.input(2)) == tf_node_map.end()) {
auto end_node = GetConstInputNode(tf_node_map, tf_op.input(2));
if (end_node == nullptr) {
MS_LOG(ERROR) << "Find StridedSlice input end failed";
return RET_ERROR;
}
auto end_node = tf_node_map.at(tf_op.input(2));
if (!TensorFlowUtils::FindAttrValue(*end_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
@@ -120,11 +120,11 @@ STATUS TFStrideSliceParser::Parse(const tensorflow::NodeDef &tf_op,
}

// strides
if (tf_node_map.find(tf_op.input(3)) == tf_node_map.end()) {
auto stride_node = GetConstInputNode(tf_node_map, tf_op.input(3));
if (stride_node == nullptr) {
MS_LOG(ERROR) << "Find StridedSlice input strides failed";
return RET_ERROR;
}
auto stride_node = tf_node_map.at(tf_op.input(3));
if (!TensorFlowUtils::FindAttrValue(*stride_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;


+ 2
- 2
mindspore/lite/tools/converter/parser/tf/tf_tile_parser.cc View File

@@ -42,11 +42,11 @@ STATUS TFTileParser::Parse(const tensorflow::NodeDef &tf_op,
return RET_NULL_PTR;
}

if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
auto multiplies_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (multiplies_node == nullptr) {
MS_LOG(ERROR) << "Find Tile input multiplies failed";
return RET_ERROR;
}
auto multiplies_node = tf_node_map.at(tf_op.input(1));
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(*multiplies_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";


+ 3
- 2
mindspore/lite/tools/converter/parser/tf/tf_transpose_parser.cc View File

@@ -42,11 +42,12 @@ STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op,
return RET_NULL_PTR;
}

if (tf_node_map.find(tf_op.input(1)) == tf_node_map.end()) {
attr->conjugate = false;
auto perm_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (perm_node == nullptr) {
MS_LOG(ERROR) << "Find Transpose input perm failed";
return RET_ERROR;
}
auto perm_node = tf_node_map.at(tf_op.input(1));
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(*perm_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";


+ 29
- 0
mindspore/lite/tools/converter/parser/tf/tf_util.cc View File

@@ -16,8 +16,10 @@

#include "tools/converter/parser/tf/tf_util.h"
#include <string>
#include <vector>
#include <string_view>
#include <unordered_map>
#include <regex>
#include "src/common/log_adapter.h"
#include "schema/inner/model_generated.h"

@@ -112,5 +114,32 @@ bool TensorFlowUtils::DecodeInt64(std::string_view *str_view, uint64_t *value) {
return true;
}
}

// convert input_arg in subgraph to node_name[:index] format
std::string TensorFlowUtils::GetFlattenNodeName(const std::string &input_name) {
std::regex re("\\:+");
std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1),
std::sregex_token_iterator());
std::string ret = input_name;
if (input_splits.size() == 3) {
if (input_splits[2] == "0") {
ret = input_splits[0];
} else {
ret = input_splits[0] + input_splits[2]; // multi output node
}
}
return ret;
}

// get referenced node name from input name
std::string TensorFlowUtils::GetNodeName(const std::string &input_name) {
std::regex re("\\:+");
std::vector<std::string> input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1),
std::sregex_token_iterator());
if (input_splits.size() > 1) {
return input_splits[0];
}
return input_name;
}
} // namespace lite
} // namespace mindspore

+ 2
- 0
mindspore/lite/tools/converter/parser/tf/tf_util.h View File

@@ -34,6 +34,8 @@ class TensorFlowUtils {
static TypeId ParseAttrDataType(const tensorflow::NodeDef &node_def, const std::string &attr_name);
static schema::Format ParseNodeFormat(const tensorflow::NodeDef &node_def);
static bool DecodeInt64(std::string_view *str_view, uint64_t *value);
static std::string GetFlattenNodeName(const std::string &input_name);
static std::string GetNodeName(const std::string &input_name);
};
} // namespace lite
} // namespace mindspore


+ 1
- 0
mindspore/lite/tools/optimizer/graph/clip_convert_activation_pass.cc View File

@@ -101,6 +101,7 @@ bool ClipConvertActivationPass::Run(const FuncGraphPtr &graph) {
op_inputs.push_back(clip_cnode->input(1));
auto new_cnode = graph->NewCNode(op_inputs);
new_cnode->set_fullname_with_scope(node->fullname_with_scope());
new_cnode->set_abstract(clip_cnode->abstract()->Clone());
manager->Replace(node, new_cnode);
}
return false;


+ 8
- 11
mindspore/lite/tools/optimizer/graph/infershape_pass.cc View File

@@ -121,7 +121,7 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l
}

if (utils::isa<ValueNodePtr>(cnode->input(i))) {
MS_LOG(ERROR) << "input is value node";
MS_LOG(WARNING) << "input is value node";
continue;
}

@@ -178,13 +178,10 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l
}
} else {
int *data = reinterpret_cast<int *>(param_value->tensor_addr());
auto tensor_list = dynamic_cast<lite::TensorList *>(tensor.get());
tensor_list->set_tensors_data_type(TypeId(data[0]));
std::vector<int> shape;
for (int j = 0; j < data[1]; ++j) {
shape.push_back(data[2 + j]);
auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor.get());
if (tensor_list->Decode(data) != RET_OK) {
return RET_ERROR;
}
tensor_list->set_element_shape(shape);
}
}
}
@@ -210,8 +207,8 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(element);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
types.push_back(typePtr->type_id());
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
types.push_back(type_ptr->type_id());
}
} else {
if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
@@ -219,8 +216,8 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
types.push_back(typePtr->type_id());
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
types.push_back(type_ptr->type_id());
}
for (auto &type : types) {
std::unique_ptr<lite::Tensor> output_tensor = nullptr;


Loading…
Cancel
Save