Browse Source

!9875 support parse TF const variant & InferShapePass in anfTransform support tensorlist

From: @wangzhe128
Reviewed-by: @hangangqiang,@ddwsky
Signed-off-by: @hangangqiang
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b64d2b273f
9 changed files with 189 additions and 33 deletions
  1. +4
    -5
      mindspore/lite/src/ops/tensorlistreserve.cc
  2. +2
    -3
      mindspore/lite/src/ops/tensorlistsetitem.cc
  3. +2
    -3
      mindspore/lite/src/tensorlist.cc
  4. +79
    -0
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc
  5. +2
    -0
      mindspore/lite/tools/converter/parser/tf/tf_model_parser.h
  6. +35
    -1
      mindspore/lite/tools/converter/parser/tf/tf_util.cc
  7. +2
    -0
      mindspore/lite/tools/converter/parser/tf/tf_util.h
  8. +62
    -21
      mindspore/lite/tools/optimizer/graph/infershape_pass.cc
  9. +1
    -0
      mindspore/lite/tools/optimizer/graph/infershape_pass.h

+ 4
- 5
mindspore/lite/src/ops/tensorlistreserve.cc View File

@@ -99,9 +99,8 @@ int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
auto input0 = inputs_.front();
MS_ASSERT(input0 != nullptr);
auto ele_shape_type = input0->data_type();
if (ele_shape_type != kNumberTypeInt) {
MS_LOG(ERROR) << "ele_shape_tensor.data_type():" << ele_shape_type
<< " must be \"kNumberTypeInt\":" << kNumberTypeInt;
if (ele_shape_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) {
MS_LOG(ERROR) << "ele_shape_tensor.data_type():" << ele_shape_type << " is not int";
return RET_ERROR;
}
if (input0->data_c() == nullptr) {
@@ -113,8 +112,8 @@ int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
auto input1 = inputs_[1];
MS_ASSERT(input1 != nullptr);
auto num_ele_type = input1->data_type();
if (num_ele_type != kNumberTypeInt) {
MS_LOG(ERROR) << "num_ele_tensor.data_type():" << num_ele_type << " must be \"kNumberTypeInt\":" << kNumberTypeInt;
if (num_ele_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) {
MS_LOG(ERROR) << "num_ele_tensor.data_type():" << num_ele_type << " is not int";
return RET_ERROR;
}
if (input1->ElementsNum() != 1) {


+ 2
- 3
mindspore/lite/src/ops/tensorlistsetitem.cc View File

@@ -97,9 +97,8 @@ int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
MS_ASSERT(input0 != nullptr);
auto get_index = inputs_[1];
MS_ASSERT(get_index != nullptr);
if (get_index->data_type() != kNumberTypeInt) {
MS_LOG(ERROR) << "inputs_[1]->data_type():" << get_index->data_type()
<< " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt;
if (get_index->data_type() != kNumberTypeInt && get_index->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "inputs_[1]->data_type():" << get_index->data_type() << " is not int";
return RET_ERROR;
}
if (get_index->ElementsNum() != 1) {


+ 2
- 3
mindspore/lite/src/tensorlist.cc View File

@@ -228,9 +228,8 @@ bool TensorList::IsCompatibleShape(const Tensor *src) {
if (static_cast<size_t>(src->ElementsNum()) != this->element_shape_.size()) {
return false;
}
if (src->data_type() != kNumberTypeInt) {
MS_LOG(ERROR) << "src tensor data_type:" << src->data_type()
<< " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt;
if (src->data_type() != kNumberTypeInt && src->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "src tensor data_type:" << src->data_type() << " is not int";
return false;
}
auto src_ptr = reinterpret_cast<int *>(src->data_c());


+ 79
- 0
mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc View File

@@ -76,6 +76,80 @@ std::string GetOriginInputName(const tensorflow::NodeDef &node,
}
} // namespace

STATUS TFModelParser::ConvertConstVariant(const tensorflow::TensorProto &tensor_proto,
const ParamValueLitePtr &param_value) {
MS_ASSERT(param_value != nullptr);
auto variant_size = tensor_proto.variant_val_size();
if (variant_size != 1) {
MS_LOG(ERROR) << "only support variant_val_size == 1 now";
return RET_ERROR;
}
auto &variant = tensor_proto.variant_val(0);
if (variant.type_name() != "tensorflow::TensorList") {
MS_LOG(ERROR) << "Only TensorList type is supported now";
return RET_NOT_SUPPORT;
}
auto descriptor = variant.GetMetadata().descriptor;
auto reflection = variant.GetMetadata().reflection;
if (descriptor == nullptr || reflection == nullptr) {
MS_LOG(ERROR) << "descriptor or reflection is nullptr";
return RET_ERROR;
}
auto field_descriptor = descriptor->field(1);
if (field_descriptor == nullptr) {
MS_LOG(ERROR) << "field_descriptor is nullptr";
return RET_ERROR;
}
auto type = field_descriptor->type();
if (type != google::protobuf::FieldDescriptor::TYPE_BYTES) {
MS_LOG(ERROR) << "metadata type is not TYPE_BYTES";
return RET_ERROR;
}
auto str = reflection->GetString(variant, field_descriptor);
std::string_view str_view(str);
uint64_t scratch;
if (!TensorFlowUtils::DecodeInt64(&str_view, &scratch)) {
return RET_ERROR;
}
size_t num_invalid_tensors = static_cast<size_t>(scratch);
for (size_t i = 0; i < num_invalid_tensors; ++i) {
if (!TensorFlowUtils::DecodeInt64(&str_view, &scratch)) {
return RET_ERROR;
}
}
if (!TensorFlowUtils::DecodeInt64(&str_view, &scratch)) {
return RET_ERROR;
}
size_t element_dtype = static_cast<size_t>(scratch);
if (!TensorFlowUtils::DecodeInt64(&str_view, &scratch)) {
return RET_ERROR;
}
std::string element_shape_str = std::string(str_view.data(), str_view.size());
tensorflow::TensorShapeProto element_shape_proto;
element_shape_proto.ParseFromString(element_shape_str);
auto dim_size = element_shape_proto.dim_size();
// we encode element_dtype,shape.size,shape[i]... into data
auto tensor_data = new (std::nothrow) int[dim_size + 2];
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "tensor_data is nullptr";
return RET_ERROR;
}
tensor_data[0] = TensorFlowUtils::GetTFDataType(tensorflow::DataType(element_dtype));
tensor_data[1] = element_shape_proto.dim_size();
for (int i = 0; i < dim_size; ++i) {
auto dim = element_shape_proto.dim(i).size();
if (dim > static_cast<int64_t>(INT32_MAX) || dim < static_cast<int64_t>(INT32_MIN)) {
MS_LOG(ERROR) << "int64 data " << dim << " too big to fit into int32";
delete[] tensor_data;
return RET_ERROR;
} else {
tensor_data[i + 2] = static_cast<int>(dim);
}
}
param_value->SetTensorData(tensor_data, (dim_size + 2) * sizeof(int));
return RET_OK;
}

STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type,
const ParameterPtr &parameter, std::vector<int64_t> *shape_vector) {
MS_ASSERT(parameter != nullptr);
@@ -143,6 +217,11 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value
}
tensor_size = shape_size * sizeof(int);
param_value->SetTensorData(tensor_data, tensor_size);
} else if (type == kObjectTypeTensorType) {
auto status = ConvertConstVariant(tensor_proto, param_value);
if (status != RET_OK) {
return status;
}
} else {
MS_LOG(ERROR) << "Unsupport dataType: " << type;
return RET_ERROR;


+ 2
- 0
mindspore/lite/tools/converter/parser/tf/tf_model_parser.h View File

@@ -28,6 +28,7 @@
#include "securec/include/securec.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/model_parser.h"
#include "mindspore/lite/src/param_value_lite.h"

namespace mindspore {
namespace lite {
@@ -43,6 +44,7 @@ class TFModelParser : public ModelParser {
const QuantType &quantType = QuantType_QUANT_NONE) override;

private:
STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, const ParamValueLitePtr &param_value);
STATUS ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr &parameter,
std::vector<int64_t> *shape_vector);
STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr &parameter,


+ 35
- 1
mindspore/lite/tools/converter/parser/tf/tf_util.cc View File

@@ -16,6 +16,7 @@

#include "tools/converter/parser/tf/tf_util.h"
#include <string>
#include <string_view>
#include <unordered_map>
#include "src/common/log_adapter.h"
#include "schema/inner/model_generated.h"
@@ -27,7 +28,7 @@ static const std::unordered_map<int, mindspore::TypeId> TF_TYPE_MAP = {
{tensorflow::DT_UINT8, mindspore::kNumberTypeUInt8},
{tensorflow::DT_INT16, mindspore::kNumberTypeInt16},
{tensorflow::DT_UINT16, mindspore::kNumberTypeUInt16},
{tensorflow::DT_INT32, mindspore::kNumberTypeInt32},
{tensorflow::DT_INT32, mindspore::kNumberTypeInt},
{tensorflow::DT_INT64, mindspore::kNumberTypeInt64},
{tensorflow::DT_HALF, mindspore::kNumberTypeFloat16},
{tensorflow::DT_FLOAT, mindspore::kNumberTypeFloat32},
@@ -65,6 +66,7 @@ TypeId TensorFlowUtils::ParseAttrDataType(const tensorflow::NodeDef &node_def, c
}
return GetTFDataType(attr_value.type());
}

schema::Format TensorFlowUtils::ParseNodeFormat(const tensorflow::NodeDef &node_def) {
tensorflow::AttrValue attr_value;
if (!FindAttrValue(node_def, "data_format", &attr_value)) {
@@ -78,5 +80,37 @@ schema::Format TensorFlowUtils::ParseNodeFormat(const tensorflow::NodeDef &node_
}
return schema::Format_NUM_OF_FORMAT;
}

bool TensorFlowUtils::DecodeInt64(std::string_view *str_view, uint64_t *value) {
if (str_view == nullptr || value == nullptr) {
*value = 0;
MS_LOG(ERROR) << "str_view or value is nullptr";
return false;
}
auto data = str_view->data();
const auto end = data + str_view->size();

const char *next = nullptr;
uint64_t result = 0;
for (uint32_t shift = 0; shift <= 63 && data < end; shift += 7) {
uint64_t byte = *(reinterpret_cast<const unsigned char *>(data));
data++;
if (byte & 128) {
result |= ((byte & 127) << shift);
} else {
result |= (byte << shift);
*value = result;
next = reinterpret_cast<const char *>(data);
break;
}
}

if (next == nullptr) {
return false;
} else {
*str_view = std::string_view(next, end - next);
return true;
}
}
} // namespace lite
} // namespace mindspore

+ 2
- 0
mindspore/lite/tools/converter/parser/tf/tf_util.h View File

@@ -18,6 +18,7 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_UTIL_H

#include <string>
#include <string_view>
#include "proto/node_def.pb.h"
#include "ir/dtype/type_id.h"
#include "include/errorcode.h"
@@ -32,6 +33,7 @@ class TensorFlowUtils {
tensorflow::AttrValue *attr_value);
static TypeId ParseAttrDataType(const tensorflow::NodeDef &node_def, const std::string &attr_name);
static schema::Format ParseNodeFormat(const tensorflow::NodeDef &node_def);
static bool DecodeInt64(std::string_view *str_view, uint64_t *value);
};
} // namespace lite
} // namespace mindspore


+ 62
- 21
mindspore/lite/tools/optimizer/graph/infershape_pass.cc View File

@@ -119,11 +119,6 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l
MS_LOG(ERROR) << "input is nullptr";
return RET_ERROR;
}
auto tensor = std::make_unique<lite::Tensor>();
if (tensor == nullptr) {
MS_LOG(ERROR) << "new input tensor failed";
return RET_ERROR;
}

if (utils::isa<ValueNodePtr>(cnode->input(i))) {
MS_LOG(ERROR) << "input is value node";
@@ -149,23 +144,47 @@ STATUS InferShapePass::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l
MS_LOG(ERROR) << "ParamValueLite of abstract is nullptr";
return RET_ERROR;
}
tensor->set_shape(param_value_lite->tensor_shape());
tensor->set_data_type(param_value_lite->tensor_type());
tensor->set_format(schema::Format(param_value_lite->format()));

std::unique_ptr<lite::Tensor> tensor = nullptr;
if (param_value_lite->tensor_type() != kObjectTypeTensorType) {
tensor = std::make_unique<lite::Tensor>();
} else {
tensor = std::make_unique<lite::TensorList>();
}
if (tensor == nullptr) {
MS_LOG(ERROR) << "new input tensor failed";
return RET_ERROR;
}
if (param_value_lite->tensor_type() != kObjectTypeTensorType) {
tensor->set_shape(param_value_lite->tensor_shape());
tensor->set_data_type(param_value_lite->tensor_type());
tensor->set_format(schema::Format(param_value_lite->format()));
}

if (utils::isa<ParameterPtr>(input)) {
auto parameter = input->cast<ParameterPtr>();
if (parameter->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
auto ret = tensor->MallocData();
if (ret != 0) {
MS_LOG(ERROR) << "Malloc tensor data failed";
return RET_ERROR;
}
ret = memcpy_s(tensor->MutableData(), tensor->Size(), param_value->tensor_addr(), param_value->tensor_size());
if (tensor->Size() != 0 && ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
if (param_value_lite->tensor_type() != kObjectTypeTensorType) {
auto ret = tensor->MallocData();
if (ret != 0) {
MS_LOG(ERROR) << "Malloc tensor data failed";
return RET_ERROR;
}
ret = memcpy_s(tensor->MutableData(), tensor->Size(), param_value->tensor_addr(), param_value->tensor_size());
if (tensor->Size() != 0 && ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
} else {
int *data = reinterpret_cast<int *>(param_value->tensor_addr());
auto tensor_list = dynamic_cast<lite::TensorList *>(tensor.get());
tensor_list->set_tensors_data_type(TypeId(data[0]));
std::vector<int> shape;
for (int j = 0; j < data[1]; ++j) {
shape.push_back(data[2 + j]);
}
tensor_list->set_element_shape(shape);
}
}
}
@@ -181,13 +200,35 @@ STATUS InferShapePass::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<
MS_LOG(ERROR) << "abstract is nullptr";
return RET_ERROR;
}
size_t num_outputs = 1;
std::vector<TypeId> types;
if (utils::isa<abstract::AbstractTuple>(abstract)) {
auto abstract_tuple = abstract->cast<abstract::AbstractTuplePtr>();
num_outputs = abstract_tuple->size();
auto elements = abstract_tuple->elements();
for (auto &element : elements) {
if (!utils::isa<abstract::AbstractTensorPtr>(element)) {
MS_LOG(ERROR) << "abstract is not AbstractTensor";
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(element);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
types.push_back(typePtr->type_id());
}
} else {
if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
MS_LOG(ERROR) << "abstract is not AbstractTensor";
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
types.push_back(typePtr->type_id());
}
for (size_t i = 0; i < num_outputs; ++i) {
auto output_tensor = std::make_unique<lite::Tensor>();
for (auto &type : types) {
std::unique_ptr<lite::Tensor> output_tensor = nullptr;
if (type == kObjectTypeTensorType) {
output_tensor = std::make_unique<lite::TensorList>();
} else {
output_tensor = std::make_unique<lite::Tensor>();
}
if (output_tensor == nullptr) {
MS_LOG(ERROR) << "new output tensor failed";
return RET_ERROR;


+ 1
- 0
mindspore/lite/tools/optimizer/graph/infershape_pass.h View File

@@ -22,6 +22,7 @@
#include "tools/optimizer/common/gllo_utils.h"
#include "backend/optimizer/common/pass.h"
#include "mindspore/lite/src/tensor.h"
#include "mindspore/lite/src/tensorlist.h"
#include "mindspore/lite/include/errorcode.h"
using mindspore::lite::STATUS;
using mindspore::lite::converter::FmkType;


Loading…
Cancel
Save