Browse Source

new implement for resize bicubic

tags/v1.2.0-rc1
fuzhiye 4 years ago
parent
commit
3c30d58273
12 changed files with 412 additions and 140 deletions
  1. +185
    -5
      mindspore/lite/nnacl/fp32/resize_fp32.c
  2. +9
    -0
      mindspore/lite/nnacl/fp32/resize_fp32.h
  3. +1
    -0
      mindspore/lite/nnacl/resize_parameter.h
  4. +8
    -0
      mindspore/lite/src/common/log_util.h
  5. +1
    -0
      mindspore/lite/src/lite_kernel.h
  6. +1
    -0
      mindspore/lite/src/ops/populate/resize_populate.cc
  7. +8
    -13
      mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc
  8. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/base/resize_base.h
  9. +149
    -112
      mindspore/lite/src/runtime/kernel/arm/fp32/resize_fp32.cc
  10. +40
    -7
      mindspore/lite/src/runtime/kernel/arm/fp32/resize_fp32.h
  11. +5
    -0
      mindspore/lite/tools/converter/parser/tf/tf_resize_parser.cc
  12. +4
    -2
      mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc

+ 185
- 5
mindspore/lite/nnacl/fp32/resize_fp32.c View File

@@ -19,12 +19,48 @@
#include "nnacl/errorcode.h"

void CalculateCoordinate(float out, int in, int *bottom, int *top, float *bottom_weight) {
*bottom = (int)(floor(out));
*bottom = (int)(floorf(out));
*top = *bottom + 1 < in ? (*bottom + 1) : (in - 1);
float top_weight = (float)out - (float)(*bottom);
*bottom_weight = 1.0f - top_weight;
}

static void BicubicBaseFunc(float a, const float x, float *weight) {
if (x > 1 && x < 2) {
weight[0] = a * x * x * x - 5 * a * x * x + 8 * a * x - 4 * a;
} else if (x >= 0 && x <= 1) {
weight[0] = ((a + 2) * x - (a + 3)) * x * x + 1;
} else {
weight[0] = 0;
}
}

// a is a coefficient
// W(x) = { (a + 2) * |x| * |x| * |x| - (a + 3) * |x| * |x| + 1, for |x| <= 1
// { a * |x| * |x| * |x| - 5 * a * |x| * |x| + 8 * a *|x| - 4 * a, for 1 < |x| < 2
// { 0, otherwise
// the value of 'a' depends on if is half_pixel_center(the scheme is the same as tf).
// If is half pixel mode, a equals to -0.5, otherwise -0.75.
void CalculateWightForBicubic(float out, int in, int *bottom, int *top, float *weights, float a) {
// can not exchange the order of calculating bottom[1] and bottom[0], because the order is decided outside.
bottom[1] = (int)(floorf(out));
bottom[0] = (bottom[1] - 1) < 0 ? 0 : (bottom[1] - 1);
top[0] = (bottom[1] + 1) < in ? (bottom[1] + 1) : (in - 1);
top[1] = (top[0] + 1) < in ? (top[0] + 1) : (in - 1);

// get positive value
float distance[4] = {1, 0, 1, 2};
float tmp_dis = out - (float)bottom[1];
distance[0] += tmp_dis;
distance[1] += tmp_dis;
distance[2] -= tmp_dis;
distance[3] -= tmp_dis;

for (int i = 0; i < 4; ++i) {
BicubicBaseFunc(a, distance[i], &weights[i]);
}
}

int PrepareResizeBilinear(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate,
int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights,
float *x_left_weights) {
@@ -50,6 +86,30 @@ int PrepareResizeBilinear(const int *input_shape, const int *output_shape, Calcu
return NNACL_OK;
}

int PrepareResizeBicubic(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate,
int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float *y_weights, float *x_weights,
float cubic_coeff) {
if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL ||
x_rights == NULL || y_weights == NULL || x_weights == NULL) {
return NNACL_NULL_PTR;
}

int in_h = input_shape[1];
int in_w = input_shape[2];
int new_height = output_shape[1];
int new_width = output_shape[2];

for (int h = 0; h < new_height; h++) {
float actual_y = calculate(h, in_h, new_height);
CalculateWightForBicubic(actual_y, in_h, y_bottoms + 2 * h, y_tops + 2 * h, y_weights + 4 * h, cubic_coeff);
}
for (int w = 0; w < new_width; w++) {
float actual_x = calculate(w, in_w, new_width);
CalculateWightForBicubic(actual_x, in_w, x_lefts + 2 * w, x_rights + 2 * w, x_weights + 4 * w, cubic_coeff);
}
return NNACL_OK;
}

int PrepareCropAndResizeBilinear(const int *input_shape, const float *boxes, const int *box_idx,
const int *output_shape, int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights,
float *y_bottom_weights, float *x_left_weights) {
@@ -222,6 +282,126 @@ int ResizeBilinear(const float *input_data, float *output_data, const int *input
return NNACL_OK;
}

void BicubicInterpRow(const float *src, float *dst, int len, const float *weights, const int *lefts, const int *rights,
int in_c) {
int l = 0;
for (; l < len; l++) {
const float weight1 = weights[4 * l];
const float weight2 = weights[4 * l + 1];
const float weight3 = weights[4 * l + 2];
const float weight4 = weights[4 * l + 3];
int c = 0;
#ifdef ENABLE_NEON
float32x4_t weight1_vec = vdupq_n_f32(weight1);
float32x4_t weight2_vec = vdupq_n_f32(weight2);
float32x4_t weight3_vec = vdupq_n_f32(weight3);
float32x4_t weight4_vec = vdupq_n_f32(weight4);

for (; c <= in_c - 4; c += 4) {
float32x4_t src1_vec = vld1q_f32(src + lefts[2 * l] * in_c + c);
float32x4_t src2_vec = vld1q_f32(src + lefts[2 * l + 1] * in_c + c);
float32x4_t src3_vec = vld1q_f32(src + rights[2 * l] * in_c + c);
float32x4_t src4_vec = vld1q_f32(src + rights[2 * l + 1] * in_c + c);

float32x4_t interp_value =
src1_vec * weight1_vec + src2_vec * weight2_vec + src3_vec * weight3_vec + src4_vec * weight4_vec;
vst1q_f32(dst + l * in_c + c, interp_value);
}
#endif
int pos1 = lefts[2 * l] * in_c;
int pos2 = lefts[2 * l + 1] * in_c;
int pos3 = rights[2 * l] * in_c;
int pos4 = rights[2 * l + 1] * in_c;

for (; c < in_c; c++) {
float value1 = src[pos1 + c];
float value2 = src[pos2 + c];
float value3 = src[pos3 + c];
float value4 = src[pos4 + c];
dst[l * in_c + c] = value1 * weight1 + value2 * weight2 + value3 * weight3 + value4 * weight4;
}
}
}

void BicubicInterpCol(const float *src1, const float *src2, const float *src3, const float *src4, float *dst, int len,
const float *weights, int in_c) {
int l = 0;
for (; l < len; l++) {
int c = 0;
int l_stride = l * in_c;
const float weight1 = weights[4 * l];
const float weight2 = weights[4 * l + 1];
const float weight3 = weights[4 * l + 2];
const float weight4 = weights[4 * l + 3];
#ifdef ENABLE_NEON
float32x4_t weight1_vec = vdupq_n_f32(weight1);
float32x4_t weight2_vec = vdupq_n_f32(weight2);
float32x4_t weight3_vec = vdupq_n_f32(weight3);
float32x4_t weight4_vec = vdupq_n_f32(weight4);

for (; c <= in_c - 4; c += 4) {
float32x4_t src1_vec = vld1q_f32(src1 + l_stride + c);
float32x4_t src2_vec = vld1q_f32(src2 + l_stride + c);
float32x4_t src3_vec = vld1q_f32(src3 + l_stride + c);
float32x4_t src4_vec = vld1q_f32(src4 + l_stride + c);
float32x4_t interp_value =
src1_vec * weight1_vec + src2_vec * weight2_vec + src3_vec * weight3_vec + src4_vec * weight4_vec;
vst1q_f32(dst + l_stride + c, interp_value);
}
#endif
for (; c < in_c; c++) {
float value1 = src1[l_stride + c];
float value2 = src2[l_stride + c];
float value3 = src3[l_stride + c];
float value4 = src4[l_stride + c];
dst[l_stride + c] = value1 * weight1 + value2 * weight2 + value3 * weight3 + value4 * weight4;
}
}
}

void Bicubic(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
const int *y_bottom, const int *y_top, const int *x_lefts, const int *x_rights, const float *y_weights,
const float *x_weights, float *line_buffer, const int h_begin, const int h_end) {
int in_w = input_shape[2];
int in_c = input_shape[3];
int new_width = output_shape[2];
int h_stride = new_width * in_c;

float *line_array[4] = {line_buffer, line_buffer + h_stride, line_buffer + 2 * h_stride, line_buffer + 3 * h_stride};
for (int h = h_begin; h < h_end; h++) {
for (int i = 0; i < 2; ++i) {
BicubicInterpRow(input_data + y_bottom[2 * h + i] * in_w * in_c, line_array[i], new_width, x_weights, x_lefts,
x_rights, in_c);
}
for (int j = 0; j < 2; ++j) {
BicubicInterpRow(input_data + y_top[2 * h + j] * in_w * in_c, line_array[j + 2], new_width, x_weights, x_lefts,
x_rights, in_c);
}

BicubicInterpCol(line_array[0], line_array[1], line_array[2], line_array[3], output_data + h * h_stride, new_width,
y_weights, in_c);
}
}

int ResizeBicubic(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
const int *y_bottoms, const int *y_tops, const int *x_lefts, const int *x_rights,
const float *y_weights, const float *x_weights, float *line_buffer, const int h_begin,
const int h_end) {
if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_bottoms == NULL ||
y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_weights == NULL || x_weights == NULL) {
return NNACL_NULL_PTR;
}
int input_cube_per_batch = input_shape[1] * input_shape[2] * input_shape[3];
int output_cube_per_batch = output_shape[1] * output_shape[2] * input_shape[3];
for (int b = 0; b < input_shape[0]; b++) {
const float *input = input_data + b * input_cube_per_batch;
float *output = output_data + b * output_cube_per_batch;
Bicubic(input, output, input_shape, output_shape, y_bottoms, y_tops, x_lefts, x_rights, y_weights, x_weights,
line_buffer, h_begin, h_end);
}
return NNACL_OK;
}

int CropAndResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
const int *y_bottoms, const int *y_tops, const int *x_lefts, const int *x_rights,
const float *y_bottom_weights, const float *x_left_weights, float *line0, float *line1,
@@ -260,17 +440,17 @@ int ResizeNearestNeighbor(const float *input_data, float *output_data, const int
float actual_y = calculate(y, input_shape[1], output_shape[1]);
int input_y;
if (align_corners) {
input_y = (int)(round(actual_y));
input_y = (int)(roundf(actual_y));
} else {
input_y = (int)(floor(actual_y));
input_y = (int)(floorf(actual_y));
}
for (int x = 0; x < output_shape[2]; x++) {
float actual_x = calculate(x, input_shape[2], output_shape[2]);
int input_x;
if (align_corners) {
input_x = (int)(round(actual_x));
input_x = (int)(roundf(actual_x));
} else {
input_x = (int)(floor(actual_x));
input_x = (int)(floorf(actual_x));
}
int in_offset = offset(input_shape, batch, input_y, input_x, 0);
int out_offset = offset(output_shape, batch, y, x, 0);


+ 9
- 0
mindspore/lite/nnacl/fp32/resize_fp32.h View File

@@ -31,11 +31,20 @@ int PrepareResizeBilinear(const int *input_shape, const int *output_shape, Calcu
int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights,
float *x_left_weights);

int PrepareResizeBicubic(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate,
int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights,
float *x_left_weights, float cubic_coeff);

int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
const int *y_bottoms, const int *y_tops, const int *x_lefts, const int *x_rights,
const float *y_bottom_weights, const float *x_left_weights, float *line0, float *line1,
const int h_begin, const int h_end);

int ResizeBicubic(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
const int *y_bottoms, const int *y_tops, const int *x_lefts, const int *x_rights,
const float *y_bottom_weights, const float *x_left_weights, float *line_buffer, const int h_begin,
const int h_end);

int PrepareCropAndResizeBilinear(const int *input_shape, const float *boxes, const int *box_idx,
const int *output_shape, int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights,
float *y_bottom_weights, float *x_left_weights);


+ 1
- 0
mindspore/lite/nnacl/resize_parameter.h View File

@@ -24,6 +24,7 @@ typedef struct ResizeParameter {
int64_t new_height_;
int64_t new_width_;
int coordinate_transform_mode_;
float cubic_coeff_;
bool preserve_aspect_ratio_;
} ResizeParameter;



+ 8
- 0
mindspore/lite/src/common/log_util.h View File

@@ -28,4 +28,12 @@
} \
} while (0)

#define CHECK_MALLOC_RES(ptr, errcode) \
do { \
if ((ptr) == nullptr) { \
MS_LOG(ERROR) << "malloc data failed."; \
return errcode; \
} \
} while (0);

#endif // MINDSPORE_LITE_SRC_COMMON_LOG_UTIL_H_

+ 1
- 0
mindspore/lite/src/lite_kernel.h View File

@@ -21,6 +21,7 @@
#include <memory>
#include <utility>
#include "src/common/utils.h"
#include "src/common/log_util.h"
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif


+ 1
- 0
mindspore/lite/src/ops/populate/resize_populate.cc View File

@@ -35,6 +35,7 @@ OpParameter *PopulateResizeParameter(const void *prim) {
resize_param->new_width_ = value->new_width();
resize_param->coordinate_transform_mode_ = value->coordinate_transform_mode();
resize_param->preserve_aspect_ratio_ = value->preserve_aspect_ratio();
resize_param->cubic_coeff_ = value->cubic_coeff();
return reinterpret_cast<OpParameter *>(resize_param);
}



+ 8
- 13
mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc View File

@@ -19,7 +19,6 @@
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/kernel/arm/fp32/resize_fp32.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INVALID_OP_ATTR;
@@ -39,9 +38,8 @@ int ResizeBaseCPUKernel::CheckParameters() {
return RET_NULL_PTR;
}
method_ = parameter->method_;
if (method_ != static_cast<int>(schema::ResizeMethod_LINEAR) &&
method_ != static_cast<int>(schema::ResizeMethod_NEAREST)) {
MS_LOG(ERROR) << "Resize method should be bilinear or nearest_neighbor, but got " << method_;
if (method_ == schema::ResizeMethod::ResizeMethod_UNKNOWN) {
MS_LOG(ERROR) << "Resize method can not be unknown.";
return RET_INVALID_OP_ATTR;
}
if (this->in_tensors_.size() == 1) {
@@ -78,25 +76,23 @@ int ResizeBaseCPUKernel::CheckParameters() {
}

int ResizeBaseCPUKernel::CheckInputsOuputs() {
// inputs
if (in_tensors_.size() <= kMaxInputNum) {
for (size_t i = 0; i < in_tensors_.size(); i++) {
auto input = in_tensors_.at(i);
if (input == nullptr) {
return RET_NULL_PTR;
}
for (auto input : in_tensors_) {
MSLITE_CHECK_PTR(input);
}
} else {
MS_LOG(ERROR) << "Resize input num should be no more than" << kMaxInputNum << ", but got " << in_tensors_.size();
return RET_ERROR;
}

// outputs
if (out_tensors_.size() != kOutputNum) {
MS_LOG(ERROR) << "Resize output num should be " << kOutputNum << ", but got " << out_tensors_.size();
return RET_ERROR;
}
auto output = out_tensors_.at(0);
if (output == nullptr) {
return RET_NULL_PTR;
}
MSLITE_CHECK_PTR(output);
return RET_OK;
}

@@ -116,7 +112,6 @@ int ResizeBaseCPUKernel::Init() {
MS_LOG(ERROR) << "Resize op support input rank 4, got " << input_shape.size();
return RET_ERROR;
}

return RET_OK;
}
} // namespace mindspore::kernel

+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/base/resize_base.h View File

@@ -30,7 +30,7 @@ class ResizeBaseCPUKernel : public LiteKernel {
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: LiteKernel(parameter, inputs, outputs, ctx) {}

virtual ~ResizeBaseCPUKernel() = default;
~ResizeBaseCPUKernel() override = default;

int Init() override;
int ReSize() override { return 0; };


+ 149
- 112
mindspore/lite/src/runtime/kernel/arm/fp32/resize_fp32.cc View File

@@ -14,6 +14,8 @@
* limitations under the License.
*/

#include <map>
#include <utility>
#include "src/runtime/kernel/arm/fp32/resize_fp32.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@@ -25,26 +27,18 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INVALID_OP_ATTR;
using mindspore::lite::RET_NULL_PTR;
using mindspore::lite::RET_OK;
using mindspore::schema::CoordinateTransformMode_ALIGN_CORNERS;
using mindspore::schema::CoordinateTransformMode_ASYMMETRIC;
using mindspore::schema::CoordinateTransformMode_HALF_PIXEL;
using mindspore::schema::PrimitiveType_Resize;

namespace mindspore::kernel {
int ResizeCPUKernel::Init() {
auto ret = ResizeBaseCPUKernel::Init();
switch (coordinate_transform_mode_) {
case schema::CoordinateTransformMode_ASYMMETRIC:
calculate_ = CalculateAsymmetric;
break;
case schema::CoordinateTransformMode_ALIGN_CORNERS:
calculate_ = CalculateAlignCorners;
break;
case schema::CoordinateTransformMode_HALF_PIXEL:
calculate_ = CalculateHalfPixel;
break;
default:
MS_LOG(ERROR) << "Do not support coordinate transform mode. Mode is"
<< schema::EnumNameCoordinateTransformMode(
static_cast<schema::CoordinateTransformMode>(coordinate_transform_mode_));
if (ret != RET_OK) {
return ret;
}
ret = SelectCalculatorFunc();
if (ret != RET_OK) {
return ret;
}
@@ -55,98 +49,106 @@ int ResizeCPUKernel::Init() {
}

int ResizeCPUKernel::ReSize() {
int ret = RET_OK;
if (method_ == static_cast<int>(schema::ResizeMethod_LINEAR)) {
if (!const_shape_) {
new_height_ = out_tensors_.at(0)->shape()[1];
new_width_ = out_tensors_.at(0)->shape()[2];
}
FreeTmpBuffer();
ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
return ret;
}
if (method_ == static_cast<int>(schema::ResizeMethod_NEAREST)) {
return RET_OK;
}

auto input = in_tensors_.at(0);
auto input_shape = input->shape();
ret = PrepareResizeBilinear(input_shape.data(), out_tensors_.at(0)->shape().data(), calculate_, y_bottoms_, y_tops_,
x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_);
if (ret != RET_OK) {
FreeTmpBuffer();
}
if (!const_shape_) {
new_height_ = out_tensors_.at(0)->shape()[1];
new_width_ = out_tensors_.at(0)->shape()[2];
}
return ret;
}

int ResizeCPUKernel::MallocTmpBuffer() {
int c = in_tensors_.at(0)->Channel();
int h = new_height_;
int w = new_width_;
y_bottoms_ = reinterpret_cast<int *>(malloc(sizeof(int) * h));
if (y_bottoms_ == nullptr) {
MS_LOG(ERROR) << "malloc data failed";
return RET_NULL_PTR;
}
y_tops_ = reinterpret_cast<int *>(malloc(sizeof(int) * h));
if (y_tops_ == nullptr) {
MS_LOG(ERROR) << "malloc data failed";
return RET_NULL_PTR;
}
y_bottom_weights_ = reinterpret_cast<float *>(malloc(sizeof(float) * h));
if (y_bottom_weights_ == nullptr) {
MS_LOG(ERROR) << "malloc data failed";
return RET_NULL_PTR;
}

x_lefts_ = reinterpret_cast<int *>(malloc(sizeof(int) * w));
if (x_lefts_ == nullptr) {
MS_LOG(ERROR) << "malloc data failed";
return RET_NULL_PTR;
}
x_rights_ = reinterpret_cast<int *>(malloc(sizeof(int) * w));
if (x_rights_ == nullptr) {
MS_LOG(ERROR) << "malloc data failed";
return RET_NULL_PTR;
}
x_left_weights_ = reinterpret_cast<float *>(malloc(sizeof(float) * w));
if (x_left_weights_ == nullptr) {
MS_LOG(ERROR) << "malloc data failed";
return RET_NULL_PTR;
}
line_buffer_ = reinterpret_cast<float *>(malloc(sizeof(float) * w * c * 2 * context_->thread_num_));
if (line_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc data failed";
return RET_NULL_PTR;
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
return ret;
}

ret = ResizePrepare();
if (ret != RET_OK) {
FreeTmpBuffer();
return ret;
}
return RET_OK;
}

void ResizeCPUKernel::FreeTmpBuffer() {
if (y_bottoms_ != nullptr) {
free(y_bottoms_);
y_bottoms_ = nullptr;
// Bilinear interpolation :
// Bilinear interpolation considers the closest 2x2 neighborhood of known pixel values surrounding the unknown pixel.
// It takes a weighted average of these 4 pixels to arrive at its final interpolated value. Thus, we need to reserve
// twice bigger space than coordinates arrays for weight arrays. It means x_weight_len is twice as much as x_len in
// detail.
//
// Bicubic interpolation:
// Bicubic goes one step beyond bilinear by considering the closest 4x4 neighborhood of known pixels --- for a total of
// 16 pixels. Since these are at various distances from the unknown pixel, closer pixels are given a higher weighting in
// the calculation.
void ResizeCPUKernel::CalTmpBufferLen(int *x_len, int *y_len, int *x_weight_len, int *y_weight_len) {
if (method_ == static_cast<int>(schema::ResizeMethod_LINEAR)) {
*x_len = *x_weight_len = new_width_;
*y_len = *y_weight_len = new_height_;
}
if (method_ == static_cast<int>(schema::ResizeMethod_CUBIC)) {
*x_len = new_width_ * 2;
*y_len = new_height_ * 2;
*x_weight_len = new_width_ * 4;
*y_weight_len = new_height_ * 4;
}
if (y_tops_ != nullptr) {
free(y_tops_);
y_tops_ = nullptr;
}

// If resize method is bicubic, x_lefts_ array stores two elements (index - 1, index - 2) for every output coordinate
// index. For example, there is a 1-D output coordinate array:
// [0, 0.5, 1]
// now, search two elements at left and two at right for every position in output array.
// Thus, x_lefts_ array looks like :
// x_lefts_ [-2, -1, -1.5, -0.5, -1, 0]
// \ / \ / \ /
// \ / \ / \/
// corresponding to index : 0 0.5 1
// Apply to x_rights_ array by the same way.
int ResizeCPUKernel::MallocTmpBuffer() {
// make sure y_bottoms_, y_tops_, etc. are null before malloc
FreeTmpBuffer();

int x_len = 0, y_len = 0, x_weight_len = 0, y_weight_len = 0;
CalTmpBufferLen(&x_len, &y_len, &x_weight_len, &y_weight_len);

// malloc memory for x, y coordinates
{
coordinate_.x_lefts_ = reinterpret_cast<int *>(malloc(sizeof(int) * x_len));
CHECK_MALLOC_RES(coordinate_.x_lefts_, RET_NULL_PTR)
coordinate_.x_rights_ = reinterpret_cast<int *>(malloc(sizeof(int) * x_len));
CHECK_MALLOC_RES(coordinate_.x_rights_, RET_NULL_PTR)
coordinate_.y_tops_ = reinterpret_cast<int *>(malloc(sizeof(int) * y_len));
CHECK_MALLOC_RES(coordinate_.y_tops_, RET_NULL_PTR)
coordinate_.y_bottoms_ = reinterpret_cast<int *>(malloc(sizeof(int) * y_len));
CHECK_MALLOC_RES(coordinate_.y_bottoms_, RET_NULL_PTR)
}
if (y_bottom_weights_ != nullptr) {
free(y_bottom_weights_);
y_bottom_weights_ = nullptr;

// malloc memory for weights of x, y axes
{
x_weights_ = reinterpret_cast<float *>(malloc(sizeof(float) * x_weight_len));
CHECK_MALLOC_RES(x_weights_, RET_NULL_PTR)
y_weights_ = reinterpret_cast<float *>(malloc(sizeof(float) * y_weight_len));
CHECK_MALLOC_RES(y_weights_, RET_NULL_PTR)
}

if (x_lefts_ != nullptr) {
free(x_lefts_);
x_lefts_ = nullptr;
{
line_buffer_ = reinterpret_cast<float *>(
malloc(sizeof(float) * x_len * in_tensors_.at(0)->Channel() * 2 * context_->thread_num_));
CHECK_MALLOC_RES(line_buffer_, RET_NULL_PTR)
}
if (x_rights_ != nullptr) {
free(x_rights_);
x_rights_ = nullptr;
return RET_OK;
}

void ResizeCPUKernel::FreeTmpBuffer() {
coordinate_.FreeData();
if (y_weights_ != nullptr) {
free(y_weights_);
y_weights_ = nullptr;
}
if (x_left_weights_ != nullptr) {
free(x_left_weights_);
x_left_weights_ = nullptr;
if (x_weights_ != nullptr) {
free(x_weights_);
x_weights_ = nullptr;
}
if (line_buffer_ != nullptr) {
free(line_buffer_);
@@ -167,18 +169,12 @@ int ResizeImpl(void *cdata, int task_id) {
int ResizeCPUKernel::RunImpl(int task_id) {
auto input = in_tensors_.at(0);
auto input_data = reinterpret_cast<float *>(input->data_c());
if (input_data == nullptr) {
return RET_NULL_PTR;
}
auto output_data = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
if (output_data == nullptr) {
return RET_NULL_PTR;
}
MSLITE_CHECK_PTR(context_);
MSLITE_CHECK_PTR(input_data);
MSLITE_CHECK_PTR(output_data);

auto input_shape = input->shape();
if (context_ == nullptr) {
return RET_NULL_PTR;
}
int ret = 0;
switch (method_) {
case static_cast<int>(schema::ResizeMethod_LINEAR): {
int unit = UP_DIV(new_height_, context_->thread_num_);
@@ -187,22 +183,29 @@ int ResizeCPUKernel::RunImpl(int task_id) {
int c = in_tensors_.at(0)->shape().at(3);
float *line0 = line_buffer_ + new_width_ * c * 2 * task_id;
float *line1 = line0 + new_width_ * c;
ret =
ResizeBilinear(input_data, output_data, input_shape.data(), out_tensors_.at(0)->shape().data(), y_bottoms_,
y_tops_, x_lefts_, x_rights_, y_bottom_weights_, x_left_weights_, line0, line1, h_begin, h_end);
break;
return ResizeBilinear(input_data, output_data, input_shape.data(), out_tensors_.at(0)->shape().data(),
coordinate_.y_bottoms_, coordinate_.y_tops_, coordinate_.x_lefts_, coordinate_.x_rights_,
y_weights_, x_weights_, line0, line1, h_begin, h_end);
}
case static_cast<int>(schema::ResizeMethod_NEAREST): {
ret = ResizeNearestNeighbor(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(),
calculate_, coordinate_transform_mode_, task_id, context_->thread_num_);
break;
return ResizeNearestNeighbor(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(),
calculate_, coordinate_transform_mode_, task_id, context_->thread_num_);
}
case static_cast<int>(schema::ResizeMethod_CUBIC): {
int unit = UP_DIV(new_height_, context_->thread_num_);
int h_begin = unit * task_id;
int h_end = std::min(h_begin + unit, new_height_);
int c = in_tensors_.at(0)->Channel();
float *line_buffer = line_buffer_ + new_width_ * c * 4 * task_id;
return ResizeBicubic(input_data, output_data, input_shape.data(), out_tensors_.at(0)->shape().data(),
coordinate_.y_bottoms_, coordinate_.y_tops_, coordinate_.x_lefts_, coordinate_.x_rights_,
y_weights_, x_weights_, line_buffer, h_begin, h_end);
}
default: {
MS_LOG(ERROR) << "Resize unknown method " << method_;
ret = RET_ERROR;
return RET_ERROR;
}
}
return ret;
}

int ResizeCPUKernel::Run() {
@@ -215,5 +218,39 @@ int ResizeCPUKernel::Run() {
return RET_OK;
}

int ResizeCPUKernel::ResizePrepare() {
auto input_shape = in_tensors_.at(0)->shape();
if (method_ == static_cast<int>(schema::ResizeMethod_LINEAR)) {
return PrepareResizeBilinear(input_shape.data(), out_tensors_.at(0)->shape().data(), calculate_,
coordinate_.y_bottoms_, coordinate_.y_tops_, coordinate_.x_lefts_,
coordinate_.x_rights_, y_weights_, x_weights_);
}
if (method_ == static_cast<int>(schema::ResizeMethod_CUBIC)) {
auto cubic_coeff = reinterpret_cast<ResizeParameter *>(op_parameter_)->cubic_coeff_;
return PrepareResizeBicubic(input_shape.data(), out_tensors_.at(0)->shape().data(), calculate_,
coordinate_.y_bottoms_, coordinate_.y_tops_, coordinate_.x_lefts_,
coordinate_.x_rights_, y_weights_, x_weights_, cubic_coeff);
}
return RET_OK;
}

int ResizeCPUKernel::SelectCalculatorFunc() {
std::map<int, CalculateOriginalCoordinate> cal_fuc_list = {
std::make_pair(CoordinateTransformMode_ASYMMETRIC, CalculateAsymmetric),
std::make_pair(CoordinateTransformMode_ALIGN_CORNERS, CalculateAlignCorners),
std::make_pair(CoordinateTransformMode_HALF_PIXEL, CalculateHalfPixel)};

auto fun_pair = cal_fuc_list.find(coordinate_transform_mode_);
if (fun_pair != cal_fuc_list.end()) {
calculate_ = fun_pair->second;
} else {
MS_LOG(ERROR) << "Do not support coordinate transform mode. Mode is"
<< schema::EnumNameCoordinateTransformMode(
static_cast<schema::CoordinateTransformMode>(coordinate_transform_mode_));
return RET_ERROR;
}
return RET_OK;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Resize, LiteKernelCreator<ResizeCPUKernel>)
} // namespace mindspore::kernel

+ 40
- 7
mindspore/lite/src/runtime/kernel/arm/fp32/resize_fp32.h View File

@@ -23,6 +23,39 @@
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/resize_base.h"

struct ResizeCoordinate {
int *x_lefts_;
int *x_rights_;
int *y_tops_;
int *y_bottoms_;

ResizeCoordinate() {
x_lefts_ = nullptr;
x_rights_ = nullptr;
y_tops_ = nullptr;
y_bottoms_ = nullptr;
}

void FreeData() {
if (x_lefts_ != nullptr) {
free(x_lefts_);
x_lefts_ = nullptr;
}
if (x_rights_ != nullptr) {
free(x_rights_);
x_rights_ = nullptr;
}
if (y_tops_ != nullptr) {
free(y_tops_);
y_tops_ = nullptr;
}
if (y_bottoms_ != nullptr) {
free(y_bottoms_);
y_bottoms_ = nullptr;
}
}
};

namespace mindspore::kernel {
class ResizeCPUKernel : public ResizeBaseCPUKernel {
public:
@@ -30,22 +63,22 @@ class ResizeCPUKernel : public ResizeBaseCPUKernel {
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: ResizeBaseCPUKernel(parameter, inputs, outputs, ctx) {}

~ResizeCPUKernel() { FreeTmpBuffer(); }
~ResizeCPUKernel() override { FreeTmpBuffer(); }

int Init() override;
int ReSize() override;
int Run() override;
virtual int RunImpl(int task_id);
int SelectCalculatorFunc();
int ResizePrepare();
void CalTmpBufferLen(int *x_len, int *y_len, int *x_weight_len, int *y_weight_len);
int MallocTmpBuffer();
void FreeTmpBuffer();

protected:
int *y_tops_ = nullptr;
int *y_bottoms_ = nullptr;
int *x_lefts_ = nullptr;
int *x_rights_ = nullptr;
float *y_bottom_weights_ = nullptr;
float *x_left_weights_ = nullptr;
ResizeCoordinate coordinate_;
float *y_weights_ = nullptr;
float *x_weights_ = nullptr;
float *line_buffer_ = nullptr;
CalculateOriginalCoordinate calculate_ = nullptr;
};


+ 5
- 0
mindspore/lite/tools/converter/parser/tf/tf_resize_parser.cc View File

@@ -30,15 +30,20 @@ ops::PrimitiveC *TFResizeParser::Parse(const tensorflow::NodeDef &tf_op,

tensorflow::AttrValue attr_value;
prim->set_format(mindspore::Format::NHWC);
prim->set_cubic_coeff(-0.75f);
if (!TensorFlowUtils::FindAttrValue(tf_op, "align_corners", &attr_value)) {
MS_LOG(ERROR) << "The align_corners attr should be specified";
return nullptr;
}
if (attr_value.b()) {
prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ALIGN_CORNERS);
} else if (TensorFlowUtils::FindAttrValue(tf_op, "half_pixel_centers", &attr_value) && attr_value.b()) {
prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::HALF_PIXEL);
prim->set_cubic_coeff(-0.5f);
} else {
prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ASYMMETRIC);
}

if (tf_op.op() == "ResizeBilinear") {
prim->set_method(mindspore::ResizeMethod::LINEAR);
} else if (tf_op.op() == "ResizeNearestNeighbor") {


+ 4
- 2
mindspore/lite/tools/converter/parser/tflite/tflite_resize_parser.cc View File

@@ -38,6 +38,8 @@ ops::PrimitiveC *TfliteResizeParser::Parse(const std::unique_ptr<tflite::Operato
MS_LOG(ERROR) << "tflite_subgraph is nullptr";
return nullptr;
}
prim->set_cubic_coeff(-0.75f);
prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ASYMMETRIC);
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
if (tflite_op_type == tflite::BuiltinOperator_RESIZE_BILINEAR) {
MS_LOG(DEBUG) << "parse TfliteResizeBilinearParser";
@@ -50,8 +52,8 @@ ops::PrimitiveC *TfliteResizeParser::Parse(const std::unique_ptr<tflite::Operato
prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::ALIGN_CORNERS);
}
if (tfliteAttr->half_pixel_centers) {
MS_LOG(ERROR) << "Does not support half pixel centers";
return nullptr;
prim->set_coordinate_transform_mode(mindspore::CoordinateTransformMode::HALF_PIXEL);
prim->set_cubic_coeff(-0.5f);
}
prim->set_method(mindspore::ResizeMethod::LINEAR);
} else if (tflite_op_type == tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR) {


Loading…
Cancel
Save