| @@ -190,8 +190,8 @@ if (PLATFORM_ARM64) | |||
| endif () | |||
| if (BUILD_MINDDATA STREQUAL "lite" OR BUILD_MINDDATA STREQUAL "full") | |||
| # TODO: add sentencepiece dependency | |||
| #include(${TOP_DIR}/cmake/external_libs/sentencepiece.cmake) | |||
| # add sentencepiece dependency | |||
| # include(${TOP_DIR}/cmake/external_libs/sentencepiece.cmake) | |||
| # opencv | |||
| set(OpenCV_DIR ${TOP_DIR}/third_party/opencv/build) | |||
| find_package(OpenCV REQUIRED) | |||
| @@ -96,7 +96,7 @@ endif () | |||
| if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND PLATFORM_ARM) | |||
| add_custom_command(TARGET mindspore-lite POST_BUILD | |||
| COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip | |||
| ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so) | |||
| ${CMAKE_BINARY_DIR}/src/libmindspore-lite.so) | |||
| endif () | |||
| if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release") | |||
| @@ -124,10 +124,10 @@ endif () | |||
| if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND (PLATFORM_ARM64)) | |||
| add_custom_command(TARGET mindspore-lite-optimize POST_BUILD COMMAND | |||
| ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip | |||
| ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite-optimize.so) | |||
| ${CMAKE_BINARY_DIR}/src/libmindspore-lite-optimize.so) | |||
| add_custom_command(TARGET mindspore-lite-fp16 POST_BUILD COMMAND | |||
| ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip | |||
| ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite-fp16.so) | |||
| ${CMAKE_BINARY_DIR}/src/libmindspore-lite-fp16.so) | |||
| endif () | |||
| @@ -1,41 +1,80 @@ | |||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||
| #define SLICES 4 | |||
| #define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, | |||
| __read_only image2d_t alpha, const int data_type, const int bias_dim) { | |||
| int H = input_shape.y; | |||
| int C = input_shape.w; // channel size | |||
| C = UP_DIV(C, SLICES); | |||
| if (C == 0 || H == 0) { | |||
| #define NHWC4 2 | |||
| #define NC4HW4 100 | |||
| __kernel void PRelu_scalar(__read_only image2d_t input, __write_only image2d_t output, float weight, int4 shape, | |||
| int data_format) { | |||
| int h = get_global_id(0); | |||
| int w = get_global_id(1); | |||
| int slice = get_global_id(2); | |||
| int H = shape.y; | |||
| int W = shape.z; | |||
| int SLICES = shape.w; | |||
| if (h >= H || w >= W || slice >= SLICES) { | |||
| return; | |||
| } | |||
| int Y = get_global_id(0); // height id | |||
| int X = get_global_id(1); // weight id | |||
| FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y)); | |||
| FLT4 tmp; | |||
| int index = 0; | |||
| if (data_type == 1) { // NHWC4 | |||
| index = X % C; | |||
| } else if (data_type == 2) { // NC4HW4 | |||
| index = Y / H; | |||
| int x, y; | |||
| if (data_format == 2) { | |||
| x = w * SLICES + slice; | |||
| y = h; | |||
| } else { | |||
| x = w; | |||
| y = slice * H + h; | |||
| } | |||
| FLT4 out = READ_IMAGE(input, smp_zero, (int2)(x, y)); | |||
| if (out.x < 0) { | |||
| out.x *= weight; | |||
| } | |||
| if (out.y < 0) { | |||
| out.y *= weight; | |||
| } | |||
| if (out.z < 0) { | |||
| out.z *= weight; | |||
| } | |||
| if (out.w < 0) { | |||
| out.w *= weight; | |||
| } | |||
| WRITE_IMAGE(output, (int2)(x, y), out); | |||
| } | |||
| __kernel void PRelu_vector(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight_vector, | |||
| int4 shape, int data_format) { | |||
| int h = get_global_id(0); | |||
| int w = get_global_id(1); | |||
| int slice = get_global_id(2); | |||
| int H = shape.y; | |||
| int W = shape.z; | |||
| int SLICES = shape.w; | |||
| if (h >= H || w >= W || slice >= SLICES) { | |||
| return; | |||
| } | |||
| if (bias_dim == 1) { | |||
| index = 0; | |||
| } | |||
| FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(index, 0)); | |||
| FLT4 bias = weight; | |||
| if (bias_dim == 1) { | |||
| bias.y = weight.x; | |||
| bias.z = weight.x; | |||
| bias.w = weight.x; | |||
| } | |||
| tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * bias.x; | |||
| tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * bias.y; | |||
| tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * bias.z; | |||
| tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * bias.w; | |||
| WRITE_IMAGE(output, (int2)(X, Y), tmp); | |||
| FLT4 weight = weight_vector[slice]; | |||
| int x, y; | |||
| if (data_format == 2) { | |||
| x = w * SLICES + slice; | |||
| y = h; | |||
| } else { | |||
| x = w; | |||
| y = slice * H + h; | |||
| } | |||
| FLT4 out = READ_IMAGE(input, smp_zero, (int2)(x, y)); | |||
| if (out.x < 0) { | |||
| out.x *= weight.x; | |||
| } | |||
| if (out.y < 0) { | |||
| out.y *= weight.y; | |||
| } | |||
| if (out.z < 0) { | |||
| out.z *= weight.z; | |||
| } | |||
| if (out.w < 0) { | |||
| out.w *= weight.w; | |||
| } | |||
| WRITE_IMAGE(output, (int2)(x, y), out); | |||
| } | |||
| @@ -359,6 +359,8 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNHWC4() { | |||
| code += "#define padBottom " + std::to_string(padBottom) + "\n"; | |||
| code += "#define padLeft " + std::to_string(padLeft) + "\n"; | |||
| code += "#define padRight " + std::to_string(padRight) + "\n"; | |||
| code += "#define dilationH " + std::to_string(param->dilation_h_) + "\n"; | |||
| code += "#define dilationW " + std::to_string(param->dilation_w_) + "\n"; | |||
| code += "#define CI_SLICES " + std::to_string(CI_SLICES_) + "\n"; | |||
| code += "#define CO_SLICES " + std::to_string(CO_SLICES_) + "\n\n"; | |||
| @@ -398,10 +400,10 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNHWC4() { | |||
| code += | |||
| " for (int kh = 0; kh < KH; ++kh)\n" | |||
| " {\n" | |||
| " int ih = kh + oh * strideH - padTop;\n" | |||
| " int ih = kh * dilationH + oh * strideH - padTop;\n" | |||
| " for (int kw = 0; kw < KW; ++kw)\n" | |||
| " {\n" | |||
| " int iw = kw + ow * strideW - padLeft;\n" | |||
| " int iw = kw * dilationW + ow * strideW - padLeft;\n" | |||
| " if (ih >= 0 && ih < IH && iw >= 0 && iw < IW)\n" | |||
| " {\n" | |||
| " for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++)\n" | |||
| @@ -491,7 +493,9 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() { | |||
| code += " #define strideH " + std::to_string(strideH) + "\n"; | |||
| code += " #define strideW " + std::to_string(strideW) + "\n"; | |||
| code += " #define padTop " + std::to_string(padTop) + "\n"; | |||
| code += " #define padLeft " + std::to_string(padLeft) + "\n\n"; | |||
| code += " #define padLeft " + std::to_string(padLeft) + "\n"; | |||
| code += " #define dilationH " + std::to_string(param->dilation_h_) + "\n"; | |||
| code += " #define dilationW " + std::to_string(param->dilation_w_) + "\n"; | |||
| code += | |||
| " if (n_oh >= N_OH || ow >= OW || co_slice >= CO_SLICES) {\n" | |||
| @@ -513,7 +517,7 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() { | |||
| "\n" | |||
| " for (int kh = 0; kh < KH; ++kh)\n" | |||
| " {\n" | |||
| " int ih = kh + oh * strideH - padTop;\n" | |||
| " int ih = kh * dilationH + oh * strideH - padTop;\n" | |||
| " for (int kw = 0; kw < KW; ++kw)\n" | |||
| " {\n"; | |||
| @@ -523,7 +527,7 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() { | |||
| "{\n"; | |||
| } | |||
| code += " int iw0 = kw + (ow + 0) * strideW - padLeft;\n"; | |||
| code += " int iw0 = kw * dilationW + (ow + 0) * strideW - padLeft;\n"; | |||
| if (check_ow) { | |||
| code += | |||
| " if (last_is_double)\n" | |||
| @@ -531,7 +535,7 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() { | |||
| } | |||
| code += | |||
| " int iw1 = kw + (ow + 1) * strideW - padLeft;\n" | |||
| " int iw1 = kw * dilationW + (ow + 1) * strideW - padLeft;\n" | |||
| " for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++)\n" | |||
| " {\n" | |||
| " FLT4 in0 = READ_IMAGE(input, smp_zero, (int2)(iw0, (n * CI_SLICES + ci_slice) * IH + ih));\n" | |||
| @@ -916,4 +920,5 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector<lite::Tenso | |||
| } | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Conv2D, OpenCLConvolutionKernelCreator) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Conv2D, OpenCLConvolutionKernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -18,7 +18,6 @@ | |||
| #include <set> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| @@ -36,85 +35,116 @@ namespace mindspore::kernel { | |||
| void PReluOpenCLKernel::InitBuffer() { | |||
| auto allocator = ocl_runtime_->GetAllocator(); | |||
| int elem_num = in_tensors_[0]->shape().size() == 2 ? in_tensors_[0]->shape()[1] : in_tensors_[0]->shape()[3]; | |||
| int elem_num_c4 = UP_DIV(elem_num, C4NUM); | |||
| size_t img_dtype = CL_FLOAT; | |||
| if (enable_fp16_) { | |||
| img_dtype = CL_HALF_FLOAT; | |||
| } | |||
| std::vector<size_t> img_size{size_t(elem_num_c4), 1, img_dtype}; | |||
| PReluWeight_ = allocator->Malloc(elem_num_c4 * C4NUM * fp_size, img_size); | |||
| PReluWeight_ = allocator->MapBuffer(PReluWeight_, CL_MAP_WRITE, nullptr, true); | |||
| memset(PReluWeight_, 0x00, elem_num_c4 * C4NUM * fp_size); | |||
| if (enable_fp16_) { | |||
| if (in_tensors_[1]->data_type() == kNumberTypeFloat32) { | |||
| auto PReluWeight_fp16 = reinterpret_cast<uint16_t *>(PReluWeight_); | |||
| auto in_tensor_data_fp32 = reinterpret_cast<float *>(in_tensors_[1]->data_c()); | |||
| for (int i = 0; i < elem_num; i++) { | |||
| PReluWeight_fp16[i] = static_cast<float16_t>(in_tensor_data_fp32[i]); | |||
| } | |||
| auto weight_tensor = in_tensors_[1]; | |||
| if (weight_is_scalar) { | |||
| if (weight_tensor->data_type() == kNumberTypeFloat16) { | |||
| weight_scalar_ = static_cast<float>(*reinterpret_cast<float16_t *>(weight_tensor->data_c())); | |||
| } else { | |||
| memcpy(PReluWeight_, in_tensors_[1]->data_c(), elem_num * fp_size); | |||
| weight_scalar_ = *reinterpret_cast<float *>(weight_tensor->data_c()); | |||
| } | |||
| } else { | |||
| if (in_tensors_[1]->data_type() == kNumberTypeFloat16) { | |||
| auto PReluWeight_fp32 = reinterpret_cast<float *>(PReluWeight_); | |||
| auto in_tensor_data_fp16 = reinterpret_cast<float16_t *>(in_tensors_[1]->data_c()); | |||
| for (int i = 0; i < elem_num; i++) { | |||
| PReluWeight_fp32[i] = static_cast<float>(in_tensor_data_fp16[i]); | |||
| auto sizeof_FLT = enable_fp16_ ? sizeof(float16_t) : sizeof(float); | |||
| size_t weight_size = UP_ROUND(C_, C4NUM) * sizeof_FLT; | |||
| weight_vector_ = allocator->Malloc(weight_size); | |||
| allocator->MapBuffer(weight_vector_, CL_MAP_WRITE, nullptr, true); | |||
| memset(weight_vector_, 0x00, weight_size); | |||
| if (weight_tensor->data_type() == kNumberTypeFloat16) { | |||
| if (enable_fp16_) { | |||
| memcpy(weight_vector_, weight_tensor->data_c(), C_ * sizeof_FLT); | |||
| } else { | |||
| auto weight_fp32 = reinterpret_cast<float *>(weight_vector_); | |||
| auto origin_bias_fp16 = reinterpret_cast<float16_t *>(weight_tensor->data_c()); | |||
| for (int i = 0; i < C_; ++i) { | |||
| weight_fp32[i] = static_cast<float>(origin_bias_fp16[i]); | |||
| } | |||
| } | |||
| } else { | |||
| memcpy(PReluWeight_, in_tensors_[1]->data_c(), elem_num * fp_size); | |||
| if (enable_fp16_) { | |||
| auto weight_fp16 = reinterpret_cast<float16_t *>(weight_vector_); | |||
| auto origin_bias_fp32 = reinterpret_cast<float *>(weight_tensor->data_c()); | |||
| for (int i = 0; i < C_; ++i) { | |||
| weight_fp16[i] = static_cast<float16_t>(origin_bias_fp32[i]); | |||
| } | |||
| } else { | |||
| memcpy(weight_vector_, weight_tensor->data_c(), C_ * sizeof_FLT); | |||
| } | |||
| } | |||
| allocator->UnmapBuffer(weight_vector_); | |||
| } | |||
| allocator->UnmapBuffer(PReluWeight_); | |||
| } | |||
| int PReluOpenCLKernel::Init() { | |||
| if (in_tensors_[0]->shape().size() != 4) { | |||
| MS_LOG(ERROR) << "PRelu only support dim=4, but your dim=" << in_tensors_[0]->shape().size(); | |||
| auto input_tensor = in_tensors_[0]; | |||
| auto weight_tensor = in_tensors_[1]; | |||
| if (input_tensor->shape().size() != 4) { | |||
| MS_LOG(ERROR) << "PRelu only support dim=4, but your dim=" << input_tensor->shape().size(); | |||
| return RET_ERROR; | |||
| } | |||
| batch_size_ = input_tensor->Batch(); | |||
| C_ = input_tensor->Channel(); | |||
| H_ = input_tensor->Height(); | |||
| W_ = input_tensor->Width(); | |||
| if (input_tensor->GetFormat() != schema::Format_NC4HW4 && input_tensor->GetFormat() != schema::Format_NHWC4) { | |||
| MS_LOG(ERROR) << "PRelu only support Format_NC4HW4 and Format_NHWC4"; | |||
| return RET_ERROR; | |||
| } | |||
| int C_Weight = in_tensors_[1]->shape()[0]; | |||
| int C = in_tensors_[0]->shape()[3]; | |||
| if (C_Weight != 1 && UP_DIV(C_Weight, C4NUM) != UP_DIV(C, C4NUM)) { | |||
| if (batch_size_ != 1) { | |||
| MS_LOG(ERROR) << "Init PRelu kernel failed: Unsupported multi-batch."; | |||
| return RET_ERROR; | |||
| } | |||
| auto weight_channel = weight_tensor->shape()[0]; | |||
| if (weight_channel != 1 && weight_channel != C_) { | |||
| MS_LOG(ERROR) | |||
| << "PRelu weight channel size must be 1 or must be equal with in_teneors channel size, but your weight size is " | |||
| << C_Weight << " and your input channel size is " << C; | |||
| << weight_channel << " and your input channel size is " << C_; | |||
| return RET_ERROR; | |||
| } | |||
| for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) { | |||
| input_shape_.s[i] = in_tensors_[0]->shape()[i]; | |||
| weight_is_scalar = weight_channel == 1; | |||
| if (weight_tensor->data_type() != kNumberTypeFloat16 && weight_tensor->data_type() != kNumberTypeFloat32) { | |||
| MS_LOG(ERROR) << "PRelu weight must be float32 or float16"; | |||
| return RET_ERROR; | |||
| } | |||
| enable_fp16_ = ocl_runtime_->GetFp16Enable(); | |||
| in_ori_format_ = input_tensor->GetFormat(); | |||
| out_ori_format_ = out_tensors_[0]->GetFormat(); | |||
| input_tensor->SetFormat(op_format_); | |||
| out_tensors_[0]->SetFormat(op_format_); | |||
| std::set<std::string> build_options; | |||
| std::string source = prelu_source; | |||
| std::string program_name = "PRelu"; | |||
| std::string kernel_name = "PRelu"; | |||
| enable_fp16_ = ocl_runtime_->GetFp16Enable(); | |||
| fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); | |||
| InitBuffer(); | |||
| std::string kernel_name = "PRelu_" + std::string(weight_is_scalar ? "scalar" : "vector"); | |||
| ocl_runtime_->LoadSource(program_name, source); | |||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||
| in_ori_format_ = in_tensors_[0]->GetFormat(); | |||
| in_tensors_[0]->SetFormat(op_format_); | |||
| out_ori_format_ = out_tensors_[0]->GetFormat(); | |||
| out_tensors_[0]->SetFormat(op_format_); | |||
| InitBuffer(); | |||
| MS_LOG(DEBUG) << program_name << " init Done!"; | |||
| return RET_OK; | |||
| } | |||
| int PReluOpenCLKernel::Run() { | |||
| MS_LOG(DEBUG) << op_parameter_->name_ << " Running!"; | |||
| std::map<schema::Format, int> data_type{{schema::Format::Format_NHWC4, 1}, {schema::Format::Format_NC4HW4, 2}}; | |||
| auto CO_SLICES_ = UP_DIV(C_, C4NUM); | |||
| cl_int4 shape = {batch_size_, H_, W_, CO_SLICES_}; | |||
| int arg_idx = 0; | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c()); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c()); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input_shape_); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, PReluWeight_); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, data_type[op_format_]); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, reinterpret_cast<int>(in_tensors_[1]->shape()[0])); | |||
| std::vector<size_t> local = {1, 1}; | |||
| std::vector<size_t> global = {static_cast<size_t>(global_shape_.s[1]), static_cast<size_t>(global_shape_.s[2])}; | |||
| if (weight_is_scalar) { | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, weight_scalar_); | |||
| } else { | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, weight_vector_); | |||
| } | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, shape); | |||
| if (op_format_ == schema::Format_NHWC4) { | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, 2); | |||
| } else { // Format_NC4HW4 = 100 | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, 100); | |||
| } | |||
| std::vector<size_t> local = {4, 4, 1}; | |||
| std::vector<size_t> global = {static_cast<size_t>(H_), static_cast<size_t>(W_), static_cast<size_t>(CO_SLICES_)}; | |||
| auto ret = ocl_runtime_->RunKernel(kernel_, global, local, nullptr); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error."; | |||
| @@ -124,22 +154,26 @@ int PReluOpenCLKernel::Run() { | |||
| } | |||
| int PReluOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| size_t img_dtype = CL_FLOAT; | |||
| if (enable_fp16_) { | |||
| img_dtype = CL_HALF_FLOAT; | |||
| } | |||
| global_shape_ = input_shape_; | |||
| if (op_format_ == schema::Format::Format_NC4HW4) { | |||
| global_shape_.s[1] = UP_DIV(input_shape_.s[3], C4NUM) * input_shape_.s[1]; | |||
| } else if (op_format_ == schema::Format::Format_NHWC4) { | |||
| global_shape_.s[2] = UP_DIV(input_shape_.s[3], C4NUM) * input_shape_.s[2]; | |||
| size_t im_dst_x, im_dst_y; | |||
| auto CO_SLICES_ = UP_DIV(C_, C4NUM); | |||
| if (in_tensors_[0]->GetFormat() == schema::Format_NHWC4) { | |||
| if (W_ * CO_SLICES_ <= MAX_IMAGE2D_SIZE) { | |||
| { | |||
| im_dst_y = batch_size_ * H_; | |||
| im_dst_x = W_ * CO_SLICES_; | |||
| } | |||
| } else { | |||
| im_dst_y = W_; | |||
| im_dst_x = batch_size_ * H_ * CO_SLICES_; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "op_format_:" << op_format_ << " is do not support!"; | |||
| return RET_ERROR; | |||
| im_dst_y = batch_size_ * CO_SLICES_ * H_; | |||
| im_dst_x = W_; | |||
| } | |||
| size_t img_dtype = enable_fp16_ ? CL_HALF_FLOAT : CL_FLOAT; | |||
| img_size->clear(); | |||
| img_size->push_back(global_shape_.s[2]); | |||
| img_size->push_back(global_shape_.s[1]); | |||
| img_size->push_back(im_dst_x); | |||
| img_size->push_back(im_dst_y); | |||
| img_size->push_back(img_dtype); | |||
| return RET_OK; | |||
| } | |||
| @@ -152,16 +186,11 @@ kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector<lite::Tensor *> & | |||
| MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size(); | |||
| return nullptr; | |||
| } | |||
| if (inputs[0]->shape()[0] > 1) { | |||
| MS_LOG(ERROR) << "Init PRelu kernel failed: Unsupported multi-batch."; | |||
| return nullptr; | |||
| } | |||
| auto *kernel = new (std::nothrow) PReluOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init PRelu kernel failed!"; | |||
| @@ -39,11 +39,14 @@ class PReluOpenCLKernel : public OpenCLKernel { | |||
| private: | |||
| cl::Kernel kernel_; | |||
| void *PReluWeight_; | |||
| cl_int4 input_shape_; | |||
| cl_int4 global_shape_; | |||
| size_t fp_size; | |||
| bool enable_fp16_{false}; | |||
| int batch_size_{}; | |||
| int C_{}; | |||
| int H_{}; | |||
| int W_{}; | |||
| void *weight_vector_{nullptr}; | |||
| float weight_scalar_{0.f}; | |||
| bool weight_is_scalar{false}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -235,14 +235,16 @@ void PrintTensor(lite::Tensor *tensor, int num, const std::string &out_file) { | |||
| if (tensor->data_c() == nullptr) { | |||
| return; | |||
| } | |||
| auto runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); | |||
| auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); | |||
| auto runtime = runtime_wrapper.GetInstance(); | |||
| runtime->SyncCommandQueue(); | |||
| auto allocator = runtime->GetAllocator(); | |||
| auto origin_data = tensor->data_c(); | |||
| allocator->MapBuffer(origin_data, CL_MAP_READ, nullptr, true); | |||
| allocator->MapBuffer(origin_data, CL_MAP_READ | CL_MAP_WRITE, nullptr, true); | |||
| tensor->SetData(origin_data); | |||
| auto Batch = tensor->Batch(); | |||
| auto Height = tensor->shape().size() == 4 ? tensor->Height() : 1; | |||
| auto Width = tensor->shape().size() == 4 ? tensor->Width() : 1; | |||
| auto SLICES = UP_DIV(tensor->Channel(), C4NUM); | |||
| @@ -250,17 +252,8 @@ void PrintTensor(lite::Tensor *tensor, int num, const std::string &out_file) { | |||
| auto dtype_size = tensor->data_type() == kNumberTypeFloat16 ? sizeof(cl_half4) : sizeof(cl_float4); | |||
| auto row_pitch = (Width * SLICES + alignment - 1) / alignment * alignment * dtype_size; | |||
| auto row_size = Width * SLICES * dtype_size; | |||
| std::cout << "tensor->GetFormat() =" << tensor->GetFormat() << "\n"; | |||
| std::cout << "Height =" << Height << "\n"; | |||
| std::cout << "Width =" << Width << "\n"; | |||
| std::cout << "SLICES =" << SLICES << "\n"; | |||
| std::cout << "image_alignment =" << alignment << "\n"; | |||
| std::cout << "dtype_size =" << dtype_size << "\n"; | |||
| std::cout << "row_pitch =" << row_pitch << "\n"; | |||
| std::cout << "row_size =" << row_size << "\n"; | |||
| std::cout << "tensor->Size() =" << tensor->Size() << "\n"; | |||
| std::vector<char> data(tensor->Size()); | |||
| for (int i = 0; i < Height; ++i) { | |||
| for (int i = 0; i < Batch * Height; ++i) { | |||
| memcpy(static_cast<char *>(data.data()) + i * row_size, static_cast<char *>(origin_data) + i * row_pitch, row_size); | |||
| } | |||