From: @yangruoqi713 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -0,0 +1,24 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp32/invert_permutation_fp32.h" | |||
| inline void InvertPermutation(const int *input, int *output, int num) { | |||
| for (int i = 0; i < num; i++) { | |||
| int index = input[i]; | |||
| output[index] = i; | |||
| } | |||
| } | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_INVERT_PERMUTATION_FP32_H_ | |||
| #define MINDSPORE_LITE_NNACL_INVERT_PERMUTATION_FP32_H_ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void InvertPermutation(const int *input, int *output, int num); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_INVERT_PERMUTATION_FP32_H_ | |||
| @@ -18,6 +18,13 @@ | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/errorcode.h" | |||
| void CalculateCoordinate(float out, int in, int *bottom, int *top, float *bottom_weight) { | |||
| *bottom = (int)(floor(out)); | |||
| *top = *bottom + 1 < in ? (*bottom + 1) : (in - 1); | |||
| float top_weight = (float)out - (float)(*bottom); | |||
| *bottom_weight = 1.0f - top_weight; | |||
| } | |||
| int PrepareResizeBilinear(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, | |||
| int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights, | |||
| float *x_left_weights) { | |||
| @@ -32,28 +39,55 @@ int PrepareResizeBilinear(const int *input_shape, const int *output_shape, Calcu | |||
| int new_height = output_shape[1]; | |||
| int new_width = output_shape[2]; | |||
| int h, w; | |||
| for (h = 0; h < new_height; h++) { | |||
| for (int h = 0; h < new_height; h++) { | |||
| float actual_y = calculate(h, in_h, new_height); | |||
| int y_bottom = (int)(floor(actual_y)); | |||
| int y_top = y_bottom + 1 < in_h ? (y_bottom + 1) : (in_h - 1); | |||
| float y_top_weight = actual_y - (float)(y_bottom); | |||
| const float y_bottom_weight = 1.0f - y_top_weight; | |||
| y_bottoms[h] = y_bottom; | |||
| y_tops[h] = y_top; | |||
| y_bottom_weights[h] = y_bottom_weight; | |||
| CalculateCoordinate(actual_y, in_h, y_bottoms + h, y_tops + h, y_bottom_weights + h); | |||
| } | |||
| for (w = 0; w < new_width; w++) { | |||
| for (int w = 0; w < new_width; w++) { | |||
| float actual_x = calculate(w, in_w, new_width); | |||
| int x_left = (int)(floor(actual_x)); | |||
| int x_right = x_left + 1 < in_w ? (x_left + 1) : (in_w - 1); | |||
| float x_right_weight = actual_x - (float)(x_left); | |||
| const float x_left_weight = 1.0f - x_right_weight; | |||
| x_lefts[w] = x_left; | |||
| x_rights[w] = x_right; | |||
| x_left_weights[w] = x_left_weight; | |||
| CalculateCoordinate(actual_x, in_w, x_lefts + w, x_rights + w, x_left_weights + w); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int PrepareCropAndResizeBilinear(const int *input_shape, const float *boxes, const int *box_idx, | |||
| const int *output_shape, int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, | |||
| float *y_bottom_weights, float *x_left_weights) { | |||
| if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || | |||
| x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { | |||
| return NNACL_NULL_PTR; | |||
| } | |||
| int in_b = input_shape[0]; | |||
| int in_h = input_shape[1]; | |||
| int in_w = input_shape[2]; | |||
| int new_height = output_shape[1]; | |||
| int new_width = output_shape[2]; | |||
| for (int i = 0; i < in_b; i++) { | |||
| int b = box_idx[i]; | |||
| const float *box = boxes + b * 4; | |||
| int start_h = box[0] * (in_h - 1); | |||
| int end_h = box[2] * (in_h - 1); | |||
| int start_w = box[1] * (in_w - 1); | |||
| int end_w = box[3] * (in_w - 1); | |||
| if (start_h >= end_h || start_w >= end_w || end_h >= in_h || end_w >= in_w) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| int *y_bottom = y_bottoms + b * new_height; | |||
| int *y_top = y_tops + b * new_height; | |||
| float *y_bottom_weight = y_bottom_weights + b * new_height; | |||
| int *x_left = x_lefts + b * new_width; | |||
| int *x_right = x_rights + b * new_width; | |||
| float *x_left_weight = x_left_weights + b * new_width; | |||
| for (int h = 0; h < new_height; h++) { | |||
| float actual_y = start_h * (in_h - 1) + h * (end_h - start_h) * (in_h - 1) / (new_height - 1); | |||
| CalculateCoordinate(actual_y, in_h, y_bottom + h, y_top + h, y_bottom_weight + h); | |||
| } | |||
| for (int w = 0; w < new_width; w++) { | |||
| float actual_x = start_w * (in_w - 1) + w * (end_w - start_w) * (in_w - 1) / (new_width - 1); | |||
| CalculateCoordinate(actual_x, in_w, x_left + w, x_right + w, x_left_weight + w); | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -114,82 +148,92 @@ int InterpCol(const float *bottom_line, const float *top_line, float *output, in | |||
| int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, | |||
| const int *y_bottoms, const int *y_tops, const int *x_lefts, const int *x_rights, | |||
| const float *y_bottom_weights, const float *x_left_weights, float *line0, float *line1, | |||
| const int n_h_begin, const int n_h_end) { | |||
| const int h_begin, const int h_end, bool is_crop) { | |||
| if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_bottoms == NULL || | |||
| y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { | |||
| return NNACL_NULL_PTR; | |||
| } | |||
| int in_b = input_shape[0]; | |||
| int in_h = input_shape[1]; | |||
| int in_w = input_shape[2]; | |||
| int in_c = input_shape[3]; | |||
| int new_height = output_shape[1]; | |||
| int new_width = output_shape[2]; | |||
| int h_stride = new_width * in_c; | |||
| int n_h; | |||
| int n_h_stride = new_width * in_c; | |||
| bool cache_line_used[2] = {false, false}; | |||
| int cache_line_num[2] = {-1, -1}; | |||
| float *const cache_line_ptr[2] = {line0, line1}; | |||
| float *current_line_ptr[2] = {line0, line1}; | |||
| int current_line_num[2] = {-1, -1}; | |||
| for (n_h = n_h_begin; n_h < n_h_end; n_h++) { | |||
| int n, h; | |||
| n = n_h / new_height; | |||
| h = n_h % new_height; | |||
| current_line_num[0] = n * in_h + y_bottoms[h]; | |||
| current_line_num[1] = n * in_h + y_tops[h]; | |||
| int i; | |||
| for (i = 0; i < 2; i++) { | |||
| cache_line_used[i] = false; | |||
| const int *y_bottom = y_bottoms; | |||
| const int *y_top = y_tops; | |||
| const float *y_bottom_weight = y_bottom_weights; | |||
| const int *x_left = x_lefts; | |||
| const int *x_right = x_rights; | |||
| const float *x_left_weight = x_left_weights; | |||
| for (int b = 0; b < in_b; b++) { | |||
| if (is_crop) { | |||
| y_bottom = y_bottoms + b * new_height; | |||
| y_top = y_tops + b * new_height; | |||
| y_bottom_weight = y_bottom_weights + b * new_height; | |||
| x_left = x_lefts + b * new_width; | |||
| x_right = x_rights + b * new_width; | |||
| x_left_weight = x_left_weights + b * new_width; | |||
| } | |||
| // search if we cached | |||
| int j, k; | |||
| for (j = 0; j < 2; j++) { | |||
| bool find = false; | |||
| for (k = 0; k < 2; k++) { | |||
| if (current_line_num[j] == cache_line_num[k]) { | |||
| cache_line_used[k] = true; | |||
| current_line_ptr[j] = cache_line_ptr[k]; | |||
| find = true; | |||
| break; | |||
| } | |||
| } | |||
| const float *input = input_data + b * in_h * in_w * in_c; | |||
| float *output = output_data + b * new_height * new_width * in_c; | |||
| bool cache_line_used[2] = {false, false}; | |||
| int cache_line_num[2] = {-1, -1}; | |||
| float *const cache_line_ptr[2] = {line0, line1}; | |||
| float *current_line_ptr[2] = {line0, line1}; | |||
| int current_line_num[2] = {-1, -1}; | |||
| for (int h = h_begin; h < h_end; h++) { | |||
| current_line_num[0] = y_bottom[h]; | |||
| current_line_num[1] = y_top[h]; | |||
| if (!find) { | |||
| const float *line = input_data + current_line_num[j] * in_w * in_c; | |||
| for (k = 0; k < 2; k++) { | |||
| if (!cache_line_used[k]) { | |||
| cache_line_num[k] = current_line_num[j]; | |||
| for (int i = 0; i < 2; i++) { | |||
| cache_line_used[i] = false; | |||
| } | |||
| // search if we cached | |||
| for (int j = 0; j < 2; j++) { | |||
| bool find = false; | |||
| for (int k = 0; k < 2; k++) { | |||
| if (current_line_num[j] == cache_line_num[k]) { | |||
| cache_line_used[k] = true; | |||
| current_line_ptr[j] = cache_line_ptr[k]; | |||
| InterpRow(line, current_line_ptr[j], new_width, x_left_weights, x_lefts, x_rights, in_c); | |||
| find = true; | |||
| break; | |||
| } | |||
| } | |||
| if (!find) { | |||
| const float *line = input + current_line_num[j] * in_w * in_c; | |||
| for (int k = 0; k < 2; k++) { | |||
| if (!cache_line_used[k]) { | |||
| cache_line_num[k] = current_line_num[j]; | |||
| cache_line_used[k] = true; | |||
| current_line_ptr[j] = cache_line_ptr[k]; | |||
| InterpRow(line, current_line_ptr[j], new_width, x_left_weight, x_left, x_right, in_c); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // do col interp | |||
| InterpCol(current_line_ptr[0], current_line_ptr[1], output_data + n_h * n_h_stride, new_width, y_bottom_weights[h], | |||
| in_c); | |||
| // do col interp | |||
| InterpCol(current_line_ptr[0], current_line_ptr[1], output + h * h_stride, new_width, y_bottom_weight[h], in_c); | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int ResizeNearestNeighbor(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, | |||
| CalculateOriginalCoordinate calculate, int coordinate_transform_mode, int tid, | |||
| int thread_num) { | |||
| int batch, y, x, c; | |||
| c = input_shape[3]; | |||
| int c = input_shape[3]; | |||
| bool align_corners = coordinate_transform_mode == 1; | |||
| for (batch = 0; batch < output_shape[0]; batch++) { | |||
| for (y = tid; y < output_shape[1]; y += thread_num) { | |||
| for (int batch = 0; batch < output_shape[0]; batch++) { | |||
| for (int y = tid; y < output_shape[1]; y += thread_num) { | |||
| float actual_y = calculate(y, input_shape[1], output_shape[1]); | |||
| int input_y; | |||
| if (align_corners) { | |||
| @@ -197,7 +241,7 @@ int ResizeNearestNeighbor(const float *input_data, float *output_data, const int | |||
| } else { | |||
| input_y = (int)(floor(actual_y)); | |||
| } | |||
| for (x = 0; x < output_shape[2]; x++) { | |||
| for (int x = 0; x < output_shape[2]; x++) { | |||
| float actual_x = calculate(x, input_shape[2], output_shape[2]); | |||
| int input_x; | |||
| if (align_corners) { | |||
| @@ -211,7 +255,6 @@ int ResizeNearestNeighbor(const float *input_data, float *output_data, const int | |||
| } | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -31,10 +31,14 @@ int PrepareResizeBilinear(const int *input_shape, const int *output_shape, Calcu | |||
| int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights, | |||
| float *x_left_weights); | |||
| int PrepareCropAndResizeBilinear(const int *input_shape, const float *boxes, const int *box_idx, | |||
| const int *output_shape, int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, | |||
| float *y_bottom_weights, float *x_left_weights); | |||
| int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, | |||
| const int *y_bottoms, const int *y_tops, const int *x_lefts, const int *x_rights, | |||
| const float *y_bottom_weights, const float *x_left_weights, float *line0, float *line1, | |||
| const int n_h_begin, const int n_h_end); | |||
| const int h_begin, const int h_end, bool is_crop); | |||
| int ResizeNearestNeighbor(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, | |||
| CalculateOriginalCoordinate calculate, int coordinate_transform_mode, int tid, | |||
| @@ -51,6 +51,14 @@ int InvertPermutation::InferShape(std::vector<lite::Tensor *> inputs_, std::vect | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| if (input->data_type() != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "InvertPermutation does not support input of data type: " << input->data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| if (input->shape().size() != 1) { | |||
| MS_LOG(ERROR) << "InvertPermutation input must be one-dimensional."; | |||
| return RET_ERROR; | |||
| } | |||
| output->set_shape(input->shape()); | |||
| return RET_OK; | |||
| } | |||
| @@ -56,7 +56,7 @@ int Size::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tens | |||
| return RET_INFER_INVALID; | |||
| } | |||
| std::vector<int> out_shape; | |||
| out_shape.push_back(static_cast<int>(in_tensor->shape().size())); | |||
| out_shape.push_back(1); | |||
| out_tensor->set_shape(out_shape); | |||
| return RET_OK; | |||
| } | |||
| @@ -38,8 +38,8 @@ class ResizeBaseCPUKernel : public LiteKernel { | |||
| protected: | |||
| int method_ = 0; | |||
| int64_t new_height_ = 0; | |||
| int64_t new_width_ = 0; | |||
| int new_height_ = 0; | |||
| int new_width_ = 0; | |||
| int coordinate_transform_mode_; | |||
| bool preserve_aspect_ratio_ = false; | |||
| bool const_shape_ = false; | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32/invert_permutation_fp32.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "schema/model_generated.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_InvertPermutation; | |||
| namespace mindspore::kernel { | |||
| int InvertPermutationCPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int InvertPermutationCPUKernel::ReSize() { | |||
| if (in_tensors_[0]->data_type() != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "InvertPermutation does not support input of data type: " << in_tensors_[0]->data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| if (in_tensors_[0]->shape().size() != 1) { | |||
| MS_LOG(ERROR) << "InvertPermutation input must be one-dimensional."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int InvertPermutationCPUKernel::Run() { | |||
| auto in_tensor = in_tensors_.front(); | |||
| auto out_tensor = out_tensors_.front(); | |||
| if (in_tensor == nullptr || out_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "null pointer dereferencing."; | |||
| return RET_ERROR; | |||
| } | |||
| auto input_ptr = reinterpret_cast<int32_t *>(in_tensor->data_c()); | |||
| auto output_ptr = reinterpret_cast<int32_t *>(out_tensor->data_c()); | |||
| if (input_ptr == nullptr || output_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "null pointer dereferencing."; | |||
| return RET_ERROR; | |||
| } | |||
| InvertPermutation(input_ptr, output_ptr, in_tensors_[0]->ElementsNum()); | |||
| return RET_OK; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_InvertPermutation, LiteKernelCreator<InvertPermutationCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_InvertPermutation, LiteKernelCreator<InvertPermutationCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_INVERT_PERMUTATION_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_INVERT_PERMUTATION_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/fp32/invert_permutation_fp32.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| using mindspore::lite::InnerContext; | |||
| namespace mindspore::kernel { | |||
| class InvertPermutationCPUKernel : public LiteKernel { | |||
| public: | |||
| InvertPermutationCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~InvertPermutationCPUKernel() = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_INVERT_PERMUTATION_H_ | |||
| @@ -40,6 +40,8 @@ int ResizeCPUKernel::Init() { | |||
| case schema::CoordinateTransformMode_HALF_PIXEL: | |||
| calculate_ = CalculateHalfPixel; | |||
| break; | |||
| case schema::CoordinateTransformMode_CROP_AND_RESIZE: | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Do not support coordinate transform mode. Mode is" | |||
| << schema::EnumNameCoordinateTransformMode( | |||
| @@ -70,8 +72,15 @@ int ResizeCPUKernel::ReSize() { | |||
| auto input = in_tensors_.at(0); | |||
| auto input_shape = input->shape(); | |||
| ret = PrepareResizeBilinear(input_shape.data(), out_tensors_.at(0)->shape().data(), calculate_, y_bottoms_, y_tops_, | |||
| x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_); | |||
| if (coordinate_transform_mode_ == schema::CoordinateTransformMode_CROP_AND_RESIZE) { | |||
| auto boxes = reinterpret_cast<float *>(in_tensors_.at(1)->data_c()); | |||
| auto box_idx = reinterpret_cast<int32_t *>(in_tensors_.at(2)->data_c()); | |||
| ret = PrepareCropAndResizeBilinear(input_shape.data(), boxes, box_idx, out_tensors_.at(0)->shape().data(), | |||
| y_bottoms_, y_tops_, x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_); | |||
| } else { | |||
| ret = PrepareResizeBilinear(input_shape.data(), out_tensors_.at(0)->shape().data(), calculate_, y_bottoms_, | |||
| y_tops_, x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_); | |||
| } | |||
| if (ret != RET_OK) { | |||
| FreeTmpBuffer(); | |||
| } | |||
| @@ -80,36 +89,42 @@ int ResizeCPUKernel::ReSize() { | |||
| } | |||
| int ResizeCPUKernel::MallocTmpBuffer() { | |||
| int b = in_tensors_.at(0)->Batch(); | |||
| // Malloc buffer to save coordinate. For mode CROP_AND_RESIZE, different batches require different cache coordinates. | |||
| // For other modes, different batches have different cache coordinates. | |||
| if (coordinate_transform_mode_ != schema::CoordinateTransformMode_CROP_AND_RESIZE) { | |||
| b = 1; | |||
| } | |||
| int c = in_tensors_.at(0)->Channel(); | |||
| int h = new_height_; | |||
| int w = new_width_; | |||
| y_bottoms_ = reinterpret_cast<int *>(malloc(sizeof(int) * h)); | |||
| y_bottoms_ = reinterpret_cast<int *>(malloc(sizeof(int) * h * b)); | |||
| if (y_bottoms_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| y_tops_ = reinterpret_cast<int *>(malloc(sizeof(int) * h)); | |||
| y_tops_ = reinterpret_cast<int *>(malloc(sizeof(int) * h * b)); | |||
| if (y_tops_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| y_bottom_weights_ = reinterpret_cast<float *>(malloc(sizeof(float) * h)); | |||
| y_bottom_weights_ = reinterpret_cast<float *>(malloc(sizeof(float) * h * b)); | |||
| if (y_bottom_weights_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| x_lefts_ = reinterpret_cast<int *>(malloc(sizeof(int) * w)); | |||
| x_lefts_ = reinterpret_cast<int *>(malloc(sizeof(int) * w * b)); | |||
| if (x_lefts_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| x_rights_ = reinterpret_cast<int *>(malloc(sizeof(int) * w)); | |||
| x_rights_ = reinterpret_cast<int *>(malloc(sizeof(int) * w * b)); | |||
| if (x_rights_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| x_left_weights_ = reinterpret_cast<float *>(malloc(sizeof(float) * w)); | |||
| x_left_weights_ = reinterpret_cast<float *>(malloc(sizeof(float) * w * b)); | |||
| if (x_left_weights_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data failed"; | |||
| return RET_NULL_PTR; | |||
| @@ -181,20 +196,17 @@ int ResizeCPUKernel::RunImpl(int task_id) { | |||
| int ret = 0; | |||
| switch (method_) { | |||
| case static_cast<int>(schema::ResizeMethod_LINEAR): { | |||
| int n_h_begin, n_h_end; | |||
| int n = out_tensors_.at(0)->shape().at(0); | |||
| int h = new_height_; | |||
| int unit = UP_DIV(n * h, context_->thread_num_); | |||
| n_h_begin = unit * task_id; | |||
| n_h_end = std::min(n_h_begin + unit, n * h); | |||
| int unit = UP_DIV(new_height_, context_->thread_num_); | |||
| int h_begin = unit * task_id; | |||
| int h_end = std::min(h_begin + unit, new_height_); | |||
| int c = in_tensors_.at(0)->shape().at(3); | |||
| float *line0 = line_buffer_ + new_width_ * c * 2 * task_id; | |||
| float *line1 = line0 + new_width_ * c; | |||
| bool is_crop = coordinate_transform_mode_ == schema::CoordinateTransformMode_CROP_AND_RESIZE; | |||
| ret = ResizeBilinear(input_data, output_data, input_shape.data(), out_tensors_.at(0)->shape().data(), y_bottoms_, | |||
| y_tops_, x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_, line0, line1, n_h_begin, | |||
| n_h_end); | |||
| y_tops_, x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_, line0, line1, h_begin, | |||
| h_end, is_crop); | |||
| break; | |||
| } | |||
| case static_cast<int>(schema::ResizeMethod_NEAREST): { | |||
| @@ -202,7 +214,6 @@ int ResizeCPUKernel::RunImpl(int task_id) { | |||
| calculate_, coordinate_transform_mode_, task_id, context_->thread_num_); | |||
| break; | |||
| } | |||
| case schema::ResizeMethod_UNKNOW: | |||
| default: { | |||
| MS_LOG(ERROR) << "Resize unknown method " << method_; | |||
| ret = RET_ERROR; | |||
| @@ -218,7 +229,6 @@ int ResizeCPUKernel::Run() { | |||
| FreeTmpBuffer(); | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32/size_fp32.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "schema/model_generated.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Size; | |||
| namespace mindspore::kernel { | |||
| int SizeCPUKernel::Init() { return RET_OK; } | |||
| int SizeCPUKernel::ReSize() { return RET_OK; } | |||
| int SizeCPUKernel::Run() { | |||
| auto in_tensor = in_tensors_.front(); | |||
| auto out_tensor = out_tensors_.front(); | |||
| if (in_tensor == nullptr || out_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "null pointer dereferencing."; | |||
| return RET_ERROR; | |||
| } | |||
| if (in_tensor->data_c() == nullptr || out_tensor->data_c() == nullptr) { | |||
| MS_LOG(ERROR) << "null pointer dereferencing."; | |||
| return RET_ERROR; | |||
| } | |||
| reinterpret_cast<int *>(out_tensor->data_c())[0] = in_tensor->ElementsNum(); | |||
| return RET_OK; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Size, LiteKernelCreator<SizeCPUKernel>) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Size, LiteKernelCreator<SizeCPUKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SIZE_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SIZE_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/fp32/invert_permutation_fp32.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| using mindspore::lite::InnerContext; | |||
| namespace mindspore::kernel { | |||
| class SizeCPUKernel : public LiteKernel { | |||
| public: | |||
| SizeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~SizeCPUKernel() = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SIZE_H_ | |||
| @@ -105,7 +105,8 @@ int UpsampleCPUKernel::RunImpl(int task_id) { | |||
| float *line0 = line_buffer_ + new_width_ * c * 2 * task_id; | |||
| float *line1 = line0 + new_width_ * c; | |||
| ret = ResizeBilinear(input_data, output_data, input_shape.data(), out_tensor->shape().data(), y_bottoms_, y_tops_, | |||
| x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_, line0, line1, n_h_begin, n_h_end); | |||
| x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_, line0, line1, n_h_begin, n_h_end, | |||
| false); | |||
| break; | |||
| } | |||