Merge pull request !4132 from chenzupeng/master-litetags/v0.7.0-beta
| @@ -9,4 +9,5 @@ set(OPENCL_KERNEL_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel/softmax.cc | ${CMAKE_CURRENT_SOURCE_DIR}/kernel/softmax.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel/concat.cc | ${CMAKE_CURRENT_SOURCE_DIR}/kernel/concat.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel/conv2d_transpose.cc | ${CMAKE_CURRENT_SOURCE_DIR}/kernel/conv2d_transpose.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel/transpose.cc | |||||
| ) | ) | ||||
| @@ -0,0 +1,44 @@ | |||||
| #define FLT half | |||||
| #define FLT4 half4 | |||||
| #define READ_IMAGE read_imageh | |||||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||||
| __kernel void transpose(__read_only image2d_t src_data, __global float4 *dst_data, int2 HW, int2 C) { | |||||
| int X = get_global_id(0); | |||||
| int Y = get_global_id(1); | |||||
| if (X >= HW.y || Y >= C.y) { | |||||
| return; | |||||
| } | |||||
| FLT4 result[4]; | |||||
| result[0] = (FLT4)(0.0f); | |||||
| result[1] = (FLT4)(0.0f); | |||||
| result[2] = (FLT4)(0.0f); | |||||
| result[3] = (FLT4)(0.0f); | |||||
| FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X)); | |||||
| FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 1)); | |||||
| FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 2)); | |||||
| FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 3)); | |||||
| result[0].x = x0.x; | |||||
| result[0].y = x1.x; | |||||
| result[0].z = x2.x; | |||||
| result[0].w = x3.x; | |||||
| result[1].x = x0.y; | |||||
| result[1].y = x1.y; | |||||
| result[1].z = x2.y; | |||||
| result[1].w = x3.y; | |||||
| result[2].x = x0.z; | |||||
| result[2].y = x1.z; | |||||
| result[2].z = x2.z; | |||||
| result[2].w = x3.z; | |||||
| result[3].x = x0.w; | |||||
| result[3].y = x1.w; | |||||
| result[3].z = x2.w; | |||||
| result[3].w = x3.w; | |||||
| dst_data[4 * Y * HW.y + X] = result[0]; | |||||
| dst_data[(4 * Y + 1) * HW.y + X] = result[1]; | |||||
| dst_data[(4 * Y + 2) * HW.y + X] = result[2]; | |||||
| dst_data[(4 * Y + 3) * HW.y + X] = result[3]; | |||||
| } | |||||
| @@ -0,0 +1,44 @@ | |||||
| #define FLT float | |||||
| #define FLT4 float4 | |||||
| #define READ_IMAGE read_imagef | |||||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||||
| __kernel void transpose(__read_only image2d_t src_data, __global float4 *dst_data, int2 HW, int2 C) { | |||||
| int X = get_global_id(0); | |||||
| int Y = get_global_id(1); | |||||
| if (X >= HW.y || Y >= C.y) { | |||||
| return; | |||||
| } | |||||
| FLT4 result[4]; | |||||
| result[0] = (FLT4)(0.0f); | |||||
| result[1] = (FLT4)(0.0f); | |||||
| result[2] = (FLT4)(0.0f); | |||||
| result[3] = (FLT4)(0.0f); | |||||
| FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X)); | |||||
| FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 1)); | |||||
| FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 2)); | |||||
| FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 3)); | |||||
| result[0].x = x0.x; | |||||
| result[0].y = x1.x; | |||||
| result[0].z = x2.x; | |||||
| result[0].w = x3.x; | |||||
| result[1].x = x0.y; | |||||
| result[1].y = x1.y; | |||||
| result[1].z = x2.y; | |||||
| result[1].w = x3.y; | |||||
| result[2].x = x0.z; | |||||
| result[2].y = x1.z; | |||||
| result[2].z = x2.z; | |||||
| result[2].w = x3.z; | |||||
| result[3].x = x0.w; | |||||
| result[3].y = x1.w; | |||||
| result[3].z = x2.w; | |||||
| result[3].w = x3.w; | |||||
| dst_data[4 * Y * HW.y + X] = result[0]; | |||||
| dst_data[(4 * Y + 1) * HW.y + X] = result[1]; | |||||
| dst_data[(4 * Y + 2) * HW.y + X] = result[2]; | |||||
| dst_data[(4 * Y + 3) * HW.y + X] = result[3]; | |||||
| } | |||||
| @@ -0,0 +1,107 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "include/errorcode.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/opencl/opencl_runtime.h" | |||||
| #include "src/runtime/kernel/opencl/kernel/transpose.h" | |||||
| #ifndef PROGRAM_WITH_IL | |||||
| #include "src/runtime/kernel/opencl/cl/fp16/transpose.cl.inc" | |||||
| #include "src/runtime/kernel/opencl/cl/fp32/transpose.cl.inc" | |||||
| #endif | |||||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Transpose; | |||||
| namespace mindspore::kernel { | |||||
| int TransposeOpenCLKernel::Init() { | |||||
| std::string kernel_name = "transpose"; | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| #ifdef PROGRAM_WITH_IL | |||||
| ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); | |||||
| #else | |||||
| std::set<std::string> build_options; | |||||
| #ifdef ENABLE_FP16 | |||||
| std::string source = transpose_source_fp16; | |||||
| #else | |||||
| std::string source = transpose_source_fp32; | |||||
| #endif | |||||
| std::string program_name = "transpose"; | |||||
| ocl_runtime->LoadSource(program_name, source); | |||||
| ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||||
| #endif | |||||
| auto input_format = inputs_[0]->GetFormat(); | |||||
| if (input_format != schema::Format_NHWC4) { | |||||
| MS_LOG(ERROR) << "input format(" << input_format << ") " | |||||
| << "format not support!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if ((inputs_[0]->Height() * inputs_[0]->Width()) % 4 != 0) { | |||||
| MS_LOG(ERROR) << "input H * W % 4 != 0 not support!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| outputs_[0]->SetFormat(schema::Format_NCHW); | |||||
| MS_LOG(DEBUG) << kernel_name << " Init Done!"; | |||||
| return RET_OK; | |||||
| } | |||||
| int TransposeOpenCLKernel::ReSize() { return 0; } | |||||
| int TransposeOpenCLKernel::Run() { | |||||
| MS_LOG(DEBUG) << this->Name() << " Running!"; | |||||
| std::vector<int> shapex = inputs_[0]->shape(); | |||||
| int h = shapex[1]; | |||||
| int w = shapex[2]; | |||||
| int c = shapex[3]; | |||||
| int c4 = UP_DIV(c, 4); | |||||
| int hw4 = UP_DIV(h * w, 4); | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| // local size should less than MAX_GROUP_SIZE | |||||
| std::vector<size_t> local = {16, 16}; | |||||
| std::vector<size_t> global = {UP_ROUND(hw4, local[0]), UP_ROUND(c4, local[1])}; | |||||
| cl_int2 HW = {h * w, hw4}; | |||||
| cl_int2 C = {c, c4}; | |||||
| ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); | |||||
| ocl_runtime->SetKernelArg(kernel_, 1, outputs_[0]->Data()); | |||||
| ocl_runtime->SetKernelArg(kernel_, 2, HW); | |||||
| ocl_runtime->SetKernelArg(kernel_, 3, C); | |||||
| ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||||
| return 0; | |||||
| } | |||||
| kernel::LiteKernel *OpenCLTransposeKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc) { | |||||
| auto *kernel = new TransposeOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||||
| auto ret = kernel->Init(); | |||||
| if (0 != ret) { | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Transpose, OpenCLTransposeKernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_TRANSPOSE_H_ | |||||
| #define MINDSPORE_LITE_SRC_BACKEND_OPENCL_TRANSPOSE_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "src/runtime/opencl/opencl_runtime.h" | |||||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||||
| namespace mindspore::kernel { | |||||
| class TransposeOpenCLKernel : public OpenCLKernel { | |||||
| public: | |||||
| explicit TransposeOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||||
| ~TransposeOpenCLKernel() override{}; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| private: | |||||
| cl::Kernel kernel_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_TRANSPOSE_H_ | |||||
| @@ -130,6 +130,7 @@ if (SUPPORT_GPU) | |||||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc | ${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc | ||||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc | ${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc | ||||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc | ${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc | ||||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| ### minddata lite | ### minddata lite | ||||
| @@ -295,6 +296,7 @@ if (SUPPORT_GPU) | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/max_pooling_tests.cc | ${TEST_DIR}/ut/src/runtime/kernel/opencl/max_pooling_tests.cc | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/utils_tests.cc | ${TEST_DIR}/ut/src/runtime/kernel/opencl/utils_tests.cc | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc | ${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/transpose_tests.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| @@ -42,10 +42,11 @@ TEST_F(BenchmarkTest, TestOCR_02) { | |||||
| } | } | ||||
| TEST_F(BenchmarkTest, TestOCR_02_GPU) { | TEST_F(BenchmarkTest, TestOCR_02_GPU) { | ||||
| const char *argv[] = {"./benchmark", "--modelPath=./hiai/hiai_cv_focusShootOCRMOdel_02.ms" | |||||
| "--inDataPath=./hiai/hiai_cv_focusShootOCRMOdel_02.bin" | |||||
| "--calibDataPath=./hiai/hiai_cv_focusShootOCRMOdel_02.txt"}; | |||||
| auto status = RunBenchmark(2, argv); | |||||
| const char *argv[] = {"./benchmark", "--modelPath=./hiai/model_02.ms", | |||||
| "--inDataPath=./hiai/model_02_in.bin", | |||||
| "--calibDataPath=./hiai/model_02_out.bin", | |||||
| "--device=GPU"}; | |||||
| auto status = RunBenchmark(5, argv); | |||||
| ASSERT_EQ(status, RET_OK); | ASSERT_EQ(status, RET_OK); | ||||
| } | } | ||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include "mindspore/core/utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/src/common/file_utils.h" | |||||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h" | |||||
| namespace mindspore { | |||||
| class TestTransposeOpenCL : public mindspore::Common { | |||||
| public: | |||||
| TestTransposeOpenCL() {} | |||||
| }; | |||||
| TEST_F(TestTransposeOpenCL, TransposeFp32) { | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| ocl_runtime->Init(); | |||||
| auto allocator = ocl_runtime->GetAllocator(); | |||||
| int h = 64; | |||||
| int w = 1; | |||||
| int c = 7360; | |||||
| size_t input_size; | |||||
| std::string input_path = "./test_data/transpose/transpose_fp32_input.bin"; | |||||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||||
| lite::tensor::Tensor *tensor_x = | |||||
| new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, c}, schema::Format_NHWC4); | |||||
| lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, c, h, w}); | |||||
| std::vector<lite::tensor::Tensor *> inputs{tensor_x}; | |||||
| std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | |||||
| auto *arith_kernel = new kernel::TransposeOpenCLKernel(nullptr, inputs, outputs); | |||||
| arith_kernel->Init(); | |||||
| inputs[0]->MallocData(allocator); | |||||
| std::vector<kernel::LiteKernel *> kernels{arith_kernel}; | |||||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| pGraph->Init(); | |||||
| memcpy(inputs[0]->Data(), input_data, input_size); | |||||
| pGraph->Run(); | |||||
| size_t output_size; | |||||
| std::string output_path = "./test_data/transpose/transpose_fp32_output.bin"; | |||||
| auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); | |||||
| printf("==================output data=================\n"); | |||||
| float *output_data = reinterpret_cast<float *>(tensor_out->Data()); | |||||
| std::cout << std::endl; | |||||
| int size_n = h * w * c; | |||||
| size_n = size_n > 100 ? 100 : size_n; | |||||
| for (int i = 0; i < size_n; i++) { | |||||
| std::cout << output_data[i] << " "; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| // compare | |||||
| CompareOutputData(output_data, correct_data, h * w * c, 0.00001); | |||||
| MS_LOG(INFO) << "TestMatMulFp32 passed"; | |||||
| } | |||||
| } // namespace mindspore | |||||