|
|
@@ -42,6 +42,16 @@ int TensorListStackCPUKernel::CheckParam() { |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
num_element_ = input0_->ElementsNum(); |
|
|
num_element_ = input0_->ElementsNum(); |
|
|
|
|
|
if (output0_->shape().size() < 1) { |
|
|
|
|
|
MS_LOG(ERROR) << "out_tensors_[0].shape().size():" << output0_->shape().size() |
|
|
|
|
|
<< " must be greater than or equal to 1!"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
int dim0 = output0_->shape()[0]; |
|
|
|
|
|
if (dim0 != num_element_) { |
|
|
|
|
|
MS_LOG(ERROR) << "out_tensors_[0].shape()[0] must be:" << num_element_ << ", but now is:" << dim0; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -50,16 +60,7 @@ int TensorListStackCPUKernel::Init() { |
|
|
MS_ASSERT(input0_ != nullptr); |
|
|
MS_ASSERT(input0_ != nullptr); |
|
|
output0_ = out_tensors_[0]; |
|
|
output0_ = out_tensors_[0]; |
|
|
MS_ASSERT(output0_ != nullptr); |
|
|
MS_ASSERT(output0_ != nullptr); |
|
|
if (output0_->shape().size() != 2) { |
|
|
|
|
|
MS_LOG(ERROR) << "out_tensors_[0].shape().size():" << output0_->shape().size() << " must be equal to 2!"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
int dim0 = output0_->shape()[0]; |
|
|
|
|
|
if (dim0 != 1) { // dim0 must be 1 |
|
|
|
|
|
MS_LOG(ERROR) << "out_tensors_[0].shape()[0] must be 1, but now is:" << dim0; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
return CheckParam(); |
|
|
|
|
|
|
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool TensorListStackCPUKernel::IsFullyDefined(const std::vector<int> &shape) const { |
|
|
bool TensorListStackCPUKernel::IsFullyDefined(const std::vector<int> &shape) const { |
|
|
@@ -129,26 +130,22 @@ int TensorListStackCPUKernel::MergeSubShape(const std::vector<int> &shape) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int TensorListStackCPUKernel::Run() { |
|
|
int TensorListStackCPUKernel::Run() { |
|
|
|
|
|
if (CheckParam() != RET_OK) { |
|
|
|
|
|
MS_LOG(ERROR) << "CheckParam failed!"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
if (output0_->ElementsNum() == 0) { |
|
|
if (output0_->ElementsNum() == 0) { |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
size_t in_ele_num = 0; |
|
|
|
|
|
for (int i = 0; i < num_element_; ++i) { |
|
|
|
|
|
auto tensor = input0_->GetTensorIndex(i); |
|
|
|
|
|
MS_ASSERT(tensor != nullptr); |
|
|
|
|
|
if (tensor->data_type() == kTypeUnknown) { |
|
|
|
|
|
if (TypeUnknownSize == 0) { |
|
|
|
|
|
TypeUnknownSize = MergeElementShape(); |
|
|
|
|
|
} |
|
|
|
|
|
in_ele_num += TypeUnknownSize; |
|
|
|
|
|
} else { |
|
|
|
|
|
in_ele_num += std::accumulate(tensor->shape().begin(), tensor->shape().end(), 1LL, std::multiplies<int>()); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
auto ret = MergeElementShape(); |
|
|
|
|
|
if (ret != RET_OK) { |
|
|
|
|
|
MS_LOG(ERROR) << "MergeElementShape failed!"; |
|
|
|
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
|
|
|
size_t in_ele_num = num_element_ * TypeUnknownSize; |
|
|
size_t out_ele_num = output0_->ElementsNum(); |
|
|
size_t out_ele_num = output0_->ElementsNum(); |
|
|
if (in_ele_num > out_ele_num) { |
|
|
|
|
|
MS_LOG(ERROR) << "out_tensors_[0]->ElementsNum():" << out_ele_num |
|
|
|
|
|
<< "must be greater than or equal to in_ele_num:" << in_ele_num; |
|
|
|
|
|
|
|
|
if (in_ele_num != out_ele_num) { |
|
|
|
|
|
MS_LOG(ERROR) << "out_tensors_[0]->ElementsNum():" << out_ele_num << "must be equal to in_ele_num:" << in_ele_num; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
auto out_ptr = reinterpret_cast<float *>(output0_->MutableData()); |
|
|
auto out_ptr = reinterpret_cast<float *>(output0_->MutableData()); |
|
|
|