diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 633b284df2..80913a8431 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -116,6 +116,33 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde return RET_OK; } +lite::Tensor *LiteSession::ConvertTensor(const schema::Tensor &src_tensor) { + auto src_category = TensorCategory(&src_tensor); + std::vector shape; + if (src_tensor.dims() == nullptr) { + MS_LOG(DEBUG) << "Dims of src_tensor is nullptr"; + } + if (src_tensor.dims() != nullptr && src_category == Tensor::Category::CONST_TENSOR) { + if (src_tensor.dataType() == kObjectTypeString && src_tensor.data() != nullptr) { + shape.push_back(src_tensor.data()->size()); + } else { + for (size_t j = 0; j < src_tensor.dims()->size(); j++) { + shape.push_back(src_tensor.dims()->data()[j]); + } + } + } + lite::Tensor *dst_tensor = nullptr; + if (TypeId(src_tensor.dataType()) == kObjectTypeTensorType) { + dst_tensor = new (std::nothrow) TensorList(shape, std::vector()); + } else { + dst_tensor = new (std::nothrow) Tensor(TypeId(src_tensor.dataType()), shape, src_tensor.format(), src_category); + } + if (dst_tensor == nullptr) { + return nullptr; + } + return dst_tensor; +} + int LiteSession::ConvertTensors(const lite::Model *model) { MS_ASSERT(model != nullptr); copyed_tensor_idxes_.clear(); @@ -126,24 +153,9 @@ int LiteSession::ConvertTensors(const lite::Model *model) { MS_LOG(ERROR) << i << "th tensor in model is nullptr"; return RET_NULL_PTR; } - auto src_category = TensorCategory(src_tensor); - std::vector shape; - if (src_tensor->dims() == nullptr) { - MS_LOG(DEBUG) << "Dims of " << i << "th tensor is nullptr"; - } - if (src_tensor->dims() != nullptr && src_category == Tensor::Category::CONST_TENSOR) { - if (src_tensor->dataType() == kObjectTypeString && src_tensor->data() != nullptr) { - shape.push_back(src_tensor->data()->size()); - } else { - for (size_t j = 0; j < src_tensor->dims()->size(); j++) { - shape.push_back(src_tensor->dims()->data()[j]); - } - } - } - auto *dst_tensor = - new (std::nothrow) Tensor(TypeId(src_tensor->dataType()), shape, src_tensor->format(), src_category); + auto *dst_tensor = ConvertTensor(*src_tensor); if (dst_tensor == nullptr) { - MS_LOG(ERROR) << "new " << i << "th tensor failed"; + MS_LOG(ERROR) << "Convert new " << i << "th tensor failed!"; return RET_NULL_PTR; } auto ret = ConvertTensorsData(model, i, src_tensor, dst_tensor); diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index ab91fcf1f6..55f953d124 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -30,6 +30,7 @@ #include "schema/model_generated.h" #include "src/executor.h" #include "src/tensor.h" +#include "src/tensorlist.h" #if SUPPORT_GPU #include "src/runtime/opencl/opencl_runtime.h" #endif @@ -71,6 +72,8 @@ class LiteSession : public session::LiteSession { int ConvertTensorsData(const lite::Model *model, size_t tensor_index, const schema::Tensor *src_tensor, lite::Tensor *dst_tensor); + lite::Tensor *ConvertTensor(const schema::Tensor &src_tensor); + int ConvertTensors(const lite::Model *model); void InitGraphInOutTensors(const lite::Model *model);