Merge pull request !4539 from zhaozhenlong/lite/ops/fp16/transposetags/v0.7.0-beta
| @@ -0,0 +1,156 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp16/transpose_fp16.h" | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/nnacl/fp16/transpose_fp16.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::lite::RET_OP_EXECUTE_FAILURE; | |||
| using mindspore::schema::PrimitiveType_Transpose; | |||
| namespace mindspore::kernel { | |||
| namespace { | |||
| constexpr int kTransposeInputNum = 1; | |||
| constexpr int kTransposeOutputNum = 1; | |||
| } // namespace | |||
| int TransposeFp16CPUKernel::Init() { | |||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||
| num_unit_ = static_cast<int>(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H])); | |||
| thread_h_num_ = MSMIN(thread_num_, num_unit_); | |||
| thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| return ReSize(); | |||
| } | |||
| int TransposeFp16CPUKernel::ReSize() { | |||
| auto &inTensor = in_tensors_.front(); | |||
| auto &outTensor = out_tensors_.front(); | |||
| auto param = reinterpret_cast<TransposeParameter *>(op_parameter_); | |||
| auto in_shape = inTensor->shape(); | |||
| auto out_shape = outTensor->shape(); | |||
| param->strides_[param->num_axes_ - 1] = 1; | |||
| param->out_strides_[param->num_axes_ - 1] = 1; | |||
| param->data_size_ = inTensor->Size(); | |||
| for (int i = param->num_axes_ - 2; i >= 0; i--) { | |||
| param->strides_[i] = in_shape[i + 1] * param->strides_[i + 1]; | |||
| param->out_strides_[i] = out_shape[i + 1] * param->out_strides_[i + 1]; | |||
| } | |||
| if (fp16_in_data_ != nullptr) { | |||
| free(fp16_in_data_); | |||
| fp16_in_data_ = nullptr; | |||
| } | |||
| fp16_in_data_ = reinterpret_cast<float16_t *>(malloc(sizeof(float16_t) * inTensor->ElementsNum())); | |||
| if (fp16_out_data_ != nullptr) { | |||
| free(fp16_out_data_); | |||
| fp16_out_data_ = nullptr; | |||
| } | |||
| fp16_out_data_ = reinterpret_cast<float16_t *>(malloc(sizeof(float16_t) * outTensor->ElementsNum())); | |||
| return RET_OK; | |||
| } | |||
| int TransposeFp16CPUKernel::TransposeParallel(int task_id) { | |||
| int num_unit_thread = MSMIN(thread_h_stride_, num_unit_ - task_id * thread_h_stride_); | |||
| if (num_unit_thread <= 0) { | |||
| return RET_OK; | |||
| } | |||
| int thread_offset = task_id * thread_h_stride_; | |||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||
| auto ret = DoTranspose(fp16_in_data_, fp16_out_data_, in_shape_, out_shape_, param, thread_offset, | |||
| thread_offset + num_unit_thread); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Transpose error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int TransposeRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto g_kernel = reinterpret_cast<TransposeFp16CPUKernel *>(cdata); | |||
| auto ret = g_kernel->TransposeParallel(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "TransposeRun error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return RET_OP_EXECUTE_FAILURE; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int TransposeFp16CPUKernel::Run() { | |||
| auto ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << ret; | |||
| return ret; | |||
| } | |||
| MS_ASSERT(in_tensors_.size() == TransposeInputNum); | |||
| MS_ASSERT(out_tensors_.size() == TransposeOutputNum); | |||
| auto &in_tensor = in_tensors_.front(); | |||
| auto &out_tensor = out_tensors_.front(); | |||
| if (in_tensor == nullptr || out_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "null pointer dreferencing."; | |||
| return RET_ERROR; | |||
| } | |||
| in_data_ = reinterpret_cast<float *>(in_tensor->Data()); | |||
| out_data_ = reinterpret_cast<float *>(out_tensor->Data()); | |||
| Float32ToFloat16(in_data_, fp16_in_data_, in_tensor->ElementsNum()); | |||
| in_shape_ = const_cast<int *>(in_tensor->shape().data()); | |||
| out_shape_ = const_cast<int *>(out_tensor->shape().data()); | |||
| ret = LiteBackendParallelLaunch(TransposeRun, this, thread_h_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Tranpose error error_code[" << ret << "]"; | |||
| return ret; | |||
| } | |||
| Float16ToFloat32(fp16_out_data_, out_data_, out_tensor->ElementsNum()); | |||
| return ret; | |||
| } // namespace mindspore::kernel | |||
| kernel::LiteKernel *CpuTransposeFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Transpose); | |||
| if (opParameter == nullptr) { | |||
| MS_LOG(ERROR) << "desc type is not Transpose"; | |||
| return nullptr; | |||
| } | |||
| auto *kernel = new (std::nothrow) TransposeFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "New kernel fails."; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Transpose, CpuTransposeFp16KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP16_TRANSPOSE_FP16_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP16_TRANSPOSE_FP16_H_ | |||
| #include <arm_neon.h> | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/kernel_factory.h" | |||
| namespace mindspore::kernel { | |||
| class TransposeFp16CPUKernel : public LiteKernel { | |||
| public: | |||
| explicit TransposeFp16CPUKernel(OpParameter *param, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | |||
| const lite::Primitive *primitive) | |||
| : LiteKernel(param, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {} | |||
| ~TransposeFp16CPUKernel() { | |||
| if (fp16_in_data_ != nullptr) { | |||
| free(fp16_in_data_); | |||
| fp16_in_data_ = nullptr; | |||
| } | |||
| if (fp16_out_data_ != nullptr) { | |||
| free(fp16_out_data_); | |||
| fp16_out_data_ = nullptr; | |||
| } | |||
| } | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int TransposeParallel(int task_id); | |||
| private: | |||
| int thread_num_; | |||
| int thread_h_stride_; | |||
| int thread_h_num_; | |||
| int num_unit_; | |||
| float *in_data_; | |||
| float *out_data_; | |||
| float16_t *fp16_in_data_ = nullptr; | |||
| float16_t *fp16_out_data_ = nullptr; | |||
| int *in_shape_; | |||
| int *out_shape_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP16_TRANSPOSE_FP16_H_ | |||
| @@ -0,0 +1,168 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp16/transpose_fp16.h" | |||
| #include <string.h> | |||
| #include "nnacl/errorcode.h" | |||
| void TransposeDim2(float16_t *in_data, float16_t *out_data, int *strides, int *out_strides, int *perm, | |||
| int *output_shape, int h_start, int h_end) { | |||
| const int stride0 = strides[perm[0]]; | |||
| const int stride1 = strides[perm[1]]; | |||
| const int output0 = output_shape[0]; | |||
| const int output1 = output_shape[1]; | |||
| for (int i = 0; i < output0; ++i) { | |||
| int out_stride0_i = i * output1; | |||
| int stride0_i = i * 1 * stride0; | |||
| for (int j = 0; j < output1; ++j) { | |||
| out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; | |||
| } | |||
| } | |||
| } | |||
| void TransposeDim3(float16_t *in_data, float16_t *out_data, int *strides, int *out_strides, int *perm, | |||
| int *output_shape, int h_start, int h_end) { | |||
| const int stride0 = strides[perm[0]]; | |||
| const int stride1 = strides[perm[1]]; | |||
| const int stride2 = strides[perm[2]]; | |||
| const int out_stride0 = out_strides[0]; | |||
| const int out_stride1 = out_strides[1]; | |||
| const int output0 = output_shape[0]; | |||
| const int output1 = output_shape[1]; | |||
| const int output2 = output_shape[2]; | |||
| for (int i = 0; i < output0; ++i) { | |||
| int out_stride0_i = i * out_stride0; | |||
| int stride0_i = i * stride0; | |||
| for (int j = 0; j < output1; ++j) { | |||
| int out_stride1_j = j * out_stride1; | |||
| int stride1_j = j * stride1; | |||
| for (int k = 0; k < output2; ++k) { | |||
| out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void TransposeDim4(float16_t *in_data, float16_t *out_data, int *strides, int *out_strides, int *perm, | |||
| int *output_shape, int h_start, int h_end) { | |||
| const int stride0 = strides[perm[0]]; | |||
| const int stride1 = strides[perm[1]]; | |||
| const int stride2 = strides[perm[2]]; | |||
| const int stride3 = strides[perm[3]]; | |||
| const int out_stride0 = out_strides[0]; | |||
| const int out_stride1 = out_strides[1]; | |||
| const int out_stride2 = out_strides[2]; | |||
| const int output0 = output_shape[0]; | |||
| const int output1 = output_shape[1]; | |||
| const int output2 = output_shape[2]; | |||
| const int output3 = output_shape[3]; | |||
| for (int i = 0; i < output0; ++i) { | |||
| int out_stride0_i = i * out_stride0; | |||
| int stride0_i = i * stride0; | |||
| for (int j = 0; j < output1; ++j) { | |||
| int out_stride1_j = j * out_stride1; | |||
| int stride1_j = j * stride1; | |||
| for (int k = 0; k < output2; ++k) { | |||
| int out_stride2_k = k * out_stride2; | |||
| int stride2_k = k * stride2; | |||
| for (int m = 0; m < output3; ++m) { | |||
| out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = | |||
| in_data[stride0_i + stride1_j + stride2_k + m * stride3]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void TransposeDim5(float16_t *in_data, float16_t *out_data, int *strides, int *out_strides, int *perm, | |||
| int *output_shape, int h_start, int h_end) { | |||
| const int stride0 = strides[perm[0]]; | |||
| const int stride1 = strides[perm[1]]; | |||
| const int stride2 = strides[perm[2]]; | |||
| const int stride3 = strides[perm[3]]; | |||
| const int stride4 = strides[perm[4]]; | |||
| const int out_stride0 = out_strides[0]; | |||
| const int out_stride1 = out_strides[1]; | |||
| const int out_stride2 = out_strides[2]; | |||
| const int out_stride3 = out_strides[3]; | |||
| const int output0 = output_shape[0]; | |||
| const int output1 = output_shape[1]; | |||
| const int output2 = output_shape[2]; | |||
| const int output3 = output_shape[3]; | |||
| const int output4 = output_shape[4]; | |||
| for (int i = 0; i < output0; ++i) { | |||
| int out_stride0_i = i * out_stride0; | |||
| int stride0_i = i * stride0; | |||
| for (int j = 0; j < output1; ++j) { | |||
| int out_stride1_j = j * out_stride1; | |||
| int stride1_j = j * stride1; | |||
| for (int k = 0; k < output2; ++k) { | |||
| int out_stride2_k = k * out_stride2; | |||
| int stride2_k = k * stride2; | |||
| for (int m = 0; m < output3; ++m) { | |||
| int out_stride3_m = m * out_stride3; | |||
| int stride3_m = m * stride3; | |||
| for (int n = 0; n < output4; ++n) { | |||
| out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = | |||
| in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| int DoTranspose(float16_t *in_data, float16_t *out_data, int *input_shape, int *output_shape, | |||
| TransposeParameter *transpose_param, int h_start, int h_end) { | |||
| if (in_data == NULL || out_data == NULL) { | |||
| return NNACL_ERR; | |||
| } | |||
| int *perm = transpose_param->perm_; | |||
| int *strides = transpose_param->strides_; | |||
| int *out_strides = transpose_param->out_strides_; | |||
| int data_size = transpose_param->data_size_; | |||
| int num_axes = transpose_param->num_axes_; | |||
| if (num_axes < 2 || num_axes > 5) { | |||
| return NNACL_ERR; | |||
| } | |||
| // check if transpose is needed | |||
| bool needTranspose = false; | |||
| for (int i = 1; i < num_axes; ++i) { | |||
| if (perm[i] - perm[i - 1] != 1) { | |||
| needTranspose = true; | |||
| break; | |||
| } | |||
| } | |||
| if (!needTranspose) { | |||
| (void)memcpy(out_data, in_data, data_size); | |||
| return NNACL_OK; | |||
| } | |||
| if (num_axes == 2) { | |||
| TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); | |||
| } else if (num_axes == 3) { | |||
| TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); | |||
| } else if (num_axes == 4) { | |||
| TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); | |||
| } else if (num_axes == 5) { | |||
| TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_TRANSPOSE_FP16_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_TRANSPOSE_FP16_H_ | |||
| #include "nnacl/op_base.h" | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| typedef struct TransposeParameter { | |||
| OpParameter op_parameter_; | |||
| int perm_[8]; | |||
| bool conjugate_; | |||
| int num_axes_; | |||
| int strides_[8]; | |||
| int out_strides_[8]; | |||
| int data_size_; | |||
| } TransposeParameter; | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int DoTranspose(float16_t *in_data, float16_t *out_data, int *input_shape, int *output_shape, | |||
| TransposeParameter *transpose_param, int h_start, int h_end); | |||
| void TransposeDim2(float16_t *in_data, float16_t *out_data, int *strides, int *out_strides, int *perm, | |||
| int *output_shape, int h_start, int h_end); | |||
| void TransposeDim3(float16_t *in_data, float16_t *out_data, int *strides, int *out_strides, int *perm, | |||
| int *output_shape, int h_start, int h_end); | |||
| void TransposeDim4(float16_t *in_data, float16_t *out_data, int *strides, int *out_strides, int *perm, | |||
| int *output_shape, int h_start, int h_end); | |||
| void TransposeDim5(float16_t *in_data, float16_t *out_data, int *strides, int *out_strides, int *perm, | |||
| int *output_shape, int h_start, int h_end); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_TRANSPOSE_FP16_H_ | |||