From 2ad7c7747b74de4b1766fa84608089fc658fd45a Mon Sep 17 00:00:00 2001 From: lzk Date: Wed, 9 Dec 2020 01:15:49 -0800 Subject: [PATCH] tensorlist some modify --- mindspore/lite/src/ops/tensorlistfromtensor.cc | 4 ++++ mindspore/lite/src/ops/tensorlistgetitem.cc | 8 ++++++++ mindspore/lite/src/ops/tensorlistreserve.cc | 10 +++++++++- mindspore/lite/src/ops/tensorlistsetitem.cc | 4 ++++ mindspore/lite/src/ops/tensorliststack.cc | 4 ++++ 5 files changed, 29 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/src/ops/tensorlistfromtensor.cc b/mindspore/lite/src/ops/tensorlistfromtensor.cc index f3f70cf46a..490b975cbd 100644 --- a/mindspore/lite/src/ops/tensorlistfromtensor.cc +++ b/mindspore/lite/src/ops/tensorlistfromtensor.cc @@ -126,6 +126,10 @@ int TensorListFromTensor::InferShape(std::vector inputs_, std::v } auto input1 = inputs_[1]; MS_ASSERT(input1 != nullptr); + if (input1->data_c() == nullptr) { + MS_LOG(ERROR) << "input1->data_c() is nullptr"; + return RET_NULL_PTR; + } auto ele_shape_ptr = reinterpret_cast(input1->data_c()); auto output = reinterpret_cast(outputs_[0]); MS_ASSERT(output != nullptr); diff --git a/mindspore/lite/src/ops/tensorlistgetitem.cc b/mindspore/lite/src/ops/tensorlistgetitem.cc index 129b38bf4d..1f68c49975 100644 --- a/mindspore/lite/src/ops/tensorlistgetitem.cc +++ b/mindspore/lite/src/ops/tensorlistgetitem.cc @@ -124,6 +124,10 @@ int TensorListGetItem::InferShape(std::vector inputs_, std::vect MS_LOG(ERROR) << "get_index->ElementsNum():" << get_index->ElementsNum() << " must be equal to 1!"; return RET_ERROR; } + if (get_index->data_c() == nullptr) { + MS_LOG(ERROR) << "get_index->data_c() is nullptr"; + return RET_NULL_PTR; + } index_ = reinterpret_cast(get_index->data_c())[0]; if (index_ < 0 || index_ > (input0->ElementsNum() - 1)) { MS_LOG(ERROR) << "index_:" << index_ << "must in [0, " << input0->ElementsNum() - 1 << "]"; @@ -138,6 +142,10 @@ int TensorListGetItem::InferShape(std::vector inputs_, std::vect output->set_shape(tensor_index->shape()); } else { auto input2 = inputs_[2]; + if (input2->data_c() == nullptr) { + MS_LOG(ERROR) << "input2->data_c() is nullptr"; + return RET_NULL_PTR; + } auto ele_shape_data = reinterpret_cast(input2->data_c()); for (int i = 0; i < input2->ElementsNum(); ++i) { element_shape_.push_back(ele_shape_data[i]); diff --git a/mindspore/lite/src/ops/tensorlistreserve.cc b/mindspore/lite/src/ops/tensorlistreserve.cc index 63b5ff44e9..88c77254ae 100644 --- a/mindspore/lite/src/ops/tensorlistreserve.cc +++ b/mindspore/lite/src/ops/tensorlistreserve.cc @@ -104,6 +104,10 @@ int TensorListReserve::InferShape(std::vector inputs_, std::vect << " must be \"kNumberTypeInt\":" << kNumberTypeInt; return RET_ERROR; } + if (input0->data_c() == nullptr) { + MS_LOG(ERROR) << "input0->data_c() is nullptr"; + return RET_NULL_PTR; + } auto ele_shape_ptr = reinterpret_cast(input0->data_c()); auto input1 = inputs_[1]; @@ -117,9 +121,13 @@ int TensorListReserve::InferShape(std::vector inputs_, std::vect MS_LOG(ERROR) << "input1->ElementsNum() must be equal to 1"; return RET_ERROR; } + if (input1->data_c() == nullptr) { + MS_LOG(ERROR) << "input1->data_c() is nullptr"; + return RET_NULL_PTR; + } int num_elements = reinterpret_cast(input1->data_c())[0]; - auto output = reinterpret_cast(outputs_[0]); + MS_ASSERT(output != nullptr); output->set_data_type(kObjectTypeTensorType); std::vector > tmp_shape(num_elements, std::vector()); output->set_element_shape(std::vector(ele_shape_ptr, ele_shape_ptr + input0->ElementsNum())); diff --git a/mindspore/lite/src/ops/tensorlistsetitem.cc b/mindspore/lite/src/ops/tensorlistsetitem.cc index 24f492b9df..3c256c661d 100644 --- a/mindspore/lite/src/ops/tensorlistsetitem.cc +++ b/mindspore/lite/src/ops/tensorlistsetitem.cc @@ -106,6 +106,10 @@ int TensorListSetItem::InferShape(std::vector inputs_, std::vect MS_LOG(ERROR) << "inputs_[1].ElementsNum():" << get_index->ElementsNum() << " must be equal to 1!"; return RET_ERROR; } + if (get_index->data_c() == nullptr) { + MS_LOG(ERROR) << "get_index->data_c() is nullptr"; + return RET_NULL_PTR; + } int index = reinterpret_cast(get_index->data_c())[0]; if (index < 0 || index > (input0->ElementsNum() - 1)) { MS_LOG(ERROR) << "index_:" << index << "must in [0, " << input0->ElementsNum() - 1 << "]"; diff --git a/mindspore/lite/src/ops/tensorliststack.cc b/mindspore/lite/src/ops/tensorliststack.cc index db6a3e91e8..00d564f886 100644 --- a/mindspore/lite/src/ops/tensorliststack.cc +++ b/mindspore/lite/src/ops/tensorliststack.cc @@ -125,6 +125,10 @@ int TensorListStack::InferShape(std::vector inputs_, std::vector } auto ele_shape = inputs_[1]; // element shape MS_ASSERT(ele_shape != nullptr); + if (ele_shape->data_c() == nullptr) { + MS_LOG(ERROR) << "ele_shape->data_c() is nullptr"; + return RET_NULL_PTR; + } auto ele_shape_ptr = reinterpret_cast(ele_shape->data_c()); for (int i = 0; ele_shape->ElementsNum(); ++i) { output_shape_.push_back(ele_shape_ptr[i]);