diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fullconnection.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fullconnection.cl index d09f93cfde..1cd7ff08be 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/fullconnection.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fullconnection.cl @@ -4,23 +4,17 @@ __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void FullConnection(__read_only image2d_t input, __write_only image2d_t output, __global FLT16 *weight, - __read_only image2d_t bias, int4 in_shape, int2 out_shape, int act_type) { + __read_only image2d_t bias, int N, int CI4, int CO4, int2 in_img_shape, int act_type) { int gidx = get_global_id(0); // CO4 int gidz = get_global_id(2); // N int lidx = get_local_id(0); int lidy = get_local_id(1); - int ci4 = UP_DIV(in_shape.w, C4NUM); - int hwci4 = ci4 * in_shape.y * in_shape.z; - int wci4 = ci4 * in_shape.z; - int co4 = UP_DIV(out_shape.y, C4NUM); - int n = out_shape.x; - bool inside = gidx < co4 && gidz < n; + bool inside = gidx < CO4 && gidz < N; FLT4 result = (FLT4)(0.0f); - for (uint i = lidy; i < hwci4 && inside; i += 4) { - int index_h = i / wci4; - int index_wci4 = i % wci4; - FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index_wci4, gidz * in_shape.y + index_h)); - FLT16 w = weight[i * co4 + gidx]; + for (uint i = lidy; i < CI4 && inside; i += 4) { + int index = gidz * CI4 + i; + FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index % in_img_shape.y, index / in_img_shape.y)); + FLT16 w = weight[i * CO4 + gidx]; result.x += dot(v, w.s0123); result.y += dot(v, w.s4567); result.z += dot(v, w.s89ab); @@ -46,3 +40,45 @@ __kernel void FullConnection(__read_only image2d_t input, __write_only image2d_t WRITE_IMAGE(output, (int2)(gidx, gidz), result); } } + +__kernel void FullConnectionWeightVar(__read_only image2d_t input, __write_only image2d_t output, + __read_only image2d_t weight, __read_only image2d_t bias, int N, int CI4, int CO4, + int2 in_img_shape, int act_type) { + int gidx = get_global_id(0); // CO4 + int gidz = get_global_id(2); // N + int lidx = get_local_id(0); + int lidy = get_local_id(1); + bool inside = gidx < CO4 && gidz < N; + FLT4 result = (FLT4)(0.0f); + for (uint i = lidy; i < CI4 && inside; i += 4) { + int index = gidz * CI4 + i; + FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index % in_img_shape.y, index / in_img_shape.y)); + FLT4 weight0 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4)); + result.x += dot(v, weight0); + FLT4 weight1 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 1)); + result.y += dot(v, weight1); + FLT4 weight2 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 2)); + result.z += dot(v, weight2); + FLT4 weight3 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 3)); + result.w += dot(v, weight3); + } + __local FLT4 temp[32][4]; + temp[lidx][lidy] = result; + barrier(CLK_LOCAL_MEM_FENCE); + if (lidy == 0 && inside) { + result += temp[lidx][1]; + result += temp[lidx][2]; + result += temp[lidx][3]; + result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0)); + if (act_type == ActivationType_RELU) { + result = max(result, (FLT4)(0.0f)); + } else if (act_type == ActivationType_RELU6) { + result = clamp(result, (FLT4)(0.0f), (FLT4)(6.0f)); + } else if (act_type == ActivationType_TANH) { + FLT4 exp0 = exp(result); + FLT4 exp1 = exp(-result); + result = (exp0 - exp1) / (exp0 + exp1); + } + WRITE_IMAGE(output, (int2)(gidx, gidz), result); + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc index a288e660fb..1d3f2cbadd 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc @@ -42,15 +42,39 @@ int FullConnectionOpenCLKernel::CheckSpecs() { MS_LOG(ERROR) << "fullconnection only support a_transpose_=false yet."; return RET_ERROR; } - if ((in_tensors_[0]->shape().size() != 4 && in_tensors_[0]->shape().size() != 2) || - (out_tensors_[0]->shape().size() != 4 && out_tensors_[0]->shape().size() != 2)) { - MS_LOG(ERROR) << "fullconnection only support input output shape size = 2 or 4"; + auto out_gpu_info = GpuTensorInfo(out_tensors_[0]); + if (out_gpu_info.H != 1 || out_gpu_info.W != 1) { + MS_LOG(ERROR) << "fullconnection only support 2d output shape or 4d output but H=W=1"; return RET_ERROR; } if (param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6) { MS_LOG(ERROR) << "Unsupported activation type " << param->act_type_; return RET_ERROR; } + N_ = out_gpu_info.N; + CO_ = out_gpu_info.C; + auto intensor_shape = GpuTensorInfo(in_tensors_[0]); + int input_nhw = intensor_shape.N * intensor_shape.H * intensor_shape.W; + if (input_nhw < N_) { + MS_LOG(ERROR) << "Unsupported fullconnection shape"; + } + if (!in_tensors_.at(kWeightIndex)->IsConst()) { + weight_var_ = true; + if (!param->b_transpose_) { + MS_LOG(ERROR) << "If fullconnection input weight is not constant, b_transpose_ should be true."; + return RET_ERROR; + } + if (in_tensors_.at(kWeightIndex)->shape().size() != 2) { + MS_LOG(ERROR) << "If fullconnection input weight is not constant, it should be 2d."; + return RET_ERROR; + } + if (intensor_shape.C != in_tensors_.at(kWeightIndex)->shape()[1]) { + MS_LOG(ERROR) + << "If fullconnection input weight is not constant, input channel should equal to weight in_channel."; + return RET_ERROR; + } + } + CI_remainder_ = input_nhw / N_; return RET_OK; } @@ -61,8 +85,9 @@ int FullConnectionOpenCLKernel::Prepare() { enable_fp16_ = ocl_runtime_->GetFp16Enable(); std::string kernel_name = "FullConnection"; - inShape = GpuTensorInfo(in_tensors_[0]); - outShape = GpuTensorInfo(out_tensors_[0]); + if (weight_var_) { + kernel_name = "FullConnectionWeightVar"; + } #ifdef PROGRAM_WITH_IL kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name); #else @@ -82,23 +107,26 @@ int FullConnectionOpenCLKernel::Prepare() { } int FullConnectionOpenCLKernel::InitWeights() { - if (!in_tensors_.at(kWeightIndex)->IsConst()) { - MS_LOG(ERROR) << "FullConnection don't support non-constant filter yet."; - return RET_ERROR; + if (!weight_var_) { + auto ret = InitFilter(); + if (ret != RET_OK) { + return ret; + } } + return InitBias(); +} // namespace mindspore::kernel + +int FullConnectionOpenCLKernel::InitFilter() { auto allocator = ocl_runtime_->GetAllocator(); - int ci = inShape.C; - int ci4 = UP_DIV(ci, C4NUM); - int co = outShape.C; - int co4 = UP_DIV(co, C4NUM); - int h = inShape.H; - int w = inShape.W; + auto intensor_shape = GpuTensorInfo(in_tensors_[0]); + int co4 = UP_DIV(CO_, C4NUM); + int nhw_remainder = intensor_shape.N * intensor_shape.H * intensor_shape.W / N_; size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); - padWeight_ = allocator->Malloc(h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size); + padWeight_ = allocator->Malloc(nhw_remainder * intensor_shape.Slice * co4 * C4NUM * C4NUM * dtype_size); padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true); auto padWeightFp32 = reinterpret_cast(padWeight_); auto padWeightFp16 = reinterpret_cast(padWeight_); - memset(padWeight_, 0x00, h * w * ci4 * co4 * C4NUM * C4NUM * dtype_size); + memset(padWeight_, 0x00, nhw_remainder * intensor_shape.Slice * co4 * C4NUM * C4NUM * dtype_size); auto originWeightFp32 = reinterpret_cast(in_tensors_.at(kWeightIndex)->data_c()); auto originWeightFp16 = reinterpret_cast(in_tensors_.at(kWeightIndex)->data_c()); bool isModelFp16 = in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16; @@ -107,36 +135,33 @@ int FullConnectionOpenCLKernel::InitWeights() { // HWCICO -> (HWCI4)(CO4)(4 from CO)(4 from CI) // if tranposeB, COHWCI -> (HWCI4)(CO4)(4 from CO)(4 from CI) int index = 0; - for (int hh = 0; hh < h; hh++) { - for (int ww = 0; ww < w; ww++) { - int baseHW = hh * w + ww; - for (int i = 0; i < ci4; ++i) { - for (int j = 0; j < co4; ++j) { - for (int k = 0; k < C4NUM; ++k) { - for (int l = 0; l < C4NUM; ++l) { - int src_ci = i * C4NUM + l; - int src_co = j * C4NUM + k; - if (src_ci < ci && src_co < co) { - int originId = baseHW * ci * co + src_ci * co + src_co; - if (transposeB) { - originId = src_co * ci * h * w + baseHW * ci + src_ci; - } - if (enable_fp16_) { - if (!isModelFp16) { - padWeightFp16[index++] = originWeightFp32[originId]; - } else { - padWeightFp16[index++] = originWeightFp16[originId]; - } + for (int nhw = 0; nhw < nhw_remainder; nhw++) { + for (int i = 0; i < intensor_shape.Slice; ++i) { + for (int j = 0; j < co4; ++j) { + for (int k = 0; k < C4NUM; ++k) { + for (int l = 0; l < C4NUM; ++l) { + int src_ci = i * C4NUM + l; + int src_co = j * C4NUM + k; + if (src_ci < intensor_shape.C && src_co < CO_) { + int originId = (nhw * intensor_shape.C + src_ci) * CO_ + src_co; + if (transposeB) { + originId = src_co * intensor_shape.C * nhw_remainder + nhw * intensor_shape.C + src_ci; + } + if (enable_fp16_) { + if (!isModelFp16) { + padWeightFp16[index++] = originWeightFp32[originId]; } else { - if (!isModelFp16) { - padWeightFp32[index++] = originWeightFp32[originId]; - } else { - padWeightFp32[index++] = originWeightFp16[originId]; - } + padWeightFp16[index++] = originWeightFp16[originId]; } } else { - index++; + if (!isModelFp16) { + padWeightFp32[index++] = originWeightFp32[originId]; + } else { + padWeightFp32[index++] = originWeightFp16[originId]; + } } + } else { + index++; } } } @@ -144,8 +169,14 @@ int FullConnectionOpenCLKernel::InitWeights() { } } allocator->UnmapBuffer(padWeight_); + return RET_OK; +} +int FullConnectionOpenCLKernel::InitBias() { // pad FC Bias + auto allocator = ocl_runtime_->GetAllocator(); + int co4 = UP_DIV(CO_, C4NUM); + size_t dtype_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); size_t im_dst_x, im_dst_y; im_dst_x = co4; im_dst_y = 1; @@ -163,15 +194,15 @@ int FullConnectionOpenCLKernel::InitWeights() { return RET_ERROR; } if (in_tensors_[2]->data_type() == kNumberTypeFloat32 && enable_fp16_) { - for (int i = 0; i < co; i++) { + for (int i = 0; i < CO_; i++) { reinterpret_cast(bias_)[i] = reinterpret_cast(in_tensors_[2]->data_c())[i]; } } else if (in_tensors_[2]->data_type() == kNumberTypeFloat16 && !enable_fp16_) { - for (int i = 0; i < co; i++) { + for (int i = 0; i < CO_; i++) { reinterpret_cast(bias_)[i] = reinterpret_cast(in_tensors_[2]->data_c())[i]; } } else { - memcpy(bias_, in_tensors_[2]->data_c(), co * dtype_size); + memcpy(bias_, in_tensors_[2]->data_c(), CO_ * dtype_size); } } allocator->UnmapBuffer(bias_); @@ -180,20 +211,27 @@ int FullConnectionOpenCLKernel::InitWeights() { void FullConnectionOpenCLKernel::SetGlobalLocal() { local_size_ = {32, 4, 1}; - global_size_ = {UP_DIV(outShape.C, C4NUM), 4, outShape.N}; + size_t CO = CO_; + size_t N = N_; + global_size_ = {UP_DIV(CO, C4NUM), 4, N}; AlignGlobalLocal(global_size_, local_size_); } void FullConnectionOpenCLKernel::SetConstArgs() { - int arg_count = 2; - cl_int4 in_shape = {static_cast(inShape.N), static_cast(inShape.H), static_cast(inShape.W), - static_cast(inShape.C)}; - cl_int2 out_shape = {static_cast(outShape.N), static_cast(outShape.C)}; - ocl_runtime_->SetKernelArg(kernel_, arg_count++, padWeight_, lite::opencl::MemType::BUF); - auto *param = reinterpret_cast(op_parameter_); + if (!weight_var_) { + ocl_runtime_->SetKernelArg(kernel_, 2, padWeight_, lite::opencl::MemType::BUF); + } + int arg_count = 3; ocl_runtime_->SetKernelArg(kernel_, arg_count++, bias_); - ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_shape); - ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_shape); + ocl_runtime_->SetKernelArg(kernel_, arg_count++, N_); + auto intensor_shape = GpuTensorInfo(in_tensors_[0]); + int CI4 = CI_remainder_ * intensor_shape.Slice; + ocl_runtime_->SetKernelArg(kernel_, arg_count++, CI4); + ocl_runtime_->SetKernelArg(kernel_, arg_count++, UP_DIV(CO_, C4NUM)); + auto in_shape_info = GpuTensorInfo(in_tensors_[0]); + cl_int2 in_img_shape = {static_cast(in_shape_info.height), static_cast(in_shape_info.width)}; + ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_img_shape); + auto *param = reinterpret_cast(op_parameter_); ocl_runtime_->SetKernelArg(kernel_, arg_count, static_cast(param->act_type_)); } @@ -202,6 +240,9 @@ int FullConnectionOpenCLKernel::Run() { int arg_count = 0; ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[0]->data_c()); ocl_runtime_->SetKernelArg(kernel_, arg_count++, out_tensors_[0]->data_c()); + if (weight_var_) { + ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[1]->data_c()); + } ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.h index 275ce3f12a..9463f15068 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.h @@ -40,13 +40,17 @@ class FullConnectionOpenCLKernel : public OpenCLKernel { int Tune() override { return lite::RET_OK; } private: + int InitFilter(); + int InitBias(); void *padWeight_{nullptr}; void *bias_{nullptr}; bool enable_fp16_{false}; bool transposeA{false}; bool transposeB{true}; - GpuTensorInfo inShape = GpuTensorInfo(nullptr); - GpuTensorInfo outShape = GpuTensorInfo(nullptr); + bool weight_var_{false}; + int N_{1}; + int CI_remainder_{1}; + int CO_{1}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.cc index 0ecd1f3498..af58526190 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.cc @@ -68,7 +68,7 @@ cl_float4 ReduceOpenCLKernel::GenC4Mask() { int ReduceOpenCLKernel::CheckSpecs() { if (in_tensors_[0]->shape()[0] > 1) { - MS_LOG(ERROR) << "reduce op only support n=2"; + MS_LOG(ERROR) << "reduce op only support n = 1"; return RET_PARAM_INVALID; } auto reduce_param = reinterpret_cast(op_parameter_); @@ -76,6 +76,10 @@ int ReduceOpenCLKernel::CheckSpecs() { MS_LOG(ERROR) << "not supported reduce type:" << reduce_param->mode_; return RET_PARAM_INVALID; } + if (reduce_param->num_axes_ == 1 && reduce_param->axes_[0] == 3 && in_tensors_[0]->shape()[2] == 1) { + reduce_param->num_axes_ = 2; + reduce_param->axes_[1] = 2; + } if (reduce_param->num_axes_ != 2) { MS_LOG(ERROR) << "reduce op only support axes=2"; return RET_PARAM_INVALID; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc index f48a853e14..1a4aebe578 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/fullconnection_tests.cc @@ -24,7 +24,7 @@ namespace { // PrimitiveType_FullConnection: src/ops/populate/full_connection_populate.cc OpParameter *CreateParameter(std::vector *input_shape, std::vector *weight_shape, std::vector *bias_shape, std::vector *output_shape, int ndim, int ci, int co, - int n = 1, int h = 1, int w = 1) { + int n = 1, int h = 1, int w = 1, int in_n = 1) { auto *param = test::CreateParameter(schema::PrimitiveType_FullConnection); param->a_transpose_ = false; param->b_transpose_ = true; @@ -41,6 +41,11 @@ OpParameter *CreateParameter(std::vector *input_shape, std::vector *we *output_shape = {n, co}; *weight_shape = {co, h * w * ci}; *bias_shape = {co}; + } else if (ndim == 3) { + *input_shape = {in_n, w, ci}; + *output_shape = {n, co}; + *weight_shape = {co, in_n * w * ci / n}; + *bias_shape = {co}; } return reinterpret_cast(param); } @@ -87,4 +92,47 @@ TEST_F(TestOpenCL_FullConnection, 4D) { } } +TEST_F(TestOpenCL_FullConnection, 3D) { + int ndim = 3; + int ci = 3; + int co = 4; + int n = 2; + int h = 1; + int w = 4; + int in_n = 1; + float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + float weight_data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + float bias_data[] = {1, 1, 1, 1}; + float output_data[] = {16, 16, 16, 16, 52, 52, 52, 52}; + + for (auto fp16_enable : {false, true}) { + std::vector input_shape, weight_shape, bias_shape, output_shape; + auto *param = CreateParameter(&input_shape, &weight_shape, &bias_shape, &output_shape, ndim, ci, co, n, h, w, in_n); + TestMain({{input_shape, input_data, VAR}, + {weight_shape, weight_data, CONST_TENSOR}, + {bias_shape, bias_data, CONST_TENSOR}}, + {output_shape, output_data}, param, fp16_enable); + } +} + +TEST_F(TestOpenCL_FullConnection, 3DWeightVar) { + int ndim = 3; + int ci = 6; + int co = 4; + int n = 2; + int h = 1; + int w = 2; + int in_n = 1; + float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + float weight_data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + float bias_data[] = {1, 1, 1, 1}; + float output_data[] = {16, 16, 16, 16, 52, 52, 52, 52}; + + for (auto fp16_enable : {false, true}) { + std::vector input_shape, weight_shape, bias_shape, output_shape; + auto *param = CreateParameter(&input_shape, &weight_shape, &bias_shape, &output_shape, ndim, ci, co, n, h, w, in_n); + TestMain({{input_shape, input_data, VAR}, {weight_shape, weight_data, VAR}, {bias_shape, bias_data, CONST_TENSOR}}, + {output_shape, output_data}, param, fp16_enable); + } +} } // namespace mindspore::lite::opencl::test