|
|
|
@@ -73,12 +73,27 @@ int ArithmeticCPUKernel::ReSize() { |
|
|
|
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); |
|
|
|
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); |
|
|
|
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); |
|
|
|
memcpy(arithmeticParameter_->in_shape0_, static_cast<void *>(in_tensors_[0]->shape().data()), |
|
|
|
in_tensors_[0]->shape().size() * sizeof(int)); |
|
|
|
memcpy(arithmeticParameter_->in_shape1_, static_cast<void *>(in_tensors_[1]->shape().data()), |
|
|
|
in_tensors_[1]->shape().size() * sizeof(int)); |
|
|
|
memcpy(arithmeticParameter_->out_shape_, static_cast<void *>(out_tensors_[0]->shape().data()), |
|
|
|
out_tensors_[0]->shape().size() * sizeof(int)); |
|
|
|
for (size_t i = 0; i < in_tensors_[0]->shape().size(); i++) { |
|
|
|
if (arithmeticParameter_->in_shape0_[i] == -1) { |
|
|
|
memcpy(arithmeticParameter_->in_shape0_, static_cast<void *>(in_tensors_[0]->shape().data()), |
|
|
|
in_tensors_[0]->shape().size() * sizeof(int)); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
for (size_t i = 0; i < in_tensors_[1]->shape().size(); i++) { |
|
|
|
if (arithmeticParameter_->in_shape1_[i] == -1) { |
|
|
|
memcpy(arithmeticParameter_->in_shape1_, static_cast<void *>(in_tensors_[1]->shape().data()), |
|
|
|
in_tensors_[1]->shape().size() * sizeof(int)); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
for (size_t i = 0; i < out_tensors_[0]->shape().size(); i++) { |
|
|
|
if (arithmeticParameter_->out_shape_[i] == -1) { |
|
|
|
memcpy(arithmeticParameter_->out_shape_, static_cast<void *>(out_tensors_[0]->shape().data()), |
|
|
|
out_tensors_[0]->shape().size() * sizeof(int)); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { |
|
|
|
switch (arithmeticParameter_->op_parameter_.type_) { |
|
|
|
|