From c525b8503681723eae1f469c8b835254c02d3670 Mon Sep 17 00:00:00 2001 From: lzk Date: Fri, 18 Dec 2020 01:38:17 -0800 Subject: [PATCH] stack modify --- mindspore/lite/src/ops/tensorlist_stack.cc | 7 ++- .../arm/fp32/tensorlist_setitem_fp32.cc | 2 +- .../kernel/arm/fp32/tensorlist_stack_fp32.cc | 47 +++++++++---------- 3 files changed, 26 insertions(+), 30 deletions(-) diff --git a/mindspore/lite/src/ops/tensorlist_stack.cc b/mindspore/lite/src/ops/tensorlist_stack.cc index 3c162e0c2a..5e91be6ee1 100644 --- a/mindspore/lite/src/ops/tensorlist_stack.cc +++ b/mindspore/lite/src/ops/tensorlist_stack.cc @@ -95,7 +95,7 @@ int TensorListStack::UnPackToFlatBuilder(const schema::Primitive *primitive, fla MS_LOG(ERROR) << "value_as_TensorListStack return nullptr"; return RET_ERROR; } - auto val_offset = schema::CreateTensorListStack(*fbb, attr->elementDType(), attr->numElements()); + auto val_offset = schema::CreateTensorListStack(*fbb, attr->numElements(), attr->elementDType()); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListStack, val_offset.o); fbb->Finish(prim_offset); return RET_OK; @@ -159,9 +159,8 @@ int TensorListStack::InferShape(std::vector inputs_, std::vector auto output = outputs_.front(); MS_ASSERT(output != nullptr); output->set_data_type(input0->tensors_data_type()); - output->set_shape(std::vector( - 1, - input0->ElementsNum() * std::accumulate(output_shape_.begin(), output_shape_.end(), 1LL, std::multiplies()))); + output_shape_.insert(output_shape_.begin(), input0->ElementsNum()); + output->set_shape(output_shape_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc index fcaa460f3e..7ce204a95d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.cc @@ -32,7 +32,7 @@ int TensorListSetItemCPUKernel::Init() { return RET_OK; } int TensorListSetItemCPUKernel::Run() { input0_ = reinterpret_cast(in_tensors_[0]); - if (dtype_ != input0_->data_type()) { + if (dtype_ != input0_->tensors_data_type()) { MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type(); return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc index ec04883bc9..1f7453616a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.cc @@ -42,6 +42,16 @@ int TensorListStackCPUKernel::CheckParam() { return RET_ERROR; } num_element_ = input0_->ElementsNum(); + if (output0_->shape().size() < 1) { + MS_LOG(ERROR) << "out_tensors_[0].shape().size():" << output0_->shape().size() + << " must be greater than or equal to 1!"; + return RET_ERROR; + } + int dim0 = output0_->shape()[0]; + if (dim0 != num_element_) { + MS_LOG(ERROR) << "out_tensors_[0].shape()[0] must be:" << num_element_ << ", but now is:" << dim0; + return RET_ERROR; + } return RET_OK; } @@ -50,16 +60,7 @@ int TensorListStackCPUKernel::Init() { MS_ASSERT(input0_ != nullptr); output0_ = out_tensors_[0]; MS_ASSERT(output0_ != nullptr); - if (output0_->shape().size() != 2) { - MS_LOG(ERROR) << "out_tensors_[0].shape().size():" << output0_->shape().size() << " must be equal to 2!"; - return RET_ERROR; - } - int dim0 = output0_->shape()[0]; - if (dim0 != 1) { // dim0 must be 1 - MS_LOG(ERROR) << "out_tensors_[0].shape()[0] must be 1, but now is:" << dim0; - return RET_ERROR; - } - return CheckParam(); + return RET_OK; } bool TensorListStackCPUKernel::IsFullyDefined(const std::vector &shape) const { @@ -129,26 +130,22 @@ int TensorListStackCPUKernel::MergeSubShape(const std::vector &shape) { } int TensorListStackCPUKernel::Run() { + if (CheckParam() != RET_OK) { + MS_LOG(ERROR) << "CheckParam failed!"; + return RET_ERROR; + } if (output0_->ElementsNum() == 0) { return RET_OK; } - size_t in_ele_num = 0; - for (int i = 0; i < num_element_; ++i) { - auto tensor = input0_->GetTensorIndex(i); - MS_ASSERT(tensor != nullptr); - if (tensor->data_type() == kTypeUnknown) { - if (TypeUnknownSize == 0) { - TypeUnknownSize = MergeElementShape(); - } - in_ele_num += TypeUnknownSize; - } else { - in_ele_num += std::accumulate(tensor->shape().begin(), tensor->shape().end(), 1LL, std::multiplies()); - } + auto ret = MergeElementShape(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "MergeElementShape failed!"; + return RET_ERROR; } + size_t in_ele_num = num_element_ * TypeUnknownSize; size_t out_ele_num = output0_->ElementsNum(); - if (in_ele_num > out_ele_num) { - MS_LOG(ERROR) << "out_tensors_[0]->ElementsNum():" << out_ele_num - << "must be greater than or equal to in_ele_num:" << in_ele_num; + if (in_ele_num != out_ele_num) { + MS_LOG(ERROR) << "out_tensors_[0]->ElementsNum():" << out_ele_num << "must be equal to in_ele_num:" << in_ele_num; return RET_ERROR; } auto out_ptr = reinterpret_cast(output0_->MutableData());