diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 438f79e276..2f2e746438 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -32,6 +32,9 @@ #include "src/runtime/agent/npu/npu_manager.h" #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" #endif +#if SUPPORT_GPU +#include "src/runtime/kernel/opencl/opencl_subgraph.h" +#endif namespace mindspore { namespace lite { @@ -629,8 +632,16 @@ int LiteSession::ReSizeKernels(const std::vector &kernels) MS_LOG(ERROR) << "All node in graph should be sub_graph"; return RET_ERROR; } - auto sub_graph = reinterpret_cast(kernel); - auto ret = sub_graph->ReSize(infer_shape_interrupt); + auto ret = RET_OK; + if (kernel->subgraph_type() == kernel::kGpuSubGraph) { +#if SUPPORT_GPU + auto sub_graph = reinterpret_cast(kernel); + ret = sub_graph->ReSize(false); +#endif + } else { + auto sub_graph = reinterpret_cast(kernel); + ret = sub_graph->ReSize(infer_shape_interrupt); + } if (ret == RET_INFER_INVALID) { MS_LOG(INFO) << "InferShape is interrupted"; infer_shape_interrupt = true; diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl index 74f862aa3c..19cce26000 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl @@ -289,6 +289,35 @@ __kernel void BroadcastNHWC4Add(__read_only image2d_t input_a, __read_only image WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result); } +__kernel void BroadcastNHWC4BiasAdd(__read_only image2d_t input_a, __read_only image2d_t input_b, + __write_only image2d_t output, const int4 a_shape, const int4 b_shape, + const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) { + int X = get_global_id(0); // C4 + int Y = get_global_id(1); // W + int Z = get_global_id(2); // H + if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) { + return; + } + int a_c = X < a_shape.w ? X : a_shape.w - 1; + int a_w = Y < a_shape.z ? Y : a_shape.z - 1; + int a_h = Z < a_shape.y ? Z : a_shape.y - 1; + FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h)); + int b_c = X < b_shape.w ? X : b_shape.w - 1; + int b_w = Y < b_shape.z ? Y : b_shape.z - 1; + int b_h = Z < b_shape.y ? Z : b_shape.y - 1; + FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h)); + FLT4 result; + if (broadcastC_flag == 0) { + result = a + b; + } else if (broadcastC_flag == 1) { + result = a.x + b; + } else { + result = a + b.x; + } + result = clamp(result, (FLT)(act_min), (FLT)(act_max)); + WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result); +} + __kernel void BroadcastNHWC4Sub(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t output, const int4 a_shape, const int4 b_shape, const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) { @@ -463,9 +492,9 @@ __kernel void BroadcastNHWC4FloorMod(__read_only image2d_t input_a, __read_only } __kernel void BroadcastNHWC4SquaredDifference(__read_only image2d_t input_a, __read_only image2d_t input_b, - __write_only image2d_t output, const int4 a_shape, const int4 b_shape, - const int4 output_shape, const int broadcastC_flag, float act_min, - float act_max) { + __write_only image2d_t output, const int4 a_shape, const int4 b_shape, + const int4 output_shape, const int broadcastC_flag, float act_min, + float act_max) { int X = get_global_id(0); // C4 int Y = get_global_id(1); // w int Z = get_global_id(2); // H diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/biasadd.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/biasadd.cl deleted file mode 100644 index 0000878e1d..0000000000 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/biasadd.cl +++ /dev/null @@ -1,29 +0,0 @@ -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#define C4NUM 4 -#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) -__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - -__kernel void BiasAdd(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, - __read_only image2d_t alpha, const int data_type) { - int H = input_shape.y; - int C = input_shape.w; // channel size - C = UP_DIV(C, C4NUM); - if ((C == 0 || H == 0) && data_type != 1) { - return; - } - int Y = get_global_id(0); // height id - int X = get_global_id(1); // weight id - - FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y)); - FLT4 tmp = in_c4; - int index = 0; - if (data_type == 1) { // NC - index = X; - } else if (data_type == 2) { // NHWC4 - index = X % C; - } else { // NC4HW4 - index = Y / H; - } - tmp += READ_IMAGE(alpha, smp_zero, (int2)(index, 0)); - WRITE_IMAGE(output, (int2)(X, Y), tmp); -} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h index bc8232c4dc..fb6693ef97 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class ActivationOpenCLKernel : public OpenCLKernel { public: ActivationOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs), + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : OpenCLKernel(parameter, inputs, outputs, ctx, primitive), type_(reinterpret_cast(parameter)->type_), alpha_(reinterpret_cast(parameter)->alpha_) {} ~ActivationOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h index 51d7c07858..feb5782f97 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class ArgMinMaxOpenCLKernel : public OpenCLKernel { public: - ArgMinMaxOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~ArgMinMaxOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc index 78e866ee7d..e8a3b197b1 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc @@ -34,6 +34,7 @@ using mindspore::lite::opencl::MemType; using mindspore::schema::ActivationType_NO_ACTIVATION; using mindspore::schema::ActivationType_RELU; using mindspore::schema::ActivationType_RELU6; +using mindspore::schema::PrimitiveType_BiasAdd; using mindspore::schema::PrimitiveType_Eltwise; namespace mindspore::kernel { @@ -180,6 +181,9 @@ int ArithmeticOpenCLKernel::Prepare() { #else auto *param = reinterpret_cast(op_parameter_); + if (Type() == PrimitiveType_BiasAdd) { + const_cast(param)->broadcasting_ = true; + } element_flag_ = !param->broadcasting_; kernel_name_ = param->broadcasting_ ? "BroadcastNHWC4" : "Element"; kernel_name_ += schema::EnumNamePrimitiveType(Type()); @@ -237,6 +241,7 @@ REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LessEqual, OpenCLKernelCreato REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Greater, OpenCLKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, OpenCLKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Eltwise, OpenCLKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, OpenCLKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Mul, OpenCLKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Add, OpenCLKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Sub, OpenCLKernelCreator) @@ -255,4 +260,5 @@ REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_LessEqual, OpenCLKernelCreato REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Greater, OpenCLKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, OpenCLKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Eltwise, OpenCLKernelCreator) +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_BiasAdd, OpenCLKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h index c4d9f504da..ac679589e1 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h @@ -29,9 +29,7 @@ extern std::set SupportedOpenCLArithmetics; class ArithmeticOpenCLKernel : public OpenCLKernel { public: - ArithmeticOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~ArithmeticOpenCLKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic_self.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic_self.h index f6c1c8ed11..6e2988b94c 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic_self.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic_self.h @@ -41,9 +41,7 @@ namespace mindspore::kernel { class ArithmeticSelfOpenCLKernel : public OpenCLKernel { public: - ArithmeticSelfOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~ArithmeticSelfOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/batch_to_space_nd.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/batch_to_space_nd.h index c6db117a17..b8875913d6 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/batch_to_space_nd.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/batch_to_space_nd.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class BatchToSpaceNDOpenCLKernel : public OpenCLKernel { public: - BatchToSpaceNDOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~BatchToSpaceNDOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.h index 19bce9907a..57301b2918 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class BatchNormOpenCLKernel : public OpenCLKernel { public: - BatchNormOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~BatchNormOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc deleted file mode 100644 index acf2f8afb4..0000000000 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/runtime/kernel/opencl/kernel/biasadd.h" -#include -#include -#include -#include - -#include "src/kernel_registry.h" -#include "include/errorcode.h" -#include "src/runtime/opencl/opencl_runtime.h" -#include "src/runtime/kernel/opencl/cl/biasadd.cl.inc" - -using mindspore::kernel::KERNEL_ARCH::kGPU; -using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_BiasAdd; - -namespace mindspore::kernel { - -int BiasAddOpenCLKernel::CheckSpecs() { - if (in_tensors_.size() != 2 || out_tensors_.size() != 1) { - MS_LOG(ERROR) << "Reshape in size: " << in_tensors_.size() << ", out size: " << out_tensors_.size(); - return RET_ERROR; - } - if (in_tensors_.size() == 0) { - MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << in_tensors_.size(); - return RET_ERROR; - } - if (in_tensors_[0]->shape()[0] > 1) { - MS_LOG(ERROR) << "Input data size unsupported multi-batch."; - return RET_ERROR; - } - return RET_OK; -} - -void BiasAddOpenCLKernel::SetConstArgs() { - int arg_idx = 2; - std::map data_type{ - {schema::Format::Format_NC4, 1}, {schema::Format::Format_NHWC4, 2}, {schema::Format::Format_NC4HW4, 3}}; - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input_shape_); - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, BiasAdd_); - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, data_type[schema::Format::Format_NHWC4]); -} - -void BiasAddOpenCLKernel::SetGlobalLocal() { - cl_int4 global_size = input_shape_; - global_size.s[2] = UP_DIV(global_size.s[3], C4NUM) * global_size.s[2]; - std::vector local = {1, 1}; - std::vector global = {static_cast(global_size.s[1]), static_cast(global_size.s[2])}; - OpenCLKernel::AlignGlobalLocal(global, local); -} - -int BiasAddOpenCLKernel::InitWeights() { - int C = in_tensors_[1]->shape()[0]; - int div_ci = UP_DIV(C, C4NUM); - auto allocator = ocl_runtime_->GetAllocator(); - size_t img_dtype = CL_FLOAT; - if (enable_fp16_) { - img_dtype = CL_HALF_FLOAT; - } - std::vector img_size{size_t(div_ci), 1, img_dtype}; - BiasAdd_ = allocator->Malloc(div_ci * C4NUM * fp_size, img_size); - BiasAdd_ = allocator->MapBuffer(BiasAdd_, CL_MAP_WRITE, nullptr, true); - memset(BiasAdd_, 0x00, div_ci * C4NUM * fp_size); - memcpy(BiasAdd_, in_tensors_[1]->data_c(), C * fp_size); - allocator->UnmapBuffer(BiasAdd_); - return RET_OK; -} - -int BiasAddOpenCLKernel::Prepare() { - in_size_ = in_tensors_[0]->shape().size(); - out_size_ = out_tensors_[0]->shape().size(); - for (int i = 0; i < in_size_; ++i) { - input_shape_.s[i + 4 - in_size_] = in_tensors_[0]->shape()[i]; - } - enable_fp16_ = ocl_runtime_->GetFp16Enable(); - fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); - if (in_size_ != 4 && in_size_ != 2) { - MS_LOG(ERROR) << "BiasAdd only support dim=4 or 2, but your dim=" << in_size_; - return mindspore::lite::RET_ERROR; - } - int C = in_tensors_[0]->shape()[3]; - int Bias_Size = in_tensors_[1]->shape()[0]; - if (UP_DIV(Bias_Size, C4NUM) != UP_DIV(C, C4NUM)) { - MS_LOG(ERROR) << "BiasAdd weight channel size:" << Bias_Size << " must be equal with in_teneors channel size:" << C; - return mindspore::lite::RET_ERROR; - } - InitWeights(); - std::string source = biasadd_source; - std::string program_name = "BiasAdd"; - std::string kernel_name = "BiasAdd"; - ocl_runtime_->LoadSource(program_name, source); - ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name); - - auto ret = InitWeights(); - if (ret != RET_OK) { - return ret; - } - SetGlobalLocal(); - SetConstArgs(); - MS_LOG(DEBUG) << program_name << " Init Done!"; - return mindspore::lite::RET_OK; -} - -int BiasAddOpenCLKernel::Run() { - ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_[0]->data_c()); - ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_[0]->data_c()); - auto ret = ocl_runtime_->RunKernel(kernel_, global_range_, local_range_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error."; - return RET_ERROR; - } - return RET_OK; -} - -REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, OpenCLKernelCreator) -REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_BiasAdd, OpenCLKernelCreator) -} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h deleted file mode 100644 index 31955dd3ac..0000000000 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_ - -#include -#include - -#include "src/tensor.h" -#include "src/runtime/kernel/opencl/opencl_kernel.h" -#include "schema/model_generated.h" - -namespace mindspore::kernel { - -class BiasAddOpenCLKernel : public OpenCLKernel { - public: - BiasAddOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} - ~BiasAddOpenCLKernel() override = default; - - int Prepare() override; - - int CheckSpecs() override; - void SetConstArgs() override; - void SetGlobalLocal() override; - int InitWeights() override; - int Run() override; - - private: - void *BiasAdd_{nullptr}; - int in_size_{}; - int out_size_{}; - size_t fp_size{}; - cl_int4 input_shape_{}; - bool enable_fp16_{}; -}; - -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.h index 53f46fd4e1..44e50ed06d 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.h @@ -26,9 +26,7 @@ namespace mindspore::kernel { class CastOpenCLKernel : public OpenCLKernel { public: - CastOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~CastOpenCLKernel() override = default; int Prepare() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h index f685edc48f..dd5960d4ac 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class ConcatOpenCLKernel : public OpenCLKernel { public: - ConcatOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~ConcatOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc index 7c8f5b60cd..db48a334f2 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc @@ -483,9 +483,22 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector(opParameter); - if (UseFcReplaceConv(inputs, outputs, conv_param)) { + bool infer_shape_done; + if (primitive != nullptr) { + infer_shape_done = primitive->infer_flag(); + } else { + bool output_shape_setted = true; + for (auto output : outputs) { + if (output->shape().empty() || output->ElementsNum() < 0) { + output_shape_setted = false; + break; + } + } + infer_shape_done = output_shape_setted; + } + if (infer_shape_done && UseFcReplaceConv(inputs, outputs, conv_param)) { auto *fc_param = CreateFcParam(conv_param); - kernel = new (std::nothrow) FullConnectionOpenCLKernel(fc_param, inputs, outputs); + kernel = new (std::nothrow) FullConnectionOpenCLKernel(fc_param, inputs, outputs, ctx, primitive); real_param = fc_param; if (kernel == nullptr) { MS_LOG(ERROR) << "Create FullConnection kernel failed."; @@ -497,11 +510,13 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector(conv_param), inputs, outputs); + kernel = new (std::nothrow) + WinogradOpenCLKernel(reinterpret_cast(conv_param), inputs, outputs, ctx, primitive); } else { - kernel = new (std::nothrow) Conv2DOpenCLKernel(reinterpret_cast(conv_param), inputs, outputs); + kernel = new (std::nothrow) + Conv2DOpenCLKernel(reinterpret_cast(conv_param), inputs, outputs, ctx, primitive); } real_param = reinterpret_cast(conv_param); if (kernel == nullptr) { @@ -510,7 +525,10 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vectorCheckSpecs(); if (ret != mindspore::lite::RET_OK) { MS_LOG(ERROR) << "Init Convolution kernel failed."; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.h index 8a01b8f4cc..266fb6c023 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.h @@ -44,8 +44,9 @@ void ConvertFilter(void *src, void *dst, TypeId src_dtype, TypeId dst_dtype, Fil class Conv2DOpenCLKernel : public OpenCLKernel { public: Conv2DOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs), param_(reinterpret_cast(parameter)) { + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : OpenCLKernel(parameter, inputs, outputs, ctx, primitive), param_(reinterpret_cast(parameter)) { bool is_adreno = ocl_runtime_->GetGpuInfo().type == lite::opencl::GpuType::ADRENO; filter_type_ = is_adreno ? MemType::IMG : MemType::BUF; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index faa1bdd374..f5d50538a5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -40,9 +40,8 @@ int Conv2dTransposeOpenCLKernel::CheckSpecs() { return RET_ERROR; } ConvParameter *param = reinterpret_cast(op_parameter_); - if (param->pad_l_ != param->pad_r_ || param->kernel_h_ - param->stride_h_ != 2 * param->pad_l_ || - param->pad_u_ != param->pad_d_ || param->kernel_w_ - param->stride_w_ != 2 * param->pad_u_) { - MS_LOG(ERROR) << "only support kernel - stride == 2 * pad"; + if (param->pad_l_ != param->pad_r_ || param->pad_u_ != param->pad_d_) { + MS_LOG(ERROR) << "only support symmetric padding"; return RET_ERROR; } if (param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h index 05009c11f4..d8064dd19a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h @@ -27,9 +27,7 @@ namespace mindspore::kernel { class Conv2dTransposeOpenCLKernel : public OpenCLKernel { public: - Conv2dTransposeOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~Conv2dTransposeOpenCLKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h index 1ceb1a732c..a09e04f50b 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.h @@ -28,8 +28,9 @@ namespace mindspore::kernel { class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { public: DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) { + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : OpenCLKernel(parameter, inputs, outputs, ctx, primitive) { bool is_adreno = ocl_runtime_->GetGpuInfo().type == lite::opencl::GpuType::ADRENO; filter_type_ = is_adreno ? MemType::IMG : MemType::BUF; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h index 9c346487a5..60db7dd9bb 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h @@ -26,9 +26,7 @@ namespace mindspore::kernel { class FillOpenCLKernel : public OpenCLKernel { public: - FillOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~FillOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.h index 9463f15068..bb2a1fd294 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.h @@ -26,9 +26,7 @@ namespace mindspore::kernel { class FullConnectionOpenCLKernel : public OpenCLKernel { public: - FullConnectionOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~FullConnectionOpenCLKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.h index b8eb26aaa2..b7ce6fed0b 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fusion_eltwise.h @@ -153,9 +153,7 @@ bool IsEltwiseAndOperatorSupported(LiteKernel *node); class FusionEltwiseOpenCLKernel : public OpenCLKernel { public: - FusionEltwiseOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~FusionEltwiseOpenCLKernel() override { if (op_parameter_ != nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc index a43a93e179..9584dc1e9f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc @@ -50,7 +50,7 @@ int GatherOpenCLKernel::CheckSpecs() { return RET_ERROR; } int indices_ndim = in_tensors_.at(1)->shape().size(); - if (indices_ndim != 1) { + if (indices_ndim > 1) { MS_LOG(ERROR) << "GatherOpenCLKernel only supports 1D indices Tensor but get " << indices_ndim << "D."; return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.h index bde9e4515f..e264627858 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class GatherOpenCLKernel : public OpenCLKernel { public: - GatherOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~GatherOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.h index 3bc57c12c1..8ff86721ab 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.h @@ -24,9 +24,7 @@ namespace mindspore::kernel { class LayerNormOpenCLKernel : public OpenCLKernel { public: - LayerNormOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~LayerNormOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc index a36d16c88a..442b2c49b5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -215,17 +215,34 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector const lite::InnerContext *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { kernel::OpenCLKernel *kernel; - if (IsUseStrassenMatmul(inputs)) { + bool infer_shape_done; + if (primitive != nullptr) { + infer_shape_done = primitive->infer_flag(); + } else { + bool output_shape_setted = true; + for (auto output : outputs) { + if (output->shape().empty() || output->ElementsNum() < 0) { + output_shape_setted = false; + break; + } + } + infer_shape_done = output_shape_setted; + } + if (infer_shape_done && IsUseStrassenMatmul(inputs)) { MS_LOG(DEBUG) << "use_matmul_strassen"; - kernel = new (std::nothrow) StrassenOpenCLKernel(opParameter, inputs, outputs); + kernel = new (std::nothrow) StrassenOpenCLKernel(opParameter, inputs, outputs, ctx, primitive); } else { - kernel = new (std::nothrow) MatMulOpenCLKernel(opParameter, inputs, outputs); + kernel = new (std::nothrow) MatMulOpenCLKernel(opParameter, inputs, outputs, ctx, primitive); } if (kernel == nullptr) { MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; free(opParameter); return nullptr; } + if (!infer_shape_done) { + MS_LOG(WARNING) << "kernel don't infer shape yet!"; + return kernel; + } auto ret = kernel->CheckSpecs(); if (ret != mindspore::lite::RET_OK) { MS_LOG(ERROR) << "Check " << opParameter->name_ << " specification failed!"; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h index ab9ec6807f..ab58197334 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h @@ -28,9 +28,7 @@ namespace mindspore::kernel { class MatMulOpenCLKernel : public OpenCLKernel { public: - MatMulOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~MatMulOpenCLKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h index 23c6bf73ac..98fcbb7fe5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h @@ -26,9 +26,7 @@ namespace mindspore::kernel { class OneHotOpenCLKernel : public OpenCLKernel { public: - OneHotOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~OneHotOpenCLKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/pad.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/pad.h index 4578e6ee04..33e05cf89d 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/pad.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/pad.h @@ -29,8 +29,10 @@ namespace mindspore::kernel { class PadOpenCLKernel : public OpenCLKernel { public: PadOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs), param_(reinterpret_cast(op_parameter_)) {} + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : OpenCLKernel(parameter, inputs, outputs, ctx, primitive), + param_(reinterpret_cast(op_parameter_)) {} ~PadOpenCLKernel() override = default; int CheckSpecs() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h index f4a0785fe9..58641347de 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h @@ -27,8 +27,10 @@ namespace mindspore::kernel { class PoolingOpenCLKernel : public OpenCLKernel { public: PoolingOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs), parameter_(reinterpret_cast(parameter)) {} + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : OpenCLKernel(parameter, inputs, outputs, ctx, primitive), + parameter_(reinterpret_cast(parameter)) {} ~PoolingOpenCLKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/power.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/power.h index 04b2a7318a..469bc2e334 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/power.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/power.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class PowerOpenCLKernel : public OpenCLKernel { public: - PowerOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~PowerOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h index 49086e4936..5300059b1b 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h @@ -27,9 +27,7 @@ namespace mindspore::kernel { class PReluOpenCLKernel : public OpenCLKernel { public: - PReluOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~PReluOpenCLKernel() override = default; int Prepare() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.h index 85c81312f8..d17b2ed4e9 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.h @@ -26,9 +26,7 @@ namespace mindspore::kernel { class ReduceOpenCLKernel : public OpenCLKernel { public: - ReduceOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~ReduceOpenCLKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc index 72bb13e7d0..496bbae6b3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc @@ -37,15 +37,16 @@ int ReshapeOpenCLKernel::CheckSpecs() { MS_LOG(ERROR) << "Reshape input output size unsupported."; return RET_ERROR; } - if (in_tensors_[0]->data_type() != kNumberTypeFloat32 && in_tensors_[0]->data_type() != kNumberTypeFloat16) { + if (in_tensors_[0]->data_type() != kNumberTypeFloat32 && in_tensors_[0]->data_type() != kNumberTypeFloat16 && + in_tensors_[0]->data_type() != kNumberTypeInt32) { MS_LOG(ERROR) << "Unsupported data type " << in_tensors_[0]->data_type(); return RET_ERROR; } - if (in_tensors_[0]->shape().size() == 0 || in_tensors_[0]->shape().size() > 4) { - MS_LOG(ERROR) << "Reshape input size should in 1-4, actual: " << in_tensors_[0]->shape(); + if (in_tensors_[0]->shape().size() > 4) { + MS_LOG(ERROR) << "Reshape input size should in 0-4, actual: " << in_tensors_[0]->shape(); return RET_ERROR; } - if (out_tensors_[0]->shape().size() == 0 || out_tensors_[0]->shape().size() > 4) { + if (out_tensors_[0]->shape().size() > 4) { MS_LOG(ERROR) << "Reshape output size should in 1-4, actual: " << out_tensors_[0]->shape(); return RET_ERROR; } @@ -95,6 +96,16 @@ int ReshapeOpenCLKernel::Run() { return RET_OK; } +int ReshapeOpenCLKernel::PreProcess() { + if (Type() == PrimitiveType_Reshape) { + auto shape_tensor = in_tensors_[1]; + if (!shape_tensor->IsConst()) { + ocl_runtime_->SyncCommandQueue(); + } + } + return OpenCLKernel::PreProcess(); +} + REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Reshape, OpenCLKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Reshape, OpenCLKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Squeeze, OpenCLKernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.h index 98cf0978ee..527d08ca7d 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class ReshapeOpenCLKernel : public OpenCLKernel { public: - ReshapeOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~ReshapeOpenCLKernel() override = default; int Run() override; @@ -35,6 +33,7 @@ class ReshapeOpenCLKernel : public OpenCLKernel { int CheckSpecs() override; void SetConstArgs() override; void SetGlobalLocal() override; + int PreProcess() override; private: }; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.h index 3a42882f2a..529cfc285b 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/resize.h @@ -26,9 +26,7 @@ namespace mindspore::kernel { class ResizeOpenCLKernel : public OpenCLKernel { public: - ResizeOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~ResizeOpenCLKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.h index 2792068b01..8c03f8b1dd 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/scale.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class ScaleOpenCLKernel : public OpenCLKernel { public: - ScaleOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~ScaleOpenCLKernel() override; int CheckSpecs() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h index f89fd58776..9de897e64a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h @@ -27,8 +27,9 @@ namespace mindspore::kernel { class SoftmaxOpenCLKernel : public OpenCLKernel { public: SoftmaxOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) { + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : OpenCLKernel(parameter, inputs, outputs, ctx, primitive) { parameter_ = reinterpret_cast(parameter); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_batch_nd.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_batch_nd.h index 2bd9b0c352..6ec5707fa1 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_batch_nd.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_batch_nd.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class SpaceToBatchNDOpenCLKernel : public OpenCLKernel { public: - SpaceToBatchNDOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~SpaceToBatchNDOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.h index 671b9dedb6..69c1f94967 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.h @@ -26,9 +26,7 @@ namespace mindspore::kernel { class SpaceToDepthOpenCLKernel : public OpenCLKernel { public: - SpaceToDepthOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~SpaceToDepthOpenCLKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.h index c2b514c6f8..b29c1efb8b 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class SparseToDenseOpenCLKernel : public OpenCLKernel { public: - SparseToDenseOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~SparseToDenseOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/split.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/split.h index 6892e1d305..3c4da24f71 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/split.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/split.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class SplitOpenCLKernel : public OpenCLKernel { public: - SplitOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~SplitOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/stack.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/stack.h index 85bb66881b..c626a0bc0f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/stack.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/stack.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class StackOpenCLKernel : public OpenCLKernel { public: - explicit StackOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~StackOpenCLKernel() override{}; int Prepare() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.h index 3210848d85..db7432d100 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class StrassenOpenCLKernel : public MatMulOpenCLKernel { public: - StrassenOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : MatMulOpenCLKernel(parameter, inputs, outputs) {} + using MatMulOpenCLKernel::MatMulOpenCLKernel; ~StrassenOpenCLKernel() override = default; public: diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.h index 7016360998..9b6aa49151 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/strided_slice.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class StridedSliceOpenCLKernel : public OpenCLKernel { public: - StridedSliceOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~StridedSliceOpenCLKernel() override = default; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc index 2a786b92dd..a7ef74e216 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc @@ -106,6 +106,19 @@ int ToFormatOpenCLKernel::Run() { return RET_OK; } +int ToFormatOpenCLKernel::InferShape() { + if (infer_shape_flag_) { + return RET_OK; + } + if (in_tensors_[0]->shape().size() == 0 || in_tensors_[0]->ElementsNum() < 0) { + MS_LOG(ERROR) << "to_format op in tensor shape is 0, infer shape failed!"; + return RET_ERROR; + } + out_tensors_[0]->set_shape(in_tensors_[0]->shape()); + infer_shape_flag_ = true; + return RET_OK; +} + REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_ToFormat, OpenCLKernelCreator) REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_ToFormat, OpenCLKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h index 41eef5a5ee..a84f10a363 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h @@ -25,9 +25,7 @@ namespace mindspore::kernel { class ToFormatOpenCLKernel : public OpenCLKernel { public: - ToFormatOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~ToFormatOpenCLKernel() override = default; int Run() override; @@ -36,6 +34,7 @@ class ToFormatOpenCLKernel : public OpenCLKernel { int CheckSpecs() override; void SetConstArgs() override; void SetGlobalLocal() override; + int InferShape() override; private: size_t N_{1}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h index cb44101f75..b337a1218e 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h @@ -29,9 +29,7 @@ enum class TransposeType { AXIS0312, AXIS0231, GENERAL }; class TransposeOpenCLKernel : public OpenCLKernel { public: - TransposeOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : OpenCLKernel(parameter, inputs, outputs) {} + using OpenCLKernel::OpenCLKernel; ~TransposeOpenCLKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/winograd.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/winograd.h index b654537fce..e174c594ec 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/winograd.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/winograd.h @@ -26,8 +26,9 @@ namespace mindspore::kernel { class WinogradOpenCLKernel : public Conv2DOpenCLKernel { public: WinogradOpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : Conv2DOpenCLKernel(parameter, inputs, outputs) { + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : Conv2DOpenCLKernel(parameter, inputs, outputs, ctx, primitive) { filter_type_ = MemType::BUF; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc index 7642eb685a..e12a6a240c 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc @@ -463,8 +463,8 @@ void CreateEltwiseKernelReplaceOld(FusionEltwiseParameter *param, LiteKernel *ol MS_ASSERT(old); MS_ASSERT(nodes); MS_ASSERT(removed_set); - auto *eltwise = new (std::nothrow) - FusionEltwiseOpenCLKernel(reinterpret_cast(param), old->in_tensors(), old->out_tensors()); + auto *eltwise = new (std::nothrow) FusionEltwiseOpenCLKernel(reinterpret_cast(param), + old->in_tensors(), old->out_tensors(), nullptr, nullptr); if (eltwise == nullptr) { MS_LOG(ERROR) << "create FusionEltwiseOpenCLKernel error."; return; @@ -539,6 +539,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::vector *nodes, s } // namespace int OpenCLSubGraph::FusionPass() { + if (!this->IsSubGraphInferShapeDone()) { + return RET_OK; + } MS_LOG(DEBUG) << "start Fusion"; std::vector input_nodes; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc index a1d8665b60..3adfc05ae8 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc @@ -122,6 +122,41 @@ void OpenCLKernel::PrintOutput(int print_num, const std::string &out_file) { } } +int OpenCLKernel::PreProcess() { + auto ret = RET_OK; + ret = ReSize(); + if (ret != RET_OK) { + return ret; + } + auto allocator = ocl_runtime_->GetAllocator(); + for (auto i = 0; i < out_tensors_.size(); ++i) { + auto *output = out_tensors_.at(i); + MS_ASSERT(output); + if (GetMemType() == lite::opencl::MemType::IMG) { + std::vector img_size; + ret = GetImageSize(i, &img_size); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GetImageSize failed"; + return ret; + } + auto data_ptr = allocator->Malloc(output->Size(), img_size); + if (data_ptr == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + return RET_ERROR; + } + output->set_data(data_ptr); + } else { + ret = output->MallocData(allocator); + if (ret != RET_OK) { + MS_LOG(ERROR) << "MallocData failed"; + return ret; + } + } + output->set_allocator(allocator); + } + return RET_OK; +} + int OpenCLKernel::PostProcess() { for (auto *output : this->out_tensors()) { MS_ASSERT(output != nullptr); @@ -130,6 +165,45 @@ int OpenCLKernel::PostProcess() { return FreeInWorkTensor(); } +int OpenCLKernel::InferShape() { + if (infer_shape_flag_) { + return RET_OK; + } + if (primitive_ == nullptr) { + return RET_ERROR; + } + (const_cast(primitive_))->set_infer_flag(true); + auto ret = (const_cast(primitive_))->InferShape(in_tensors_, out_tensors_); + if (ret != RET_OK) { + (const_cast(primitive_))->set_infer_flag(false); + MS_LOG(ERROR) << "InferShape fail!"; + return ret; + } + infer_shape_flag_ = true; + return RET_OK; +} + +int OpenCLKernel::ReSize() { + if (infer_shape_flag_) { + return RET_OK; + } + auto ret = InferShape(); + if (ret != RET_OK) { + return ret; + } + ret = CheckSpecs(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ReSize failed for check kernel specs!"; + return ret; + } + ret = Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ReSize failed for kernel prepare!"; + return ret; + } + return RET_OK; +} + std::vector OpenCLKernel::GenerateTuningParam() { size_t ndim = global_size_.size(); std::vector tuning_params = {}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h index 5e68c6c7d8..65551b397a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h @@ -156,17 +156,30 @@ struct BaseTuningParameter { class OpenCLKernel : public LiteKernel { public: OpenCLKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs, nullptr, nullptr) { + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { ocl_runtime_ = ocl_runtime_wrap_.GetInstance(); + if (primitive != nullptr) { + infer_shape_flag_ = primitive->infer_flag(); + } else { + bool output_shape_setted = true; + for (auto output : outputs) { + if (output->shape().empty() || output->ElementsNum() < 0) { + output_shape_setted = false; + break; + } + } + infer_shape_flag_ = output_shape_setted; + } } ~OpenCLKernel() override = default; int AlignGlobalLocal(const std::vector &global, const std::vector &local); int Prepare() override { return RET_OK; } - int PreProcess() override { return RET_ERROR; } + int PreProcess() override; int PostProcess() override; - int ReSize() override { return RET_ERROR; } + int ReSize() override; int Run() override { return RET_ERROR; } virtual int CheckSpecs() { return RET_ERROR; } @@ -189,6 +202,9 @@ class OpenCLKernel : public LiteKernel { double GetProfilingTimeMs(); int DequantWeight(); void FreeDequantedWeight(); + virtual int InferShape(); + bool GetInferShapeFlag() { return infer_shape_flag_; } + void SetInferShapeFlag(bool flag) { infer_shape_flag_ = flag; } protected: static std::set GenerateLocalByGlobal(size_t global_i); @@ -213,6 +229,7 @@ class OpenCLKernel : public LiteKernel { cl::Event event_; void *restore_quant_data_{nullptr}; bool dequant_flag_{false}; + bool infer_shape_flag_{false}; private: lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_; @@ -223,12 +240,16 @@ kernel::LiteKernel *OpenCLKernelCreator(const std::vector &input const std::vector &outputs, OpParameter *opParameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { - auto *kernel = new (std::nothrow) T(reinterpret_cast(opParameter), inputs, outputs); + auto *kernel = new (std::nothrow) T(reinterpret_cast(opParameter), inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; free(opParameter); return nullptr; } + if (!reinterpret_cast(kernel)->GetInferShapeFlag()) { + MS_LOG(WARNING) << "kernel don't infer shape yet!"; + return kernel; + } auto ret = kernel->CheckSpecs(); if (ret != mindspore::lite::RET_OK) { MS_LOG(ERROR) << "Check " << opParameter->name_ << " specification failed!"; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc index 6fed22b189..0dc088288a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc @@ -171,6 +171,8 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector &in_tensors, parameter = nullptr; return RET_ERROR; } + static int index = 0; + in_convert_op->set_name("ToFormat_" + std::to_string(index)); ReplaceOutTensorAndKernelToConvert(in_tensor, in_kernels.at(i), new_tensor, in_convert_op, mem_type); @@ -302,16 +304,36 @@ void OpenCLSubGraph::GetInOutNodes() { } } +bool OpenCLSubGraph::IsSubGraphInferShapeDone() { + for (auto node : this->nodes_) { + auto opencl_kernel = reinterpret_cast(node); + if (!opencl_kernel->GetInferShapeFlag()) { + return false; + } + } + return true; +} + int OpenCLSubGraph::Prepare() { executor_ = new (std::nothrow) lite::opencl::OpenCLExecutor(); if (executor_ == nullptr) { MS_LOG(ERROR) << "Create OpenCLExecutor fail"; return RET_ERROR; } - auto ret = SubGraphKernel::Prepare(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "OpenCL prepare fail"; - return ret; + auto ret = RET_OK; + for (auto node : this->nodes_) { + if (node == nullptr) { + MS_LOG(ERROR) << "node in Subgraph is nullptr"; + return mindspore::lite::RET_NULL_PTR; + } + auto opencl_kernel = reinterpret_cast(node); + if (opencl_kernel->GetInferShapeFlag()) { + ret = node->Prepare(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "prepare node " << node->name() << " failed"; + return ret; + } + } } auto opencl_exec = reinterpret_cast(executor_); // If tuning_mode is DEFAULT, just malloc memory for reuse. @@ -341,7 +363,40 @@ void OpenCLSubGraph::UnInit() { delete this->executor_; } -int OpenCLSubGraph::ReSize() { return RET_OK; } +int OpenCLSubGraph::ReSize() { return ReSize(false); } + +int OpenCLSubGraph::ReSize(bool interrupt) { + for (auto kernel : nodes_) { + if (kernel == nullptr) { + MS_LOG(ERROR) << "input kernel is nullptr!"; + return RET_ERROR; + } + auto opencl_kernel = reinterpret_cast(kernel); + if (kernel->subgraph_type() != kernel::kNotSubGraph) { + MS_LOG(ERROR) << "all nodes in should be kernel"; + return RET_ERROR; + } + std::vector inputs = kernel->in_tensors(); + std::vector outputs = kernel->out_tensors(); + for (auto &output : outputs) { + output->FreeData(); + } + opencl_kernel->SetInferShapeFlag(false); + } + for (auto kernel : nodes_) { + auto opencl_kernel = reinterpret_cast(kernel); + auto ret = opencl_kernel->ReSize(); + if (ret != RET_OK) { + MS_LOG(WARNING) << "ReSize " << opencl_kernel->name() << "failed!"; + if (interrupt) { + return ret; + } else { + break; + } + } + } + return RET_OK; +} int OpenCLSubGraph::Run() { if (executor_ == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h index 7e737968e3..6c2024a80e 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h @@ -44,9 +44,11 @@ class OpenCLSubGraph : public SubGraphKernel { int Prepare() override; int Init() override; int ReSize() override; + int ReSize(bool interrupt); int Run() override; int Run(const KernelCallBack &before, const KernelCallBack &after) override { return this->Run(); }; int InsertOpsPass(); + bool IsSubGraphInferShapeDone(); private: void UnInit(); diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.cc b/mindspore/lite/src/runtime/kernel/opencl/utils.cc index 2537c3e9ae..a3c3c273c1 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/utils.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.cc @@ -58,7 +58,8 @@ const std::set ArithmeticPrimitives = {schema::PrimitiveT schema::PrimitiveType_LessEqual, schema::PrimitiveType_Greater, schema::PrimitiveType_GreaterEqual, - schema::PrimitiveType_Eltwise}; + schema::PrimitiveType_Eltwise, + schema::PrimitiveType_BiasAdd}; const std::set ArithmeticSelfPrimitives = { schema::PrimitiveType_Abs, schema::PrimitiveType_Ceil, schema::PrimitiveType_Cos, diff --git a/mindspore/lite/src/runtime/opencl/opencl_executor.cc b/mindspore/lite/src/runtime/opencl/opencl_executor.cc index 046e76f0f7..ce9b4bc187 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_executor.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_executor.cc @@ -47,31 +47,16 @@ int OpenCLExecutor::RunOrTune(std::vector &inputs, std::vector(kernel); - auto cur_outputs = kernel->out_tensors(); - for (auto i = 0; i < cur_outputs.size(); ++i) { - auto *output = cur_outputs.at(i); - MS_ASSERT(output); - if (op_kernel->GetMemType() == lite::opencl::MemType::IMG) { - std::vector img_size; - ret = op_kernel->GetImageSize(i, &img_size); - if (ret != RET_OK) { - MS_LOG(ERROR) << "GetImageSize failed"; - return ret; - } - auto data_ptr = allocator_->Malloc(output->Size(), img_size); - if (data_ptr == nullptr) { - MS_LOG(ERROR) << "Malloc data failed"; - return RET_ERROR; - } - output->set_data(data_ptr); + ret = kernel->PreProcess(); + if (RET_OK != ret) { + if (is_tune) { + MS_LOG(WARNING) << "PreProcess kernel failed, name: " << kernel->name() << " in tuning"; + opencl_runtime_ins->SetProfiling(profiling_tmp); + return RET_OK; } else { - ret = output->MallocData(allocator_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "MallocData failed"; - return ret; - } + MS_LOG(ERROR) << "PreProcess kernel failed, name: " << kernel->name(); + return ret; } - output->set_allocator(allocator_); } if (is_tune) { ret = op_kernel->Tune(); @@ -79,26 +64,20 @@ int OpenCLExecutor::RunOrTune(std::vector &inputs, std::vectorname(); return ret; } - ret = kernel->PostProcess(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "PostProcess kernel failed, name: " << kernel->name(); - return ret; - } } else { ret = kernel->Run(); if (ret != RET_OK) { MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); return ret; } - ret = kernel->PostProcess(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "PostProcess kernel failed, name: " << kernel->name(); - return ret; - } - if (profiling_tmp) { + if (profiling_tmp) MS_LOG(INFO) << "OpenCl kernel " << kernel->name() << "(" << kernel->type_str() << ") execute time is: " << op_kernel->GetProfilingTimeMs() << "ms"; - } + } + ret = kernel->PostProcess(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PostProcess kernel failed, name: " << kernel->name(); + return ret; } if (after != nullptr) { if (!after(TensorVectorCast(kernel->in_tensors()), TensorVectorCast(kernel->out_tensors()), callbackParam)) { diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 6523110f60..551b720b28 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -187,7 +187,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in kernel::KernelKey desc{kCPU, data_type, static_cast(primitive->Type())}; #if SUPPORT_GPU if (context_->IsGpuEnabled()) { - kernel::KernelKey gpu_desc{kGPU, desc.data_type, desc.type}; + // support more data type like int32 + kernel::KernelKey gpu_desc{kGPU, kNumberTypeFloat32, desc.type}; + if (context_->IsGpuFloat16Enabled()) gpu_desc.data_type = kNumberTypeFloat16; auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, gpu_desc); if (kernel != nullptr) { MS_LOG(DEBUG) << "Get gpu op success: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " << node->name_; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc deleted file mode 100644 index 037aac107e..0000000000 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc +++ /dev/null @@ -1,205 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "src/common/log_adapter.h" -#include "common/common_test.h" -#include "mindspore/lite/src/common/file_utils.h" -#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" -#include "mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h" -#include "mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h" - -using mindspore::kernel::BiasAddOpenCLKernel; -using mindspore::kernel::LiteKernel; -using mindspore::kernel::OpenCLSubGraph; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; - -namespace mindspore { - -// PrimitiveType_BiasAdd: src/ops/populate/bias_add_populate.cc - -class TestBiasAddOpenCL : public CommonTest {}; - -void LoadDataBiasAdd(void *dst, size_t dst_size, const std::string &file_path) { - if (file_path.empty()) { - memset(dst, 0x00, dst_size); - } else { - auto src_data = mindspore::lite::ReadFile(file_path.c_str(), &dst_size); - memcpy(dst, src_data, dst_size); - } -} - -template -void CompareOutBiasAdd(lite::Tensor *output_tensor, const std::string &standard_answer_file) { - size_t output_size = output_tensor->ElementsNum(); - auto output_data = reinterpret_cast(output_tensor->data_c()); - auto expect_data = reinterpret_cast(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); - constexpr float atol = 0.0002; - for (int i = 0; i < output_tensor->ElementsNum(); ++i) { - if (std::fabs(output_data[i] - expect_data[i]) > atol) { - printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]); - printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]); - printf("error at idx[%d] expect=%f output=%f\n\n\n", i, expect_data[i], output_data[i]); - return; - } - } - printf("compare success!\n"); - printf("compare success!\n"); - printf("compare success!\n\n\n"); -} - -template -void printf_tensor_BiasAdd(const std::string log, mindspore::lite::Tensor *in_data, int size) { - MS_LOG(INFO) << log; - auto input_data = reinterpret_cast(in_data->data_c()); - for (int i = 0; i < size; ++i) { - printf("%f ", input_data[i]); - } - printf("\n"); - MS_LOG(INFO) << "Print tensor done"; -} - -TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { - std::string in_file = "/data/local/tmp/in_data.bin"; - std::string weight_file = "/data/local/tmp/weight_data.bin"; - std::string standard_answer_file = "/data/local/tmp/biasadd.bin"; - MS_LOG(INFO) << "BiasAdd Begin test:"; - auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); - ocl_runtime->Init(); - auto data_type = kNumberTypeFloat16; - ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); - std::vector input_shape = {1, 9}; - std::vector output_shape = {1, 9}; - auto tensor_type = lite::Tensor::CONST_TENSOR; - schema::Format type = schema::Format_NC; - int weight_shape = 0; - if (input_shape.size() == 4) { - weight_shape = input_shape[3]; - } else { - weight_shape = input_shape[1]; - } - auto *input_tensor = new (std::nothrow) lite::Tensor(data_type, input_shape, type, tensor_type); - if (input_tensor == nullptr) { - MS_LOG(ERROR) << "new input tensor error!"; - return; - } - auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, type, tensor_type); - if (output_tensor == nullptr) { - MS_LOG(ERROR) << "new output tensor error!"; - delete input_tensor; - return; - } - auto *weight_tensor = - new (std::nothrow) lite::Tensor(data_type, std::vector{weight_shape}, schema::Format_NHWC, tensor_type); - if (weight_tensor == nullptr) { - MS_LOG(ERROR) << "new weight tensor error!"; - delete output_tensor; - delete input_tensor; - return; - } - std::vector inputs{input_tensor, weight_tensor}; - std::vector outputs{output_tensor}; - auto allocator = ocl_runtime->GetAllocator(); - inputs[0]->MallocData(allocator); - inputs[1]->MallocData(allocator); - LoadDataBiasAdd(input_tensor->data_c(), input_tensor->Size(), in_file); - LoadDataBiasAdd(weight_tensor->data_c(), weight_tensor->Size(), weight_file); - if (ocl_runtime->GetFp16Enable()) { - printf_tensor_BiasAdd("BiasAdd:FP16--input data", inputs[0], input_tensor->ElementsNum()); - printf_tensor_BiasAdd("BiasAdd:FP16--weight data", inputs[1], weight_tensor->ElementsNum()); - } else { - printf_tensor_BiasAdd("BiasAdd:FP32--input data", inputs[0], input_tensor->ElementsNum()); - printf_tensor_BiasAdd("BiasAdd:FP32--weight data", inputs[1], weight_tensor->ElementsNum()); - } - - auto *param = new (std::nothrow) OpParameter(); - if (param == nullptr) { - delete input_tensor; - delete output_tensor; - delete weight_tensor; - MS_LOG(ERROR) << "new OpParameter error!"; - return; - } - auto *biasadd_kernel = - new (std::nothrow) kernel::BiasAddOpenCLKernel(reinterpret_cast(param), inputs, outputs); - if (biasadd_kernel == nullptr) { - MS_LOG(ERROR) << "Create biasadd kernel error."; - delete input_tensor; - delete output_tensor; - delete weight_tensor; - delete param; - return; - } - auto ret = biasadd_kernel->Init(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "biasadd kernel init error."; - delete input_tensor; - delete output_tensor; - delete weight_tensor; - delete param; - delete biasadd_kernel; - return; - } - - MS_LOG(INFO) << "initialize sub_graph"; - std::vector kernels{biasadd_kernel}; - auto *sub_graph = new (std::nothrow) kernel::OpenCLSubGraph({input_tensor}, outputs, kernels, kernels, kernels); - if (sub_graph == nullptr) { - MS_LOG(ERROR) << "Create sub_graph kernel error."; - delete input_tensor; - delete output_tensor; - delete weight_tensor; - delete param; - delete biasadd_kernel; - return; - } - ret = sub_graph->Init(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "sub_graph init error."; - delete input_tensor; - delete output_tensor; - delete weight_tensor; - delete sub_graph; - delete param; - return; - } - MS_LOG(INFO) << "Sub graph begin running!"; - ret = sub_graph->Run(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "sub_graph run error."; - delete input_tensor; - delete output_tensor; - delete weight_tensor; - delete sub_graph; - delete param; - return; - } - - if (ocl_runtime->GetFp16Enable()) { - printf_tensor_BiasAdd("BiasAdd:FP16--output data", outputs[0], output_tensor->ElementsNum()); - CompareOutBiasAdd(output_tensor, standard_answer_file); - } else { - printf_tensor_BiasAdd("BiasAdd:FP32--output data", outputs[0], output_tensor->ElementsNum()); - CompareOutBiasAdd(output_tensor, standard_answer_file); - } - delete input_tensor; - delete weight_tensor; - delete output_tensor; - delete sub_graph; - delete param; -} -} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc index f34eaf8055..783cd9ea9e 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc @@ -73,8 +73,8 @@ TEST_F(TestCastSelfOpenCL, Castfp32tofp16) { std::vector inputs{input_tensor}; std::vector outputs{output_tensor}; - auto *cast_kernel = - new (std::nothrow) kernel::CastOpenCLKernel(reinterpret_cast(param), inputs, outputs); + auto *cast_kernel = new (std::nothrow) + kernel::CastOpenCLKernel(reinterpret_cast(param), inputs, outputs, nullptr, nullptr); if (cast_kernel == nullptr) { MS_LOG(INFO) << " new kernel::CastOpenCLKernel failed "; for (auto tensor : inputs) { @@ -159,8 +159,8 @@ TEST_F(TestCastSelfOpenCL, Castfp16tofp32) { std::vector inputs{input_tensor}; std::vector outputs{output_tensor}; - auto *cast_kernel = - new (std::nothrow) kernel::CastOpenCLKernel(reinterpret_cast(param), inputs, outputs); + auto *cast_kernel = new (std::nothrow) + kernel::CastOpenCLKernel(reinterpret_cast(param), inputs, outputs, nullptr, nullptr); if (cast_kernel == nullptr) { MS_LOG(INFO) << " new kernel::CastOpenCLKernel failed "; for (auto tensor : inputs) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc index 917073ec64..1611c8ffac 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc @@ -60,8 +60,8 @@ TEST_F(TestFillOpenCLCI, Fp32testfill) { return; } - auto *fill_kernel = - new (std::nothrow) kernel::FillOpenCLKernel(reinterpret_cast(param), inputs, outputs); + auto *fill_kernel = new (std::nothrow) + kernel::FillOpenCLKernel(reinterpret_cast(param), inputs, outputs, nullptr, nullptr); if (fill_kernel == nullptr) { MS_LOG(INFO) << " new kernel::FillOpenCLKernel failed "; delete param; @@ -116,8 +116,8 @@ TEST_F(TestFillOpenCLCI, Fp32testshape) { return; } - auto *fill_kernel = - new (std::nothrow) kernel::FillOpenCLKernel(reinterpret_cast(param), inputs, outputs); + auto *fill_kernel = new (std::nothrow) + kernel::FillOpenCLKernel(reinterpret_cast(param), inputs, outputs, nullptr, nullptr); if (fill_kernel == nullptr) { MS_LOG(INFO) << " new kernel::FillOpenCLKernel failed "; delete param; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc index a5aaf26ca8..ffbd200a84 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc @@ -58,7 +58,7 @@ TEST_F(TestToFormatOpenCL, ToFormatNHWC2NCHW) { } std::vector inputs{tensor_x}; std::vector outputs{tensor_out}; - auto arith_kernel_ptr = std::make_unique(nullptr, inputs, outputs); + auto arith_kernel_ptr = std::make_unique(nullptr, inputs, outputs, nullptr, nullptr); auto arith_kernel = arith_kernel_ptr.get(); if (arith_kernel == nullptr) { MS_LOG(ERROR) << "arith_kernel create error.";