| @@ -3,7 +3,8 @@ | |||||
| #define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) | #define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) | ||||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | ||||
| __kernel void FullConnection_NHWC4(__read_only image2d_t input, __global FLT16 *weight, __read_only image2d_t bias, | __kernel void FullConnection_NHWC4(__read_only image2d_t input, __global FLT16 *weight, __read_only image2d_t bias, | ||||
| __write_only image2d_t output, int4 in_shape, int2 out_shape) { | |||||
| __write_only image2d_t output, int4 in_shape, int2 out_shape, float act_min, | |||||
| float act_max) { | |||||
| int gidx = get_global_id(0); // CO4 | int gidx = get_global_id(0); // CO4 | ||||
| int gidz = get_global_id(2); // N | int gidz = get_global_id(2); // N | ||||
| int lidx = get_local_id(0); | int lidx = get_local_id(0); | ||||
| @@ -12,9 +13,9 @@ __kernel void FullConnection_NHWC4(__read_only image2d_t input, __global FLT16 * | |||||
| int hwci4 = ci4 * in_shape.y * in_shape.z; | int hwci4 = ci4 * in_shape.y * in_shape.z; | ||||
| int co4 = UP_DIV(out_shape.y, C4NUM); | int co4 = UP_DIV(out_shape.y, C4NUM); | ||||
| int n = out_shape.x; | int n = out_shape.x; | ||||
| bool inside = gidx < co4 && gidz < n; | |||||
| if (gidx >= co4 || gidz >= n) return; | |||||
| FLT4 result = (FLT4)(0.0f); | FLT4 result = (FLT4)(0.0f); | ||||
| for (uint i = lidy; i < hwci4 && inside; i += 4) { | |||||
| for (uint i = lidy; i < hwci4; i += 4) { | |||||
| int index_h = i / (ci4 * in_shape.z); | int index_h = i / (ci4 * in_shape.z); | ||||
| int index_wci4 = i % (ci4 * in_shape.z); | int index_wci4 = i % (ci4 * in_shape.z); | ||||
| FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index_wci4, gidz * in_shape.y + index_h)); | FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index_wci4, gidz * in_shape.y + index_h)); | ||||
| @@ -24,83 +25,15 @@ __kernel void FullConnection_NHWC4(__read_only image2d_t input, __global FLT16 * | |||||
| result.z += dot(v, w.s89ab); | result.z += dot(v, w.s89ab); | ||||
| result.w += dot(v, w.scdef); | result.w += dot(v, w.scdef); | ||||
| } | } | ||||
| __local FLT4 temp[32][4]; | |||||
| temp[lidx][lidy] = result; | |||||
| __local FLT4 temp[4]; | |||||
| temp[lidy] = result; | |||||
| barrier(CLK_LOCAL_MEM_FENCE); | barrier(CLK_LOCAL_MEM_FENCE); | ||||
| if (lidy == 0 && inside) { | |||||
| result += temp[lidx][1]; | |||||
| result += temp[lidx][2]; | |||||
| result += temp[lidx][3]; | |||||
| if (lidy == 0) { | |||||
| result += temp[1]; | |||||
| result += temp[2]; | |||||
| result += temp[3]; | |||||
| result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0)); | result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0)); | ||||
| result = clamp(result, (FLT)(act_min), (FLT)(act_max)); | |||||
| WRITE_IMAGE(output, (int2)(gidx, gidz), result); | WRITE_IMAGE(output, (int2)(gidx, gidz), result); | ||||
| } | } | ||||
| } | } | ||||
| __kernel void FullConnection_NHWC4_ReLU(__read_only image2d_t input, __global FLT16 *weight, __read_only image2d_t bias, | |||||
| __write_only image2d_t output, int4 in_shape, int2 out_shape) { | |||||
| int gidx = get_global_id(0); // CO4 | |||||
| int gidz = get_global_id(2); // N | |||||
| int lidx = get_local_id(0); | |||||
| int lidy = get_local_id(1); | |||||
| int ci4 = UP_DIV(in_shape.w, C4NUM); | |||||
| int hwci4 = ci4 * in_shape.y * in_shape.z; | |||||
| int co4 = UP_DIV(out_shape.y, C4NUM); | |||||
| int n = out_shape.x; | |||||
| bool inside = gidx < co4 && gidz < n; | |||||
| FLT4 result = (FLT4)(0.0f); | |||||
| for (uint i = lidy; i < hwci4 && inside; i += 4) { | |||||
| int index_h = i / (ci4 * in_shape.z); | |||||
| int index_wci4 = i % (ci4 * in_shape.z); | |||||
| FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index_wci4, gidz * in_shape.y + index_h)); | |||||
| FLT16 w = weight[i * co4 + gidx]; | |||||
| result.x += dot(v, w.s0123); | |||||
| result.y += dot(v, w.s4567); | |||||
| result.z += dot(v, w.s89ab); | |||||
| result.w += dot(v, w.scdef); | |||||
| } | |||||
| __local FLT4 temp[32][4]; | |||||
| temp[lidx][lidy] = result; | |||||
| barrier(CLK_LOCAL_MEM_FENCE); | |||||
| if (lidy == 0 && inside) { | |||||
| result += temp[lidx][1]; | |||||
| result += temp[lidx][2]; | |||||
| result += temp[lidx][3]; | |||||
| result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0)); | |||||
| result = max(result, (FLT4)(0.f)); | |||||
| WRITE_IMAGE(output, (int2)(gidx, gidz), result); | |||||
| } | |||||
| } | |||||
| __kernel void FullConnection_NC4HW4(__read_only image2d_t input, __global FLT16 *weight, __read_only image2d_t bias, | |||||
| __write_only image2d_t output, int4 in_shape, int2 out_shape) { | |||||
| int gidx = get_global_id(0); // CO4 | |||||
| int gidz = get_global_id(2); // N | |||||
| int lidx = get_local_id(0); | |||||
| int lidy = get_local_id(1); | |||||
| int ci4 = UP_DIV(in_shape.w, C4NUM); | |||||
| int hwci4 = ci4 * in_shape.y * in_shape.z; | |||||
| int co4 = UP_DIV(out_shape.y, C4NUM); | |||||
| int n = out_shape.x; | |||||
| bool inside = gidx < co4 && gidz < n; | |||||
| FLT4 result = (FLT4)(0.0f); | |||||
| for (uint i = lidy; i < hwci4 && inside; i += 4) { | |||||
| int index_ci4h = i / in_shape.z; | |||||
| int index_w = i % in_shape.z; | |||||
| FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index_w, gidz * in_shape.y * ci4 + index_ci4h)); | |||||
| FLT16 w = weight[i * co4 + gidx]; | |||||
| result.x += dot(v, w.s0123); | |||||
| result.y += dot(v, w.s4567); | |||||
| result.z += dot(v, w.s89ab); | |||||
| result.w += dot(v, w.scdef); | |||||
| } | |||||
| __local FLT4 temp[32][4]; | |||||
| temp[lidx][lidy] = result; | |||||
| barrier(CLK_LOCAL_MEM_FENCE); | |||||
| if (lidy == 0 && inside) { | |||||
| result += temp[lidx][1]; | |||||
| result += temp[lidx][2]; | |||||
| result += temp[lidx][3]; | |||||
| result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0)); | |||||
| WRITE_IMAGE(output, (int2)(0, gidz * co4 + gidx), result); | |||||
| } | |||||
| } | |||||
| @@ -19,6 +19,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/runtime/kernel/opencl/kernel/convolution.h" | #include "src/runtime/kernel/opencl/kernel/convolution.h" | ||||
| #include "src/runtime/kernel/opencl/kernel/fullconnection.h" | |||||
| #include "src/runtime/kernel/opencl/utils.h" | #include "src/runtime/kernel/opencl/utils.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -29,6 +30,7 @@ using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using mindspore::schema::PrimitiveType_Conv2D; | using mindspore::schema::PrimitiveType_Conv2D; | ||||
| using mindspore::schema::PrimitiveType_FullConnection; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| @@ -339,12 +341,40 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector<lite::Tenso | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | ||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | const lite::InnerContext *ctx, const kernel::KernelKey &desc, | ||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| auto *kernel = | |||||
| new (std::nothrow) ConvolutionOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Create OpenCL Convolution kernel failed!"; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| kernel::LiteKernel *kernel; | |||||
| bool is_hw1 = inputs[0]->shape().size() == 4 && inputs[0]->shape()[1] == 1 && inputs[0]->shape()[2] == 1 && | |||||
| outputs[0]->shape().size() == 4 && outputs[0]->shape()[1] == 1 && outputs[0]->shape()[2] == 1; | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||||
| bool is_pad_stride_ok = conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1 && conv_param->stride_h_ == 1 && | |||||
| conv_param->stride_w_ == 1 && conv_param->pad_u_ == 0 && conv_param->pad_d_ == 0 && | |||||
| conv_param->pad_l_ == 0 && conv_param->pad_r_ == 0 && conv_param->dilation_h_ == 1 && | |||||
| conv_param->dilation_w_ == 1; | |||||
| if (is_hw1 && is_pad_stride_ok) { | |||||
| auto param = static_cast<MatMulParameter *>(malloc(sizeof(MatMulParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "Create OpenCL FullConnection kernel param failed!"; | |||||
| return nullptr; | |||||
| } | |||||
| param->op_parameter_.type_ = PrimitiveType_FullConnection; | |||||
| param->a_transpose_ = false; | |||||
| param->b_transpose_ = true; | |||||
| param->act_type_ = conv_param->act_type_; | |||||
| kernel = new (std::nothrow) FullConnectionOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Create OpenCL FullConnection kernel failed!"; | |||||
| free(param); | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } else { | |||||
| free(opParameter); | |||||
| } | |||||
| } else { | |||||
| kernel = new (std::nothrow) ConvolutionOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Create OpenCL Convolution kernel failed!"; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| } | } | ||||
| auto ret = kernel->Init(); | auto ret = kernel->Init(); | ||||
| if (ret != mindspore::lite::RET_OK) { | if (ret != mindspore::lite::RET_OK) { | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "nnacl/fp32/common_func.h" | #include "nnacl/fp32/common_func.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/opencl/kernel/fullconnection.h" | #include "src/runtime/kernel/opencl/kernel/fullconnection.h" | ||||
| #include "src/runtime/kernel/opencl/utils.h" | |||||
| #ifndef PROGRAM_WITH_IL | #ifndef PROGRAM_WITH_IL | ||||
| #include "src/runtime/kernel/opencl/cl/fullconnection.cl.inc" | #include "src/runtime/kernel/opencl/cl/fullconnection.cl.inc" | ||||
| #endif | #endif | ||||
| @@ -43,23 +44,22 @@ int FullConnectionOpenCLKernel::Init() { | |||||
| transposeB = param->b_transpose_; | transposeB = param->b_transpose_; | ||||
| enable_fp16_ = ocl_runtime_->GetFp16Enable(); | enable_fp16_ = ocl_runtime_->GetFp16Enable(); | ||||
| if ((in_tensors_[0]->shape().size() != 4 && in_tensors_[0]->shape().size() != 2) || | if ((in_tensors_[0]->shape().size() != 4 && in_tensors_[0]->shape().size() != 2) || | ||||
| out_tensors_[0]->shape().size() != 2) { | |||||
| MS_LOG(ERROR) << "fullconnection only support input shape size = 2 or 4"; | |||||
| (out_tensors_[0]->shape().size() != 4 && out_tensors_[0]->shape().size() != 2)) { | |||||
| MS_LOG(ERROR) << "fullconnection only support input output shape size = 2 or 4"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (in_tensors_[0]->shape().size() == 4) { | |||||
| inShape = {in_tensors_[0]->shape()[0], in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2], | |||||
| in_tensors_[0]->shape()[3]}; | |||||
| } else { | |||||
| inShape = {in_tensors_[0]->shape()[0], 1, 1, in_tensors_[0]->shape()[1]}; | |||||
| } | |||||
| outShape = out_tensors_[0]->shape(); | |||||
| // call default move constructor(elemwised moved) | |||||
| inShape = Image2DInfo(in_tensors_[0]); | |||||
| outShape = Image2DInfo(out_tensors_[0]); | |||||
| switch (param->act_type_) { | switch (param->act_type_) { | ||||
| case ActType_No: | case ActType_No: | ||||
| break; | break; | ||||
| case ActType_Relu: | case ActType_Relu: | ||||
| kernel_name += "_ReLU"; | |||||
| activation_min_ = 0.f; | |||||
| break; | |||||
| case ActType_Relu6: | |||||
| activation_min_ = 0.f; | |||||
| activation_max_ = 6.f; | |||||
| break; | break; | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupported activation type " << param->act_type_; | MS_LOG(ERROR) << "Unsupported activation type " << param->act_type_; | ||||
| @@ -81,14 +81,13 @@ int FullConnectionOpenCLKernel::Init() { | |||||
| } | } | ||||
| void FullConnectionOpenCLKernel::PadWeight() { | void FullConnectionOpenCLKernel::PadWeight() { | ||||
| // ABMCI @ ABCICO = ABMCO | |||||
| auto allocator = ocl_runtime_->GetAllocator(); | auto allocator = ocl_runtime_->GetAllocator(); | ||||
| int ci = inShape[3]; | |||||
| int ci = inShape.C; | |||||
| int ci4 = UP_DIV(ci, C4NUM); | int ci4 = UP_DIV(ci, C4NUM); | ||||
| int co = outShape[1]; | |||||
| int co = outShape.C; | |||||
| int co4 = UP_DIV(co, C4NUM); | int co4 = UP_DIV(co, C4NUM); | ||||
| int h = inShape[1]; | |||||
| int w = inShape[2]; | |||||
| int h = inShape.H; | |||||
| int w = inShape.W; | |||||
| size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); | size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); | ||||
| padWeight_ = allocator->Malloc(h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size); | padWeight_ = allocator->Malloc(h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size); | ||||
| @@ -172,18 +171,20 @@ void FullConnectionOpenCLKernel::PadWeight() { | |||||
| int FullConnectionOpenCLKernel::Run() { | int FullConnectionOpenCLKernel::Run() { | ||||
| MS_LOG(DEBUG) << this->name() << " Running!"; | MS_LOG(DEBUG) << this->name() << " Running!"; | ||||
| // local size should less than MAX_GROUP_SIZE | |||||
| std::vector<size_t> local = {32, 4, 1}; | |||||
| std::vector<size_t> global = {UP_DIV(static_cast<size_t>(outShape[1]), C4NUM), 4, static_cast<size_t>(outShape[0])}; | |||||
| std::vector<size_t> local = {1, 4, 1}; | |||||
| std::vector<size_t> global = {UP_DIV(outShape.C, C4NUM), 4, outShape.N}; | |||||
| int arg_count = 0; | int arg_count = 0; | ||||
| cl_int4 in_shape = {inShape[0], inShape[1], inShape[2], inShape[3]}; | |||||
| cl_int2 out_shape = {outShape[0], outShape[1]}; | |||||
| cl_int4 in_shape = {static_cast<int>(inShape.N), static_cast<int>(inShape.H), static_cast<int>(inShape.W), | |||||
| static_cast<int>(inShape.C)}; | |||||
| cl_int2 out_shape = {static_cast<int>(outShape.N), static_cast<int>(outShape.C)}; | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[0]->data_c()); | ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[0]->data_c()); | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_count++, padWeight_, lite::opencl::MemType::BUF); | ocl_runtime_->SetKernelArg(kernel_, arg_count++, padWeight_, lite::opencl::MemType::BUF); | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_count++, bias_); | ocl_runtime_->SetKernelArg(kernel_, arg_count++, bias_); | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_tensors_[0]->data_c()); | ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_tensors_[0]->data_c()); | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_shape); | ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_shape); | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_shape); | ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_shape); | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_count++, activation_min_); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_count++, activation_max_); | |||||
| ocl_runtime_->RunKernel(kernel_, global, local, nullptr); | ocl_runtime_->RunKernel(kernel_, global, local, nullptr); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -43,8 +43,10 @@ class FullConnectionOpenCLKernel : public OpenCLKernel { | |||||
| bool enable_fp16_{false}; | bool enable_fp16_{false}; | ||||
| bool transposeA{false}; | bool transposeA{false}; | ||||
| bool transposeB{true}; | bool transposeB{true}; | ||||
| std::vector<int> inShape; | |||||
| std::vector<int> outShape; | |||||
| float activation_min_{-FLT_MAX}; | |||||
| float activation_max_{FLT_MAX}; | |||||
| Image2DInfo inShape = Image2DInfo(nullptr); | |||||
| Image2DInfo outShape = Image2DInfo(nullptr); | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||