|
|
|
@@ -30,6 +30,22 @@ namespace mindspore::kernel { |
|
|
|
|
|
|
|
int TensorListSetItemCPUKernel::Init() { return RET_OK; } |
|
|
|
|
|
|
|
int TensorListSetItemCPUKernel::CheckParam() { |
|
|
|
if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) { |
|
|
|
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) { |
|
|
|
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (in_tensors_[1]->ElementsNum() != 1) { |
|
|
|
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) { |
|
|
|
output0_ = reinterpret_cast<lite::TensorList *>(out_tensors_[0]); |
|
|
|
int new_tensors_size = origin_size + 1; |
|
|
|
@@ -46,19 +62,13 @@ int TensorListSetItemCPUKernel::IncrementOutputSize(int origin_size) { |
|
|
|
|
|
|
|
int TensorListSetItemCPUKernel::Run() { |
|
|
|
input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]); |
|
|
|
if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) { |
|
|
|
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type(); |
|
|
|
|
|
|
|
if (CheckParam() != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "check param failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
int dim0 = input0_->ElementsNum() - 1; |
|
|
|
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) { |
|
|
|
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (in_tensors_[1]->ElementsNum() != 1) { |
|
|
|
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
index_ = reinterpret_cast<int *>(in_tensors_[1]->data_c())[0]; |
|
|
|
if (index_ < 0 || index_ > dim0) { |
|
|
|
if (IncrementOutputSize(output0_->shape()[0]) != RET_OK) { |
|
|
|
@@ -81,6 +91,10 @@ int TensorListSetItemCPUKernel::Run() { |
|
|
|
} |
|
|
|
} |
|
|
|
// copy each tensor in tensors_ |
|
|
|
if (input0_->tensors().empty() && index_ == 0) { |
|
|
|
input0_->set_element_shape(input2_->shape()); |
|
|
|
output0_->set_element_shape(input2_->shape()); |
|
|
|
} |
|
|
|
for (int i = 0; i < output0_->ElementsNum(); ++i) { |
|
|
|
if (i == index_) { |
|
|
|
auto dst = output0_->GetTensor(i); |
|
|
|
|