|
|
|
@@ -104,6 +104,10 @@ int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vect |
|
|
|
<< " must be \"kNumberTypeInt\":" << kNumberTypeInt; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (input0->data_c() == nullptr) { |
|
|
|
MS_LOG(ERROR) << "input0->data_c() is nullptr"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
auto ele_shape_ptr = reinterpret_cast<int *>(input0->data_c()); |
|
|
|
|
|
|
|
auto input1 = inputs_[1]; |
|
|
|
@@ -117,9 +121,13 @@ int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vect |
|
|
|
MS_LOG(ERROR) << "input1->ElementsNum() must be equal to 1"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (input1->data_c() == nullptr) { |
|
|
|
MS_LOG(ERROR) << "input1->data_c() is nullptr"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
int num_elements = reinterpret_cast<int *>(input1->data_c())[0]; |
|
|
|
|
|
|
|
auto output = reinterpret_cast<TensorList *>(outputs_[0]); |
|
|
|
MS_ASSERT(output != nullptr); |
|
|
|
output->set_data_type(kObjectTypeTensorType); |
|
|
|
std::vector<std::vector<int> > tmp_shape(num_elements, std::vector<int>()); |
|
|
|
output->set_element_shape(std::vector<int>(ele_shape_ptr, ele_shape_ptr + input0->ElementsNum())); |
|
|
|
|