From 6a6d538e7213206d99d18c122d63aa5e92f46632 Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Fri, 7 Aug 2020 19:28:42 +0800 Subject: [PATCH] add opencl transpose(NHWC2HCHW) --- .../src/runtime/kernel/opencl/CMakeLists.txt | 1 + .../kernel/opencl/cl/fp16/transpose.cl | 44 +++++++ .../kernel/opencl/cl/fp32/transpose.cl | 44 +++++++ .../runtime/kernel/opencl/kernel/transpose.cc | 107 ++++++++++++++++++ .../runtime/kernel/opencl/kernel/transpose.h | 44 +++++++ mindspore/lite/test/CMakeLists.txt | 2 + mindspore/lite/test/st/benchmark_test.cc | 9 +- .../runtime/kernel/opencl/transpose_tests.cc | 76 +++++++++++++ 8 files changed, 323 insertions(+), 4 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp16/transpose.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/transpose.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc diff --git a/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt index 6fae2e3a76..b090065ca1 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt +++ b/mindspore/lite/src/runtime/kernel/opencl/CMakeLists.txt @@ -9,4 +9,5 @@ set(OPENCL_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernel/softmax.cc ${CMAKE_CURRENT_SOURCE_DIR}/kernel/concat.cc ${CMAKE_CURRENT_SOURCE_DIR}/kernel/conv2d_transpose.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel/transpose.cc ) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/transpose.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/transpose.cl new file mode 100644 index 0000000000..ebc3db633f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp16/transpose.cl @@ -0,0 +1,44 @@ +#define FLT half +#define FLT4 half4 +#define READ_IMAGE read_imageh +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void transpose(__read_only image2d_t src_data, __global float4 *dst_data, int2 HW, int2 C) { + int X = get_global_id(0); + int Y = get_global_id(1); + if (X >= HW.y || Y >= C.y) { + return; + } + FLT4 result[4]; + result[0] = (FLT4)(0.0f); + result[1] = (FLT4)(0.0f); + result[2] = (FLT4)(0.0f); + result[3] = (FLT4)(0.0f); + FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X)); + FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 1)); + FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 2)); + FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 3)); + result[0].x = x0.x; + result[0].y = x1.x; + result[0].z = x2.x; + result[0].w = x3.x; + + result[1].x = x0.y; + result[1].y = x1.y; + result[1].z = x2.y; + result[1].w = x3.y; + + result[2].x = x0.z; + result[2].y = x1.z; + result[2].z = x2.z; + result[2].w = x3.z; + + result[3].x = x0.w; + result[3].y = x1.w; + result[3].z = x2.w; + result[3].w = x3.w; + + dst_data[4 * Y * HW.y + X] = result[0]; + dst_data[(4 * Y + 1) * HW.y + X] = result[1]; + dst_data[(4 * Y + 2) * HW.y + X] = result[2]; + dst_data[(4 * Y + 3) * HW.y + X] = result[3]; +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/transpose.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/transpose.cl new file mode 100644 index 0000000000..08069ee80f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/transpose.cl @@ -0,0 +1,44 @@ +#define FLT float +#define FLT4 float4 +#define READ_IMAGE read_imagef +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void transpose(__read_only image2d_t src_data, __global float4 *dst_data, int2 HW, int2 C) { + int X = get_global_id(0); + int Y = get_global_id(1); + if (X >= HW.y || Y >= C.y) { + return; + } + FLT4 result[4]; + result[0] = (FLT4)(0.0f); + result[1] = (FLT4)(0.0f); + result[2] = (FLT4)(0.0f); + result[3] = (FLT4)(0.0f); + FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X)); + FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 1)); + FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 2)); + FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 3)); + result[0].x = x0.x; + result[0].y = x1.x; + result[0].z = x2.x; + result[0].w = x3.x; + + result[1].x = x0.y; + result[1].y = x1.y; + result[1].z = x2.y; + result[1].w = x3.y; + + result[2].x = x0.z; + result[2].y = x1.z; + result[2].z = x2.z; + result[2].w = x3.z; + + result[3].x = x0.w; + result[3].y = x1.w; + result[3].z = x2.w; + result[3].w = x3.w; + + dst_data[4 * Y * HW.y + X] = result[0]; + dst_data[(4 * Y + 1) * HW.y + X] = result[1]; + dst_data[(4 * Y + 2) * HW.y + X] = result[2]; + dst_data[(4 * Y + 3) * HW.y + X] = result[3]; +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc new file mode 100644 index 0000000000..428c071800 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "include/errorcode.h" +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/kernel/transpose.h" +#ifndef PROGRAM_WITH_IL +#include "src/runtime/kernel/opencl/cl/fp16/transpose.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/transpose.cl.inc" +#endif + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Transpose; + +namespace mindspore::kernel { + +int TransposeOpenCLKernel::Init() { + std::string kernel_name = "transpose"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + +#ifdef PROGRAM_WITH_IL + ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); +#else + std::set build_options; +#ifdef ENABLE_FP16 + std::string source = transpose_source_fp16; +#else + std::string source = transpose_source_fp32; +#endif + std::string program_name = "transpose"; + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif + auto input_format = inputs_[0]->GetFormat(); + if (input_format != schema::Format_NHWC4) { + MS_LOG(ERROR) << "input format(" << input_format << ") " + << "format not support!"; + return RET_ERROR; + } + if ((inputs_[0]->Height() * inputs_[0]->Width()) % 4 != 0) { + MS_LOG(ERROR) << "input H * W % 4 != 0 not support!"; + return RET_ERROR; + } + outputs_[0]->SetFormat(schema::Format_NCHW); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return RET_OK; +} + +int TransposeOpenCLKernel::ReSize() { return 0; } + +int TransposeOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->Name() << " Running!"; + std::vector shapex = inputs_[0]->shape(); + int h = shapex[1]; + int w = shapex[2]; + int c = shapex[3]; + int c4 = UP_DIV(c, 4); + int hw4 = UP_DIV(h * w, 4); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + // local size should less than MAX_GROUP_SIZE + std::vector local = {16, 16}; + std::vector global = {UP_ROUND(hw4, local[0]), UP_ROUND(c4, local[1])}; + + cl_int2 HW = {h * w, hw4}; + cl_int2 C = {c, c4}; + ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 1, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, 2, HW); + ocl_runtime->SetKernelArg(kernel_, 3, C); + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + return 0; +} + +kernel::LiteKernel *OpenCLTransposeKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc) { + auto *kernel = new TransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + auto ret = kernel->Init(); + if (0 != ret) { + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Transpose, OpenCLTransposeKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h new file mode 100644 index 0000000000..c16557b5b0 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_TRANSPOSE_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_TRANSPOSE_H_ + +#include + +#include "src/lite_kernel.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" + + +namespace mindspore::kernel { +class TransposeOpenCLKernel : public OpenCLKernel { + public: + explicit TransposeOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + ~TransposeOpenCLKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + cl::Kernel kernel_; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_TRANSPOSE_H_ diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 1ab193459a..ea1591ad4b 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -129,6 +129,7 @@ if (SUPPORT_GPU) ${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc ) endif() ### minddata lite @@ -295,6 +296,7 @@ if (SUPPORT_GPU) ${TEST_DIR}/ut/src/runtime/kernel/opencl/max_pooling_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/utils_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/transpose_tests.cc ) endif() diff --git a/mindspore/lite/test/st/benchmark_test.cc b/mindspore/lite/test/st/benchmark_test.cc index 9bd8af960e..86b468677b 100644 --- a/mindspore/lite/test/st/benchmark_test.cc +++ b/mindspore/lite/test/st/benchmark_test.cc @@ -42,10 +42,11 @@ TEST_F(BenchmarkTest, TestOCR_02) { } TEST_F(BenchmarkTest, TestOCR_02_GPU) { -const char *argv[] = {"./benchmark", "--modelPath=./hiai/hiai_cv_focusShootOCRMOdel_02.ms" - "--inDataPath=./hiai/hiai_cv_focusShootOCRMOdel_02.bin" - "--calibDataPath=./hiai/hiai_cv_focusShootOCRMOdel_02.txt"}; -auto status = RunBenchmark(2, argv); +const char *argv[] = {"./benchmark", "--modelPath=./hiai/model_02.ms", + "--inDataPath=./hiai/model_02_in.bin", + "--calibDataPath=./hiai/model_02_out.bin", + "--device=GPU"}; +auto status = RunBenchmark(5, argv); ASSERT_EQ(status, RET_OK); } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc new file mode 100644 index 0000000000..20324e0cdb --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "mindspore/core/utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h" + +namespace mindspore { +class TestTransposeOpenCL : public mindspore::Common { + public: + TestTransposeOpenCL() {} +}; + +TEST_F(TestTransposeOpenCL, TransposeFp32) { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + int h = 64; + int w = 1; + int c = 7360; + size_t input_size; + std::string input_path = "./test_data/transpose/transpose_fp32_input.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + + lite::tensor::Tensor *tensor_x = + new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, c}, schema::Format_NHWC4); + + lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, c, h, w}); + std::vector inputs{tensor_x}; + std::vector outputs{tensor_out}; + auto *arith_kernel = new kernel::TransposeOpenCLKernel(nullptr, inputs, outputs); + arith_kernel->Init(); + + inputs[0]->MallocData(allocator); + + std::vector kernels{arith_kernel}; + auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + pGraph->Init(); + memcpy(inputs[0]->Data(), input_data, input_size); + pGraph->Run(); + + size_t output_size; + std::string output_path = "./test_data/transpose/transpose_fp32_output.bin"; + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + printf("==================output data=================\n"); + float *output_data = reinterpret_cast(tensor_out->Data()); + std::cout << std::endl; + int size_n = h * w * c; + size_n = size_n > 100 ? 100 : size_n; + for (int i = 0; i < size_n; i++) { + std::cout << output_data[i] << " "; + } + std::cout << std::endl; + + // compare + CompareOutputData(output_data, correct_data, h * w * c, 0.00001); + MS_LOG(INFO) << "TestMatMulFp32 passed"; +} +} // namespace mindspore