diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose.cl index 722c7d564f..fd5423c298 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/conv2d_transpose.cl @@ -3,7 +3,7 @@ __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP __kernel void conv2d_transpose2x2_NHWC4(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, int4 src_size, int4 dst_size) { - int h = get_global_id(2); + int h = get_global_id(0); int kh = h % 2; int src_h = h / 2; src_h = src_h * 2; @@ -11,7 +11,7 @@ __kernel void conv2d_transpose2x2_NHWC4(__read_only image2d_t src_data, __global int kw = w % 2; int src_w = w / 2; src_w = src_w * 2; - int co = get_global_id(0); + int co = get_global_id(2); if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return; FLT4 r0 = (FLT4)(0.f); FLT4 r1 = (FLT4)(0.f); @@ -59,7 +59,7 @@ __kernel void conv2d_transpose2x2_NHWC4(__read_only image2d_t src_data, __global __kernel void conv2d_transpose2x2_NC4HW4(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, int4 src_size, int4 dst_size) { - int h = get_global_id(2); + int h = get_global_id(0); int kh = h % 2; int src_h = h / 2; src_h = src_h * 2; @@ -67,7 +67,7 @@ __kernel void conv2d_transpose2x2_NC4HW4(__read_only image2d_t src_data, __globa int kw = w % 2; int src_w = w / 2; src_w = src_w * 2; - int co = get_global_id(0); + int co = get_global_id(2); if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return; FLT4 r0 = (FLT4)(0.f); FLT4 r1 = (FLT4)(0.f); @@ -115,7 +115,7 @@ __kernel void conv2d_transpose2x2_NC4HW4(__read_only image2d_t src_data, __globa __kernel void conv2d_transpose_NHWC4(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, int4 src_size, int4 dst_size) { - int dst_h = get_global_id(2); + int dst_h = get_global_id(0); int rem_h = dst_h % stride.x; int ceil_h = dst_h / stride.x; dst_h = ceil_h * stride.x * 2 + rem_h; @@ -123,7 +123,7 @@ __kernel void conv2d_transpose_NHWC4(__read_only image2d_t src_data, __global FL int rem_w = dst_w % stride.y; int ceil_w = dst_w / stride.y; dst_w = ceil_w * stride.y * 2 + rem_w; - int dst_c = get_global_id(0); + int dst_c = get_global_id(2); if (dst_h >= dst_size.x || dst_w >= dst_size.y || dst_c >= dst_size.z) return; int weight_base = dst_c * src_size.z * kernel_size.x * kernel_size.y; FLT4 r0 = (FLT4)(0.f); diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl index 1141cb7171..6ae366a8bd 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl @@ -4,188 +4,89 @@ #define divide_no_check(a, b) (a / b) __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; -__kernel void SoftMax_NHWC4_BUF(__read_only image2d_t input, __global FLT4 *output, const int4 input_shape) { - int X = get_global_id(0); // H - int Y = get_global_id(1); // W - int H = input_shape.x; - int W = input_shape.y; - int C = input_shape.z; - int S = input_shape.w; +__kernel void SoftMaxAxis3_NHWC4(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, + const int4 input_shape) { + int X = get_global_id(1); // H + int Y = get_global_id(0); // W + int H = input_shape.y; + int W = input_shape.z; + int C4 = input_shape.w; if (X >= H || Y >= W) return; float sum = 0.0f; - for (int d = 0; d < S; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * S + d, X))); - sum += exp(t.x); - if (d * 4 + 1 < C) sum += exp(t.y); - if (d * 4 + 2 < C) sum += exp(t.z); - if (d * 4 + 3 < C) sum += exp(t.w); - } - - for (int d = 0; d < S; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * S + d, X))); - t = divide_no_check(exp(t), sum); - __global FLT *output_flt = (__global FLT *)output; - output_flt += (X * W + Y) * C + 4 * d; - FLT4 result = TO_FLT4(t); - output_flt[0] = result.x; - if (d * 4 + 1 < C) output_flt[1] += result.y; - if (d * 4 + 2 < C) output_flt[2] += result.z; - if (d * 4 + 3 < C) output_flt[3] += result.w; - } + for (int d = 0; d < C4 - 1; ++d) { + float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X))); + sum += dot(exp(t), (float4)(1.f)); + } + float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X))); + sum += dot(exp(t), mask); + for (int d = 0; d < C4 - 1; ++d) { + float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X))); + result = exp(result) / sum; + WRITE_IMAGE(output, (int2)(Y * C4 + d, X), TO_FLT4(result)); + } + float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X))); + result = exp(result) / sum; + result = result * mask; + WRITE_IMAGE(output, (int2)(Y * C4 + C4 - 1, X), TO_FLT4(result)); } -__kernel void SoftMax_NHWC4_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { - int X = get_global_id(0); // H - int Y = get_global_id(1); // W - int H = input_shape.x; - int W = input_shape.y; - int C = input_shape.z; - int S = input_shape.w; +__kernel void SoftMaxAxis1_NHWC4(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, + const int4 input_shape) { + int X = get_global_id(1); // W + int Y = get_global_id(0); // C4 + int H = input_shape.y; + int W = input_shape.z; + int C4 = input_shape.w; - if (X >= H || Y >= W) return; + if (X >= W || Y >= C4) return; - float sum = 0.0f; - for (int d = 0; d < S; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * S + d, X))); - sum += exp(t.x); - if (d * 4 + 1 < C) sum += exp(t.y); - if (d * 4 + 2 < C) sum += exp(t.z); - if (d * 4 + 3 < C) sum += exp(t.w); + float4 sum = 0.0f; + for (int d = 0; d < H; ++d) { + float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(X * C4 + Y, d))); + sum += exp(t); } - - for (int d = 0; d < S; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * S + d, X))); - t = exp(t) / sum; - WRITE_IMAGE(output, (int2)(Y * S + d, X), TO_FLT4(t)); + for (int d = 0; d < H; ++d) { + float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(X * C4 + Y, d))); + result = exp(result) / sum; + WRITE_IMAGE(output, (int2)(X * C4 + Y, d), TO_FLT4(result)); } } -__kernel void SoftMax_NC4HW4_BUF(__read_only image2d_t input, __global FLT4 *output, const int4 input_shape) { - int X = get_global_id(0); // H - int Y = get_global_id(1); // W - int H = input_shape.x; - int W = input_shape.y; - int C = input_shape.z; - int S = input_shape.w; - - if (X >= H || Y >= W) return; - - float sum = 0.0f; - for (int d = 0; d < S; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y, d * H + X))); - sum += exp(t.x); - if (d * 4 + 1 < C) sum += exp(t.y); - if (d * 4 + 2 < C) sum += exp(t.z); - if (d * 4 + 3 < C) sum += exp(t.w); - } - - for (int d = 0; d < S; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y, d * H + X))); - t = divide_no_check(exp(t), sum); - __global FLT *output_flt = (__global FLT *)output; - output_flt += (X * W + Y) * C + 4 * d; - FLT4 result = TO_FLT4(t); - output_flt[0] = result.x; - if (d * 4 + 1 < C) output_flt[1] += result.y; - if (d * 4 + 2 < C) output_flt[2] += result.z; - if (d * 4 + 3 < C) output_flt[3] += result.w; - } -} +__kernel void SoftMaxAxis2_NHWC4(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, + const int4 input_shape) { + int X = get_global_id(1); // H + int Y = get_global_id(0); // C4 + int H = input_shape.y; + int W = input_shape.z; + int C4 = input_shape.w; -__kernel void SoftMax_NC4HW4_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { - int X = get_global_id(0); // H - int Y = get_global_id(1); // W - int H = input_shape.x; - int W = input_shape.y; - int C = input_shape.z; - int S = input_shape.w; - - if (X >= H || Y >= W) return; - - float sum = 0.0f; - for (int d = 0; d < S; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y, d * H + X))); - sum += exp(t.x); - if (d * 4 + 1 < C) sum += exp(t.y); - if (d * 4 + 2 < C) sum += exp(t.z); - if (d * 4 + 3 < C) sum += exp(t.w); - } - - for (int d = 0; d < S; ++d) { - float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y, d * H + X))); - t = exp(t) / sum; - WRITE_IMAGE(output, (int2)(Y, d * H + X), TO_FLT4(t)); - } -} - -__kernel void SoftMax1x1_NHWC4_BUF(__read_only image2d_t input, __global FLT4 *output, const float4 mask, - const int slices, const int slices_x32) { - int tid = get_local_id(0); - float sum = 0.0f; - for (size_t i = tid; i < slices - 1; i += 32) { - float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, 0))); - sum += dot((float4)(1.0f), exp(src)); - } - if ((slices - 1) % 32 == tid) { - float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(slices - 1, 0))); - sum += dot(convert_float4(mask), exp(src)); - } + if (X >= H || Y >= C4) return; - __local float4 tmp[8]; - __local float *tmpx1 = (__local float *)tmp; - tmpx1[tid] = sum; - barrier(CLK_LOCAL_MEM_FENCE); - if (tid == 0) { - sum = dot((float4)(1.0f), tmp[0]); - sum += dot((float4)(1.0f), tmp[1]); - sum += dot((float4)(1.0f), tmp[2]); - sum += dot((float4)(1.0f), tmp[3]); - sum += dot((float4)(1.0f), tmp[4]); - sum += dot((float4)(1.0f), tmp[5]); - sum += dot((float4)(1.0f), tmp[6]); - sum += dot((float4)(1.0f), tmp[7]); - tmpx1[0] = divide_no_check(1.0f, sum); + float4 sum = 0.0f; + for (int d = 0; d < W; ++d) { + float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(d * C4 + Y, X))); + sum += exp(t); } - barrier(CLK_LOCAL_MEM_FENCE); - sum = tmpx1[0]; - for (size_t i = tid; i < slices - 1; i += 32) { - float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, 0))); - result = exp(result) * sum; - output[i] = TO_FLT4(result); - } - if ((slices - 1) % 32 == tid) { - float4 result_float = convert_float4(READ_IMAGE(input, smp_zero, (int2)(slices - 1, 0))); - result_float = exp(result_float) * sum; - FLT4 result = TO_FLT4(result_float); - __global FLT4 *remain_ptr4 = output; - remain_ptr4 += slices - 1; - __global FLT *remain_ptr = (__global FLT *)remain_ptr4; - remain_ptr[0] = result.x; - if (mask.y > 0.f) { - remain_ptr[1] = result.y; - } - if (mask.z > 0.f) { - remain_ptr[2] = result.z; - } - if (mask.w > 0.f) { - remain_ptr[3] = result.w; - } + for (int d = 0; d < W; ++d) { + float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(d * C4 + Y, X))); + result = exp(result) / sum; + WRITE_IMAGE(output, (int2)(d * C4 + Y, X), TO_FLT4(result)); } } -__kernel void SoftMax1x1_NHWC4_IMG(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, - const int slices, const int slices_x32) { +__kernel void SoftMax1x1_NHWC4(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, + const int4 input_shape) { int tid = get_local_id(0); + int C4 = input_shape.w; float sum = 0.0f; - for (size_t i = tid; i < slices - 1; i += 32) { + for (size_t i = tid; i < C4 - 1; i += 32) { float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, 0))); sum += dot((float4)(1.0f), exp(src)); } - if ((slices - 1) % 32 == tid) { - float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(slices - 1, 0))); - + if ((C4 - 1) % 32 == tid) { + float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(C4 - 1, 0))); sum += dot(convert_float4(mask), exp(src)); } @@ -206,102 +107,9 @@ __kernel void SoftMax1x1_NHWC4_IMG(__read_only image2d_t input, __write_only ima } barrier(CLK_LOCAL_MEM_FENCE); sum = tmpx1[0]; - for (size_t i = tid; i < slices; i += 32) { + for (size_t i = tid; i < C4; i += 32) { float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, 0))); result = exp(result) * sum; WRITE_IMAGE(output, (int2)(i, 0), TO_FLT4(result)); } } - -__kernel void SoftMax1x1_NC4HW4_BUF(__read_only image2d_t input, __global FLT4 *output, const float4 mask, - const int slices, const int slices_x32) { - int tid = get_local_id(0); - float sum = 0.0f; - for (size_t i = tid; i < slices - 1; i += 32) { - float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(0, i))); - sum += dot((float4)(1.0f), exp(src)); - } - if ((slices - 1) % 32 == tid) { - float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(0, slices - 1))); - - sum += dot(convert_float4(mask), exp(src)); - } - - __local float4 tmp[8]; - __local float *tmpx1 = (__local float *)tmp; - tmpx1[tid] = sum; - barrier(CLK_LOCAL_MEM_FENCE); - if (tid == 0) { - sum = dot((float4)(1.0f), tmp[0]); - sum += dot((float4)(1.0f), tmp[1]); - sum += dot((float4)(1.0f), tmp[2]); - sum += dot((float4)(1.0f), tmp[3]); - sum += dot((float4)(1.0f), tmp[4]); - sum += dot((float4)(1.0f), tmp[5]); - sum += dot((float4)(1.0f), tmp[6]); - sum += dot((float4)(1.0f), tmp[7]); - tmpx1[0] = divide_no_check(1.0f, sum); - } - barrier(CLK_LOCAL_MEM_FENCE); - sum = tmpx1[0]; - for (size_t i = tid; i < slices - 1; i += 32) { - float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(0, i))); - result = exp(result) * sum; - output[i] = TO_FLT4(result); - } - if ((slices - 1) % 32 == tid) { - float4 result_float = convert_float4(READ_IMAGE(input, smp_zero, (int2)(0, slices - 1))); - result_float = exp(result_float) * sum; - FLT4 result = TO_FLT4(result_float); - __global FLT4 *remain_ptr4 = output; - remain_ptr4 += slices - 1; - __global FLT *remain_ptr = (__global FLT *)remain_ptr4; - remain_ptr[0] = result.x; - if (mask.y > 0.f) { - remain_ptr[1] = result.y; - } - if (mask.z > 0.f) { - remain_ptr[2] = result.z; - } - if (mask.w > 0.f) { - remain_ptr[3] = result.w; - } - } -} - -__kernel void SoftMax1x1_NC4HW4_IMG(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, - const int slices, const int slices_x32) { - int tid = get_local_id(0); - float sum = 0.0f; - for (size_t i = tid; i < slices - 1; i += 32) { - float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(0, i))); - sum += dot((float4)(1.0f), exp(src)); - } - if ((slices - 1) % 32 == tid) { - float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(0, slices - 1))); - sum += dot(convert_float4(mask), exp(src)); - } - - __local float4 tmp[8]; - __local float *tmpx1 = (__local float *)tmp; - tmpx1[tid] = sum; - barrier(CLK_LOCAL_MEM_FENCE); - if (tid == 0) { - sum = dot((float4)(1.0f), tmp[0]); - sum += dot((float4)(1.0f), tmp[1]); - sum += dot((float4)(1.0f), tmp[2]); - sum += dot((float4)(1.0f), tmp[3]); - sum += dot((float4)(1.0f), tmp[4]); - sum += dot((float4)(1.0f), tmp[5]); - sum += dot((float4)(1.0f), tmp[6]); - sum += dot((float4)(1.0f), tmp[7]); - tmpx1[0] = divide_no_check(1.0f, sum); - } - barrier(CLK_LOCAL_MEM_FENCE); - sum = tmpx1[0]; - for (size_t i = tid; i < slices; i += 32) { - float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(0, i))); - result = exp(result) * sum; - WRITE_IMAGE(output, (int2)(0, i), TO_FLT4(result)); - } -} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index 14471f599c..31d273b186 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -189,8 +189,7 @@ int Conv2dTransposeOpenCLKernel::Run() { int w = in_tensors_[0]->shape()[2]; // local size should less than MAX_GROUP_SIZE std::vector local = {16, 1, 16}; - std::vector global = {UP_ROUND(co4, local[0]), UP_ROUND((size_t)UP_ROUND(ow / 2, stride_w), local[1]), - UP_ROUND((size_t)UP_ROUND(oh / 2, stride_h), local[2])}; + std::vector global = {(size_t)UP_ROUND(oh / 2, stride_h), (size_t)UP_ROUND(ow / 2, stride_w), (size_t)co4}; cl_int2 kernel_size = {kh, kw}; cl_int2 stride = {stride_h, stride_w}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc index 86f6424f81..acb3cd6d97 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc @@ -20,6 +20,7 @@ #include "include/errorcode.h" #include "src/kernel_registry.h" #include "src/runtime/kernel/opencl/utils.h" +#include "nnacl/softmax_parameter.h" #ifndef PROGRAM_WITH_IL #include "src/runtime/kernel/opencl/cl/softmax.cl.inc" #endif @@ -40,9 +41,21 @@ std::vector SoftmaxOpenCLKernel::GetMaskForLastChannel(int channels) { } int SoftmaxOpenCLKernel::InitGlobalSize() { - const size_t global_x = out_tensors_[0]->shape()[1]; - const size_t global_y = out_tensors_[0]->shape()[2]; - const size_t global_z = 1; + size_t global_x, global_y, global_z; + global_z = 1; + if (axis_ == 1) { + global_x = UP_DIV(nhwc_shape_[3], C4NUM); + global_y = nhwc_shape_[2]; + } else if (axis_ == 2) { + global_x = UP_DIV(nhwc_shape_[3], C4NUM); + global_y = nhwc_shape_[1]; + } else if (axis_ == 3) { + global_x = nhwc_shape_[2]; + global_y = nhwc_shape_[1]; + } else { + global_x = 1; + global_y = 1; + } global_size_ = {global_x, global_y, global_z}; return lite::RET_OK; } @@ -65,16 +78,7 @@ int SoftmaxOpenCLKernel::SetWorkGroupSize1x1() { int SoftmaxOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { size_t im_dst_x, im_dst_y; auto out_shape = out_tensors_[0]->shape(); - int n = 1, h = 1, w = 1, c = 1; - if (out_shape.size() == 2) { - n = out_shape[0]; - c = out_shape[1]; - } else if (out_shape.size() == 4) { - n = out_shape[0]; - h = out_shape[1]; - w = out_shape[2]; - c = out_shape[3]; - } + int n = nhwc_shape_[0], h = nhwc_shape_[1], w = nhwc_shape_[2], c = nhwc_shape_[3]; if (op_format_ == schema::Format_NHWC4) { im_dst_x = w * UP_DIV(c, C4NUM); im_dst_y = n * h; @@ -98,38 +102,39 @@ int SoftmaxOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) int SoftmaxOpenCLKernel::Init() { std::string kernel_name = "SoftMax"; std::string program_name = "SoftMax"; - + auto softmax_param = reinterpret_cast(op_parameter_); + axis_ = softmax_param->axis_; + auto in_shape = in_tensors_[0]->shape(); + if (in_shape.size() > 4) { + MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported shape size: " << in_shape.size(); + return RET_ERROR; + } + if (axis_ < 0) { + axis_ = in_shape.size() + axis_; + } + axis_ += 4 - in_shape.size(); + if (axis_ != 1 && axis_ != 2 && axis_ != 3) { + MS_LOG(ERROR) << "Init `Softmax` kernel failed: softmax axis should be H W or C"; + return RET_ERROR; + } + nhwc_shape_ = GetNHWCShape(in_shape); std::string source = softmax_source; enable_fp16_ = ocl_runtime_->GetFp16Enable(); // framework not set this param yet! just use default. - if (in_tensors_[0]->shape().size() == 4) { + if (nhwc_shape_[1] == 1 && nhwc_shape_[2] == 1 && axis_ == 3) { // support 4d tensor - onexone_flag_ = false; - } else if (in_tensors_[0]->shape().size() == 2) { - // support 2d tensor + onexone_flag_ = true; kernel_name += "1x1"; program_name += "1x1"; - onexone_flag_ = true; } else { - MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported shape size: " << in_tensors_[0]->shape().size(); - return RET_ERROR; + onexone_flag_ = false; + kernel_name += "Axis" + std::to_string(axis_); + program_name += "Axis" + std::to_string(axis_); } kernel_name += "_" + std::string(EnumNameFormat(op_format_)); #ifdef PROGRAM_WITH_IL kernel_ = ocl_runtime->GetKernelFromBinary(kernel_name); #else - if (!is_image_out_) { - out_mem_type_ = OpenCLMemType::BUF; - } else { - out_mem_type_ = OpenCLMemType::IMG; - } - if (out_mem_type_ == OpenCLMemType::BUF) { - kernel_name += "_BUF"; - program_name += "_BUF"; - } else { - kernel_name += "_IMG"; - program_name += "_IMG"; - } std::set build_options; ocl_runtime_->LoadSource(program_name, source); ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); @@ -138,9 +143,6 @@ int SoftmaxOpenCLKernel::Init() { out_ori_format_ = out_tensors_[0]->GetFormat(); in_tensors_[0]->SetFormat(op_format_); out_tensors_[0]->SetFormat(op_format_); - if (!is_image_out_) { - out_tensors_[0]->SetFormat(out_ori_format_); - } MS_LOG(DEBUG) << kernel_name << " Init Done!"; return lite::RET_OK; } @@ -149,34 +151,18 @@ int SoftmaxOpenCLKernel::Run() { MS_LOG(DEBUG) << this->name() << " Running!"; int arg_idx = 0; + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c()); + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c()); + int channel = nhwc_shape_[3]; + int c4 = UP_DIV(channel, C4NUM); + auto mask_ = GetMaskForLastChannel(channel); + cl_float4 mask = {mask_[0], mask_[1], mask_[2], mask_[3]}; + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, mask); + cl_int4 input_shape = {nhwc_shape_[0], nhwc_shape_[1], nhwc_shape_[2], c4}; + ocl_runtime_->SetKernelArg(kernel_, arg_idx, input_shape); if (onexone_flag_) { - int channel_size = in_tensors_[0]->shape()[1]; - int slices = UP_DIV(channel_size, C4NUM); - cl_int slices_x32 = UP_DIV(slices, 32); - auto mask_ = GetMaskForLastChannel(channel_size); - cl_float4 mask = {mask_[0], mask_[1], mask_[2], mask_[3]}; - - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c()); - if (is_image_out_) { - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c()); - } else { - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF); - } - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, mask); - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, slices); - ocl_runtime_->SetKernelArg(kernel_, arg_idx, slices_x32); SetWorkGroupSize1x1(); } else { - int slices = UP_DIV(out_tensors_[0]->shape()[3], C4NUM); - cl_int4 input_shape = {in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2], in_tensors_[0]->shape()[3], slices}; - - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c()); - if (is_image_out_) { - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c()); - } else { - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c(), lite::opencl::MemType::BUF); - } - ocl_runtime_->SetKernelArg(kernel_, arg_idx, input_shape); SetWorkGroupSize(); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h index d949aa8092..136fdd89a9 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h @@ -49,8 +49,9 @@ class SoftmaxOpenCLKernel : public OpenCLKernel { bool onexone_flag_{false}; std::vector local_size_; std::vector global_size_; - bool is_image_out_{true}; bool enable_fp16_{false}; + int axis_{0}; + std::vector nhwc_shape_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc index b82806ad20..006577634b 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc @@ -29,7 +29,8 @@ class TestSoftmaxOpenCL : public mindspore::CommonTest { TestSoftmaxOpenCL() {} }; -void RunTestCaseSoftmax(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16) { +void RunTestCaseSoftmax(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16, + int axis) { auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); @@ -68,7 +69,14 @@ void RunTestCaseSoftmax(const std::vector &shape, void *input_data, void *o } std::vector inputs{tensor_x}; std::vector outputs{tensor_out}; - auto arith_kernel_ptr = std::make_unique(nullptr, inputs, outputs); + auto opParameter = static_cast(malloc(sizeof(SoftmaxParameter))); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "opParameter create error."; + return; + } + opParameter->axis_ = axis; + auto arith_kernel_ptr = + std::make_unique(reinterpret_cast(opParameter), inputs, outputs); auto arith_kernel = arith_kernel_ptr.release(); if (arith_kernel == nullptr) { MS_LOG(ERROR) << "arith_kernel create error."; @@ -112,7 +120,7 @@ TEST_F(TestSoftmaxOpenCL, Softmax2DFp32) { std::vector input_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; std::vector output_data = {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}; - RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), false); + RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), false, 1); } TEST_F(TestSoftmaxOpenCL, Softmax2DFp16) { @@ -122,7 +130,7 @@ TEST_F(TestSoftmaxOpenCL, Softmax2DFp16) { std::vector input_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; std::vector output_data = {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}; - RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), true); + RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), true, 1); } TEST_F(TestSoftmaxOpenCL, Softmax4DFp32) { @@ -134,7 +142,7 @@ TEST_F(TestSoftmaxOpenCL, Softmax4DFp32) { std::vector input_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; std::vector output_data = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; - RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), false); + RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), false, 3); } TEST_F(TestSoftmaxOpenCL, Softmax4DFp16) { @@ -146,6 +154,18 @@ TEST_F(TestSoftmaxOpenCL, Softmax4DFp16) { std::vector input_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; std::vector output_data = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; - RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), true); + RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), true, 3); +} + +TEST_F(TestSoftmaxOpenCL, Softmax4DAxis1Fp32) { + int n = 1; + int h = 2; + int w = 1; + int c = 1; + std::vector shape = {n, h, w, c}; + std::vector input_data = {1.0f, 1.0f}; + std::vector output_data = {0.5f, 0.5f}; + + RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), false, 1); } } // namespace mindspore