|
|
@@ -33,7 +33,6 @@ int ResizeCPUKernel::Init() { |
|
|
if (ret != RET_OK) { |
|
|
if (ret != RET_OK) { |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
thread_num_ = context_->thread_num_; |
|
|
|
|
|
if (!InferShapeDone()) { |
|
|
if (!InferShapeDone()) { |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
@@ -43,7 +42,6 @@ int ResizeCPUKernel::Init() { |
|
|
int ResizeCPUKernel::ReSize() { |
|
|
int ResizeCPUKernel::ReSize() { |
|
|
int ret = RET_OK; |
|
|
int ret = RET_OK; |
|
|
if (method_ == static_cast<int>(schema::ResizeMethod_LINEAR)) { |
|
|
if (method_ == static_cast<int>(schema::ResizeMethod_LINEAR)) { |
|
|
thread_num_ = 1; |
|
|
|
|
|
FreeTmpBuffer(); |
|
|
FreeTmpBuffer(); |
|
|
ret = MallocTmpBuffer(); |
|
|
ret = MallocTmpBuffer(); |
|
|
if (ret != RET_OK) { |
|
|
if (ret != RET_OK) { |
|
|
@@ -97,7 +95,7 @@ int ResizeCPUKernel::MallocTmpBuffer() { |
|
|
MS_LOG(ERROR) << "malloc data failed"; |
|
|
MS_LOG(ERROR) << "malloc data failed"; |
|
|
return RET_NULL_PTR; |
|
|
return RET_NULL_PTR; |
|
|
} |
|
|
} |
|
|
line_buffer_ = reinterpret_cast<float *>(malloc(sizeof(float) * w * c * 2 * thread_num_)); |
|
|
|
|
|
|
|
|
line_buffer_ = reinterpret_cast<float *>(malloc(sizeof(float) * w * c * 2 * context_->thread_num_)); |
|
|
if (line_buffer_ == nullptr) { |
|
|
if (line_buffer_ == nullptr) { |
|
|
MS_LOG(ERROR) << "malloc data failed"; |
|
|
MS_LOG(ERROR) << "malloc data failed"; |
|
|
return RET_NULL_PTR; |
|
|
return RET_NULL_PTR; |
|
|
@@ -168,14 +166,14 @@ int ResizeCPUKernel::RunImpl(int task_id) { |
|
|
int n_h_begin, n_h_end; |
|
|
int n_h_begin, n_h_end; |
|
|
int n = out_tensors_.at(0)->shape()[0]; |
|
|
int n = out_tensors_.at(0)->shape()[0]; |
|
|
int h = new_height_; |
|
|
int h = new_height_; |
|
|
int unit = UP_DIV(n * h, thread_num_); |
|
|
|
|
|
|
|
|
int unit = UP_DIV(n * h, context_->thread_num_); |
|
|
n_h_begin = unit * task_id; |
|
|
n_h_begin = unit * task_id; |
|
|
n_h_end = std::min(n_h_begin + unit, n * h); |
|
|
n_h_end = std::min(n_h_begin + unit, n * h); |
|
|
int c = in_tensors_.at(0)->shape()[3]; |
|
|
int c = in_tensors_.at(0)->shape()[3]; |
|
|
line0_ = line_buffer_ + new_width_ * c * 2 * task_id; |
|
|
|
|
|
line1_ = line0_ + new_width_ * c; |
|
|
|
|
|
|
|
|
float *line0 = line_buffer_ + new_width_ * c * 2 * task_id; |
|
|
|
|
|
float *line1 = line0 + new_width_ * c; |
|
|
ret = ResizeBilinear2(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(), y_bottoms_, |
|
|
ret = ResizeBilinear2(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(), y_bottoms_, |
|
|
y_tops_, x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_, line0_, line1_, n_h_begin, |
|
|
|
|
|
|
|
|
y_tops_, x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_, line0, line1, n_h_begin, |
|
|
n_h_end); |
|
|
n_h_end); |
|
|
|
|
|
|
|
|
break; |
|
|
break; |
|
|
@@ -193,7 +191,7 @@ int ResizeCPUKernel::RunImpl(int task_id) { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
ret = ResizeNearestNeighbor(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(), |
|
|
ret = ResizeNearestNeighbor(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(), |
|
|
align_corners_, task_id, thread_num_); |
|
|
|
|
|
|
|
|
align_corners_, task_id, context_->thread_num_); |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
case schema::ResizeMethod_UNKNOW: |
|
|
case schema::ResizeMethod_UNKNOW: |
|
|
@@ -206,7 +204,7 @@ int ResizeCPUKernel::RunImpl(int task_id) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int ResizeCPUKernel::Run() { |
|
|
int ResizeCPUKernel::Run() { |
|
|
int error_code = ParallelLaunch(this->context_->thread_pool_, ResizeImpl, this, thread_num_); |
|
|
|
|
|
|
|
|
int error_code = ParallelLaunch(this->context_->thread_pool_, ResizeImpl, this, context_->thread_num_); |
|
|
if (error_code != RET_OK) { |
|
|
if (error_code != RET_OK) { |
|
|
MS_LOG(ERROR) << "Resize run error, error_code[" << error_code << "]"; |
|
|
MS_LOG(ERROR) << "Resize run error, error_code[" << error_code << "]"; |
|
|
FreeTmpBuffer(); |
|
|
FreeTmpBuffer(); |
|
|
|