Browse Source

!9728 [MS][Lite][cpu] tensorlist modify

From: @lzkcode
Reviewed-by: @zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
fa286fe787
5 changed files with 29 additions and 1 deletions
  1. +4
    -0
      mindspore/lite/src/ops/tensorlistfromtensor.cc
  2. +8
    -0
      mindspore/lite/src/ops/tensorlistgetitem.cc
  3. +9
    -1
      mindspore/lite/src/ops/tensorlistreserve.cc
  4. +4
    -0
      mindspore/lite/src/ops/tensorlistsetitem.cc
  5. +4
    -0
      mindspore/lite/src/ops/tensorliststack.cc

+ 4
- 0
mindspore/lite/src/ops/tensorlistfromtensor.cc View File

@@ -126,6 +126,10 @@ int TensorListFromTensor::InferShape(std::vector<lite::Tensor *> inputs_, std::v
}
auto input1 = inputs_[1];
MS_ASSERT(input1 != nullptr);
if (input1->data_c() == nullptr) {
MS_LOG(ERROR) << "input1->data_c() is nullptr";
return RET_NULL_PTR;
}
auto ele_shape_ptr = reinterpret_cast<int *>(input1->data_c());
auto output = reinterpret_cast<TensorList *>(outputs_[0]);
MS_ASSERT(output != nullptr);


+ 8
- 0
mindspore/lite/src/ops/tensorlistgetitem.cc View File

@@ -124,6 +124,10 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
MS_LOG(ERROR) << "get_index->ElementsNum():" << get_index->ElementsNum() << " must be equal to 1!";
return RET_ERROR;
}
if (get_index->data_c() == nullptr) {
MS_LOG(ERROR) << "get_index->data_c() is nullptr";
return RET_NULL_PTR;
}
index_ = reinterpret_cast<int *>(get_index->data_c())[0];
if (index_ < 0 || index_ > (input0->ElementsNum() - 1)) {
MS_LOG(ERROR) << "index_:" << index_ << "must in [0, " << input0->ElementsNum() - 1 << "]";
@@ -138,6 +142,10 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
output->set_shape(tensor_index->shape());
} else {
auto input2 = inputs_[2];
if (input2->data_c() == nullptr) {
MS_LOG(ERROR) << "input2->data_c() is nullptr";
return RET_NULL_PTR;
}
auto ele_shape_data = reinterpret_cast<int *>(input2->data_c());
for (int i = 0; i < input2->ElementsNum(); ++i) {
element_shape_.push_back(ele_shape_data[i]);


+ 9
- 1
mindspore/lite/src/ops/tensorlistreserve.cc View File

@@ -104,6 +104,10 @@ int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
<< " must be \"kNumberTypeInt\":" << kNumberTypeInt;
return RET_ERROR;
}
if (input0->data_c() == nullptr) {
MS_LOG(ERROR) << "input0->data_c() is nullptr";
return RET_NULL_PTR;
}
auto ele_shape_ptr = reinterpret_cast<int *>(input0->data_c());

auto input1 = inputs_[1];
@@ -117,9 +121,13 @@ int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
MS_LOG(ERROR) << "input1->ElementsNum() must be equal to 1";
return RET_ERROR;
}
if (input1->data_c() == nullptr) {
MS_LOG(ERROR) << "input1->data_c() is nullptr";
return RET_NULL_PTR;
}
int num_elements = reinterpret_cast<int *>(input1->data_c())[0];

auto output = reinterpret_cast<TensorList *>(outputs_[0]);
MS_ASSERT(output != nullptr);
output->set_data_type(kObjectTypeTensorType);
std::vector<std::vector<int> > tmp_shape(num_elements, std::vector<int>());
output->set_element_shape(std::vector<int>(ele_shape_ptr, ele_shape_ptr + input0->ElementsNum()));


+ 4
- 0
mindspore/lite/src/ops/tensorlistsetitem.cc View File

@@ -106,6 +106,10 @@ int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
MS_LOG(ERROR) << "inputs_[1].ElementsNum():" << get_index->ElementsNum() << " must be equal to 1!";
return RET_ERROR;
}
if (get_index->data_c() == nullptr) {
MS_LOG(ERROR) << "get_index->data_c() is nullptr";
return RET_NULL_PTR;
}
int index = reinterpret_cast<int *>(get_index->data_c())[0];
if (index < 0 || index > (input0->ElementsNum() - 1)) {
MS_LOG(ERROR) << "index_:" << index << "must in [0, " << input0->ElementsNum() - 1 << "]";


+ 4
- 0
mindspore/lite/src/ops/tensorliststack.cc View File

@@ -125,6 +125,10 @@ int TensorListStack::InferShape(std::vector<lite::Tensor *> inputs_, std::vector
}
auto ele_shape = inputs_[1]; // element shape
MS_ASSERT(ele_shape != nullptr);
if (ele_shape->data_c() == nullptr) {
MS_LOG(ERROR) << "ele_shape->data_c() is nullptr";
return RET_NULL_PTR;
}
auto ele_shape_ptr = reinterpret_cast<int *>(ele_shape->data_c());
for (int i = 0; ele_shape->ElementsNum(); ++i) {
output_shape_.push_back(ele_shape_ptr[i]);


Loading…
Cancel
Save