diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc index cd06b8666b..40e9bc80a3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc @@ -50,6 +50,12 @@ std::vector ArithmeticOpenCLKernel::InitGlobalSize() const { void ArithmeticOpenCLKernel::Image2dGetWorkGroupSize() { local_size_ = {16, 16}; + if (out_tensors_[0]->shape().size() == 2) { + size_t H = out_tensors_[0]->shape()[0]; + size_t W = UP_DIV(out_tensors_[0]->shape()[1], C4NUM); + global_size_ = {W, H}; + return; + } if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) { size_t H = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM); size_t W = out_tensors_[0]->Width(); @@ -74,18 +80,23 @@ void ArithmeticOpenCLKernel::BufferGetWorkGroupSize() { int ArithmeticOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { size_t im_dst_x, im_dst_y; - if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) { - im_dst_x = out_tensors_[0]->Width(); - im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM); - } else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) { - im_dst_x = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM); - im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height(); - } else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) { - im_dst_y = out_tensors_[0]->Batch(); - im_dst_x = UP_DIV(out_tensors_[0]->Channel(), C4NUM); + if (out_tensors_[0]->shape().size() == 2) { + im_dst_x = UP_DIV(out_tensors_[0]->shape()[1], C4NUM); + im_dst_y = out_tensors_[0]->shape()[0]; } else { - MS_LOG(ERROR) << "Unsupport data format " << out_tensors_[0]->GetFormat(); - return RET_ERROR; + if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) { + im_dst_x = out_tensors_[0]->Width(); + im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM); + } else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) { + im_dst_x = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM); + im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height(); + } else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) { + im_dst_y = out_tensors_[0]->Batch(); + im_dst_x = UP_DIV(out_tensors_[0]->Channel(), C4NUM); + } else { + MS_LOG(ERROR) << "Unsupport data format " << out_tensors_[0]->GetFormat(); + return RET_ERROR; + } } size_t img_dtype = CL_FLOAT; @@ -335,22 +346,7 @@ int ArithmeticOpenCLKernel::Run() { } ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c()); - int H = 0; - int W = 0; - if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) { - H = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM); - W = out_tensors_[0]->Width(); - } else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) { - H = out_tensors_[0]->Batch() * out_tensors_[0]->Height(); - W = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM); - } else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) { - H = out_tensors_[0]->Batch(); - W = UP_DIV(out_tensors_[0]->Channel(), C4NUM); - } else { - MS_LOG(ERROR) << "Error output type " << out_tensors_[0]->GetFormat(); - return RET_ERROR; - } - cl_int2 output_shape{W, H}; + cl_int2 output_shape{static_cast(global_size_[0]), static_cast(global_size_[1])}; ocl_runtime_->SetKernelArg(kernel_, arg_idx++, output_shape); ocl_runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr); return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc index 746e5211c4..3a3a2db6ee 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc @@ -162,4 +162,5 @@ kernel::LiteKernel *OpenCLBiasAddKernelCreator(const std::vector } REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, OpenCLBiasAddKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_BiasAdd, OpenCLBiasAddKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc index 52d271c164..063ab515e2 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc @@ -172,4 +172,5 @@ kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector & } REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_PReLU, OpenCLPReluKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_PReLU, OpenCLPReluKernelCreator) } // namespace mindspore::kernel