From bde334411cc12b1ddb643066c5f3634e542d9757 Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Mon, 31 Aug 2020 17:36:14 +0800 Subject: [PATCH] fix transpose softmax reshape bug --- .../src/runtime/kernel/opencl/cl/reshape.cl | 9 +- .../src/runtime/kernel/opencl/cl/softmax.cl | 140 ++++++++++++++-- .../runtime/kernel/opencl/cl/softmax1x1.cl | 104 ------------ .../src/runtime/kernel/opencl/cl/transpose.cl | 12 +- .../runtime/kernel/opencl/kernel/pooling2d.cc | 3 +- .../runtime/kernel/opencl/kernel/reshape.cc | 32 +++- .../runtime/kernel/opencl/kernel/softmax.cc | 65 ++++--- .../runtime/kernel/opencl/kernel/softmax.h | 1 + .../runtime/kernel/opencl/kernel/to_format.cc | 6 +- .../runtime/kernel/opencl/kernel/transpose.cc | 12 +- .../runtime/kernel/opencl/reshape_tests.cc | 6 +- .../runtime/kernel/opencl/softmax_tests.cc | 158 +++++++++++------- .../runtime/kernel/opencl/transpose_tests.cc | 4 +- 13 files changed, 320 insertions(+), 232 deletions(-) delete mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/softmax1x1.cl diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/reshape.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/reshape.cl index b51c514856..c292a5dee8 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/reshape.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/reshape.cl @@ -1,11 +1,14 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; -__kernel void reshape(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { +__kernel void reshape(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size, int4 size_out) { int X = get_global_id(0); int Y = get_global_id(1); int Z = get_global_id(2); - if (X >= size.x || Y >= size.y || Z >= size.z) { + if (X >= size_out.x || Y >= size_out.y || Z >= size_out.z) { return; } - WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X))); + int out_index = X * size_out.y + Y; + int ih = out_index / size.y; + int iw = out_index % size.y; + WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(iw * size.z + Z, ih))); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl index 5ecff6a4d7..08dcc1f2fc 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/softmax.cl @@ -1,16 +1,21 @@ -__kernel void SoftMax_BUF(__global float4 *input, __global float4 *output, const int4 input_shape) { - int X = get_global_id(0); - int Y = get_global_id(1); +#ifdef cl_khr_fp16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +__kernel void SoftMax_BUF(__read_only image2d_t input, __global FLT4 *output, const int4 input_shape) { + int X = get_global_id(0); // H + int Y = get_global_id(1); // W int H = input_shape.x; int W = input_shape.y; int C = input_shape.z; int S = input_shape.w; - if (X >= W || Y >= H) return; + if (X >= H || Y >= W) return; - float sum = 0.0f; + FLT sum = 0.0f; for (int d = 0; d < S; ++d) { - float4 t = input[(Y * W + X * H) * C + d]; + FLT4 t = READ_IMAGE(input, smp_zero, (int2)(Y * S + d, X)); sum += exp(t.x); if (d * 4 + 1 < C) sum += exp(t.y); if (d * 4 + 2 < C) sum += exp(t.z); @@ -18,15 +23,17 @@ __kernel void SoftMax_BUF(__global float4 *input, __global float4 *output, const } for (int d = 0; d < S; ++d) { - float4 t = input[(Y * W + X * H) * C + d]; + FLT4 t = READ_IMAGE(input, smp_zero, (int2)(Y * S + d, X)); t = exp(t) / sum; - float4 result = convert_float4(t); - output[(Y * W + X * H) * C + d] = result; + __global FLT *output_flt = (__global FLT *)output; + output_flt += (X * W + Y) * C + 4 * d; + output_flt[0] = t.x; + if (d * 4 + 1 < C) output_flt[1] += t.y; + if (d * 4 + 2 < C) output_flt[2] += t.z; + if (d * 4 + 3 < C) output_flt[3] += t.w; } } -__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; - __kernel void SoftMax_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { int X = get_global_id(0); int Y = get_global_id(1); @@ -34,7 +41,7 @@ __kernel void SoftMax_IMG(__read_only image2d_t input, __write_only image2d_t ou float sum = 0.0f; for (int d = 0; d < input_shape.w; ++d) { - float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X)); + FLT4 t = READ_IMAGE(input, smp_none, (int2)(Y * input_shape.w + d, X)); sum += exp(t.x); if (d * 4 + 1 < input_shape.z) sum += exp(t.y); if (d * 4 + 2 < input_shape.z) sum += exp(t.z); @@ -42,9 +49,112 @@ __kernel void SoftMax_IMG(__read_only image2d_t input, __write_only image2d_t ou } for (int d = 0; d < input_shape.w; ++d) { - float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X)); + FLT4 t = READ_IMAGE(input, smp_none, (int2)(Y * input_shape.w + d, X)); t = exp(t) / sum; - float4 result = convert_float4(t); - write_imagef(output, (int2)(Y * input_shape.w + d, X), result); + FLT4 result = TO_FLT4(t); + WRITE_IMAGE(output, (int2)(Y * input_shape.w + d, X), result); + } +} + +__kernel void SoftMax1x1_IMG(__read_only image2d_t input, __write_only image2d_t output, const FLT4 mask, + const int slices, const int slices_x32) { + int tid = get_local_id(0); + int slices_count = 0; + int offset = 0; + FLT sum = 0.0f; + do { + int z = offset + tid; + if (z < slices) { + FLT4 mask_temp = z == slices - 1 ? mask : (FLT4)(1.0f); + FLT4 src = READ_IMAGE(input, smp_none, (int2)(0, 0)); + sum += dot(mask_temp, exp(src)); + offset += 32; + } + slices_count++; + } while (slices_count < slices_x32); + + __local FLT4 tmp[8]; + __local FLT *tmpx1 = (__local FLT *)tmp; + tmpx1[tid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (tid == 0) { + sum = dot((FLT4)(1.0f), tmp[0]); + sum += dot((FLT4)(1.0f), tmp[1]); + sum += dot((FLT4)(1.0f), tmp[2]); + sum += dot((FLT4)(1.0f), tmp[3]); + sum += dot((FLT4)(1.0f), tmp[4]); + sum += dot((FLT4)(1.0f), tmp[5]); + sum += dot((FLT4)(1.0f), tmp[6]); + sum += dot((FLT4)(1.0f), tmp[7]); + tmpx1[0] = 1.0f / sum; + } + barrier(CLK_LOCAL_MEM_FENCE); + sum = tmpx1[0]; + + offset = 0; + slices_count = 0; + do { + int z = offset + tid; + if (z < slices) { + FLT4 res = TO_FLT4(exp(READ_IMAGE(input, smp_none, (int2)(0, 0))) * sum); + WRITE_IMAGE(output, (int2)(0, 0), res); + offset += 32; + } + slices_count++; + } while (slices_count < slices_x32); +} + +__kernel void SoftMax1x1_BUF(__read_only image2d_t input, __global FLT4 *output, const float4 mask, const int slices, + const int slices_x32) { + int tid = get_local_id(0); + FLT sum = 0.0f; + for (size_t i = tid; i < slices - 1; i += 32) { + FLT4 src = READ_IMAGE(input, smp_zero, (int2)(i, 0)); + sum += dot((FLT4)(1.0f), exp(src)); + } + if ((slices - 1) % 32 == tid) { + FLT4 src = READ_IMAGE(input, smp_zero, (int2)(slices - 1, 0)); + + sum += dot(TO_FLT4(mask), exp(src)); + } + + __local FLT4 tmp[8]; + __local FLT *tmpx1 = (__local FLT *)tmp; + tmpx1[tid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (tid == 0) { + sum = dot((FLT4)(1.0f), tmp[0]); + sum += dot((FLT4)(1.0f), tmp[1]); + sum += dot((FLT4)(1.0f), tmp[2]); + sum += dot((FLT4)(1.0f), tmp[3]); + sum += dot((FLT4)(1.0f), tmp[4]); + sum += dot((FLT4)(1.0f), tmp[5]); + sum += dot((FLT4)(1.0f), tmp[6]); + sum += dot((FLT4)(1.0f), tmp[7]); + tmpx1[0] = 1.0f / sum; + } + barrier(CLK_LOCAL_MEM_FENCE); + sum = tmpx1[0]; + for (size_t i = tid; i < slices - 1; i += 32) { + FLT4 result = READ_IMAGE(input, smp_zero, (int2)(i, 0)); + result = exp(result) * sum; + output[i] = result; + } + if ((slices - 1) % 32 == tid) { + FLT4 result = READ_IMAGE(input, smp_zero, (int2)(slices - 1, 0)); + result = exp(result) * sum; + __global FLT4 *remain_ptr4 = output; + remain_ptr4 += slices - 1; + __global FLT *remain_ptr = (__global FLT *)remain_ptr4; + remain_ptr[0] = result.x; + if (mask.y > 0.f) { + remain_ptr[1] = result.y; + } + if (mask.z > 0.f) { + remain_ptr[2] = result.z; + } + if (mask.w > 0.f) { + remain_ptr[3] = result.w; + } } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/softmax1x1.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/softmax1x1.cl deleted file mode 100644 index 68cdad1f70..0000000000 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/softmax1x1.cl +++ /dev/null @@ -1,104 +0,0 @@ -__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; -__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; -// what is mask and args.slices_x32 -__kernel void SoftMax1x1_IMG(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, - const int slices, const int slices_x32) { - int tid = get_local_id(0); - int slices_count = 0; - int offset = 0; - float sum = 0.0f; - do { - int z = offset + tid; - if (z < slices) { - float4 mask_temp = z == slices - 1 ? mask : (float4)(1.0f); - float4 src = read_imagef(input, smp_none, (int2)(0, 0)); - sum += dot(mask_temp, exp(src)); - offset += 32; - } - slices_count++; - } while (slices_count < slices_x32); - - __local float4 tmp[8]; - __local float *tmpx1 = (__local float *)tmp; - tmpx1[tid] = sum; - barrier(CLK_LOCAL_MEM_FENCE); - if (tid == 0) { - sum = dot((float4)(1.0f), tmp[0]); - sum += dot((float4)(1.0f), tmp[1]); - sum += dot((float4)(1.0f), tmp[2]); - sum += dot((float4)(1.0f), tmp[3]); - sum += dot((float4)(1.0f), tmp[4]); - sum += dot((float4)(1.0f), tmp[5]); - sum += dot((float4)(1.0f), tmp[6]); - sum += dot((float4)(1.0f), tmp[7]); - tmpx1[0] = 1.0f / sum; - } - barrier(CLK_LOCAL_MEM_FENCE); - sum = tmpx1[0]; - - offset = 0; - slices_count = 0; - do { - int z = offset + tid; - if (z < slices) { - float4 res = convert_float4(exp(read_imagef(input, smp_none, (int2)(0, 0))) * sum); - write_imagef(output, (int2)(0, 0), res); - offset += 32; - } - slices_count++; - } while (slices_count < slices_x32); -} - -__kernel void SoftMax1x1_BUF(__read_only image2d_t input, __global float4 *output, const float4 mask, const int slices, - const int slices_x32) { - int tid = get_local_id(0); - float sum = 0.0f; - for (size_t i = tid; i < slices - 1; i += 32) { - float4 src = read_imagef(input, smp_zero, (int2)(i, 0)); - sum += dot((float4)(1.0f), exp(src)); - } - if ((slices - 1) % 32 == tid) { - float4 src = read_imagef(input, smp_zero, (int2)(slices - 1, 0)); - sum += dot(mask, exp(src)); - } - - __local float4 tmp[8]; - __local float *tmpx1 = (__local float *)tmp; - tmpx1[tid] = sum; - barrier(CLK_LOCAL_MEM_FENCE); - if (tid == 0) { - sum = dot((float4)(1.0f), tmp[0]); - sum += dot((float4)(1.0f), tmp[1]); - sum += dot((float4)(1.0f), tmp[2]); - sum += dot((float4)(1.0f), tmp[3]); - sum += dot((float4)(1.0f), tmp[4]); - sum += dot((float4)(1.0f), tmp[5]); - sum += dot((float4)(1.0f), tmp[6]); - sum += dot((float4)(1.0f), tmp[7]); - tmpx1[0] = 1.0f / sum; - } - barrier(CLK_LOCAL_MEM_FENCE); - sum = tmpx1[0]; - for (size_t i = tid; i < slices - 1; i += 32) { - float4 result = read_imagef(input, smp_zero, (int2)(i, 0)); - result = exp(result) * sum; - output[i] = result; - } - if ((slices - 1) % 32 == tid) { - float4 result = read_imagef(input, smp_zero, (int2)(slices - 1, 0)); - result = exp(result) * sum; - __global float4 *remain_ptr4 = output; - remain_ptr4 += slices - 1; - __global float *remain_ptr = (__global float *)remain_ptr4; - remain_ptr[0] = result.x; - if (mask.y > 0.f) { - remain_ptr[1] = result.y; - } - if (mask.z > 0.f) { - remain_ptr[2] = result.z; - } - if (mask.w > 0.f) { - remain_ptr[3] = result.w; - } - } -} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/transpose.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/transpose.cl index 0076b5fdb5..ac11eaa1e3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/transpose.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/transpose.cl @@ -1,4 +1,6 @@ +#ifdef cl_khr_fp16 #pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void transpose_IMG(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 HW, int2 C) { int X = get_global_id(0); @@ -41,7 +43,7 @@ __kernel void transpose_IMG(__read_only image2d_t src_data, __write_only image2d WRITE_IMAGE(dst_data, (int2)(X, 4 * Y + 3), result[3]); } -__kernel void transpose_BUF(__read_only image2d_t src_data, global FLT4 *dst_data, int2 HW, int2 C) { +__kernel void transpose_BUF(__read_only image2d_t src_data, global FLT4 *dst_data, int2 HW, int2 C, int W) { int X = get_global_id(0); int Y = get_global_id(1); if (X >= HW.y || Y >= C.y) { @@ -52,10 +54,10 @@ __kernel void transpose_BUF(__read_only image2d_t src_data, global FLT4 *dst_dat result[1] = (FLT4)(0.0f); result[2] = (FLT4)(0.0f); result[3] = (FLT4)(0.0f); - FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X)); - FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 1)); - FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 2)); - FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 3)); + FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)((4 * X) % W * C.y + Y, (4 * X) / W)); + FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)((4 * X + 1) % W * C.y + Y, (4 * X + 1) / W)); + FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)((4 * X + 2) % W * C.y + Y, (4 * X + 2) / W)); + FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)((4 * X + 3) % W * C.y + Y, (4 * X + 3) / W)); result[0].x = x0.x; result[0].y = x1.x; result[0].z = x2.x; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc index baabd30a72..1e1c18e6eb 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc @@ -65,7 +65,8 @@ int PoolingOpenCLKernel::Init() { kernel_ = ocl_runtime->GetKernelFromBinary(kernel_name); #else if (out_mem_type_ == OpenCLMemType::BUF) { - kernel_name += "_BUF"; + MS_LOG(ERROR) << "buffer output not support yet."; + return RET_ERROR; } else { kernel_name += "_IMG"; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc index 3254d0758d..a504d6a84e 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc @@ -68,10 +68,16 @@ int ReshapeOpenCLKernel::ReSize() { return RET_OK; } int ReshapeOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { size_t im_dst_x, im_dst_y; - std::vector shapex = in_tensors_[0]->shape(); - int h = shapex[1]; - int w = shapex[2]; - int c = shapex[3]; + std::vector shapex = out_tensors_[0]->shape(); + int h, w, c; + if (shapex.size() == 2) { + h = w = 1; + c = shapex[1]; + } else { + h = shapex[1]; + w = shapex[2]; + c = shapex[3]; + } im_dst_x = w * UP_DIV(c, C4NUM); im_dst_y = h; size_t img_dtype = CL_FLOAT; @@ -91,13 +97,23 @@ int ReshapeOpenCLKernel::Run() { int w = shapex[2]; int c = shapex[3]; int c4 = UP_DIV(c, C4NUM); + int oh, ow; + if (out_tensors_[0]->shape().size() == 2) { + oh = ow = 1; + } else { + oh = out_tensors_[0]->shape()[1]; + ow = out_tensors_[0]->shape()[2]; + } auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); std::vector local = {}; - std::vector global = {(size_t)h, (size_t)w, (size_t)c4}; + std::vector global = {(size_t)oh, (size_t)ow, (size_t)c4}; cl_int4 size = {h, w, c4, 1}; - ocl_runtime->SetKernelArg(kernel_, 0, in_tensors_[0]->Data()); - ocl_runtime->SetKernelArg(kernel_, 1, out_tensors_[0]->Data()); - ocl_runtime->SetKernelArg(kernel_, 2, size); + cl_int4 size_out = {oh, ow, c4, 1}; + int arg_idx = 0; + ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, size); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, size_out); ocl_runtime->RunKernel(kernel_, global, local, nullptr); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc index 9422c08e60..39a869bf5a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc @@ -23,7 +23,6 @@ #include "src/runtime/kernel/opencl/utils.h" #ifndef PROGRAM_WITH_IL #include "src/runtime/kernel/opencl/cl/softmax.cl.inc" -#include "src/runtime/kernel/opencl/cl/softmax1x1.cl.inc" #endif using mindspore::kernel::KERNEL_ARCH::kGPU; @@ -42,8 +41,8 @@ std::vector SoftmaxOpenCLKernel::GetMaskForLastChannel(int channels) { } int SoftmaxOpenCLKernel::InitGlobalSize() { - const size_t global_x = out_tensors_[0]->Height(); - const size_t global_y = out_tensors_[0]->Width(); + const size_t global_x = out_tensors_[0]->shape()[1]; + const size_t global_y = out_tensors_[0]->shape()[2]; const size_t global_z = 1; global_size_ = {global_x, global_y, global_z}; return lite::RET_OK; @@ -74,11 +73,10 @@ int SoftmaxOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) im_dst_x = out_tensors_[0]->Width() * CO4; im_dst_y = out_tensors_[0]->Height(); } -#ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; -#else size_t img_dtype = CL_FLOAT; -#endif + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; + } img_size->clear(); std::vector vec{im_dst_x, im_dst_y, img_dtype}; *img_size = vec; @@ -90,27 +88,28 @@ int SoftmaxOpenCLKernel::Init() { std::string program_name = "SoftMax"; std::string source = softmax_source; runtime_ = lite::opencl::OpenCLRuntime::GetInstance(); + enable_fp16_ = runtime_->GetFp16Enable(); // framework not set this param yet! just use default. - if (parameter_->axis_ == -1) { - parameter_->axis_ = 1; - } - if (in_tensors_[0]->shape().size() == 4 && parameter_->axis_ == 3) { + if (in_tensors_[0]->shape().size() == 4) { // support 4d tensor onexone_flag_ = false; - } else if (in_tensors_[0]->shape().size() == 2 && parameter_->axis_ == 1) { + } else if (in_tensors_[0]->shape().size() == 2) { // support 2d tensor kernel_name += "1x1"; program_name += "1x1"; - source = softmax1x1_source; onexone_flag_ = true; } else { - MS_LOG(EXCEPTION) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_; + MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported shape size: " << in_tensors_[0]->shape().size(); + return RET_ERROR; } #ifdef PROGRAM_WITH_IL kernel_ = ocl_runtime->GetKernelFromBinary(kernel_name); #else if (!is_image_out_) { out_mem_type_ = OpenCLMemType::BUF; + } else { + MS_LOG(ERROR) << "image2d output not support yet."; + return RET_ERROR; } if (out_mem_type_ == OpenCLMemType::BUF) { kernel_name += "_BUF"; @@ -124,12 +123,23 @@ int SoftmaxOpenCLKernel::Init() { runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); #endif in_ori_format_ = in_tensors_[0]->GetFormat(); - in_tensors_[0]->SetFormat(schema::Format_NHWC4); out_ori_format_ = out_tensors_[0]->GetFormat(); - out_tensors_[0]->SetFormat(schema::Format_NHWC4); - if (!is_image_out_) { - out_ori_format_ = schema::Format_NC; - out_tensors_[0]->SetFormat(schema::Format_NC); + if (in_tensors_[0]->shape().size() == 2) { + in_tensors_[0]->SetFormat(schema::Format_NC4); + } else { + in_tensors_[0]->SetFormat(schema::Format_NHWC4); + } + + if (is_image_out_) { + if (out_tensors_[0]->shape().size() == 2) { + out_ori_format_ = schema::Format_NC; + out_tensors_[0]->SetFormat(schema::Format_NC4); + } else { + out_ori_format_ = schema::Format_NHWC; + out_tensors_[0]->SetFormat(schema::Format_NHWC4); + } + } else { + out_tensors_[0]->SetFormat(out_ori_format_); } MS_LOG(DEBUG) << kernel_name << " Init Done!"; return lite::RET_OK; @@ -147,17 +157,25 @@ int SoftmaxOpenCLKernel::Run() { cl_float4 mask = {mask_[0], mask_[1], mask_[2], mask_[3]}; runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); - runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); + if (is_image_out_) { + runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); + } else { + runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data(), lite::opencl::MemType::BUF); + } runtime_->SetKernelArg(kernel_, arg_idx++, mask); runtime_->SetKernelArg(kernel_, arg_idx++, slices); runtime_->SetKernelArg(kernel_, arg_idx, slices_x32); SetWorkGroupSize1x1(); } else { - int slices = UP_DIV(out_tensors_[0]->Channel(), C4NUM); - cl_int4 input_shape = {in_tensors_[0]->Height(), in_tensors_[0]->Width(), in_tensors_[0]->Channel(), slices}; + int slices = UP_DIV(out_tensors_[0]->shape()[3], C4NUM); + cl_int4 input_shape = {in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2], in_tensors_[0]->shape()[3], slices}; runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); - runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); + if (is_image_out_) { + runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); + } else { + runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data(), lite::opencl::MemType::BUF); + } runtime_->SetKernelArg(kernel_, arg_idx, input_shape); SetWorkGroupSize(); } @@ -193,4 +211,5 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector local_size_; std::vector global_size_; bool is_image_out_{false}; + bool enable_fp16_{false}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc index 93b67fba4e..7932c313c3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc @@ -119,11 +119,9 @@ int ToFormatOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size im_dst_x = w * UP_DIV(c, C4NUM); im_dst_y = h; } else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) { - const int h = 1; - const int w = 1; int c = shapex[1]; - im_dst_x = w * UP_DIV(c, C4NUM); - im_dst_y = h; + im_dst_x = UP_DIV(c, C4NUM); + im_dst_y = 1; } else { MS_LOG(ERROR) << "Unsupported format. " << out_tensors_[0]->GetFormat(); return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc index d5b679630b..0d78330357 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc @@ -69,7 +69,7 @@ int TransposeOpenCLKernel::ReSize() { return RET_OK; } int TransposeOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { size_t im_dst_x, im_dst_y; - im_dst_x = UP_DIV(out_tensors_[0]->Height() * out_tensors_[0]->Width(), C4NUM); + im_dst_x = out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Width(), C4NUM); im_dst_y = out_tensors_[0]->Channel(); size_t img_dtype = CL_FLOAT; if (enable_fp16_) { @@ -96,10 +96,12 @@ int TransposeOpenCLKernel::Run() { cl_int2 HW = {h * w, hw4}; cl_int2 C = {c, c4}; - ocl_runtime->SetKernelArg(kernel_, 0, in_tensors_[0]->Data()); - ocl_runtime->SetKernelArg(kernel_, 1, out_tensors_[0]->Data()); - ocl_runtime->SetKernelArg(kernel_, 2, HW); - ocl_runtime->SetKernelArg(kernel_, 3, C); + int arg_idx = 0; + ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, HW); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, C); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, w); ocl_runtime->RunKernel(kernel_, global, local, nullptr); return RET_OK; } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc index 0172df1a31..6c21bc61c8 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc @@ -86,14 +86,14 @@ void RunTestCaseReshape(const std::vector &shape, void *input_data, void *o inputs[0]->SetData(nullptr); outputs[0]->SetData(nullptr); - MS_LOG(INFO) << "Test ReshapeFp32 passed"; + MS_LOG(INFO) << "Test Reshape passed"; lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestReshapeOpenCL, ReshapeFp32) { int c = 7; std::vector shape = {c}; - std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector output_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; RunTestCaseReshape(shape, input_data.data(), output_data.data(), false); @@ -102,7 +102,7 @@ TEST_F(TestReshapeOpenCL, ReshapeFp32) { TEST_F(TestReshapeOpenCL, ReshapeFp16) { int c = 7; std::vector shape = {c}; - std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector output_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; RunTestCaseReshape(shape, input_data.data(), output_data.data(), true); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc index 2913d322d9..fc2ce9dd48 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc @@ -17,94 +17,134 @@ #include #include "mindspore/core/utils/log_adapter.h" #include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h" #include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" namespace mindspore { +class TestSoftmaxOpenCL : public mindspore::CommonTest { + public: + TestSoftmaxOpenCL() {} +}; -class TestSoftmaxOpenCL : public mindspore::CommonTest {}; - -void RunTestCase(std::vector input_shape, std::vector output_shape, std::string input_file, - std::string expect_file, SoftmaxParameter *param, schema::Format format) { +void RunTestCaseSoftmax(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16) { auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); + size_t dtype_size = sizeof(float); + if (enable_fp16) { + ocl_runtime->SetFp16Enable(true); + dtype_size = sizeof(float16_t); + } auto allocator = ocl_runtime->GetAllocator(); - - // define tensor - MS_LOG(INFO) << "defineTensor"; - auto data_type = kNumberTypeFloat32; - auto tensorType = schema::NodeType_ValueNode; - auto input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, format, tensorType); - auto output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, format, tensorType); - if (input_tensor == nullptr) { - MS_LOG(ERROR) << "input tensor null"; + int n, h, w, c; + bool is_2d = false; + if (shape.size() == 2) { + is_2d = true; + h = w = 1; + n = shape[0]; + c = shape[1]; + } else { + n = shape[0]; + h = shape[1]; + w = shape[2]; + c = shape[3]; + } + std::vector input_shape = {n, h, w, c}; + if (is_2d) { + input_shape = {n, c}; + } + auto input_format = is_2d ? schema::Format_NC : schema::Format_NHWC; + auto input_dtype = enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32; + auto tensor_x_ptr = std::make_unique(TypeId(input_dtype), input_shape, input_format); + auto tensor_x = tensor_x_ptr.get(); + if (tensor_x == nullptr) { + MS_LOG(ERROR) << "tensor_x create error."; return; } - if (output_tensor == nullptr) { - MS_LOG(ERROR) << "output tensor null"; + auto tensor_out_ptr = std::make_unique(TypeId(input_dtype), input_shape, input_format); + auto tensor_out = tensor_out_ptr.get(); + if (tensor_out == nullptr) { + MS_LOG(ERROR) << "tensor_out create error."; return; } - std::vector inputs{input_tensor}; - std::vector outputs{output_tensor}; - - // run - MS_LOG(INFO) << "NewOpenCLKernel"; - auto *kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast(param), inputs, outputs); - if (kernel == nullptr) { - MS_LOG(ERROR) << "kernel null"; + std::vector inputs{tensor_x}; + std::vector outputs{tensor_out}; + auto arith_kernel_ptr = std::make_unique(nullptr, inputs, outputs); + auto arith_kernel = arith_kernel_ptr.get(); + if (arith_kernel == nullptr) { + MS_LOG(ERROR) << "arith_kernel create error."; return; } - MS_LOG(INFO) << "KernelInit"; - kernel->Init(); + arith_kernel->Init(); - std::vector kernels{kernel}; inputs[0]->MallocData(allocator); - auto *pGraph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + + std::vector kernels{arith_kernel}; + auto pGraph_ptr = std::make_unique(inputs, outputs, kernels, kernels, kernels); + auto pGraph = pGraph_ptr.get(); if (pGraph == nullptr) { - MS_LOG(ERROR) << "pGraph null"; + MS_LOG(ERROR) << "pGraph create error."; return; } - MS_LOG(INFO) << "pGraphinit"; pGraph->Init(); - - // load data - MS_LOG(INFO) << "load data1"; - LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file); - auto *input_data = reinterpret_cast(input_tensor->Data()); - printf("\ninput[0:10]:"); - for (int i = 0; i < 10; i++) { - printf("[%d]:%.3f ", i, input_data[i]); - } - printf("\n\n"); - - MS_LOG(INFO) << "Run"; + memcpy(inputs[0]->Data(), input_data, inputs[0]->ElementsNum() * dtype_size); pGraph->Run(); - MS_LOG(INFO) << "compare result"; - CompareOutput(output_tensor, expect_file, static_cast(1e-5)); - for (auto tensor : inputs) { - delete tensor; + if (enable_fp16) { + CompareOutput(outputs[0]->Data(), output_data, outputs[0]->ElementsNum(), static_cast(1e-3), 2e-2); + } else { + CompareOutput(outputs[0]->Data(), output_data, outputs[0]->ElementsNum(), static_cast(1e-5)); } - for (auto tensor : outputs) { - delete tensor; - } - delete kernel; - delete pGraph; + inputs[0]->SetData(nullptr); + outputs[0]->SetData(nullptr); + + MS_LOG(INFO) << "Test Softmax passed"; lite::opencl::OpenCLRuntime::DeleteInstance(); } -TEST_F(TestSoftmaxOpenCL, Softmax_1) { - std::vector input_shape = {1, 2, 2, 8}; - std::vector output_shape = {1, 2, 2, 8}; - std::string input_file = "softmax_in.bin"; - std::string expect_file = "softmax_out.bin"; - auto param = new (std::nothrow) SoftmaxParameter; - param->axis_ = 3; - schema::Format format = schema::Format_NHWC4; +TEST_F(TestSoftmaxOpenCL, Softmax2DFp32) { + int n = 1; + int c = 10; + std::vector shape = {n, c}; + std::vector input_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + std::vector output_data = {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}; + + RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), false); +} + +TEST_F(TestSoftmaxOpenCL, Softmax2DFp16) { + int n = 1; + int c = 10; + std::vector shape = {n, c}; + std::vector input_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + std::vector output_data = {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}; - RunTestCase(input_shape, output_shape, input_file, expect_file, param, format); + RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), true); } +TEST_F(TestSoftmaxOpenCL, Softmax4DFp32) { + int n = 1; + int h = 2; + int w = 1; + int c = 5; + std::vector shape = {n, h, w, c}; + std::vector input_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + std::vector output_data = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; + + RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), false); +} + +TEST_F(TestSoftmaxOpenCL, Softmax4DFp16) { + int n = 1; + int h = 2; + int w = 1; + int c = 5; + std::vector shape = {n, h, w, c}; + std::vector input_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + std::vector output_data = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; + + RunTestCaseSoftmax(shape, input_data.data(), output_data.data(), true); +} } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc index 925b45de20..c0dbcf3b22 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc @@ -117,8 +117,8 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) { } TEST_F(TestTransposeOpenCL, TransposeFp16) { - int h = 4; - int w = 1; + int h = 2; + int w = 2; int c = 3; std::vector shape = {h, w, c}; std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};