Merge pull request !4390 from liuzhongkai/leaky_relutags/v0.7.0-beta
| @@ -0,0 +1,70 @@ | |||
| #pragma OPENCL EXTENSION cl_arm_printf : enable | |||
| #define SLICES 4 | |||
| #define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) | |||
| #define FLT4 float4 | |||
| #define MIN(X, Y) (X < Y ? X : Y) | |||
| #define READ_FLT4 read_imagef | |||
| #define WRITE_FLT4 write_imagef | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void ReluScalar(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, | |||
| const float alpha) { | |||
| int C = input_shape.w; // channel size | |||
| int Y = get_global_id(0); // height id | |||
| int X = get_global_id(1); // weight id | |||
| for (int num = 0; num < UP_DIV(C, SLICES); ++num) { | |||
| FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC | |||
| FLT4 tmp; | |||
| tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha; | |||
| tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha; | |||
| tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha; | |||
| tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha; | |||
| WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC | |||
| } | |||
| } | |||
| __kernel void Relu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { | |||
| int C = input_shape.w; // channel size | |||
| int Y = get_global_id(0); // height id | |||
| int X = get_global_id(1); // weight id | |||
| for (int num = 0; num < UP_DIV(C, SLICES); ++num) { | |||
| FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC | |||
| FLT4 tmp; | |||
| tmp.x = in_c4.x >= 0 ? in_c4.x : 0; | |||
| tmp.y = in_c4.y >= 0 ? in_c4.y : 0; | |||
| tmp.z = in_c4.z >= 0 ? in_c4.z : 0; | |||
| tmp.w = in_c4.w >= 0 ? in_c4.w : 0; | |||
| WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC | |||
| } | |||
| } | |||
| __kernel void Relu6(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { | |||
| int C = input_shape.w; // channel size | |||
| int Y = get_global_id(0); // height id | |||
| int X = get_global_id(1); // weight id | |||
| for (int num = 0; num < UP_DIV(C, SLICES); ++num) { | |||
| FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC | |||
| FLT4 tmp; | |||
| tmp.x = in_c4.x >= 0 ? MIN(in_c4.x, 6) : 0; | |||
| tmp.y = in_c4.y >= 0 ? MIN(in_c4.y, 6) : 0; | |||
| tmp.z = in_c4.z >= 0 ? MIN(in_c4.z, 6) : 0; | |||
| tmp.w = in_c4.w >= 0 ? MIN(in_c4.w, 6) : 0; | |||
| WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC | |||
| } | |||
| } | |||
| __kernel void Sigmoid(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { | |||
| int C = input_shape.w; // channel size | |||
| int Y = get_global_id(0); // height id | |||
| int X = get_global_id(1); // weight id | |||
| for (int num = 0; num < UP_DIV(C, SLICES); ++num) { | |||
| FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC | |||
| FLT4 tmp; | |||
| tmp.x = 1 / (1 + exp(-in_c4.x)); | |||
| tmp.y = 1 / (1 + exp(-in_c4.y)); | |||
| tmp.z = 1 / (1 + exp(-in_c4.z)); | |||
| tmp.w = 1 / (1 + exp(-in_c4.w)); | |||
| WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC | |||
| } | |||
| } | |||
| @@ -1,28 +0,0 @@ | |||
| #pragma OPENCL EXTENSION cl_arm_printf : enable | |||
| #define SLICES 4 | |||
| #define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) | |||
| #define FLT4 float4 | |||
| #define READ_FLT4 read_imagef | |||
| #define WRITE_FLT4 write_imagef | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void LeakyRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, | |||
| const float alpha) { | |||
| // int B = input_shape.x; // size | |||
| // int H = input_shape.y; // | |||
| // int W = input_shape.z; | |||
| int C = input_shape.w; | |||
| int Y = get_global_id(0); // height id | |||
| int X = get_global_id(1); // weight id | |||
| for (int num = 0; num < UP_DIV(C, SLICES); ++num) { | |||
| FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC | |||
| FLT4 tmp; | |||
| tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha; | |||
| tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha; | |||
| tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha; | |||
| tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha; | |||
| WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC | |||
| } | |||
| } | |||
| @@ -0,0 +1,146 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <set> | |||
| #include "src/runtime/kernel/opencl/kernel/activation.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/ops/ops.h" | |||
| #include "src/runtime/kernel/opencl/cl/fp32/activation.cl.inc" | |||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::ActivationType_LEAKY_RELU; | |||
| using mindspore::schema::ActivationType_RELU; | |||
| using mindspore::schema::ActivationType_RELU6; | |||
| using mindspore::schema::ActivationType_SIGMOID; | |||
| using mindspore::schema::PrimitiveType_Activation; | |||
| namespace mindspore::kernel { | |||
| int ActivationOpenClKernel::Init() { | |||
| const int max_shape_dim = 4; | |||
| if (in_tensors_[0]->shape().size() != max_shape_dim) { | |||
| MS_LOG(ERROR) << "Activate fun only support dim=4, but your dim=" << in_tensors_[0]->shape().size(); | |||
| return RET_ERROR; | |||
| } | |||
| std::string program_name = ""; | |||
| std::string kernel_name = ""; | |||
| std::string source = activation_source_fp32; | |||
| if (type_ == ActivationType_RELU) { | |||
| program_name = "RELU"; | |||
| kernel_name = "Relu"; | |||
| } else if (type_ == ActivationType_RELU6) { | |||
| program_name = "RELU6"; | |||
| kernel_name = "Relu6"; | |||
| } else if (type_ == ActivationType_LEAKY_RELU) { | |||
| program_name = "LEAKY_RELU"; | |||
| kernel_name = "ReluScalar"; | |||
| } else if (type_ == ActivationType_SIGMOID) { | |||
| program_name = "SIGMOID"; | |||
| kernel_name = "Sigmoid"; | |||
| } else { | |||
| MS_LOG(ERROR) << "Activation type error"; | |||
| return RET_ERROR; | |||
| } | |||
| std::set<std::string> build_options; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->LoadSource(program_name, source); | |||
| ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||
| MS_LOG(DEBUG) << op_parameter_->name_ << " init Done!"; | |||
| return RET_OK; | |||
| } | |||
| int ActivationOpenClKernel::Run() { | |||
| MS_LOG(DEBUG) << op_parameter_->name_ << " begin running!"; | |||
| int N = in_tensors_[0]->shape()[0]; | |||
| int H = in_tensors_[0]->shape()[1]; | |||
| int W = in_tensors_[0]->shape()[2]; | |||
| int C = in_tensors_[0]->shape()[3]; | |||
| cl_int4 input_shape = {N, H, W, C}; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| int arg_idx = 0; | |||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); | |||
| if (type_ == ActivationType_LEAKY_RELU) { | |||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, alpha_); | |||
| } | |||
| std::vector<size_t> local = {1, 1}; | |||
| std::vector<size_t> global = {static_cast<size_t>(H), static_cast<size_t>(W)}; | |||
| std::cout << type_ << " " << std::endl; | |||
| auto ret = ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Run kernel:" << op_parameter_->name_ << " fail."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int ActivationOpenClKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| int H = in_tensors_[0]->shape()[1]; | |||
| int W = in_tensors_[0]->shape()[2]; | |||
| int C = in_tensors_[0]->shape()[3]; | |||
| #ifdef ENABLE_FP16 | |||
| size_t img_dtype = CL_HALF_FLOAT; | |||
| #else | |||
| size_t img_dtype = CL_FLOAT; | |||
| #endif | |||
| img_size->clear(); | |||
| img_size->push_back(W * UP_DIV(C, C4NUM)); | |||
| img_size->push_back(H); | |||
| img_size->push_back(img_dtype); | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *OpenClActivationFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||
| if (inputs.size() == 0) { | |||
| MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size(); | |||
| return nullptr; | |||
| } | |||
| if (inputs[0]->shape()[0] > 1) { | |||
| MS_LOG(ERROR) << "Activation kernel:" << opParameter->name_ << " failed: Unsupported multi-batch."; | |||
| return nullptr; | |||
| } | |||
| auto *kernel = | |||
| new (std::nothrow) ActivationOpenClKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "New kernel:" << opParameter->name_ << "is nullptr."; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init activation kernel:" << opParameter->name_ << " failed!"; | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Activation, OpenClActivationFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -14,24 +14,26 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ACTIVATION_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ACTIVATION_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include "src/ir/tensor.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/activation.h" | |||
| namespace mindspore::kernel { | |||
| class LeakyReluOpenCLKernel : public OpenCLKernel { | |||
| class ActivationOpenClKernel : public OpenCLKernel { | |||
| public: | |||
| explicit LeakyReluOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| ~LeakyReluOpenCLKernel() override{}; | |||
| explicit ActivationOpenClKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) { | |||
| type_ = (reinterpret_cast<ActivationParameter *>(parameter))->type_; | |||
| alpha_ = (reinterpret_cast<ActivationParameter *>(parameter))->alpha_; | |||
| } | |||
| ~ActivationOpenClKernel() override{}; | |||
| int Init() override; | |||
| int Run() override; | |||
| @@ -39,8 +41,9 @@ class LeakyReluOpenCLKernel : public OpenCLKernel { | |||
| private: | |||
| cl::Kernel kernel_; | |||
| int type_; | |||
| float alpha_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LEAKYRELU_H_ | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_ACTIVATION_H_ | |||
| @@ -161,7 +161,8 @@ kernel::LiteKernel *OpenCLArithmeticKernelCreator(const std::vector<lite::tensor | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||
| auto *kernel = new ArithmeticOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs, ctx); | |||
| auto *kernel = | |||
| new (std::nothrow) ArithmeticOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs, ctx); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Create OpenCL Arithmetic kernel failed!"; | |||
| return nullptr; | |||
| @@ -184,7 +184,8 @@ kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vector<lite::t | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const lite::Primitive *primitive) { | |||
| auto *kernel = new Conv2dTransposeOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| auto *kernel = | |||
| new (std::nothrow) Conv2dTransposeOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; | |||
| return nullptr; | |||
| @@ -193,7 +193,8 @@ kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector<lite::t | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const lite::Primitive *primitive) { | |||
| auto *kernel = new DepthwiseConv2dOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| auto *kernel = | |||
| new (std::nothrow) DepthwiseConv2dOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; | |||
| return nullptr; | |||
| @@ -1,122 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/kernel/opencl/kernel/leaky_relu.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl.inc" | |||
| #include "src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_LeakyReLU; | |||
| namespace mindspore::kernel { | |||
| int LeakyReluOpenCLKernel::Init() { | |||
| if (in_tensors_[0]->shape().size() != 4) { | |||
| MS_LOG(ERROR) << "leaky_relu only support dim=4, but your dim=" << in_tensors_[0]->shape().size(); | |||
| return RET_ERROR; | |||
| } | |||
| std::set<std::string> build_options; | |||
| std::string source = leaky_relu_source_fp32; | |||
| std::string program_name = "LeakyRelu"; | |||
| std::string kernel_name = "LeakyRelu"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->LoadSource(program_name, source); | |||
| ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||
| MS_LOG(DEBUG) << kernel_name << " Init Done!"; | |||
| return RET_OK; | |||
| } | |||
| int LeakyReluOpenCLKernel::Run() { | |||
| auto param = reinterpret_cast<LeakyReluParameter *>(op_parameter_); | |||
| MS_LOG(DEBUG) << " Running!"; | |||
| int N = in_tensors_[0]->shape()[0]; | |||
| int H = in_tensors_[0]->shape()[1]; | |||
| int W = in_tensors_[0]->shape()[2]; | |||
| int C = in_tensors_[0]->shape()[3]; | |||
| cl_int4 input_shape = {N, H, W, C}; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| int arg_idx = 0; | |||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, param->alpha); | |||
| std::vector<size_t> local = {1, 1}; | |||
| std::vector<size_t> global = {static_cast<size_t>(H), static_cast<size_t>(W)}; | |||
| ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||
| return RET_OK; | |||
| } | |||
| int LeakyReluOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| int H = in_tensors_[0]->shape()[1]; | |||
| int W = in_tensors_[0]->shape()[2]; | |||
| int C = in_tensors_[0]->shape()[3]; | |||
| #ifdef ENABLE_FP16 | |||
| size_t img_dtype = CL_HALF_FLOAT; | |||
| #else | |||
| size_t img_dtype = CL_FLOAT; | |||
| #endif | |||
| img_size->clear(); | |||
| img_size->push_back(W * UP_DIV(C, C4NUM)); | |||
| img_size->push_back(H); | |||
| img_size->push_back(img_dtype); | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *OpenCLLeakyReluKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||
| if (inputs.size() == 0) { | |||
| MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size(); | |||
| return nullptr; | |||
| } | |||
| if (inputs[0]->shape()[0] > 1) { | |||
| MS_LOG(ERROR) << "Init `leaky relu` kernel failed: Unsupported multi-batch."; | |||
| return nullptr; | |||
| } | |||
| auto *kernel = new LeakyReluOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init `Leaky Relu` kernel failed!"; | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, OpenCLLeakyReluKernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -160,7 +160,8 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector<lite::tensor::Te | |||
| if (opParameter->type_ == PrimitiveType_FullConnection) { | |||
| hasBias = (reinterpret_cast<MatMulParameter *>(opParameter))->has_bias_; | |||
| } | |||
| auto *kernel = new MatMulOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs, hasBias); | |||
| auto *kernel = | |||
| new (std::nothrow) MatMulOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs, hasBias); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; | |||
| return nullptr; | |||
| @@ -145,7 +145,7 @@ kernel::LiteKernel *OpenCLPooling2dKernelCreator(const std::vector<lite::tensor: | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||
| auto *kernel = new PoolingOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| auto *kernel = new (std::nothrow)PoolingOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Create OpenCL Pooling kernel failed!"; | |||
| return nullptr; | |||
| @@ -158,7 +158,7 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::T | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||
| auto *kernel = new SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| auto *kernel = new (std::nothrow)SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; | |||
| return nullptr; | |||
| @@ -109,7 +109,7 @@ kernel::LiteKernel *OpenCLTransposeKernelCreator(const std::vector<lite::tensor: | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||
| auto *kernel = new TransposeOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| auto *kernel = new (std::nothrow)TransposeOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; | |||
| return nullptr; | |||
| @@ -142,7 +142,7 @@ if (SUPPORT_GPU) | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/matmul.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/leaky_relu.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/activation.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/reshape.cc | |||
| @@ -323,14 +323,14 @@ if (SUPPORT_GPU) | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/transpose_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/convolution_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/activation_tests.cc | |||
| ) | |||
| endif() | |||
| if (ENABLE_FP16) | |||
| set(TEST_SRC | |||
| ${TEST_SRC} | |||
| ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/convolution_fp16_tests.cc) | |||
| ) | |||
| endif () | |||
| @@ -0,0 +1,185 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include "utils/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/common/file_utils.h" | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_allocator.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h" | |||
| using mindspore::kernel::LiteKernel; | |||
| using mindspore::kernel::SubGraphOpenCLKernel; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::ActivationType_LEAKY_RELU; | |||
| using mindspore::schema::ActivationType_RELU; | |||
| using mindspore::schema::ActivationType_RELU6; | |||
| using mindspore::schema::ActivationType_SIGMOID; | |||
| using mindspore::schema::PrimitiveType_Activation; | |||
| namespace mindspore { | |||
| class TestActivationOpenCL : public mindspore::CommonTest {}; | |||
| void LoadActivationData(void *dst, size_t dst_size, const std::string &file_path) { | |||
| if (file_path.empty()) { | |||
| memset(dst, 0x00, dst_size); | |||
| } else { | |||
| auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); | |||
| memcpy(dst, src_data, dst_size); | |||
| } | |||
| } | |||
| void CompareRes(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { | |||
| auto *output_data = reinterpret_cast<float *>(output_tensor->Data()); | |||
| size_t output_size = output_tensor->Size(); | |||
| auto expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); | |||
| constexpr float atol = 0.0002; | |||
| for (int i = 0; i < output_tensor->ElementsNum(); ++i) { | |||
| if (std::fabs(output_data[i] - expect_data[i]) > atol) { | |||
| printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); | |||
| printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); | |||
| printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]); | |||
| return; | |||
| } | |||
| } | |||
| printf("compare success!\n"); | |||
| printf("compare success!\n"); | |||
| printf("compare success!\n\n\n"); | |||
| } | |||
| void printf_tensor(mindspore::lite::tensor::Tensor *in_data) { | |||
| auto input_data = reinterpret_cast<float *>(in_data->Data()); | |||
| for (int i = 0; i < in_data->ElementsNum(); ++i) { | |||
| printf("%f ", input_data[i]); | |||
| } | |||
| printf("\n"); | |||
| MS_LOG(INFO) << "Print tensor done"; | |||
| } | |||
| kernel::ActivationOpenClKernel *create_kernel(lite::opencl::OpenCLAllocator *allocator, | |||
| const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, std::string test_name, | |||
| int type, std::string in_file, float alpha = 0.2) { | |||
| auto *param = new (std::nothrow) ActivationParameter(); | |||
| if (param == nullptr) { | |||
| MS_LOG(ERROR) << "New ActivationParameter fail."; | |||
| return nullptr; | |||
| } | |||
| memcpy(param->op_parameter_.name_, test_name.c_str(), test_name.size()); | |||
| param->alpha_ = alpha; | |||
| param->type_ = type; | |||
| auto *kernel = | |||
| new (std::nothrow) kernel::ActivationOpenClKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Kernel:" << test_name << " create fail."; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init " << test_name << " fail."; | |||
| return nullptr; | |||
| } | |||
| MS_LOG(INFO) << "Initialize input data"; | |||
| LoadActivationData(inputs[0]->Data(), inputs[0]->Size(), in_file); | |||
| MS_LOG(INFO) << "==================input data================"; | |||
| printf_tensor(inputs[0]); | |||
| return kernel; | |||
| } | |||
| int RunSubGraphOpenCLKernel(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| kernel::ActivationOpenClKernel *kernel) { | |||
| MS_LOG(INFO) << "Create kernel SubGraphOpenCLKernel."; | |||
| std::vector<kernel::LiteKernel *> kernels{kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Kernel SubGraphOpenCLKernel create fail."; | |||
| return RET_ERROR; | |||
| } | |||
| MS_LOG(INFO) << "Initialize sub_graph."; | |||
| auto ret = sub_graph->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init sub_graph error."; | |||
| return RET_ERROR; | |||
| } | |||
| MS_LOG(INFO) << "Run SubGraphOpenCLKernel."; | |||
| ret = sub_graph->Run(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) { | |||
| MS_LOG(INFO) << "Begin test:"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| MS_LOG(INFO) << "Init tensors."; | |||
| std::vector<int> input_shape = {1, 4, 3, 8}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| auto tensor_type = schema::NodeType_ValueNode; | |||
| auto *input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); | |||
| auto *output_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); | |||
| std::vector<lite::tensor::Tensor *> inputs{input_tensor}; | |||
| std::vector<lite::tensor::Tensor *> outputs{output_tensor}; | |||
| // freamework to do!!! allocate memory by hand | |||
| inputs[0]->MallocData(allocator); | |||
| std::map<std::string, int> Test_Activation_Type; | |||
| std::map<std::string, std::string> Test_Res_File; | |||
| Test_Activation_Type["Relu"] = ActivationType_RELU; | |||
| Test_Activation_Type["Leaky_Relu"] = ActivationType_LEAKY_RELU; | |||
| Test_Activation_Type["Relu6"] = ActivationType_RELU6; | |||
| Test_Activation_Type["Sigmoid"] = ActivationType_SIGMOID; | |||
| Test_Res_File["Leaky_Relu"] = "/data/local/tmp/leaky_relu.bin"; | |||
| Test_Res_File["Relu"] = "/data/local/tmp/relu.bin"; | |||
| Test_Res_File["Relu6"] = "/data/local/tmp/relu6.bin"; | |||
| Test_Res_File["Sigmoid"] = "/data/local/tmp/sigmoid.bin"; | |||
| std::string in_file = "/data/local/tmp/in_data.bin"; | |||
| std::map<std::string, int>::iterator it = Test_Activation_Type.begin(); | |||
| while (it != Test_Activation_Type.end()) { | |||
| auto kernel = create_kernel(allocator, inputs, outputs, it->first, it->second, in_file, 0.3); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Create kernel:" << it->first << " error."; | |||
| return; | |||
| } | |||
| auto ret = RunSubGraphOpenCLKernel(inputs, outputs, kernel); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << it->first << " RunSubGraphOpenCLKernel error."; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "==================output data================"; | |||
| printf_tensor(outputs[0]); | |||
| CompareRes(output_tensor, Test_Res_File[it->first]); | |||
| delete kernel; | |||
| it++; | |||
| } | |||
| delete input_tensor; | |||
| delete output_tensor; | |||
| return; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -1,110 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include "utils/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/common/file_utils.h" | |||
| #include "src/runtime/kernel/arm/nnacl/pack.h" | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h" | |||
| #include "mindspore/lite/src/runtime/kernel/arm/nnacl/leaky_relu_parameter.h" | |||
| using mindspore::kernel::LeakyReluOpenCLKernel; | |||
| using mindspore::kernel::LiteKernel; | |||
| using mindspore::kernel::SubGraphOpenCLKernel; | |||
| namespace mindspore { | |||
| class TestLeakyReluOpenCL : public mindspore::CommonTest {}; | |||
| void LoadDataLeakyRelu(void *dst, size_t dst_size, const std::string &file_path) { | |||
| if (file_path.empty()) { | |||
| memset(dst, 0x00, dst_size); | |||
| } else { | |||
| auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); | |||
| memcpy(dst, src_data, dst_size); | |||
| } | |||
| } | |||
| void CompareOutLeakyRelu(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { | |||
| auto *output_data = reinterpret_cast<float *>(output_tensor->Data()); | |||
| size_t output_size = output_tensor->Size(); | |||
| auto expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); | |||
| constexpr float atol = 0.0002; | |||
| for (int i = 0; i < output_tensor->ElementsNum(); ++i) { | |||
| if (std::fabs(output_data[i] - expect_data[i]) > atol) { | |||
| printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); | |||
| printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); | |||
| printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]); | |||
| return; | |||
| } | |||
| } | |||
| printf("compare success!\n"); | |||
| printf("compare success!\n"); | |||
| printf("compare success!\n\n\n"); | |||
| } | |||
| void printf_tensor(mindspore::lite::tensor::Tensor *in_data) { | |||
| auto input_data = reinterpret_cast<float *>(in_data->Data()); | |||
| for (int i = 0; i < in_data->ElementsNum(); ++i) { | |||
| printf("%f ", input_data[i]); | |||
| } | |||
| printf("\n"); | |||
| MS_LOG(INFO) << "Print tensor done"; | |||
| } | |||
| TEST_F(TestLeakyReluOpenCL, LeakyReluFp32_dim4) { | |||
| std::string in_file = "/data/local/tmp/in_data.bin"; | |||
| std::string standard_answer_file = "/data/local/tmp/out_data.bin"; | |||
| MS_LOG(INFO) << "Begin test:"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| MS_LOG(INFO) << "Init tensors."; | |||
| std::vector<int> input_shape = {1, 4, 3, 8}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| auto tensor_type = schema::NodeType_ValueNode; | |||
| auto *input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); | |||
| auto *output_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); | |||
| std::vector<lite::tensor::Tensor *> inputs{input_tensor}; | |||
| std::vector<lite::tensor::Tensor *> outputs{output_tensor}; | |||
| // freamework to do!!! allocate memory by hand | |||
| inputs[0]->MallocData(allocator); | |||
| auto param = new LeakyReluParameter(); | |||
| param->alpha = 0.3; | |||
| auto *leakyrelu_kernel = new kernel::LeakyReluOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| leakyrelu_kernel->Init(); | |||
| MS_LOG(INFO) << "initialize sub_graph"; | |||
| std::vector<kernel::LiteKernel *> kernels{leakyrelu_kernel}; | |||
| auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| sub_graph->Init(); | |||
| MS_LOG(INFO) << "initialize input data"; | |||
| LoadDataLeakyRelu(input_tensor->Data(), input_tensor->Size(), in_file); | |||
| MS_LOG(INFO) << "==================input data================"; | |||
| printf_tensor(inputs[0]); | |||
| sub_graph->Run(); | |||
| MS_LOG(INFO) << "==================output data================"; | |||
| printf_tensor(outputs[0]); | |||
| CompareOutLeakyRelu(output_tensor, standard_answer_file); | |||
| } | |||
| } // namespace mindspore | |||