From 2da8370f185aade83dcf4ac4e6e9a59e47ea4f2b Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Mon, 26 Oct 2020 15:58:32 +0800 Subject: [PATCH] optimize arithmetic --- .../runtime/kernel/opencl/cl/arithmetic.cl | 72 ++++++++----------- .../kernel/opencl/kernel/arithmetic.cc | 4 ++ 2 files changed, 32 insertions(+), 44 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl index 8f4cbccd41..c0f53be904 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl @@ -269,25 +269,21 @@ __kernel void BroadcastNHWC4Add_IMG(__read_only image2d_t input_a, __read_only i const int4 output_shape, float act_min, float act_max) { int X = get_global_id(0); // C4 int Y = get_global_id(1); // W - int Z = get_global_id(2); // N * H - int N = Z / output_shape.y; - int H = Z % output_shape.y; - if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.x * output_shape.y) { + int Z = get_global_id(2); // H + if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) { return; } int a_c = X < a_shape.w ? X : a_shape.w - 1; int a_w = Y < a_shape.z ? Y : a_shape.z - 1; - int a_h = H < a_shape.y ? H : a_shape.y - 1; - int a_n = N < a_shape.x ? N : a_shape.x - 1; - FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h)); + int a_h = Z < a_shape.y ? Z : a_shape.y - 1; + FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h)); int b_c = X < b_shape.w ? X : b_shape.w - 1; int b_w = Y < b_shape.z ? Y : b_shape.z - 1; - int b_h = H < b_shape.y ? H : b_shape.y - 1; - int b_n = N < b_shape.x ? N : b_shape.x - 1; - FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h)); + int b_h = Z < b_shape.y ? Z : b_shape.y - 1; + FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h)); FLT4 result = a + b; result = clamp(result, (FLT)(act_min), (FLT)(act_max)); - WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, N * output_shape.y + H), result); + WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result); } __kernel void BroadcastNHWC4Sub_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b, @@ -295,25 +291,21 @@ __kernel void BroadcastNHWC4Sub_IMG(__read_only image2d_t input_a, __read_only i const int4 output_shape, float act_min, float act_max) { int X = get_global_id(0); // C4 int Y = get_global_id(1); // W - int Z = get_global_id(2); // N * H - int N = Z / output_shape.y; - int H = Z % output_shape.y; - if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.x * output_shape.y) { + int Z = get_global_id(2); // H + if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) { return; } int a_c = X < a_shape.w ? X : a_shape.w - 1; int a_w = Y < a_shape.z ? Y : a_shape.z - 1; - int a_h = H < a_shape.y ? H : a_shape.y - 1; - int a_n = N < a_shape.x ? N : a_shape.x - 1; - FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h)); + int a_h = Z < a_shape.y ? Z : a_shape.y - 1; + FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h)); int b_c = X < b_shape.w ? X : b_shape.w - 1; int b_w = Y < b_shape.z ? Y : b_shape.z - 1; - int b_h = H < b_shape.y ? H : b_shape.y - 1; - int b_n = N < b_shape.x ? N : b_shape.x - 1; - FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h)); + int b_h = Z < b_shape.y ? Z : b_shape.y - 1; + FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h)); FLT4 result = a - b; result = clamp(result, (FLT)(act_min), (FLT)(act_max)); - WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, N * output_shape.y + H), result); + WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result); } __kernel void BroadcastNHWC4Mul_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b, @@ -321,25 +313,21 @@ __kernel void BroadcastNHWC4Mul_IMG(__read_only image2d_t input_a, __read_only i const int4 output_shape, float act_min, float act_max) { int X = get_global_id(0); // C4 int Y = get_global_id(1); // W - int Z = get_global_id(2); // N * H - int N = Z / output_shape.y; - int H = Z % output_shape.y; - if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.x * output_shape.y) { + int Z = get_global_id(2); // H + if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) { return; } int a_c = X < a_shape.w ? X : a_shape.w - 1; int a_w = Y < a_shape.z ? Y : a_shape.z - 1; - int a_h = H < a_shape.y ? H : a_shape.y - 1; - int a_n = N < a_shape.x ? N : a_shape.x - 1; - FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h)); + int a_h = Z < a_shape.y ? Z : a_shape.y - 1; + FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h)); int b_c = X < b_shape.w ? X : b_shape.w - 1; int b_w = Y < b_shape.z ? Y : b_shape.z - 1; - int b_h = H < b_shape.y ? H : b_shape.y - 1; - int b_n = N < b_shape.x ? N : b_shape.x - 1; - FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h)); + int b_h = Z < b_shape.y ? Z : b_shape.y - 1; + FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h)); FLT4 result = a * b; result = clamp(result, (FLT)(act_min), (FLT)(act_max)); - WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, N * output_shape.y + H), result); + WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result); } __kernel void BroadcastNHWC4Div_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b, @@ -347,25 +335,21 @@ __kernel void BroadcastNHWC4Div_IMG(__read_only image2d_t input_a, __read_only i const int4 output_shape, float act_min, float act_max) { int X = get_global_id(0); // C4 int Y = get_global_id(1); // W - int Z = get_global_id(2); // N * H - int N = Z / output_shape.y; - int H = Z % output_shape.y; - if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.x * output_shape.y) { + int Z = get_global_id(2); // H + if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) { return; } int a_c = X < a_shape.w ? X : a_shape.w - 1; int a_w = Y < a_shape.z ? Y : a_shape.z - 1; - int a_h = H < a_shape.y ? H : a_shape.y - 1; - int a_n = N < a_shape.x ? N : a_shape.x - 1; - FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h)); + int a_h = Z < a_shape.y ? Z : a_shape.y - 1; + FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h)); int b_c = X < b_shape.w ? X : b_shape.w - 1; int b_w = Y < b_shape.z ? Y : b_shape.z - 1; - int b_h = H < b_shape.y ? H : b_shape.y - 1; - int b_n = N < b_shape.x ? N : b_shape.x - 1; - FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h)); + int b_h = Z < b_shape.y ? Z : b_shape.y - 1; + FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h)); FLT4 result = a / b; result = clamp(result, (FLT)(act_min), (FLT)(act_max)); - WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, N * output_shape.y + H), result); + WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result); } __kernel void BroadcastAnd_IMG(__read_only image2d_t input_a, float b, __write_only image2d_t output, diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc index 9850e40e51..a1c82b8c68 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc @@ -136,6 +136,10 @@ int ArithmeticOpenCLKernel::Init() { if (arithmetic_parameter->broadcasting_) { element_flag_ = false; kernel_name = "BroadcastNHWC4"; + if (out_tensors_[0]->shape()[0] > 1) { + MS_LOG(ERROR) << "Broadcasting don't support N > 1"; + return RET_ERROR; + } } else { kernel_name = "Element"; }