Browse Source

fp32 transpose and spacetobatch multi-thread

tags/v0.7.0-beta
zhongligeng 5 years ago
parent
commit
0366b70c3b
10 changed files with 386 additions and 77 deletions
  1. +47
    -9
      mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc
  2. +6
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h
  3. +45
    -12
      mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc
  4. +10
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h
  5. +10
    -20
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_batch.c
  6. +5
    -3
      mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h
  7. +28
    -24
      mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.c
  8. +9
    -5
      mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h
  9. +6
    -2
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc
  10. +220
    -0
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc

+ 47
- 9
mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.cc View File

@@ -21,6 +21,7 @@
#include "src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h"
#include "src/runtime/kernel/arm/nnacl/errorcode.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"

using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
@@ -47,6 +48,31 @@ int SpaceToBatchCPUKernel::Init() {
return ReSize();
}

int SpaceToBatchCPUKernel::SpaceToBatchParallel(int task_id) {
int num_unit_thread = MSMIN(thread_h_stride_, num_unit_ - task_id * thread_h_stride_);
if (num_unit_thread <= 0) {
return RET_OK;
}
int thread_offset = task_id * thread_h_stride_;
SpaceToBatchParameter *param = reinterpret_cast<SpaceToBatchParameter *>(this->op_parameter_);
auto ret = SpaceToBatch(input_ptr_, output_ptr_, *param, thread_offset, thread_offset + num_unit_thread);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SpaceToDepth error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}

int SpaceToBatchRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto g_kernel = reinterpret_cast<SpaceToBatchCPUKernel *>(cdata);
auto ret = g_kernel->SpaceToBatchParallel(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SpaceToBatchRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_OP_EXECUTE_FAILURE;
}
return RET_OK;
}

int SpaceToBatchCPUKernel::ReSize() {
if (in_tensors_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "space_to_batch only support NHWC now!";
@@ -55,6 +81,10 @@ int SpaceToBatchCPUKernel::ReSize() {
SpaceToBatchParameter *param = reinterpret_cast<SpaceToBatchParameter *>(this->op_parameter_);
param->num_elements_ = EnumElement(param->in_shape_, param->n_dims_);
param->num_elements_padded_ = EnumElement(param->padded_in_shape_, param->n_dims_);
num_unit_ = static_cast<int>(in_tensors_[kInputIndex]->shape().at(kNHWC_H));
num_unit_ /= param->block_sizes_[0];
thread_h_num_ = MSMIN(thread_num_, num_unit_);
thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_);
return RET_OK;
}

@@ -81,20 +111,28 @@ int SpaceToBatchCPUKernel::Run() {
return RET_ERROR;
}
}
ret = SpaceToBatch(input_ptr_, output_ptr_, *param, tmp_space);
for (int i = 0; i < 3; ++i) {
context_->allocator->Free(tmp_space);
auto padded_input = tmp_space[0];
DoPadding(input_ptr_, padded_input, *param, tmp_space + 1);
input_ptr_ = padded_input;
}

if (input->GetFormat() == schema::Format_NHWC) {
ret = LiteBackendParallelLaunch(SpaceToBatchRun, this, thread_h_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SpaceToBatch error error_code[" << ret << "]";
}
} else {
ret = SpaceToBatch(input_ptr_, output_ptr_, *param, tmp_space);
MS_LOG(ERROR) << "Only support NHWC now!";
ret = RET_FORMAT_ERR;
}
if (ret != NNACL_OK) {
MS_LOG(ERROR) << "Do space to batch fails!";
return RET_OP_EXECUTE_FAILURE;
if (param->need_paddings_) {
for (int i = 0; i < 3; ++i) {
context_->allocator->Free(tmp_space[i]);
}
}

return RET_OK;
}
return ret;
} // namespace mindspore::kernel

kernel::LiteKernel *CpuSpaceToBatchFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,


+ 6
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch.h View File

@@ -25,15 +25,20 @@ class SpaceToBatchCPUKernel : public LiteKernel {
SpaceToBatchCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {}

~SpaceToBatchCPUKernel() = default;

int Init() override;
int ReSize() override;
int Run() override;
int SpaceToBatchParallel(int task_id);

private:
int thread_num_;
int thread_h_stride_;
int thread_h_num_;
int num_unit_;
const float *input_ptr_;
float *output_ptr_;
};


+ 45
- 12
mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc View File

@@ -20,10 +20,12 @@
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"

using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::RET_OP_EXECUTE_FAILURE;
using mindspore::schema::PrimitiveType_Transpose;

namespace mindspore::kernel {
@@ -32,6 +34,10 @@ constexpr int kTransposeInputNum = 1;
constexpr int kTransposeOutputNum = 1;
} // namespace
int TransposeCPUKernel::Init() {
TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_);
num_unit_ = static_cast<int>(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H]));
thread_h_num_ = MSMIN(thread_num_, num_unit_);
thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_);
if (!InferShapeDone()) {
return RET_OK;
}
@@ -54,6 +60,32 @@ int TransposeCPUKernel::ReSize() {
return RET_OK;
}

int TransposeCPUKernel::TransposeParallel(int task_id) {
int num_unit_thread = MSMIN(thread_h_stride_, num_unit_ - task_id * thread_h_stride_);
if (num_unit_thread <= 0) {
return RET_OK;
}
int thread_offset = task_id * thread_h_stride_;
TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_);
auto ret =
DoTranspose(in_data_, out_data_, in_shape_, out_shape_, param, thread_offset, thread_offset + num_unit_thread);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Transpose error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}

int TransposeRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto g_kernel = reinterpret_cast<TransposeCPUKernel *>(cdata);
auto ret = g_kernel->TransposeParallel(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "TransposeRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_OP_EXECUTE_FAILURE;
}
return RET_OK;
}

int TransposeCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
@@ -62,23 +94,24 @@ int TransposeCPUKernel::Run() {
}
MS_ASSERT(in_tensors_.size() == TransposeInputNum);
MS_ASSERT(out_tensors_.size() == TransposeOutputNum);
auto &inTensor = in_tensors_.front();
auto &outTensor = out_tensors_.front();
if (inTensor == nullptr || outTensor == nullptr) {
auto &in_tensor = in_tensors_.front();
auto &out_tensor = out_tensors_.front();
if (in_tensor == nullptr || out_tensor == nullptr) {
MS_LOG(ERROR) << "null pointer dreferencing.";
return RET_ERROR;
}
auto *in_data = static_cast<float *>(inTensor->Data());
auto *out_data = static_cast<float *>(outTensor->Data());
auto in_shape = inTensor->shape();
auto out_shape = outTensor->shape();
auto *input_shape = &in_shape.front();
auto *output_shape = &out_shape.front();
in_data_ = reinterpret_cast<float *>(in_tensor->Data());
out_data_ = reinterpret_cast<float *>(out_tensor->Data());
in_shape_ = const_cast<int *>(in_tensor->shape().data());
out_shape_ = const_cast<int *>(out_tensor->shape().data());

ret =
DoTranspose(in_data, out_data, input_shape, output_shape, reinterpret_cast<TransposeParameter *>(op_parameter_));
ret = LiteBackendParallelLaunch(TransposeRun, this, thread_h_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Tranpose error error_code[" << ret << "]";
return ret;
}
return ret;
}
} // namespace mindspore::kernel

kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,


+ 10
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h View File

@@ -29,14 +29,23 @@ class TransposeCPUKernel : public LiteKernel {
explicit TransposeCPUKernel(OpParameter *param, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(param, inputs, outputs, ctx, primitive) {}
: LiteKernel(param, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {}
~TransposeCPUKernel() override = default;

int Init() override;
int ReSize() override;
int Run() override;
int TransposeParallel(int task_id);

private:
int thread_num_;
int thread_h_stride_;
int thread_h_num_;
int num_unit_;
float *in_data_;
float *out_data_;
int *in_shape_;
int *out_shape_;
};
} // namespace mindspore::kernel



+ 10
- 20
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_batch.c View File

@@ -28,7 +28,7 @@ int EnumElement(int *shape, int n_dims) {
}

void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm,
int *output_shape) {
int *output_shape, int h_start, int h_end) {
const int stride0 = strides[perm[0]];
const int stride1 = strides[perm[1]];
const int stride2 = strides[perm[2]];
@@ -40,7 +40,6 @@ void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *
const int out_stride3 = out_strides[3];
const int out_stride4 = out_strides[4];
const int output0 = output_shape[0];
const int output1 = output_shape[1];
const int output2 = output_shape[2];
const int output3 = output_shape[3];
const int output4 = output_shape[4];
@@ -48,7 +47,7 @@ void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *
for (int i = 0; i < output0; ++i) {
int out_stride0_i = i * out_stride0;
int stride0_i = i * stride0;
for (int j = 0; j < output1; ++j) {
for (int j = h_start; j < h_end; ++j) {
int out_stride1_j = j * out_stride1;
int stride1_j = j * stride1;
for (int k = 0; k < output2; ++k) {
@@ -69,7 +68,8 @@ void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *
}
}

int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_sizes) {
int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_sizes, int h_start,
int h_end) {
int trans_in_shape[6] = {in_shape[0], in_shape[1] / block_sizes[0],
block_sizes[0], in_shape[2] / block_sizes[1],
block_sizes[1], in_shape[3]};
@@ -82,7 +82,7 @@ int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int sh
ComputeStrides(trans_out_shape, out_strides, shape_size + 2);

int perm[6] = {0, 2, 4, 1, 3, 5};
TransposeForNHWC(input, output, in_strides, out_strides, perm, trans_out_shape);
TransposeForNHWC(input, output, in_strides, out_strides, perm, trans_out_shape, h_start, h_end);
return NNACL_OK;
}

@@ -137,21 +137,11 @@ void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter pa
}
}

int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, float *tmp_space[3]) {
float *padded_input = NULL;
int ret;
if (param.need_paddings_) {
if (tmp_space[0] == NULL || tmp_space[1] == NULL || tmp_space[2] == NULL) {
return NNACL_NULL_PTR;
}
padded_input = tmp_space[0];
DoPadding(input, padded_input, param, tmp_space + 1);
}

if (param.need_paddings_) {
ret = SpaceToBatchForNHWC(padded_input, output, param.padded_in_shape_, param.n_dims_, param.block_sizes_);
} else {
ret = SpaceToBatchForNHWC(input, output, param.padded_in_shape_, param.n_dims_, param.block_sizes_);
int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, int h_start, int h_end) {
if (input == NULL || output == NULL) {
return NNACL_NULL_PTR;
}
auto ret =
SpaceToBatchForNHWC(input, output, param.padded_in_shape_, param.n_dims_, param.block_sizes_, h_start, h_end);
return ret;
}

+ 5
- 3
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h View File

@@ -35,10 +35,12 @@ typedef struct SpaceToBatchParameter {
#ifdef __cplusplus
extern "C" {
#endif
int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, float *tmp_space[3]);
int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_size);
int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, int h_start, int h_end);
int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_size, int h_start,
int h_end);
void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm,
int *output_shape);
int *output_shape, int h_start, int h_end);
void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter param, float *tmp_space[]);
int EnumElement(int *shape, int n_dims);
#ifdef __cplusplus
}


+ 28
- 24
mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.c View File

@@ -18,21 +18,23 @@
#include <string.h>
#include "nnacl/errorcode.h"

void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) {
void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
int h_start, int h_end) {
const int stride0 = strides[perm[0]];
const int stride1 = strides[perm[1]];
const int output0 = output_shape[0];
const int output1 = output_shape[1];
for (int i = 0; i < output0; i++) {
for (int i = 0; i < output0; ++i) {
int out_stride0_i = i * output1;
int stride0_i = i * 1 * stride0;
for (int j = 0; j < output1; j++) {
for (int j = 0; j < output1; ++j) {
out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1];
}
}
}

void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) {
void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
int h_start, int h_end) {
const int stride0 = strides[perm[0]];
const int stride1 = strides[perm[1]];
const int stride2 = strides[perm[2]];
@@ -41,20 +43,21 @@ void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strid
const int output0 = output_shape[0];
const int output1 = output_shape[1];
const int output2 = output_shape[2];
for (int i = 0; i < output0; i++) {
for (int i = 0; i < output0; ++i) {
int out_stride0_i = i * out_stride0;
int stride0_i = i * stride0;
for (int j = 0; j < output1; j++) {
for (int j = 0; j < output1; ++j) {
int out_stride1_j = j * out_stride1;
int stride1_j = j * stride1;
for (int k = 0; k < output2; k++) {
for (int k = 0; k < output2; ++k) {
out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2];
}
}
}
}

void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) {
void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
int h_start, int h_end) {
const int stride0 = strides[perm[0]];
const int stride1 = strides[perm[1]];
const int stride2 = strides[perm[2]];
@@ -67,16 +70,16 @@ void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strid
const int output2 = output_shape[2];
const int output3 = output_shape[3];

for (int i = 0; i < output0; i++) {
for (int i = 0; i < output0; ++i) {
int out_stride0_i = i * out_stride0;
int stride0_i = i * stride0;
for (int j = 0; j < output1; j++) {
for (int j = 0; j < output1; ++j) {
int out_stride1_j = j * out_stride1;
int stride1_j = j * stride1;
for (int k = 0; k < output2; k++) {
for (int k = 0; k < output2; ++k) {
int out_stride2_k = k * out_stride2;
int stride2_k = k * stride2;
for (int m = 0; m < output3; m++) {
for (int m = 0; m < output3; ++m) {
out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] =
in_data[stride0_i + stride1_j + stride2_k + m * stride3];
}
@@ -85,7 +88,8 @@ void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strid
}
}

void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) {
void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
int h_start, int h_end) {
const int stride0 = strides[perm[0]];
const int stride1 = strides[perm[1]];
const int stride2 = strides[perm[2]];
@@ -101,19 +105,19 @@ void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strid
const int output3 = output_shape[3];
const int output4 = output_shape[4];

for (int i = 0; i < output0; i++) {
for (int i = 0; i < output0; ++i) {
int out_stride0_i = i * out_stride0;
int stride0_i = i * stride0;
for (int j = 0; j < output1; j++) {
for (int j = 0; j < output1; ++j) {
int out_stride1_j = j * out_stride1;
int stride1_j = j * stride1;
for (int k = 0; k < output2; k++) {
for (int k = 0; k < output2; ++k) {
int out_stride2_k = k * out_stride2;
int stride2_k = k * stride2;
for (int m = 0; m < output3; m++) {
for (int m = 0; m < output3; ++m) {
int out_stride3_m = m * out_stride3;
int stride3_m = m * stride3;
for (int n = 0; n < output4; n++) {
for (int n = 0; n < output4; ++n) {
out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] =
in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4];
}
@@ -124,7 +128,7 @@ void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strid
}

int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape,
TransposeParameter *transpose_param) {
TransposeParameter *transpose_param, int h_start, int h_end) {
if (in_data == NULL || out_data == NULL) {
return NNACL_ERR;
}
@@ -140,7 +144,7 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s

// check if transpose is needed
bool needTranspose = false;
for (int i = 1; i < num_axes; i++) {
for (int i = 1; i < num_axes; ++i) {
if (perm[i] - perm[i - 1] != 1) {
needTranspose = true;
break;
@@ -152,13 +156,13 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s
return NNACL_OK;
}
if (num_axes == 2) {
TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape);
TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end);
} else if (num_axes == 3) {
TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape);
TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end);
} else if (num_axes == 4) {
TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape);
TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end);
} else if (num_axes == 5) {
TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape);
TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end);
}
return NNACL_OK;
}

+ 9
- 5
mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h View File

@@ -33,11 +33,15 @@ typedef struct TransposeParameter {
extern "C" {
#endif
int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape,
TransposeParameter *transpose_param);
void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape);
void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape);
void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape);
void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape);
TransposeParameter *transpose_param, int h_start, int h_end);
void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
int h_start, int h_end);
void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
int h_start, int h_end);
void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
int h_start, int h_end);
void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape,
int h_start, int h_end);
#ifdef __cplusplus
}
#endif


+ 6
- 2
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/space_to_batch_fp32_tests.cc View File

@@ -85,7 +85,7 @@ TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest1) {
int in_shape[4] = {1, 4, 4, 1};
int out_shape[4] = {4, 2, 2, 1};
int block_sizes[2] = {2, 2};
SpaceToBatchForNHWC((const float *)input, output, in_shape, 4, block_sizes);
SpaceToBatchForNHWC((const float *)input, output, in_shape, 4, block_sizes, 0, 4 / 2);
for (int i = 0; i < out_size; ++i) {
std::cout << output[i] << " ";
}
@@ -107,7 +107,10 @@ TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest2) {

float padded_input[48]{}, tmp[48]{}, tmp_zero[48]{};
float *tmp_space[3] = {padded_input, tmp, tmp_zero};
auto ret = SpaceToBatch((const float *)input, output, param, tmp_space);
// DoPadding
DoPadding(input, padded_input, param, tmp_space + 1);

auto ret = SpaceToBatch((const float *)padded_input, output, param, 0, 4 / 2);
std::cout << "return " << ret << std::endl;
for (int i = 0; i < out_size; ++i) {
std::cout << output[i] << " ";
@@ -145,6 +148,7 @@ TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest3) {
outputs_tensor.emplace_back(&output_tensor);

lite::Context ctx;
ctx.thread_num_ = 2;
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_SpaceToBatch};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);


+ 220
- 0
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc View File

@@ -0,0 +1,220 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h"
#include "mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/lite_kernel.h"

namespace mindspore {

class TestTransposeFp32 : public mindspore::CommonTest {
public:
TestTransposeFp32() {}
};

TEST_F(TestTransposeFp32, TransposeFp32_axes4) {
/* 1x2x3x4 */
float in[24] = {-0.35779851, -0.4857257, 1.2791597, -0.36793608, 0.95098744, -0.12716428, 0.17405411, 0.42663834,
-1.11871315, 1.02777593, 1.20223761, 0.30183748, 1.39663453, -1.11923312, -1.02032341, 1.91074871,
1.52489095, -1.13020852, -0.66358529, 1.8033383, 0.62647028, 1.03094635, -1.65733338, 0.3952082};

float out[24] = {0};

float correct[24] = {-0.35779851, 1.39663453, 0.95098744, 1.52489095, -1.11871315, 0.62647028,
-0.4857257, -1.11923312, -0.12716428, -1.13020852, 1.02777593, 1.03094635,
1.2791597, -1.02032341, 0.17405411, -0.66358529, 1.20223761, -1.65733338,
-0.36793608, 1.91074871, 0.42663834, 1.8033383, 0.30183748, 0.3952082};

int input_shape[4] = {1, 2, 3, 4};
int output_shape[4] = {4, 3, 2, 1};
int perm[8] = {3, 2, 1, 0, 0, 0, 0, 0};
int strides[8] = {24, 12, 4, 1, 1, 1, 1, 1};
int out_strides[8] = {6, 2, 1, 1, 1, 1, 1, 1};

auto param = new (std::nothrow) TransposeParameter();
if (param == nullptr) {
MS_LOG(ERROR) << "New param fails.";
return;
}
param->num_axes_ = 4;
param->conjugate_ = false;
param->data_size_ = 24 * sizeof(float);

for (int i = 0; i < 8; i++) {
param->perm_[i] = perm[i];
param->strides_[i] = strides[i];
param->out_strides_[i] = out_strides[i];
}

auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 3);
MS_ASSERT(ret == 0);
delete param;
CompareOutputData(out, correct, 24, 0.000001);
}

TEST_F(TestTransposeFp32, TransposeFp32_axes3) {
/* 2x3x4 */
float in[24] = {1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387, 1.74481176, -0.7612069,
0.3190391, -0.24937038, 1.46210794, -2.06014071, -0.3224172, -0.38405435, 1.13376944, -1.09989127,
-0.17242821, -0.87785842, 0.04221375, 0.58281521, -1.10061918, 1.14472371, 0.90159072, 0.50249434};

float out[24] = {0};

float correct[24] = {1.62434536, -0.3224172, 0.86540763, -0.17242821, 0.3190391, -1.10061918,
-0.61175641, -0.38405435, -2.3015387, -0.87785842, -0.24937038, 1.14472371,
-0.52817175, 1.13376944, 1.74481176, 0.04221375, 1.46210794, 0.90159072,
-1.07296862, -1.09989127, -0.7612069, 0.58281521, -2.06014071, 0.50249434};

int input_shape[3] = {2, 3, 4};
int output_shape[3] = {4, 3, 2};
int perm[8] = {2, 1, 0, 0, 0, 0, 0, 0};
int strides[8] = {12, 4, 1, 1, 1, 1, 1, 1};
int out_strides[8] = {6, 2, 1, 1, 1, 1, 1, 1};

auto param = new (std::nothrow) TransposeParameter();
if (param == nullptr) {
MS_LOG(ERROR) << "New param fails.";
return;
}
param->num_axes_ = 3;
param->conjugate_ = false;
param->data_size_ = 24 * sizeof(float);

for (int i = 0; i < 8; i++) {
param->perm_[i] = perm[i];
param->strides_[i] = strides[i];
param->out_strides_[i] = out_strides[i];
}

auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 3);
MS_ASSERT(ret == 0);
delete param;
CompareOutputData(out, correct, 24, 0.000001);
}

TEST_F(TestTransposeFp32, TransposeFp32_axes2) {
/* 6x4 */
float in[24] = {1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387, 1.74481176, -0.7612069,
0.3190391, -0.24937038, 1.46210794, -2.06014071, -0.3224172, -0.38405435, 1.13376944, -1.09989127,
-0.17242821, -0.87785842, 0.04221375, 0.58281521, -1.10061918, 1.14472371, 0.90159072, 0.50249434};

float out[24] = {0};

float correct[24] = {1.62434536, 0.86540763, 0.3190391, -0.3224172, -0.17242821, -1.10061918,
-0.61175641, -2.3015387, -0.24937038, -0.38405435, -0.87785842, 1.14472371,
-0.52817175, 1.74481176, 1.46210794, 1.13376944, 0.04221375, 0.90159072,
-1.07296862, -0.7612069, -2.06014071, -1.09989127, 0.58281521, 0.50249434};

int input_shape[2] = {6, 4};
int output_shape[2] = {4, 6};
int perm[8] = {1, 0, 0, 0, 0, 0, 0, 0};
int strides[8] = {4, 1, 1, 1, 1, 1, 1, 1};
int out_strides[8] = {6, 1, 1, 1, 1, 1, 1, 1};

auto param = new (std::nothrow) TransposeParameter();
if (param == nullptr) {
MS_LOG(ERROR) << "New param fails.";
return;
}

param->num_axes_ = 2;
param->conjugate_ = false;
param->data_size_ = 24 * sizeof(float);

for (int i = 0; i < 8; i++) {
param->perm_[i] = perm[i];
param->strides_[i] = strides[i];
param->out_strides_[i] = out_strides[i];
}

auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 6);
MS_ASSERT(ret == 0);
delete param;
CompareOutputData(out, correct, 24, 0.000001);
}

TEST_F(TestTransposeFp32, TransposeFp32_test5) {
/* 1x2x3x2x2 */
std::vector<float> input = {1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387,
1.74481176, -0.7612069, 0.3190391, -0.24937038, 1.46210794, -2.06014071,
-0.3224172, -0.38405435, 1.13376944, -1.09989127, -0.17242821, -0.87785842,
0.04221375, 0.58281521, -1.10061918, 1.14472371, 0.90159072, 0.50249434};

float correct[24] = {1.62434536, -0.3224172, 0.86540763, -0.17242821, 0.3190391, -1.10061918,
-0.52817175, 1.13376944, 1.74481176, 0.04221375, 1.46210794, 0.90159072,
-0.61175641, -0.38405435, -2.3015387, -0.87785842, -0.24937038, 1.14472371,
-1.07296862, -1.09989127, -0.7612069, 0.58281521, -2.06014071, 0.50249434};

std::vector<float> output(24);
std::vector<int> input_shape = {1, 2, 3, 2, 2};
std::vector<int> output_shape = {2, 2, 3, 2, 1};
int perm[8] = {4, 3, 2, 1, 0, 0, 0, 0};
int strides[8] = {24, 12, 4, 2, 1, 1, 1, 1};
int out_strides[8] = {12, 6, 2, 1, 1, 1, 1, 1};

TransposeParameter param;
param.op_parameter_.type_ = schema::PrimitiveType_Transpose;
param.num_axes_ = 5;
param.conjugate_ = false;
param.data_size_ = 24 * sizeof(float);

for (int i = 0; i < 8; i++) {
param.perm_[i] = perm[i];
param.strides_[i] = strides[i];
param.out_strides_[i] = out_strides[i];
}

lite::tensor::Tensor input_tensor;
input_tensor.SetData(input.data());
input_tensor.set_shape(input_shape);
input_tensor.SetFormat(schema::Format_NHWC);
input_tensor.set_data_type(kNumberTypeFloat32);
std::vector<lite::tensor::Tensor *> inputs_tensor;
inputs_tensor.emplace_back(&input_tensor);

lite::tensor::Tensor output_tensor;
output_tensor.SetData(output.data());
output_tensor.set_shape(output_shape);
output_tensor.SetFormat(schema::Format_NHWC);
output_tensor.set_data_type(kNumberTypeFloat32);
std::vector<lite::tensor::Tensor *> outputs_tensor;
outputs_tensor.emplace_back(&output_tensor);

lite::Context ctx;
ctx.thread_num_ = 2;
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Transpose};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
kernel::LiteKernel *kernel =
creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(&param), &ctx, desc, nullptr);
ASSERT_NE(kernel, nullptr);
kernel->Run();

for (int i = 0; i < 24; ++i) {
std::cout << output[i] << " ";
}
std::cout << "\n";
CompareOutputData(output.data(), correct, 24, 0.000001);
input_tensor.SetData(nullptr);
output_tensor.SetData(nullptr);
}

} // namespace mindspore

Loading…
Cancel
Save