diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/pooling2d.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/pooling2d.cl index d345617c0c..130e296409 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/pooling2d.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/pooling2d.cl @@ -7,12 +7,13 @@ __kernel void AvgPooling2d_NHWC4_IMG(__read_only image2d_t input, __write_only i const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { // axis to dst tensor coordinate - int X = get_global_id(2); - int Y = get_global_id(1); - int Z = get_global_id(0); - + int X = get_global_id(2); // N*H + int Y = get_global_id(1); // W + int Z = get_global_id(0); // C4 + int N = X / output_shape.y; + X = X % output_shape.y; // boundary check - if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { + if (N >= output_shape.x || X >= output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { return; } @@ -23,28 +24,30 @@ __kernel void AvgPooling2d_NHWC4_IMG(__read_only image2d_t input, __write_only i for (int ky = 0; ky < kernel_size.y; ++ky) { int y_c = ys + ky; - bool outside_y = y_c < 0 || y_c >= input_shape.y; + bool outside_y = y_c < 0 || y_c >= input_shape.z; for (int kx = 0; kx < kernel_size.x; ++kx) { int x_c = xs + kx; - bool outside = outside_y || x_c < 0 || x_c >= input_shape.x; - r += !outside ? READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, x_c)) : (FLT4)(0.0f); + bool outside = outside_y || x_c < 0 || x_c >= input_shape.y; + r += + !outside ? READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, N * input_shape.y + x_c)) : (FLT4)(0.0f); window_size += !outside ? 1.0f : 0.0f; } } FLT4 result = TO_FLT4(divide_no_check(r, window_size)); - WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, X), result); + WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, N * output_shape.y + X), result); } __kernel void AvgPooling2d_ReLU_NHWC4_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { // axis to dst tensor coordinate - int X = get_global_id(2); - int Y = get_global_id(1); - int Z = get_global_id(0); - + int X = get_global_id(2); // N*H + int Y = get_global_id(1); // W + int Z = get_global_id(0); // C4 + int N = X / output_shape.y; + X = X % output_shape.y; // boundary check - if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { + if (N >= output_shape.x || X >= output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { return; } @@ -55,28 +58,30 @@ __kernel void AvgPooling2d_ReLU_NHWC4_IMG(__read_only image2d_t input, __write_o for (int ky = 0; ky < kernel_size.y; ++ky) { int y_c = ys + ky; - bool outside_y = y_c < 0 || y_c >= input_shape.y; + bool outside_y = y_c < 0 || y_c >= input_shape.z; for (int kx = 0; kx < kernel_size.x; ++kx) { int x_c = xs + kx; - bool outside = outside_y || x_c < 0 || x_c >= input_shape.x; - r += !outside ? READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, x_c)) : (FLT4)(0.0f); + bool outside = outside_y || x_c < 0 || x_c >= input_shape.y; + r += + !outside ? READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, N * input_shape.y + x_c)) : (FLT4)(0.0f); window_size += !outside ? 1.0f : 0.0f; } } FLT4 result = TO_FLT4(divide_no_check(r, window_size)); - WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, X), max(result, (FLT4)(0.f))); + WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, N * output_shape.y + X), max(result, (FLT4)(0.f))); } __kernel void MaxPooling2d_NHWC4_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { // axis to dst tensor coordinate - int X = get_global_id(2); - int Y = get_global_id(1); - int Z = get_global_id(0); - + int X = get_global_id(2); // N*H + int Y = get_global_id(1); // W + int Z = get_global_id(0); // C4 + int N = X / output_shape.y; + X = X % output_shape.y; // boundary check - if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { + if (N >= output_shape.x || X >= output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { return; } @@ -85,27 +90,28 @@ __kernel void MaxPooling2d_NHWC4_IMG(__read_only image2d_t input, __write_only i int ys = Y * stride.y - padding.y; for (int ky = 0; ky < kernel_size.y; ++ky) { int y_c = ys + ky; - if (y_c < 0 || y_c >= input_shape.y) continue; + if (y_c < 0 || y_c >= input_shape.z) continue; for (int kx = 0; kx < kernel_size.x; ++kx) { int x_c = xs + kx; - if (x_c < 0 || x_c >= input_shape.x) continue; - FLT4 src = READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, x_c)); + if (x_c < 0 || x_c >= input_shape.y) continue; + FLT4 src = READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, N * input_shape.y + x_c)); maximum = max(src, maximum); } } - WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, X), maximum); + WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, N * output_shape.y + X), maximum); } __kernel void MaxPooling2d_ReLU_NHWC4_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { // axis to dst tensor coordinate - int X = get_global_id(2); - int Y = get_global_id(1); - int Z = get_global_id(0); - + int X = get_global_id(2); // N*H + int Y = get_global_id(1); // W + int Z = get_global_id(0); // C4 + int N = X / output_shape.y; + X = X % output_shape.y; // boundary check - if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { + if (N >= output_shape.x || X >= output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { return; } @@ -114,13 +120,13 @@ __kernel void MaxPooling2d_ReLU_NHWC4_IMG(__read_only image2d_t input, __write_o int ys = Y * stride.y - padding.y; for (int ky = 0; ky < kernel_size.y; ++ky) { int y_c = ys + ky; - if (y_c < 0 || y_c >= input_shape.y) continue; + if (y_c < 0 || y_c >= input_shape.z) continue; for (int kx = 0; kx < kernel_size.x; ++kx) { int x_c = xs + kx; - if (x_c < 0 || x_c >= input_shape.x) continue; - FLT4 src = READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, x_c)); + if (x_c < 0 || x_c >= input_shape.y) continue; + FLT4 src = READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, N * input_shape.y + x_c)); maximum = max(src, maximum); } } - WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, X), max(maximum, (FLT4)(0.f))); + WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, N * output_shape.y + X), max(maximum, (FLT4)(0.f))); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/resize.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/resize.cl index 2ff5fa0e90..b1ad1a08a5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/resize.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/resize.cl @@ -4,16 +4,18 @@ __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void resize_nearest_neighbor_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_size, int4 out_size, float2 scale_factor) { - int X = get_global_id(2); // H + int X = get_global_id(2); // H * N int Y = get_global_id(1); // W int Z = get_global_id(0); // C4 - if (X >= out_size.y || Y >= out_size.z || Z >= out_size.w) { + if (X >= out_size.x * out_size.y || Y >= out_size.z || Z >= out_size.w) { return; } + int N = X / out_size.y; + X = X % out_size.y; int src_x = (int)(X * scale_factor.x); int src_y = (int)(Y * scale_factor.y); - WRITE_IMAGE(dst_data, (int2)(Y * out_size.w + Z, X), - READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, src_x))); + WRITE_IMAGE(dst_data, (int2)(Y * out_size.w + Z, N * out_size.y + X), + READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, N * in_size.y + src_x))); } __kernel void resize_nearest_neighbor_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, @@ -32,25 +34,27 @@ __kernel void resize_nearest_neighbor_NC4HW4(__read_only image2d_t src_data, __w __kernel void resize_bilinear_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_size, int4 out_size, float2 scale_factor) { - int X = get_global_id(2); // H + int X = get_global_id(2); // H * N int Y = get_global_id(1); // W int Z = get_global_id(0); // C4 - if (X >= out_size.y || Y >= out_size.z || Z >= out_size.w) { + if (X >= out_size.x * out_size.y || Y >= out_size.z || Z >= out_size.w) { return; } + int N = X / out_size.y; + X = X % out_size.y; float scale_x = X * scale_factor.x; float scale_y = Y * scale_factor.y; int src_x = (int)(scale_x); int src_y = (int)(scale_y); int src_x_1 = min(src_x + 1, in_size.y - 1); int src_y_1 = min(src_y + 1, in_size.z - 1); - FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, src_x)); - FLT4 src1 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1 * in_size.w + Z, src_x)); - FLT4 src2 = READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, src_x_1)); - FLT4 src3 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1 * in_size.w + Z, src_x_1)); + FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, N * in_size.y + src_x)); + FLT4 src1 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1 * in_size.w + Z, N * in_size.y + src_x)); + FLT4 src2 = READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, N * in_size.y + src_x_1)); + FLT4 src3 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1 * in_size.w + Z, N * in_size.y + src_x_1)); FLT4 result = mix(mix(src0, src1, TO_FLT(scale_y - src_y)), mix(src2, src3, TO_FLT(scale_y - src_y)), TO_FLT(scale_x - src_x)); - WRITE_IMAGE(dst_data, (int2)(Y * out_size.w + Z, X), result); + WRITE_IMAGE(dst_data, (int2)(Y * out_size.w + Z, N * out_size.y + X), result); } __kernel void resize_bilinear_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_size, diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl index 8d2b572141..79ed7524b7 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl @@ -8,20 +8,21 @@ __kernel void SoftmaxAxis3_NHWC4(__read_only image2d_t input, __write_only image const int4 input_shape) { int X = get_global_id(1); // H int Y = get_global_id(0); // W + int n = get_global_id(2); // N int H = input_shape.y; int W = input_shape.z; int C4 = input_shape.w; - if (X >= H || Y >= W) return; + if (n >= input_shape.x || X >= H || Y >= W) return; // get max - float4 last = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X))); + float4 last = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, n * H + X))); float input_max = last.x; if (mask.y > 0.5f) input_max = max(input_max, last.y); if (mask.z > 0.5f) input_max = max(input_max, last.z); if (mask.w > 0.5f) input_max = max(input_max, last.w); for (int d = 0; d < C4 - 1; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X))); + float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, n * H + X))); input_max = max(input_max, t.x); input_max = max(input_max, t.y); input_max = max(input_max, t.z); @@ -31,41 +32,42 @@ __kernel void SoftmaxAxis3_NHWC4(__read_only image2d_t input, __write_only image float sum = 0.0f; for (int d = 0; d < C4 - 1; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X))); + float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, n * H + X))); sum += dot(exp(t - input_max_f4), (float4)(1.f)); } - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X))); + float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, n * H + X))); sum += dot(exp(min(t - input_max_f4, 0)), mask); for (int d = 0; d < C4 - 1; ++d) { - float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X))); + float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, n * H + X))); result = exp(result - input_max_f4) / sum; - WRITE_IMAGE(output, (int2)(Y * C4 + d, X), TO_FLT4(result)); + WRITE_IMAGEOUT(output, (int2)(Y * C4 + d, n * H + X), OUT_FLT4(result)); } - float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X))); + float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, n * H + X))); result = exp(min(result - input_max_f4, 0)) / sum; result = result * mask; - WRITE_IMAGEOUT(output, (int2)(Y * C4 + C4 - 1, X), OUT_FLT4(result)); + WRITE_IMAGEOUT(output, (int2)(Y * C4 + C4 - 1, n * H + X), OUT_FLT4(result)); } __kernel void SoftmaxAxis1_NHWC4(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, const int4 input_shape) { int X = get_global_id(1); // W int Y = get_global_id(0); // C4 + int n = get_global_id(2); // N int H = input_shape.y; int W = input_shape.z; int C4 = input_shape.w; - if (X >= W || Y >= C4) return; + if (n >= input_shape.x || X >= W || Y >= C4) return; float4 sum = 0.0f; for (int d = 0; d < H; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(X * C4 + Y, d))); + float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(X * C4 + Y, n * H + d))); sum += exp(t); } for (int d = 0; d < H; ++d) { - float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(X * C4 + Y, d))); + float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(X * C4 + Y, n * H + d))); result = exp(result) / sum; - WRITE_IMAGEOUT(output, (int2)(X * C4 + Y, d), OUT_FLT4(result)); + WRITE_IMAGEOUT(output, (int2)(X * C4 + Y, n * H + d), OUT_FLT4(result)); } } @@ -73,35 +75,38 @@ __kernel void SoftmaxAxis2_NHWC4(__read_only image2d_t input, __write_only image const int4 input_shape) { int X = get_global_id(1); // H int Y = get_global_id(0); // C4 + int n = get_global_id(2); // n int H = input_shape.y; int W = input_shape.z; int C4 = input_shape.w; - if (X >= H || Y >= C4) return; + if (n >= input_shape.x || X >= H || Y >= C4) return; float4 sum = 0.0f; for (int d = 0; d < W; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(d * C4 + Y, X))); + float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(d * C4 + Y, n * H + X))); sum += exp(t); } for (int d = 0; d < W; ++d) { - float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(d * C4 + Y, X))); + float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(d * C4 + Y, n * H + X))); result = exp(result) / sum; - WRITE_IMAGEOUT(output, (int2)(d * C4 + Y, X), OUT_FLT4(result)); + WRITE_IMAGEOUT(output, (int2)(d * C4 + Y, n * H + X), OUT_FLT4(result)); } } __kernel void Softmax1x1_NHWC4(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, const int4 input_shape) { int tid = get_local_id(0); + int n = get_global_id(1); + if (n >= input_shape.x) return; int C4 = input_shape.w; float sum = 0.0f; for (size_t i = tid; i < C4 - 1; i += 32) { - float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, 0))); + float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, n))); sum += dot((float4)(1.0f), exp(src)); } if ((C4 - 1) % 32 == tid) { - float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(C4 - 1, 0))); + float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(C4 - 1, n))); sum += dot(convert_float4(mask), exp(src)); } @@ -123,8 +128,8 @@ __kernel void Softmax1x1_NHWC4(__read_only image2d_t input, __write_only image2d barrier(CLK_LOCAL_MEM_FENCE); sum = tmpx1[0]; for (size_t i = tid; i < C4; i += 32) { - float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, 0))); + float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, n))); result = exp(result) * sum; - WRITE_IMAGEOUT(output, (int2)(i, 0), OUT_FLT4(result)); + WRITE_IMAGEOUT(output, (int2)(i, n), OUT_FLT4(result)); } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/space_to_batch_nd.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/space_to_batch_nd.cl index f005793344..d07c26f700 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/space_to_batch_nd.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/space_to_batch_nd.cl @@ -4,8 +4,11 @@ __kernel void space_to_batch_nd_NHWC4(__read_only image2d_t src_data, __write_on int4 dst_size, int2 block_size, int4 paddings) { int X = get_global_id(0); // c int Y = get_global_id(1); // w - int Z = get_global_id(2); // h - if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) { + int Z = get_global_id(2); // h * n_i + // (N,H*BH,W*BW,C) to (BH*BW*N,H,W,C) + int N_I = Z / dst_size.z; + Z = Z % dst_size.z; + if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z || N_I >= src_size.w) { return; } for (int i = 0; i < block_size.x; ++i) { @@ -13,8 +16,10 @@ __kernel void space_to_batch_nd_NHWC4(__read_only image2d_t src_data, __write_on int w_org = Y * block_size.y + j - paddings.z; int h_org = Z * block_size.x + i - paddings.x; FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f); - res_data = READ_IMAGE(src_data, smp_zero, (int2)(w_org * dst_size.x + X, h_org)); - WRITE_IMAGE(dst_data, (int2)(Y * dst_size.x + X, (i * block_size.y + j) * dst_size.z + Z), res_data); + if (h_org >= 0 && h_org < src_size.z) + res_data = READ_IMAGE(src_data, smp_zero, (int2)(w_org * dst_size.x + X, N_I * src_size.z + h_org)); + WRITE_IMAGE(dst_data, (int2)(Y * dst_size.x + X, ((i * block_size.y + j) * src_size.w + N_I) * dst_size.z + Z), + res_data); } } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc index 416f5405e4..e38b4e0b2f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc @@ -42,6 +42,10 @@ int PoolingOpenCLKernel::CheckSpecs() { MS_LOG(ERROR) << "in size: " << in_tensors_.size() << ", out size: " << out_tensors_.size(); return RET_ERROR; } + if (in_tensors_[0]->shape().size() != 4) { + MS_LOG(ERROR) << "Only support 4d tensor."; + return RET_ERROR; + } if (parameter_->pool_mode_ != PoolMode_MaxPool && parameter_->pool_mode_ != PoolMode_AvgPool) { MS_LOG(ERROR) << "Init `Pooling2d` kernel failed, unsupported pool mode!"; return RET_ERROR; @@ -88,7 +92,7 @@ int PoolingOpenCLKernel::Prepare() { } void PoolingOpenCLKernel::SetGlobalLocal() { - const size_t global_x = out_tensors_[0]->shape()[1]; + const size_t global_x = out_tensors_[0]->shape()[1] * out_tensors_[0]->shape()[0]; const size_t global_y = out_tensors_[0]->shape()[2]; const size_t global_z = UP_DIV(out_tensors_[0]->shape()[3], C4NUM); global_size_ = {global_z, global_y, global_x}; @@ -98,8 +102,8 @@ void PoolingOpenCLKernel::SetGlobalLocal() { void PoolingOpenCLKernel::SetConstArgs() { int slices = UP_DIV(out_tensors_[0]->shape()[3], C4NUM); - cl_int4 input_shape = {in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2], in_tensors_[0]->shape()[3], slices}; - cl_int4 output_shape = {out_tensors_[0]->shape()[1], out_tensors_[0]->shape()[2], out_tensors_[0]->shape()[3], + cl_int4 input_shape = {in_tensors_[0]->shape()[0], in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2], slices}; + cl_int4 output_shape = {out_tensors_[0]->shape()[0], out_tensors_[0]->shape()[1], out_tensors_[0]->shape()[2], slices}; cl_int2 stride = {parameter_->stride_h_, parameter_->stride_w_}; cl_int2 kernel_size = {parameter_->window_h_, parameter_->window_w_}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.cc index b83e6d87fa..a1ec824b35 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.cc @@ -106,7 +106,7 @@ void ResizeOpenCLKernel::SetConstArgs() { void ResizeOpenCLKernel::SetGlobalLocal() { local_size_ = {}; auto out_shape = GpuTensorInfo(out_tensors_[0]); - global_size_ = {out_shape.Slice, out_shape.W, out_shape.H}; + global_size_ = {out_shape.Slice, out_shape.W, out_shape.H * out_shape.N}; AlignGlobalLocal(global_size_, local_size_); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc index 62e50637ba..0e14de1614 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc @@ -54,10 +54,6 @@ int SoftmaxOpenCLKernel::CheckSpecs() { MS_LOG(ERROR) << "Init Softmax kernel failed: Unsupported shape size: " << in_shape.size(); return RET_ERROR; } - if (in_shape[0] > 1) { - MS_LOG(ERROR) << "Init Softmax kernel failed: Unsupported multi-batch."; - return RET_ERROR; - } if (axis_ < 0) { axis_ = in_shape.size() + axis_; } @@ -104,8 +100,8 @@ int SoftmaxOpenCLKernel::Prepare() { void SoftmaxOpenCLKernel::SetGlobalLocal() { if (onexone_flag_) { - local_size_ = {32}; - global_size_ = {32}; + local_size_ = {32, 1}; + global_size_ = {32, out_shape_.N}; } else { size_t global_x, global_y; if (axis_ == 1) { @@ -121,7 +117,7 @@ void SoftmaxOpenCLKernel::SetGlobalLocal() { global_x = 1; global_y = 1; } - global_size_ = {global_x, global_y}; + global_size_ = {global_x, global_y, out_shape_.N}; local_size_ = {}; } AlignGlobalLocal(global_size_, local_size_); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_batch_nd.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_batch_nd.cc index 8cef795797..a75c64c51e 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_batch_nd.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_batch_nd.cc @@ -85,7 +85,8 @@ void SpaceToBatchNDOpenCLKernel::SetGlobalLocal() { size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM); cl_int4 dst_size = {(cl_int)CO4, out_tensors_[0]->Width(), out_tensors_[0]->Height(), out_tensors_[0]->Batch()}; local_size_ = {1, 1, 1}; - global_size_ = {(size_t)dst_size.s[0], (size_t)dst_size.s[1], (size_t)dst_size.s[2]}; + global_size_ = {(size_t)dst_size.s[0], (size_t)dst_size.s[1], + (size_t)dst_size.s[2] * (size_t)(in_tensors_[0]->Batch())}; OpenCLKernel::AlignGlobalLocal(global_size_, local_size_); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc index 0b32b62614..bac832129d 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc @@ -239,10 +239,11 @@ int OpenCLSubGraph::Init() { MS_ASSERT(tensor); tensor->set_allocator(allocator_); } - std::map> pass_manager{ + std::vector>> pass_manager{ + {"FusionPass", std::bind(&OpenCLSubGraph::FusionPass, this)}, {"InsertOpsPass", std::bind(&OpenCLSubGraph::InsertOpsPass, this)}, {"UpdateTensorDataTypePass", std::bind(&OpenCLSubGraph::UpdateTensorDataTypePass, this)}, - {"FusionPass", std::bind(&OpenCLSubGraph::FusionPass, this)}}; + }; for (auto iv : pass_manager) { auto ret = iv.second(); if (ret != RET_OK) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc index 2141e086f4..80fac1b918 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/pooling_tests.cc @@ -23,7 +23,7 @@ class TestOpenCL_Pooling : public CommonTest {}; namespace { // PrimitiveType_Pooling: src/ops/populate/pooling_populate.cc OpParameter *CreateParameter(PoolMode pool_mode, int window_h, int window_w, int stride_h, int stride_w, int pad_u, - int pad_d, int pad_l, int pad_r, RoundMode round_mode = RoundMode_No, + int pad_d, int pad_l, int pad_r, RoundMode round_mode = RoundMode_Floor, ActType act_type = ActType_No) { auto *param = test::CreateParameter(schema::PrimitiveType_MaxPoolFusion); param->global_ = false; @@ -65,4 +65,27 @@ TEST_F(TestOpenCL_Pooling, Max) { } } +TEST_F(TestOpenCL_Pooling, AvgMultiBatch) { + std::vector input_shape = {2, 2, 2, 4}; + std::vector output_shape = {2, 1, 1, 4}; + float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + float output_data[] = {6, 7, 8, 9, 6, 7, 8, 9}; + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(PoolMode_AvgPool, 2, 2, 2, 2, 0, 0, 0, 0); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); + } +} + +TEST_F(TestOpenCL_Pooling, MaxMultiBatch) { + std::vector input_shape = {2, 2, 2, 4}; + std::vector output_shape = {2, 1, 1, 4}; + float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + float output_data[] = {12, 13, 14, 15, 12, 13, 14, 15}; + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(PoolMode_MaxPool, 2, 2, 2, 2, 0, 0, 0, 0); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); + } +} } // namespace mindspore::lite::opencl::test diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc index dc9ec81e70..e849da46ee 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/resize_tests.cc @@ -83,4 +83,53 @@ TEST_F(TestOpenCL_Resize, NEAREST) { } } +TEST_F(TestOpenCL_Resize, BilinearBatch) { + schema::ResizeMethod method = schema::ResizeMethod_LINEAR; + int oh = 4; + int ow = 4; + bool align_corners = false; + + std::vector input_shape = {2, 2, 2, 1}; + std::vector output_shape = {2, oh, ow, 1}; + float input_data[] = {0, 1, 2, 3, 0, 1, 2, 3}; + float output_data[] = {0, 0.5, 1, 1, 1, 1.5, 2, 2, 2, 2.5, 3, 3, 2, 2.5, 3, 3, + 0, 0.5, 1, 1, 1, 1.5, 2, 2, 2, 2.5, 3, 3, 2, 2.5, 3, 3}; + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(method, oh, ow, align_corners); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); + } +} + +TEST_F(TestOpenCL_Resize, Bilinear_AlignCornersBatch) { + schema::ResizeMethod method = schema::ResizeMethod_LINEAR; + int oh = 3; + int ow = 3; + bool align_corners = true; + + std::vector input_shape = {2, 2, 2, 1}; + std::vector output_shape = {2, oh, ow, 1}; + float input_data[] = {0, 1, 2, 3, 0, 1, 2, 3}; + float output_data[] = {0, 0.5, 1, 1, 1.5, 2, 2, 2.5, 3, 0, 0.5, 1, 1, 1.5, 2, 2, 2.5, 3}; + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(method, oh, ow, align_corners); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); + } +} + +TEST_F(TestOpenCL_Resize, NEARESTBatch) { + schema::ResizeMethod method = schema::ResizeMethod_NEAREST; + int oh = 4; + int ow = 4; + bool align_corners = false; + + std::vector input_shape = {2, 2, 2, 1}; + std::vector output_shape = {2, oh, ow, 1}; + float input_data[] = {0, 1, 2, 3, 0, 1, 2, 3}; + float output_data[] = {0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 3, 3, 2, 2, 3, 3, + 0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 3, 3, 2, 2, 3, 3}; + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(method, oh, ow, align_corners); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); + } +} } // namespace mindspore::lite::opencl::test diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc index b1ff2d5039..420e341c13 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc @@ -71,4 +71,47 @@ TEST_F(TestOpenCL_SoftMax, 4D_axis1) { } } +TEST_F(TestOpenCL_SoftMax, 2D_axis1_N) { + int axis = 1; + std::vector input_shape = {2, 10}; + std::vector output_shape = input_shape; + float input_data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + float output_data[] = {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, + 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1}; + + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(axis); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, + fp16_enable ? 2e-2 : 1e-5); + } +} + +TEST_F(TestOpenCL_SoftMax, 4D_axis3_N) { + int axis = 3; + std::vector input_shape = {2, 2, 1, 5}; + std::vector output_shape = input_shape; + float input_data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + float output_data[] = {0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, + 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2}; + + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(axis); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, + fp16_enable ? 2e-2 : 1e-5); + } +} + +TEST_F(TestOpenCL_SoftMax, 4D_axis1_N) { + int axis = 1; + std::vector input_shape = {2, 2, 1, 1}; + std::vector output_shape = input_shape; + float input_data[] = {1, 1, 1, 1}; + float output_data[] = {0.5, 0.5, 0.5, 0.5}; + + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(axis); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, + fp16_enable ? 2e-2 : 1e-5); + } +} } // namespace mindspore::lite::opencl::test diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc index c8c127837d..91da782331 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/space_to_batch_nd_tests.cc @@ -32,6 +32,7 @@ OpParameter *CreateParameter(const std::vector &block_sizes, const std::vec for (int i = 0; i < paddings.size(); ++i) { param->paddings_[i] = paddings[i]; } + param->m_ = 2; return reinterpret_cast(param); } @@ -83,4 +84,27 @@ TEST_F(TestOpenCL_SpaceToBatch, H2W2Pad2222) { TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); } } + +TEST_F(TestOpenCL_SpaceToBatch, H2W2Pad2222MultiBatch) { + std::vector input_shape{2, 6, 6, 1}; + std::vector block_sizes = {2, 2}; + std::vector paddings = {2, 2, 2, 2}; + auto output_shape = InferShape(input_shape, block_sizes, paddings); + float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71}; + float output_data[] = {0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 12, 14, 16, 0, 0, 24, 26, 28, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 36, 38, 40, 0, 0, 48, 50, 52, 0, 0, 60, 62, 64, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 3, 5, 0, 0, 13, 15, 17, 0, 0, 25, 27, 29, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 37, 39, 41, 0, 0, 49, 51, 53, 0, 0, 61, 63, 65, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 6, 8, 10, 0, 0, 18, 20, 22, 0, 0, 30, 32, 34, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 42, 44, 46, 0, 0, 54, 56, 58, 0, 0, 66, 68, 70, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 7, 9, 11, 0, 0, 19, 21, 23, 0, 0, 31, 33, 35, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 43, 45, 47, 0, 0, 55, 57, 59, 0, 0, 67, 69, 71, 0, 0, 0, 0, 0, 0}; + + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(block_sizes, paddings); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); + } +} } // namespace mindspore::lite::opencl::test