Browse Source

!9645 [ms][lite][cpu] lite_session add tensorlist logic

From: @lzkcode
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a86eb3ed74
2 changed files with 32 additions and 17 deletions
  1. +29
    -17
      mindspore/lite/src/lite_session.cc
  2. +3
    -0
      mindspore/lite/src/lite_session.h

+ 29
- 17
mindspore/lite/src/lite_session.cc View File

@@ -116,6 +116,33 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
return RET_OK;
}

lite::Tensor *LiteSession::ConvertTensor(const schema::Tensor &src_tensor) {
auto src_category = TensorCategory(&src_tensor);
std::vector<int> shape;
if (src_tensor.dims() == nullptr) {
MS_LOG(DEBUG) << "Dims of src_tensor is nullptr";
}
if (src_tensor.dims() != nullptr && src_category == Tensor::Category::CONST_TENSOR) {
if (src_tensor.dataType() == kObjectTypeString && src_tensor.data() != nullptr) {
shape.push_back(src_tensor.data()->size());
} else {
for (size_t j = 0; j < src_tensor.dims()->size(); j++) {
shape.push_back(src_tensor.dims()->data()[j]);
}
}
}
lite::Tensor *dst_tensor = nullptr;
if (TypeId(src_tensor.dataType()) == kObjectTypeTensorType) {
dst_tensor = new (std::nothrow) TensorList(shape, std::vector<int>());
} else {
dst_tensor = new (std::nothrow) Tensor(TypeId(src_tensor.dataType()), shape, src_tensor.format(), src_category);
}
if (dst_tensor == nullptr) {
return nullptr;
}
return dst_tensor;
}

int LiteSession::ConvertTensors(const lite::Model *model) {
MS_ASSERT(model != nullptr);
copyed_tensor_idxes_.clear();
@@ -126,24 +153,9 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
MS_LOG(ERROR) << i << "th tensor in model is nullptr";
return RET_NULL_PTR;
}
auto src_category = TensorCategory(src_tensor);
std::vector<int> shape;
if (src_tensor->dims() == nullptr) {
MS_LOG(DEBUG) << "Dims of " << i << "th tensor is nullptr";
}
if (src_tensor->dims() != nullptr && src_category == Tensor::Category::CONST_TENSOR) {
if (src_tensor->dataType() == kObjectTypeString && src_tensor->data() != nullptr) {
shape.push_back(src_tensor->data()->size());
} else {
for (size_t j = 0; j < src_tensor->dims()->size(); j++) {
shape.push_back(src_tensor->dims()->data()[j]);
}
}
}
auto *dst_tensor =
new (std::nothrow) Tensor(TypeId(src_tensor->dataType()), shape, src_tensor->format(), src_category);
auto *dst_tensor = ConvertTensor(*src_tensor);
if (dst_tensor == nullptr) {
MS_LOG(ERROR) << "new " << i << "th tensor failed";
MS_LOG(ERROR) << "Convert new " << i << "th tensor failed!";
return RET_NULL_PTR;
}
auto ret = ConvertTensorsData(model, i, src_tensor, dst_tensor);


+ 3
- 0
mindspore/lite/src/lite_session.h View File

@@ -30,6 +30,7 @@
#include "schema/model_generated.h"
#include "src/executor.h"
#include "src/tensor.h"
#include "src/tensorlist.h"
#if SUPPORT_GPU
#include "src/runtime/opencl/opencl_runtime.h"
#endif
@@ -71,6 +72,8 @@ class LiteSession : public session::LiteSession {
int ConvertTensorsData(const lite::Model *model, size_t tensor_index, const schema::Tensor *src_tensor,
lite::Tensor *dst_tensor);

lite::Tensor *ConvertTensor(const schema::Tensor &src_tensor);

int ConvertTensors(const lite::Model *model);

void InitGraphInOutTensors(const lite::Model *model);


Loading…
Cancel
Save