| @@ -21,6 +21,7 @@ | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/space_to_batch.h" | |||
| #include "src/runtime/kernel/arm/nnacl/errorcode.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -47,6 +48,31 @@ int SpaceToBatchCPUKernel::Init() { | |||
| return ReSize(); | |||
| } | |||
| int SpaceToBatchCPUKernel::SpaceToBatchParallel(int task_id) { | |||
| int num_unit_thread = MSMIN(thread_h_stride_, num_unit_ - task_id * thread_h_stride_); | |||
| if (num_unit_thread <= 0) { | |||
| return RET_OK; | |||
| } | |||
| int thread_offset = task_id * thread_h_stride_; | |||
| SpaceToBatchParameter *param = reinterpret_cast<SpaceToBatchParameter *>(this->op_parameter_); | |||
| auto ret = SpaceToBatch(input_ptr_, output_ptr_, *param, thread_offset, thread_offset + num_unit_thread); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SpaceToDepth error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SpaceToBatchRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto g_kernel = reinterpret_cast<SpaceToBatchCPUKernel *>(cdata); | |||
| auto ret = g_kernel->SpaceToBatchParallel(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SpaceToBatchRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return RET_OP_EXECUTE_FAILURE; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SpaceToBatchCPUKernel::ReSize() { | |||
| if (in_tensors_[0]->GetFormat() != schema::Format_NHWC) { | |||
| MS_LOG(ERROR) << "space_to_batch only support NHWC now!"; | |||
| @@ -55,6 +81,10 @@ int SpaceToBatchCPUKernel::ReSize() { | |||
| SpaceToBatchParameter *param = reinterpret_cast<SpaceToBatchParameter *>(this->op_parameter_); | |||
| param->num_elements_ = EnumElement(param->in_shape_, param->n_dims_); | |||
| param->num_elements_padded_ = EnumElement(param->padded_in_shape_, param->n_dims_); | |||
| num_unit_ = static_cast<int>(in_tensors_[kInputIndex]->shape().at(kNHWC_H)); | |||
| num_unit_ /= param->block_sizes_[0]; | |||
| thread_h_num_ = MSMIN(thread_num_, num_unit_); | |||
| thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); | |||
| return RET_OK; | |||
| } | |||
| @@ -81,20 +111,28 @@ int SpaceToBatchCPUKernel::Run() { | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| ret = SpaceToBatch(input_ptr_, output_ptr_, *param, tmp_space); | |||
| for (int i = 0; i < 3; ++i) { | |||
| context_->allocator->Free(tmp_space); | |||
| auto padded_input = tmp_space[0]; | |||
| DoPadding(input_ptr_, padded_input, *param, tmp_space + 1); | |||
| input_ptr_ = padded_input; | |||
| } | |||
| if (input->GetFormat() == schema::Format_NHWC) { | |||
| ret = LiteBackendParallelLaunch(SpaceToBatchRun, this, thread_h_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "SpaceToBatch error error_code[" << ret << "]"; | |||
| } | |||
| } else { | |||
| ret = SpaceToBatch(input_ptr_, output_ptr_, *param, tmp_space); | |||
| MS_LOG(ERROR) << "Only support NHWC now!"; | |||
| ret = RET_FORMAT_ERR; | |||
| } | |||
| if (ret != NNACL_OK) { | |||
| MS_LOG(ERROR) << "Do space to batch fails!"; | |||
| return RET_OP_EXECUTE_FAILURE; | |||
| if (param->need_paddings_) { | |||
| for (int i = 0; i < 3; ++i) { | |||
| context_->allocator->Free(tmp_space[i]); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| return ret; | |||
| } // namespace mindspore::kernel | |||
| kernel::LiteKernel *CpuSpaceToBatchFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| @@ -25,15 +25,20 @@ class SpaceToBatchCPUKernel : public LiteKernel { | |||
| SpaceToBatchCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | |||
| const lite::Primitive *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {} | |||
| ~SpaceToBatchCPUKernel() = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int SpaceToBatchParallel(int task_id); | |||
| private: | |||
| int thread_num_; | |||
| int thread_h_stride_; | |||
| int thread_h_num_; | |||
| int num_unit_; | |||
| const float *input_ptr_; | |||
| float *output_ptr_; | |||
| }; | |||
| @@ -20,10 +20,12 @@ | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::lite::RET_OP_EXECUTE_FAILURE; | |||
| using mindspore::schema::PrimitiveType_Transpose; | |||
| namespace mindspore::kernel { | |||
| @@ -32,6 +34,10 @@ constexpr int kTransposeInputNum = 1; | |||
| constexpr int kTransposeOutputNum = 1; | |||
| } // namespace | |||
| int TransposeCPUKernel::Init() { | |||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||
| num_unit_ = static_cast<int>(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H])); | |||
| thread_h_num_ = MSMIN(thread_num_, num_unit_); | |||
| thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| @@ -54,6 +60,32 @@ int TransposeCPUKernel::ReSize() { | |||
| return RET_OK; | |||
| } | |||
| int TransposeCPUKernel::TransposeParallel(int task_id) { | |||
| int num_unit_thread = MSMIN(thread_h_stride_, num_unit_ - task_id * thread_h_stride_); | |||
| if (num_unit_thread <= 0) { | |||
| return RET_OK; | |||
| } | |||
| int thread_offset = task_id * thread_h_stride_; | |||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||
| auto ret = | |||
| DoTranspose(in_data_, out_data_, in_shape_, out_shape_, param, thread_offset, thread_offset + num_unit_thread); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Transpose error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int TransposeRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto g_kernel = reinterpret_cast<TransposeCPUKernel *>(cdata); | |||
| auto ret = g_kernel->TransposeParallel(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "TransposeRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return RET_OP_EXECUTE_FAILURE; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int TransposeCPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| @@ -62,23 +94,24 @@ int TransposeCPUKernel::Run() { | |||
| } | |||
| MS_ASSERT(in_tensors_.size() == TransposeInputNum); | |||
| MS_ASSERT(out_tensors_.size() == TransposeOutputNum); | |||
| auto &inTensor = in_tensors_.front(); | |||
| auto &outTensor = out_tensors_.front(); | |||
| if (inTensor == nullptr || outTensor == nullptr) { | |||
| auto &in_tensor = in_tensors_.front(); | |||
| auto &out_tensor = out_tensors_.front(); | |||
| if (in_tensor == nullptr || out_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "null pointer dreferencing."; | |||
| return RET_ERROR; | |||
| } | |||
| auto *in_data = static_cast<float *>(inTensor->Data()); | |||
| auto *out_data = static_cast<float *>(outTensor->Data()); | |||
| auto in_shape = inTensor->shape(); | |||
| auto out_shape = outTensor->shape(); | |||
| auto *input_shape = &in_shape.front(); | |||
| auto *output_shape = &out_shape.front(); | |||
| in_data_ = reinterpret_cast<float *>(in_tensor->Data()); | |||
| out_data_ = reinterpret_cast<float *>(out_tensor->Data()); | |||
| in_shape_ = const_cast<int *>(in_tensor->shape().data()); | |||
| out_shape_ = const_cast<int *>(out_tensor->shape().data()); | |||
| ret = | |||
| DoTranspose(in_data, out_data, input_shape, output_shape, reinterpret_cast<TransposeParameter *>(op_parameter_)); | |||
| ret = LiteBackendParallelLaunch(TransposeRun, this, thread_h_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Tranpose error error_code[" << ret << "]"; | |||
| return ret; | |||
| } | |||
| return ret; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| @@ -29,14 +29,23 @@ class TransposeCPUKernel : public LiteKernel { | |||
| explicit TransposeCPUKernel(OpParameter *param, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | |||
| const lite::Primitive *primitive) | |||
| : LiteKernel(param, inputs, outputs, ctx, primitive) {} | |||
| : LiteKernel(param, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {} | |||
| ~TransposeCPUKernel() override = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int TransposeParallel(int task_id); | |||
| private: | |||
| int thread_num_; | |||
| int thread_h_stride_; | |||
| int thread_h_num_; | |||
| int num_unit_; | |||
| float *in_data_; | |||
| float *out_data_; | |||
| int *in_shape_; | |||
| int *out_shape_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -28,7 +28,7 @@ int EnumElement(int *shape, int n_dims) { | |||
| } | |||
| void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm, | |||
| int *output_shape) { | |||
| int *output_shape, int h_start, int h_end) { | |||
| const int stride0 = strides[perm[0]]; | |||
| const int stride1 = strides[perm[1]]; | |||
| const int stride2 = strides[perm[2]]; | |||
| @@ -40,7 +40,6 @@ void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int * | |||
| const int out_stride3 = out_strides[3]; | |||
| const int out_stride4 = out_strides[4]; | |||
| const int output0 = output_shape[0]; | |||
| const int output1 = output_shape[1]; | |||
| const int output2 = output_shape[2]; | |||
| const int output3 = output_shape[3]; | |||
| const int output4 = output_shape[4]; | |||
| @@ -48,7 +47,7 @@ void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int * | |||
| for (int i = 0; i < output0; ++i) { | |||
| int out_stride0_i = i * out_stride0; | |||
| int stride0_i = i * stride0; | |||
| for (int j = 0; j < output1; ++j) { | |||
| for (int j = h_start; j < h_end; ++j) { | |||
| int out_stride1_j = j * out_stride1; | |||
| int stride1_j = j * stride1; | |||
| for (int k = 0; k < output2; ++k) { | |||
| @@ -69,7 +68,8 @@ void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int * | |||
| } | |||
| } | |||
| int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_sizes) { | |||
| int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_sizes, int h_start, | |||
| int h_end) { | |||
| int trans_in_shape[6] = {in_shape[0], in_shape[1] / block_sizes[0], | |||
| block_sizes[0], in_shape[2] / block_sizes[1], | |||
| block_sizes[1], in_shape[3]}; | |||
| @@ -82,7 +82,7 @@ int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int sh | |||
| ComputeStrides(trans_out_shape, out_strides, shape_size + 2); | |||
| int perm[6] = {0, 2, 4, 1, 3, 5}; | |||
| TransposeForNHWC(input, output, in_strides, out_strides, perm, trans_out_shape); | |||
| TransposeForNHWC(input, output, in_strides, out_strides, perm, trans_out_shape, h_start, h_end); | |||
| return NNACL_OK; | |||
| } | |||
| @@ -137,21 +137,11 @@ void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter pa | |||
| } | |||
| } | |||
| int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, float *tmp_space[3]) { | |||
| float *padded_input = NULL; | |||
| int ret; | |||
| if (param.need_paddings_) { | |||
| if (tmp_space[0] == NULL || tmp_space[1] == NULL || tmp_space[2] == NULL) { | |||
| return NNACL_NULL_PTR; | |||
| } | |||
| padded_input = tmp_space[0]; | |||
| DoPadding(input, padded_input, param, tmp_space + 1); | |||
| } | |||
| if (param.need_paddings_) { | |||
| ret = SpaceToBatchForNHWC(padded_input, output, param.padded_in_shape_, param.n_dims_, param.block_sizes_); | |||
| } else { | |||
| ret = SpaceToBatchForNHWC(input, output, param.padded_in_shape_, param.n_dims_, param.block_sizes_); | |||
| int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, int h_start, int h_end) { | |||
| if (input == NULL || output == NULL) { | |||
| return NNACL_NULL_PTR; | |||
| } | |||
| auto ret = | |||
| SpaceToBatchForNHWC(input, output, param.padded_in_shape_, param.n_dims_, param.block_sizes_, h_start, h_end); | |||
| return ret; | |||
| } | |||
| @@ -35,10 +35,12 @@ typedef struct SpaceToBatchParameter { | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, float *tmp_space[3]); | |||
| int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_size); | |||
| int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, int h_start, int h_end); | |||
| int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_size, int h_start, | |||
| int h_end); | |||
| void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm, | |||
| int *output_shape); | |||
| int *output_shape, int h_start, int h_end); | |||
| void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter param, float *tmp_space[]); | |||
| int EnumElement(int *shape, int n_dims); | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -18,21 +18,23 @@ | |||
| #include <string.h> | |||
| #include "nnacl/errorcode.h" | |||
| void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { | |||
| void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape, | |||
| int h_start, int h_end) { | |||
| const int stride0 = strides[perm[0]]; | |||
| const int stride1 = strides[perm[1]]; | |||
| const int output0 = output_shape[0]; | |||
| const int output1 = output_shape[1]; | |||
| for (int i = 0; i < output0; i++) { | |||
| for (int i = 0; i < output0; ++i) { | |||
| int out_stride0_i = i * output1; | |||
| int stride0_i = i * 1 * stride0; | |||
| for (int j = 0; j < output1; j++) { | |||
| for (int j = 0; j < output1; ++j) { | |||
| out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; | |||
| } | |||
| } | |||
| } | |||
| void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { | |||
| void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape, | |||
| int h_start, int h_end) { | |||
| const int stride0 = strides[perm[0]]; | |||
| const int stride1 = strides[perm[1]]; | |||
| const int stride2 = strides[perm[2]]; | |||
| @@ -41,20 +43,21 @@ void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strid | |||
| const int output0 = output_shape[0]; | |||
| const int output1 = output_shape[1]; | |||
| const int output2 = output_shape[2]; | |||
| for (int i = 0; i < output0; i++) { | |||
| for (int i = 0; i < output0; ++i) { | |||
| int out_stride0_i = i * out_stride0; | |||
| int stride0_i = i * stride0; | |||
| for (int j = 0; j < output1; j++) { | |||
| for (int j = 0; j < output1; ++j) { | |||
| int out_stride1_j = j * out_stride1; | |||
| int stride1_j = j * stride1; | |||
| for (int k = 0; k < output2; k++) { | |||
| for (int k = 0; k < output2; ++k) { | |||
| out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { | |||
| void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape, | |||
| int h_start, int h_end) { | |||
| const int stride0 = strides[perm[0]]; | |||
| const int stride1 = strides[perm[1]]; | |||
| const int stride2 = strides[perm[2]]; | |||
| @@ -67,16 +70,16 @@ void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strid | |||
| const int output2 = output_shape[2]; | |||
| const int output3 = output_shape[3]; | |||
| for (int i = 0; i < output0; i++) { | |||
| for (int i = 0; i < output0; ++i) { | |||
| int out_stride0_i = i * out_stride0; | |||
| int stride0_i = i * stride0; | |||
| for (int j = 0; j < output1; j++) { | |||
| for (int j = 0; j < output1; ++j) { | |||
| int out_stride1_j = j * out_stride1; | |||
| int stride1_j = j * stride1; | |||
| for (int k = 0; k < output2; k++) { | |||
| for (int k = 0; k < output2; ++k) { | |||
| int out_stride2_k = k * out_stride2; | |||
| int stride2_k = k * stride2; | |||
| for (int m = 0; m < output3; m++) { | |||
| for (int m = 0; m < output3; ++m) { | |||
| out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = | |||
| in_data[stride0_i + stride1_j + stride2_k + m * stride3]; | |||
| } | |||
| @@ -85,7 +88,8 @@ void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strid | |||
| } | |||
| } | |||
| void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape) { | |||
| void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape, | |||
| int h_start, int h_end) { | |||
| const int stride0 = strides[perm[0]]; | |||
| const int stride1 = strides[perm[1]]; | |||
| const int stride2 = strides[perm[2]]; | |||
| @@ -101,19 +105,19 @@ void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strid | |||
| const int output3 = output_shape[3]; | |||
| const int output4 = output_shape[4]; | |||
| for (int i = 0; i < output0; i++) { | |||
| for (int i = 0; i < output0; ++i) { | |||
| int out_stride0_i = i * out_stride0; | |||
| int stride0_i = i * stride0; | |||
| for (int j = 0; j < output1; j++) { | |||
| for (int j = 0; j < output1; ++j) { | |||
| int out_stride1_j = j * out_stride1; | |||
| int stride1_j = j * stride1; | |||
| for (int k = 0; k < output2; k++) { | |||
| for (int k = 0; k < output2; ++k) { | |||
| int out_stride2_k = k * out_stride2; | |||
| int stride2_k = k * stride2; | |||
| for (int m = 0; m < output3; m++) { | |||
| for (int m = 0; m < output3; ++m) { | |||
| int out_stride3_m = m * out_stride3; | |||
| int stride3_m = m * stride3; | |||
| for (int n = 0; n < output4; n++) { | |||
| for (int n = 0; n < output4; ++n) { | |||
| out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = | |||
| in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; | |||
| } | |||
| @@ -124,7 +128,7 @@ void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strid | |||
| } | |||
| int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape, | |||
| TransposeParameter *transpose_param) { | |||
| TransposeParameter *transpose_param, int h_start, int h_end) { | |||
| if (in_data == NULL || out_data == NULL) { | |||
| return NNACL_ERR; | |||
| } | |||
| @@ -140,7 +144,7 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s | |||
| // check if transpose is needed | |||
| bool needTranspose = false; | |||
| for (int i = 1; i < num_axes; i++) { | |||
| for (int i = 1; i < num_axes; ++i) { | |||
| if (perm[i] - perm[i - 1] != 1) { | |||
| needTranspose = true; | |||
| break; | |||
| @@ -152,13 +156,13 @@ int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_s | |||
| return NNACL_OK; | |||
| } | |||
| if (num_axes == 2) { | |||
| TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape); | |||
| TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); | |||
| } else if (num_axes == 3) { | |||
| TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape); | |||
| TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); | |||
| } else if (num_axes == 4) { | |||
| TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape); | |||
| TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); | |||
| } else if (num_axes == 5) { | |||
| TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape); | |||
| TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -33,11 +33,15 @@ typedef struct TransposeParameter { | |||
| extern "C" { | |||
| #endif | |||
| int DoTranspose(float *in_data, float *out_data, int *input_shape, int *output_shape, | |||
| TransposeParameter *transpose_param); | |||
| void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); | |||
| void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); | |||
| void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); | |||
| void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape); | |||
| TransposeParameter *transpose_param, int h_start, int h_end); | |||
| void TransposeDim2(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape, | |||
| int h_start, int h_end); | |||
| void TransposeDim3(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape, | |||
| int h_start, int h_end); | |||
| void TransposeDim4(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape, | |||
| int h_start, int h_end); | |||
| void TransposeDim5(float *in_data, float *out_data, int *strides, int *out_strides, int *perm, int *output_shape, | |||
| int h_start, int h_end); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -85,7 +85,7 @@ TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest1) { | |||
| int in_shape[4] = {1, 4, 4, 1}; | |||
| int out_shape[4] = {4, 2, 2, 1}; | |||
| int block_sizes[2] = {2, 2}; | |||
| SpaceToBatchForNHWC((const float *)input, output, in_shape, 4, block_sizes); | |||
| SpaceToBatchForNHWC((const float *)input, output, in_shape, 4, block_sizes, 0, 4 / 2); | |||
| for (int i = 0; i < out_size; ++i) { | |||
| std::cout << output[i] << " "; | |||
| } | |||
| @@ -107,7 +107,10 @@ TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest2) { | |||
| float padded_input[48]{}, tmp[48]{}, tmp_zero[48]{}; | |||
| float *tmp_space[3] = {padded_input, tmp, tmp_zero}; | |||
| auto ret = SpaceToBatch((const float *)input, output, param, tmp_space); | |||
| // DoPadding | |||
| DoPadding(input, padded_input, param, tmp_space + 1); | |||
| auto ret = SpaceToBatch((const float *)padded_input, output, param, 0, 4 / 2); | |||
| std::cout << "return " << ret << std::endl; | |||
| for (int i = 0; i < out_size; ++i) { | |||
| std::cout << output[i] << " "; | |||
| @@ -145,6 +148,7 @@ TEST_F(SpaceToBatchTestFp32, SpaceToBatchTest3) { | |||
| outputs_tensor.emplace_back(&output_tensor); | |||
| lite::Context ctx; | |||
| ctx.thread_num_ = 2; | |||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_SpaceToBatch}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| ASSERT_NE(creator, nullptr); | |||
| @@ -0,0 +1,220 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include "utils/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/runtime/kernel/arm/fp32/transpose.h" | |||
| #include "mindspore/lite/src/runtime/kernel/arm/nnacl/transpose.h" | |||
| #include "mindspore/lite/src/kernel_registry.h" | |||
| #include "mindspore/lite/src/lite_kernel.h" | |||
| namespace mindspore { | |||
| class TestTransposeFp32 : public mindspore::CommonTest { | |||
| public: | |||
| TestTransposeFp32() {} | |||
| }; | |||
| TEST_F(TestTransposeFp32, TransposeFp32_axes4) { | |||
| /* 1x2x3x4 */ | |||
| float in[24] = {-0.35779851, -0.4857257, 1.2791597, -0.36793608, 0.95098744, -0.12716428, 0.17405411, 0.42663834, | |||
| -1.11871315, 1.02777593, 1.20223761, 0.30183748, 1.39663453, -1.11923312, -1.02032341, 1.91074871, | |||
| 1.52489095, -1.13020852, -0.66358529, 1.8033383, 0.62647028, 1.03094635, -1.65733338, 0.3952082}; | |||
| float out[24] = {0}; | |||
| float correct[24] = {-0.35779851, 1.39663453, 0.95098744, 1.52489095, -1.11871315, 0.62647028, | |||
| -0.4857257, -1.11923312, -0.12716428, -1.13020852, 1.02777593, 1.03094635, | |||
| 1.2791597, -1.02032341, 0.17405411, -0.66358529, 1.20223761, -1.65733338, | |||
| -0.36793608, 1.91074871, 0.42663834, 1.8033383, 0.30183748, 0.3952082}; | |||
| int input_shape[4] = {1, 2, 3, 4}; | |||
| int output_shape[4] = {4, 3, 2, 1}; | |||
| int perm[8] = {3, 2, 1, 0, 0, 0, 0, 0}; | |||
| int strides[8] = {24, 12, 4, 1, 1, 1, 1, 1}; | |||
| int out_strides[8] = {6, 2, 1, 1, 1, 1, 1, 1}; | |||
| auto param = new (std::nothrow) TransposeParameter(); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "New param fails."; | |||
| return; | |||
| } | |||
| param->num_axes_ = 4; | |||
| param->conjugate_ = false; | |||
| param->data_size_ = 24 * sizeof(float); | |||
| for (int i = 0; i < 8; i++) { | |||
| param->perm_[i] = perm[i]; | |||
| param->strides_[i] = strides[i]; | |||
| param->out_strides_[i] = out_strides[i]; | |||
| } | |||
| auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 3); | |||
| MS_ASSERT(ret == 0); | |||
| delete param; | |||
| CompareOutputData(out, correct, 24, 0.000001); | |||
| } | |||
| TEST_F(TestTransposeFp32, TransposeFp32_axes3) { | |||
| /* 2x3x4 */ | |||
| float in[24] = {1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387, 1.74481176, -0.7612069, | |||
| 0.3190391, -0.24937038, 1.46210794, -2.06014071, -0.3224172, -0.38405435, 1.13376944, -1.09989127, | |||
| -0.17242821, -0.87785842, 0.04221375, 0.58281521, -1.10061918, 1.14472371, 0.90159072, 0.50249434}; | |||
| float out[24] = {0}; | |||
| float correct[24] = {1.62434536, -0.3224172, 0.86540763, -0.17242821, 0.3190391, -1.10061918, | |||
| -0.61175641, -0.38405435, -2.3015387, -0.87785842, -0.24937038, 1.14472371, | |||
| -0.52817175, 1.13376944, 1.74481176, 0.04221375, 1.46210794, 0.90159072, | |||
| -1.07296862, -1.09989127, -0.7612069, 0.58281521, -2.06014071, 0.50249434}; | |||
| int input_shape[3] = {2, 3, 4}; | |||
| int output_shape[3] = {4, 3, 2}; | |||
| int perm[8] = {2, 1, 0, 0, 0, 0, 0, 0}; | |||
| int strides[8] = {12, 4, 1, 1, 1, 1, 1, 1}; | |||
| int out_strides[8] = {6, 2, 1, 1, 1, 1, 1, 1}; | |||
| auto param = new (std::nothrow) TransposeParameter(); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "New param fails."; | |||
| return; | |||
| } | |||
| param->num_axes_ = 3; | |||
| param->conjugate_ = false; | |||
| param->data_size_ = 24 * sizeof(float); | |||
| for (int i = 0; i < 8; i++) { | |||
| param->perm_[i] = perm[i]; | |||
| param->strides_[i] = strides[i]; | |||
| param->out_strides_[i] = out_strides[i]; | |||
| } | |||
| auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 3); | |||
| MS_ASSERT(ret == 0); | |||
| delete param; | |||
| CompareOutputData(out, correct, 24, 0.000001); | |||
| } | |||
| TEST_F(TestTransposeFp32, TransposeFp32_axes2) { | |||
| /* 6x4 */ | |||
| float in[24] = {1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387, 1.74481176, -0.7612069, | |||
| 0.3190391, -0.24937038, 1.46210794, -2.06014071, -0.3224172, -0.38405435, 1.13376944, -1.09989127, | |||
| -0.17242821, -0.87785842, 0.04221375, 0.58281521, -1.10061918, 1.14472371, 0.90159072, 0.50249434}; | |||
| float out[24] = {0}; | |||
| float correct[24] = {1.62434536, 0.86540763, 0.3190391, -0.3224172, -0.17242821, -1.10061918, | |||
| -0.61175641, -2.3015387, -0.24937038, -0.38405435, -0.87785842, 1.14472371, | |||
| -0.52817175, 1.74481176, 1.46210794, 1.13376944, 0.04221375, 0.90159072, | |||
| -1.07296862, -0.7612069, -2.06014071, -1.09989127, 0.58281521, 0.50249434}; | |||
| int input_shape[2] = {6, 4}; | |||
| int output_shape[2] = {4, 6}; | |||
| int perm[8] = {1, 0, 0, 0, 0, 0, 0, 0}; | |||
| int strides[8] = {4, 1, 1, 1, 1, 1, 1, 1}; | |||
| int out_strides[8] = {6, 1, 1, 1, 1, 1, 1, 1}; | |||
| auto param = new (std::nothrow) TransposeParameter(); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "New param fails."; | |||
| return; | |||
| } | |||
| param->num_axes_ = 2; | |||
| param->conjugate_ = false; | |||
| param->data_size_ = 24 * sizeof(float); | |||
| for (int i = 0; i < 8; i++) { | |||
| param->perm_[i] = perm[i]; | |||
| param->strides_[i] = strides[i]; | |||
| param->out_strides_[i] = out_strides[i]; | |||
| } | |||
| auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 6); | |||
| MS_ASSERT(ret == 0); | |||
| delete param; | |||
| CompareOutputData(out, correct, 24, 0.000001); | |||
| } | |||
| TEST_F(TestTransposeFp32, TransposeFp32_test5) { | |||
| /* 1x2x3x2x2 */ | |||
| std::vector<float> input = {1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387, | |||
| 1.74481176, -0.7612069, 0.3190391, -0.24937038, 1.46210794, -2.06014071, | |||
| -0.3224172, -0.38405435, 1.13376944, -1.09989127, -0.17242821, -0.87785842, | |||
| 0.04221375, 0.58281521, -1.10061918, 1.14472371, 0.90159072, 0.50249434}; | |||
| float correct[24] = {1.62434536, -0.3224172, 0.86540763, -0.17242821, 0.3190391, -1.10061918, | |||
| -0.52817175, 1.13376944, 1.74481176, 0.04221375, 1.46210794, 0.90159072, | |||
| -0.61175641, -0.38405435, -2.3015387, -0.87785842, -0.24937038, 1.14472371, | |||
| -1.07296862, -1.09989127, -0.7612069, 0.58281521, -2.06014071, 0.50249434}; | |||
| std::vector<float> output(24); | |||
| std::vector<int> input_shape = {1, 2, 3, 2, 2}; | |||
| std::vector<int> output_shape = {2, 2, 3, 2, 1}; | |||
| int perm[8] = {4, 3, 2, 1, 0, 0, 0, 0}; | |||
| int strides[8] = {24, 12, 4, 2, 1, 1, 1, 1}; | |||
| int out_strides[8] = {12, 6, 2, 1, 1, 1, 1, 1}; | |||
| TransposeParameter param; | |||
| param.op_parameter_.type_ = schema::PrimitiveType_Transpose; | |||
| param.num_axes_ = 5; | |||
| param.conjugate_ = false; | |||
| param.data_size_ = 24 * sizeof(float); | |||
| for (int i = 0; i < 8; i++) { | |||
| param.perm_[i] = perm[i]; | |||
| param.strides_[i] = strides[i]; | |||
| param.out_strides_[i] = out_strides[i]; | |||
| } | |||
| lite::tensor::Tensor input_tensor; | |||
| input_tensor.SetData(input.data()); | |||
| input_tensor.set_shape(input_shape); | |||
| input_tensor.SetFormat(schema::Format_NHWC); | |||
| input_tensor.set_data_type(kNumberTypeFloat32); | |||
| std::vector<lite::tensor::Tensor *> inputs_tensor; | |||
| inputs_tensor.emplace_back(&input_tensor); | |||
| lite::tensor::Tensor output_tensor; | |||
| output_tensor.SetData(output.data()); | |||
| output_tensor.set_shape(output_shape); | |||
| output_tensor.SetFormat(schema::Format_NHWC); | |||
| output_tensor.set_data_type(kNumberTypeFloat32); | |||
| std::vector<lite::tensor::Tensor *> outputs_tensor; | |||
| outputs_tensor.emplace_back(&output_tensor); | |||
| lite::Context ctx; | |||
| ctx.thread_num_ = 2; | |||
| kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_Transpose}; | |||
| auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); | |||
| ASSERT_NE(creator, nullptr); | |||
| kernel::LiteKernel *kernel = | |||
| creator(inputs_tensor, outputs_tensor, reinterpret_cast<OpParameter *>(¶m), &ctx, desc, nullptr); | |||
| ASSERT_NE(kernel, nullptr); | |||
| kernel->Run(); | |||
| for (int i = 0; i < 24; ++i) { | |||
| std::cout << output[i] << " "; | |||
| } | |||
| std::cout << "\n"; | |||
| CompareOutputData(output.data(), correct, 24, 0.000001); | |||
| input_tensor.SetData(nullptr); | |||
| output_tensor.SetData(nullptr); | |||
| } | |||
| } // namespace mindspore | |||