| @@ -32,6 +32,9 @@ | |||
| #include "src/runtime/agent/npu/npu_manager.h" | |||
| #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" | |||
| #endif | |||
| #if SUPPORT_GPU | |||
| #include "src/runtime/kernel/opencl/opencl_subgraph.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -629,8 +632,16 @@ int LiteSession::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) | |||
| MS_LOG(ERROR) << "All node in graph should be sub_graph"; | |||
| return RET_ERROR; | |||
| } | |||
| auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); | |||
| auto ret = sub_graph->ReSize(infer_shape_interrupt); | |||
| auto ret = RET_OK; | |||
| if (kernel->subgraph_type() == kernel::kGpuSubGraph) { | |||
| #if SUPPORT_GPU | |||
| auto sub_graph = reinterpret_cast<kernel::OpenCLSubGraph *>(kernel); | |||
| ret = sub_graph->ReSize(false); | |||
| #endif | |||
| } else { | |||
| auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel); | |||
| ret = sub_graph->ReSize(infer_shape_interrupt); | |||
| } | |||
| if (ret == RET_INFER_INVALID) { | |||
| MS_LOG(INFO) << "InferShape is interrupted"; | |||
| infer_shape_interrupt = true; | |||
| @@ -289,6 +289,35 @@ __kernel void BroadcastNHWC4Add(__read_only image2d_t input_a, __read_only image | |||
| WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result); | |||
| } | |||
| __kernel void BroadcastNHWC4BiasAdd(__read_only image2d_t input_a, __read_only image2d_t input_b, | |||
| __write_only image2d_t output, const int4 a_shape, const int4 b_shape, | |||
| const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) { | |||
| int X = get_global_id(0); // C4 | |||
| int Y = get_global_id(1); // W | |||
| int Z = get_global_id(2); // H | |||
| if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) { | |||
| return; | |||
| } | |||
| int a_c = X < a_shape.w ? X : a_shape.w - 1; | |||
| int a_w = Y < a_shape.z ? Y : a_shape.z - 1; | |||
| int a_h = Z < a_shape.y ? Z : a_shape.y - 1; | |||
| FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_h)); | |||
| int b_c = X < b_shape.w ? X : b_shape.w - 1; | |||
| int b_w = Y < b_shape.z ? Y : b_shape.z - 1; | |||
| int b_h = Z < b_shape.y ? Z : b_shape.y - 1; | |||
| FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_h)); | |||
| FLT4 result; | |||
| if (broadcastC_flag == 0) { | |||
| result = a + b; | |||
| } else if (broadcastC_flag == 1) { | |||
| result = a.x + b; | |||
| } else { | |||
| result = a + b.x; | |||
| } | |||
| result = clamp(result, (FLT)(act_min), (FLT)(act_max)); | |||
| WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result); | |||
| } | |||
| __kernel void BroadcastNHWC4Sub(__read_only image2d_t input_a, __read_only image2d_t input_b, | |||
| __write_only image2d_t output, const int4 a_shape, const int4 b_shape, | |||
| const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) { | |||
| @@ -463,9 +492,9 @@ __kernel void BroadcastNHWC4FloorMod(__read_only image2d_t input_a, __read_only | |||
| } | |||
| __kernel void BroadcastNHWC4SquaredDifference(__read_only image2d_t input_a, __read_only image2d_t input_b, | |||
| __write_only image2d_t output, const int4 a_shape, const int4 b_shape, | |||
| const int4 output_shape, const int broadcastC_flag, float act_min, | |||
| float act_max) { | |||
| __write_only image2d_t output, const int4 a_shape, const int4 b_shape, | |||
| const int4 output_shape, const int broadcastC_flag, float act_min, | |||
| float act_max) { | |||
| int X = get_global_id(0); // C4 | |||
| int Y = get_global_id(1); // w | |||
| int Z = get_global_id(2); // H | |||
| @@ -1,29 +0,0 @@ | |||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||
| #define C4NUM 4 | |||
| #define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void BiasAdd(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, | |||
| __read_only image2d_t alpha, const int data_type) { | |||
| int H = input_shape.y; | |||
| int C = input_shape.w; // channel size | |||
| C = UP_DIV(C, C4NUM); | |||
| if ((C == 0 || H == 0) && data_type != 1) { | |||
| return; | |||
| } | |||
| int Y = get_global_id(0); // height id | |||
| int X = get_global_id(1); // weight id | |||
| FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y)); | |||
| FLT4 tmp = in_c4; | |||
| int index = 0; | |||
| if (data_type == 1) { // NC | |||
| index = X; | |||
| } else if (data_type == 2) { // NHWC4 | |||
| index = X % C; | |||
| } else { // NC4HW4 | |||
| index = Y / H; | |||
| } | |||
| tmp += READ_IMAGE(alpha, smp_zero, (int2)(index, 0)); | |||
| WRITE_IMAGE(output, (int2)(X, Y), tmp); | |||
| } | |||
| @@ -28,8 +28,9 @@ namespace mindspore::kernel { | |||
| class ActivationOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| ActivationOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs), | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : OpenCLKernel(parameter, inputs, outputs, ctx, primitive), | |||
| type_(reinterpret_cast<ActivationParameter *>(parameter)->type_), | |||
| alpha_(reinterpret_cast<ActivationParameter *>(parameter)->alpha_) {} | |||
| ~ActivationOpenCLKernel() override = default; | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class ArgMinMaxOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| ArgMinMaxOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~ArgMinMaxOpenCLKernel() override = default; | |||
| @@ -34,6 +34,7 @@ using mindspore::lite::opencl::MemType; | |||
| using mindspore::schema::ActivationType_NO_ACTIVATION; | |||
| using mindspore::schema::ActivationType_RELU; | |||
| using mindspore::schema::ActivationType_RELU6; | |||
| using mindspore::schema::PrimitiveType_BiasAdd; | |||
| using mindspore::schema::PrimitiveType_Eltwise; | |||
| namespace mindspore::kernel { | |||
| @@ -180,6 +181,9 @@ int ArithmeticOpenCLKernel::Prepare() { | |||
| #else | |||
| auto *param = reinterpret_cast<const ArithmeticParameter *>(op_parameter_); | |||
| if (Type() == PrimitiveType_BiasAdd) { | |||
| const_cast<ArithmeticParameter *>(param)->broadcasting_ = true; | |||
| } | |||
| element_flag_ = !param->broadcasting_; | |||
| kernel_name_ = param->broadcasting_ ? "BroadcastNHWC4" : "Element"; | |||
| kernel_name_ += schema::EnumNamePrimitiveType(Type()); | |||
| @@ -237,6 +241,7 @@ REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LessEqual, OpenCLKernelCreato | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Greater, OpenCLKernelCreator<ArithmeticOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_GreaterEqual, OpenCLKernelCreator<ArithmeticOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Eltwise, OpenCLKernelCreator<ArithmeticOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, OpenCLKernelCreator<ArithmeticOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Mul, OpenCLKernelCreator<ArithmeticOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Add, OpenCLKernelCreator<ArithmeticOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Sub, OpenCLKernelCreator<ArithmeticOpenCLKernel>) | |||
| @@ -255,4 +260,5 @@ REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_LessEqual, OpenCLKernelCreato | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Greater, OpenCLKernelCreator<ArithmeticOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, OpenCLKernelCreator<ArithmeticOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Eltwise, OpenCLKernelCreator<ArithmeticOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_BiasAdd, OpenCLKernelCreator<ArithmeticOpenCLKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -29,9 +29,7 @@ extern std::set<schema::PrimitiveType> SupportedOpenCLArithmetics; | |||
| class ArithmeticOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| ArithmeticOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~ArithmeticOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -41,9 +41,7 @@ namespace mindspore::kernel { | |||
| class ArithmeticSelfOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| ArithmeticSelfOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~ArithmeticSelfOpenCLKernel() override = default; | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class BatchToSpaceNDOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| BatchToSpaceNDOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~BatchToSpaceNDOpenCLKernel() override = default; | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class BatchNormOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| BatchNormOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~BatchNormOpenCLKernel() override = default; | |||
| @@ -1,136 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/opencl/kernel/biasadd.h" | |||
| #include <string> | |||
| #include <map> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/opencl/cl/biasadd.cl.inc" | |||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_BiasAdd; | |||
| namespace mindspore::kernel { | |||
| int BiasAddOpenCLKernel::CheckSpecs() { | |||
| if (in_tensors_.size() != 2 || out_tensors_.size() != 1) { | |||
| MS_LOG(ERROR) << "Reshape in size: " << in_tensors_.size() << ", out size: " << out_tensors_.size(); | |||
| return RET_ERROR; | |||
| } | |||
| if (in_tensors_.size() == 0) { | |||
| MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << in_tensors_.size(); | |||
| return RET_ERROR; | |||
| } | |||
| if (in_tensors_[0]->shape()[0] > 1) { | |||
| MS_LOG(ERROR) << "Input data size unsupported multi-batch."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void BiasAddOpenCLKernel::SetConstArgs() { | |||
| int arg_idx = 2; | |||
| std::map<schema::Format, int> data_type{ | |||
| {schema::Format::Format_NC4, 1}, {schema::Format::Format_NHWC4, 2}, {schema::Format::Format_NC4HW4, 3}}; | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input_shape_); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, BiasAdd_); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_idx++, data_type[schema::Format::Format_NHWC4]); | |||
| } | |||
| void BiasAddOpenCLKernel::SetGlobalLocal() { | |||
| cl_int4 global_size = input_shape_; | |||
| global_size.s[2] = UP_DIV(global_size.s[3], C4NUM) * global_size.s[2]; | |||
| std::vector<size_t> local = {1, 1}; | |||
| std::vector<size_t> global = {static_cast<size_t>(global_size.s[1]), static_cast<size_t>(global_size.s[2])}; | |||
| OpenCLKernel::AlignGlobalLocal(global, local); | |||
| } | |||
| int BiasAddOpenCLKernel::InitWeights() { | |||
| int C = in_tensors_[1]->shape()[0]; | |||
| int div_ci = UP_DIV(C, C4NUM); | |||
| auto allocator = ocl_runtime_->GetAllocator(); | |||
| size_t img_dtype = CL_FLOAT; | |||
| if (enable_fp16_) { | |||
| img_dtype = CL_HALF_FLOAT; | |||
| } | |||
| std::vector<size_t> img_size{size_t(div_ci), 1, img_dtype}; | |||
| BiasAdd_ = allocator->Malloc(div_ci * C4NUM * fp_size, img_size); | |||
| BiasAdd_ = allocator->MapBuffer(BiasAdd_, CL_MAP_WRITE, nullptr, true); | |||
| memset(BiasAdd_, 0x00, div_ci * C4NUM * fp_size); | |||
| memcpy(BiasAdd_, in_tensors_[1]->data_c(), C * fp_size); | |||
| allocator->UnmapBuffer(BiasAdd_); | |||
| return RET_OK; | |||
| } | |||
| int BiasAddOpenCLKernel::Prepare() { | |||
| in_size_ = in_tensors_[0]->shape().size(); | |||
| out_size_ = out_tensors_[0]->shape().size(); | |||
| for (int i = 0; i < in_size_; ++i) { | |||
| input_shape_.s[i + 4 - in_size_] = in_tensors_[0]->shape()[i]; | |||
| } | |||
| enable_fp16_ = ocl_runtime_->GetFp16Enable(); | |||
| fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); | |||
| if (in_size_ != 4 && in_size_ != 2) { | |||
| MS_LOG(ERROR) << "BiasAdd only support dim=4 or 2, but your dim=" << in_size_; | |||
| return mindspore::lite::RET_ERROR; | |||
| } | |||
| int C = in_tensors_[0]->shape()[3]; | |||
| int Bias_Size = in_tensors_[1]->shape()[0]; | |||
| if (UP_DIV(Bias_Size, C4NUM) != UP_DIV(C, C4NUM)) { | |||
| MS_LOG(ERROR) << "BiasAdd weight channel size:" << Bias_Size << " must be equal with in_teneors channel size:" << C; | |||
| return mindspore::lite::RET_ERROR; | |||
| } | |||
| InitWeights(); | |||
| std::string source = biasadd_source; | |||
| std::string program_name = "BiasAdd"; | |||
| std::string kernel_name = "BiasAdd"; | |||
| ocl_runtime_->LoadSource(program_name, source); | |||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name); | |||
| auto ret = InitWeights(); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| SetGlobalLocal(); | |||
| SetConstArgs(); | |||
| MS_LOG(DEBUG) << program_name << " Init Done!"; | |||
| return mindspore::lite::RET_OK; | |||
| } | |||
| int BiasAddOpenCLKernel::Run() { | |||
| ocl_runtime_->SetKernelArg(kernel_, 0, in_tensors_[0]->data_c()); | |||
| ocl_runtime_->SetKernelArg(kernel_, 1, out_tensors_[0]->data_c()); | |||
| auto ret = ocl_runtime_->RunKernel(kernel_, global_range_, local_range_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, OpenCLKernelCreator<BiasAddOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_BiasAdd, OpenCLKernelCreator<BiasAddOpenCLKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -1,55 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include "src/tensor.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "schema/model_generated.h" | |||
| namespace mindspore::kernel { | |||
| class BiasAddOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| BiasAddOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| ~BiasAddOpenCLKernel() override = default; | |||
| int Prepare() override; | |||
| int CheckSpecs() override; | |||
| void SetConstArgs() override; | |||
| void SetGlobalLocal() override; | |||
| int InitWeights() override; | |||
| int Run() override; | |||
| private: | |||
| void *BiasAdd_{nullptr}; | |||
| int in_size_{}; | |||
| int out_size_{}; | |||
| size_t fp_size{}; | |||
| cl_int4 input_shape_{}; | |||
| bool enable_fp16_{}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_BIASADD_H_ | |||
| @@ -26,9 +26,7 @@ namespace mindspore::kernel { | |||
| class CastOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| CastOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~CastOpenCLKernel() override = default; | |||
| int Prepare() override; | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class ConcatOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| ConcatOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~ConcatOpenCLKernel() override = default; | |||
| @@ -483,9 +483,22 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector<lite::Tenso | |||
| kernel::OpenCLKernel *kernel; | |||
| OpParameter *real_param; | |||
| auto *conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||
| if (UseFcReplaceConv(inputs, outputs, conv_param)) { | |||
| bool infer_shape_done; | |||
| if (primitive != nullptr) { | |||
| infer_shape_done = primitive->infer_flag(); | |||
| } else { | |||
| bool output_shape_setted = true; | |||
| for (auto output : outputs) { | |||
| if (output->shape().empty() || output->ElementsNum() < 0) { | |||
| output_shape_setted = false; | |||
| break; | |||
| } | |||
| } | |||
| infer_shape_done = output_shape_setted; | |||
| } | |||
| if (infer_shape_done && UseFcReplaceConv(inputs, outputs, conv_param)) { | |||
| auto *fc_param = CreateFcParam(conv_param); | |||
| kernel = new (std::nothrow) FullConnectionOpenCLKernel(fc_param, inputs, outputs); | |||
| kernel = new (std::nothrow) FullConnectionOpenCLKernel(fc_param, inputs, outputs, ctx, primitive); | |||
| real_param = fc_param; | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Create FullConnection kernel failed."; | |||
| @@ -497,11 +510,13 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector<lite::Tenso | |||
| MS_LOG(INFO) << "use FullConnection to replace Convolution."; | |||
| } | |||
| } else { | |||
| if (UseWinograd4x4To6x6(conv_param, inputs, outputs)) { | |||
| if (infer_shape_done && UseWinograd4x4To6x6(conv_param, inputs, outputs)) { | |||
| MS_LOG(DEBUG) << "use Winograd algorithm."; | |||
| kernel = new (std::nothrow) WinogradOpenCLKernel(reinterpret_cast<OpParameter *>(conv_param), inputs, outputs); | |||
| kernel = new (std::nothrow) | |||
| WinogradOpenCLKernel(reinterpret_cast<OpParameter *>(conv_param), inputs, outputs, ctx, primitive); | |||
| } else { | |||
| kernel = new (std::nothrow) Conv2DOpenCLKernel(reinterpret_cast<OpParameter *>(conv_param), inputs, outputs); | |||
| kernel = new (std::nothrow) | |||
| Conv2DOpenCLKernel(reinterpret_cast<OpParameter *>(conv_param), inputs, outputs, ctx, primitive); | |||
| } | |||
| real_param = reinterpret_cast<OpParameter *>(conv_param); | |||
| if (kernel == nullptr) { | |||
| @@ -510,7 +525,10 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector<lite::Tenso | |||
| return nullptr; | |||
| } | |||
| } | |||
| if (!infer_shape_done) { | |||
| MS_LOG(WARNING) << "kernel don't infer shape yet!"; | |||
| return kernel; | |||
| } | |||
| int ret = kernel->CheckSpecs(); | |||
| if (ret != mindspore::lite::RET_OK) { | |||
| MS_LOG(ERROR) << "Init Convolution kernel failed."; | |||
| @@ -44,8 +44,9 @@ void ConvertFilter(void *src, void *dst, TypeId src_dtype, TypeId dst_dtype, Fil | |||
| class Conv2DOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| Conv2DOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs), param_(reinterpret_cast<ConvParameter *>(parameter)) { | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : OpenCLKernel(parameter, inputs, outputs, ctx, primitive), param_(reinterpret_cast<ConvParameter *>(parameter)) { | |||
| bool is_adreno = ocl_runtime_->GetGpuInfo().type == lite::opencl::GpuType::ADRENO; | |||
| filter_type_ = is_adreno ? MemType::IMG : MemType::BUF; | |||
| } | |||
| @@ -40,9 +40,8 @@ int Conv2dTransposeOpenCLKernel::CheckSpecs() { | |||
| return RET_ERROR; | |||
| } | |||
| ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_); | |||
| if (param->pad_l_ != param->pad_r_ || param->kernel_h_ - param->stride_h_ != 2 * param->pad_l_ || | |||
| param->pad_u_ != param->pad_d_ || param->kernel_w_ - param->stride_w_ != 2 * param->pad_u_) { | |||
| MS_LOG(ERROR) << "only support kernel - stride == 2 * pad"; | |||
| if (param->pad_l_ != param->pad_r_ || param->pad_u_ != param->pad_d_) { | |||
| MS_LOG(ERROR) << "only support symmetric padding"; | |||
| return RET_ERROR; | |||
| } | |||
| if (param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6) { | |||
| @@ -27,9 +27,7 @@ namespace mindspore::kernel { | |||
| class Conv2dTransposeOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| Conv2dTransposeOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~Conv2dTransposeOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -28,8 +28,9 @@ namespace mindspore::kernel { | |||
| class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) { | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : OpenCLKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| bool is_adreno = ocl_runtime_->GetGpuInfo().type == lite::opencl::GpuType::ADRENO; | |||
| filter_type_ = is_adreno ? MemType::IMG : MemType::BUF; | |||
| } | |||
| @@ -26,9 +26,7 @@ namespace mindspore::kernel { | |||
| class FillOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| FillOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~FillOpenCLKernel() override = default; | |||
| @@ -26,9 +26,7 @@ namespace mindspore::kernel { | |||
| class FullConnectionOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| FullConnectionOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~FullConnectionOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -153,9 +153,7 @@ bool IsEltwiseAndOperatorSupported(LiteKernel *node); | |||
| class FusionEltwiseOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| FusionEltwiseOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~FusionEltwiseOpenCLKernel() override { | |||
| if (op_parameter_ != nullptr) { | |||
| @@ -50,7 +50,7 @@ int GatherOpenCLKernel::CheckSpecs() { | |||
| return RET_ERROR; | |||
| } | |||
| int indices_ndim = in_tensors_.at(1)->shape().size(); | |||
| if (indices_ndim != 1) { | |||
| if (indices_ndim > 1) { | |||
| MS_LOG(ERROR) << "GatherOpenCLKernel only supports 1D indices Tensor but get " << indices_ndim << "D."; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class GatherOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| GatherOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~GatherOpenCLKernel() override = default; | |||
| @@ -24,9 +24,7 @@ namespace mindspore::kernel { | |||
| class LayerNormOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| LayerNormOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~LayerNormOpenCLKernel() override = default; | |||
| @@ -215,17 +215,34 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector<lite::Tensor *> | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| kernel::OpenCLKernel *kernel; | |||
| if (IsUseStrassenMatmul(inputs)) { | |||
| bool infer_shape_done; | |||
| if (primitive != nullptr) { | |||
| infer_shape_done = primitive->infer_flag(); | |||
| } else { | |||
| bool output_shape_setted = true; | |||
| for (auto output : outputs) { | |||
| if (output->shape().empty() || output->ElementsNum() < 0) { | |||
| output_shape_setted = false; | |||
| break; | |||
| } | |||
| } | |||
| infer_shape_done = output_shape_setted; | |||
| } | |||
| if (infer_shape_done && IsUseStrassenMatmul(inputs)) { | |||
| MS_LOG(DEBUG) << "use_matmul_strassen"; | |||
| kernel = new (std::nothrow) StrassenOpenCLKernel(opParameter, inputs, outputs); | |||
| kernel = new (std::nothrow) StrassenOpenCLKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } else { | |||
| kernel = new (std::nothrow) MatMulOpenCLKernel(opParameter, inputs, outputs); | |||
| kernel = new (std::nothrow) MatMulOpenCLKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| if (!infer_shape_done) { | |||
| MS_LOG(WARNING) << "kernel don't infer shape yet!"; | |||
| return kernel; | |||
| } | |||
| auto ret = kernel->CheckSpecs(); | |||
| if (ret != mindspore::lite::RET_OK) { | |||
| MS_LOG(ERROR) << "Check " << opParameter->name_ << " specification failed!"; | |||
| @@ -28,9 +28,7 @@ namespace mindspore::kernel { | |||
| class MatMulOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| MatMulOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~MatMulOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -26,9 +26,7 @@ | |||
| namespace mindspore::kernel { | |||
| class OneHotOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| OneHotOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~OneHotOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -29,8 +29,10 @@ namespace mindspore::kernel { | |||
| class PadOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| PadOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs), param_(reinterpret_cast<PadParameter *>(op_parameter_)) {} | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : OpenCLKernel(parameter, inputs, outputs, ctx, primitive), | |||
| param_(reinterpret_cast<PadParameter *>(op_parameter_)) {} | |||
| ~PadOpenCLKernel() override = default; | |||
| int CheckSpecs() override; | |||
| @@ -27,8 +27,10 @@ namespace mindspore::kernel { | |||
| class PoolingOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| PoolingOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs), parameter_(reinterpret_cast<PoolingParameter *>(parameter)) {} | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : OpenCLKernel(parameter, inputs, outputs, ctx, primitive), | |||
| parameter_(reinterpret_cast<PoolingParameter *>(parameter)) {} | |||
| ~PoolingOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class PowerOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| PowerOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~PowerOpenCLKernel() override = default; | |||
| @@ -27,9 +27,7 @@ namespace mindspore::kernel { | |||
| class PReluOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| PReluOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~PReluOpenCLKernel() override = default; | |||
| int Prepare() override; | |||
| @@ -26,9 +26,7 @@ | |||
| namespace mindspore::kernel { | |||
| class ReduceOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| ReduceOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~ReduceOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -37,15 +37,16 @@ int ReshapeOpenCLKernel::CheckSpecs() { | |||
| MS_LOG(ERROR) << "Reshape input output size unsupported."; | |||
| return RET_ERROR; | |||
| } | |||
| if (in_tensors_[0]->data_type() != kNumberTypeFloat32 && in_tensors_[0]->data_type() != kNumberTypeFloat16) { | |||
| if (in_tensors_[0]->data_type() != kNumberTypeFloat32 && in_tensors_[0]->data_type() != kNumberTypeFloat16 && | |||
| in_tensors_[0]->data_type() != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "Unsupported data type " << in_tensors_[0]->data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| if (in_tensors_[0]->shape().size() == 0 || in_tensors_[0]->shape().size() > 4) { | |||
| MS_LOG(ERROR) << "Reshape input size should in 1-4, actual: " << in_tensors_[0]->shape(); | |||
| if (in_tensors_[0]->shape().size() > 4) { | |||
| MS_LOG(ERROR) << "Reshape input size should in 0-4, actual: " << in_tensors_[0]->shape(); | |||
| return RET_ERROR; | |||
| } | |||
| if (out_tensors_[0]->shape().size() == 0 || out_tensors_[0]->shape().size() > 4) { | |||
| if (out_tensors_[0]->shape().size() > 4) { | |||
| MS_LOG(ERROR) << "Reshape output size should in 1-4, actual: " << out_tensors_[0]->shape(); | |||
| return RET_ERROR; | |||
| } | |||
| @@ -95,6 +96,16 @@ int ReshapeOpenCLKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| int ReshapeOpenCLKernel::PreProcess() { | |||
| if (Type() == PrimitiveType_Reshape) { | |||
| auto shape_tensor = in_tensors_[1]; | |||
| if (!shape_tensor->IsConst()) { | |||
| ocl_runtime_->SyncCommandQueue(); | |||
| } | |||
| } | |||
| return OpenCLKernel::PreProcess(); | |||
| } | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Reshape, OpenCLKernelCreator<ReshapeOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Reshape, OpenCLKernelCreator<ReshapeOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Squeeze, OpenCLKernelCreator<ReshapeOpenCLKernel>) | |||
| @@ -25,9 +25,7 @@ | |||
| namespace mindspore::kernel { | |||
| class ReshapeOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| ReshapeOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~ReshapeOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -35,6 +33,7 @@ class ReshapeOpenCLKernel : public OpenCLKernel { | |||
| int CheckSpecs() override; | |||
| void SetConstArgs() override; | |||
| void SetGlobalLocal() override; | |||
| int PreProcess() override; | |||
| private: | |||
| }; | |||
| @@ -26,9 +26,7 @@ | |||
| namespace mindspore::kernel { | |||
| class ResizeOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| ResizeOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~ResizeOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class ScaleOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| ScaleOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~ScaleOpenCLKernel() override; | |||
| int CheckSpecs() override; | |||
| @@ -27,8 +27,9 @@ namespace mindspore::kernel { | |||
| class SoftmaxOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| SoftmaxOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) { | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : OpenCLKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| parameter_ = reinterpret_cast<SoftmaxParameter *>(parameter); | |||
| } | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class SpaceToBatchNDOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| SpaceToBatchNDOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~SpaceToBatchNDOpenCLKernel() override = default; | |||
| @@ -26,9 +26,7 @@ | |||
| namespace mindspore::kernel { | |||
| class SpaceToDepthOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| SpaceToDepthOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~SpaceToDepthOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class SparseToDenseOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| SparseToDenseOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~SparseToDenseOpenCLKernel() override = default; | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class SplitOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| SplitOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~SplitOpenCLKernel() override = default; | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class StackOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| explicit StackOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~StackOpenCLKernel() override{}; | |||
| int Prepare() override; | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class StrassenOpenCLKernel : public MatMulOpenCLKernel { | |||
| public: | |||
| StrassenOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : MatMulOpenCLKernel(parameter, inputs, outputs) {} | |||
| using MatMulOpenCLKernel::MatMulOpenCLKernel; | |||
| ~StrassenOpenCLKernel() override = default; | |||
| public: | |||
| @@ -25,9 +25,7 @@ namespace mindspore::kernel { | |||
| class StridedSliceOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| StridedSliceOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~StridedSliceOpenCLKernel() override = default; | |||
| @@ -106,6 +106,19 @@ int ToFormatOpenCLKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| int ToFormatOpenCLKernel::InferShape() { | |||
| if (infer_shape_flag_) { | |||
| return RET_OK; | |||
| } | |||
| if (in_tensors_[0]->shape().size() == 0 || in_tensors_[0]->ElementsNum() < 0) { | |||
| MS_LOG(ERROR) << "to_format op in tensor shape is 0, infer shape failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| out_tensors_[0]->set_shape(in_tensors_[0]->shape()); | |||
| infer_shape_flag_ = true; | |||
| return RET_OK; | |||
| } | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_ToFormat, OpenCLKernelCreator<ToFormatOpenCLKernel>) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_ToFormat, OpenCLKernelCreator<ToFormatOpenCLKernel>) | |||
| } // namespace mindspore::kernel | |||
| @@ -25,9 +25,7 @@ | |||
| namespace mindspore::kernel { | |||
| class ToFormatOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| ToFormatOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~ToFormatOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -36,6 +34,7 @@ class ToFormatOpenCLKernel : public OpenCLKernel { | |||
| int CheckSpecs() override; | |||
| void SetConstArgs() override; | |||
| void SetGlobalLocal() override; | |||
| int InferShape() override; | |||
| private: | |||
| size_t N_{1}; | |||
| @@ -29,9 +29,7 @@ enum class TransposeType { AXIS0312, AXIS0231, GENERAL }; | |||
| class TransposeOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| TransposeOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| using OpenCLKernel::OpenCLKernel; | |||
| ~TransposeOpenCLKernel() override = default; | |||
| int Run() override; | |||
| @@ -26,8 +26,9 @@ namespace mindspore::kernel { | |||
| class WinogradOpenCLKernel : public Conv2DOpenCLKernel { | |||
| public: | |||
| WinogradOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : Conv2DOpenCLKernel(parameter, inputs, outputs) { | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : Conv2DOpenCLKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| filter_type_ = MemType::BUF; | |||
| } | |||
| @@ -463,8 +463,8 @@ void CreateEltwiseKernelReplaceOld(FusionEltwiseParameter *param, LiteKernel *ol | |||
| MS_ASSERT(old); | |||
| MS_ASSERT(nodes); | |||
| MS_ASSERT(removed_set); | |||
| auto *eltwise = new (std::nothrow) | |||
| FusionEltwiseOpenCLKernel(reinterpret_cast<OpParameter *>(param), old->in_tensors(), old->out_tensors()); | |||
| auto *eltwise = new (std::nothrow) FusionEltwiseOpenCLKernel(reinterpret_cast<OpParameter *>(param), | |||
| old->in_tensors(), old->out_tensors(), nullptr, nullptr); | |||
| if (eltwise == nullptr) { | |||
| MS_LOG(ERROR) << "create FusionEltwiseOpenCLKernel error."; | |||
| return; | |||
| @@ -539,6 +539,9 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::vector<LiteKernel *> *nodes, s | |||
| } // namespace | |||
| int OpenCLSubGraph::FusionPass() { | |||
| if (!this->IsSubGraphInferShapeDone()) { | |||
| return RET_OK; | |||
| } | |||
| MS_LOG(DEBUG) << "start Fusion"; | |||
| std::vector<LiteKernel *> input_nodes; | |||
| @@ -122,6 +122,41 @@ void OpenCLKernel::PrintOutput(int print_num, const std::string &out_file) { | |||
| } | |||
| } | |||
| int OpenCLKernel::PreProcess() { | |||
| auto ret = RET_OK; | |||
| ret = ReSize(); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| auto allocator = ocl_runtime_->GetAllocator(); | |||
| for (auto i = 0; i < out_tensors_.size(); ++i) { | |||
| auto *output = out_tensors_.at(i); | |||
| MS_ASSERT(output); | |||
| if (GetMemType() == lite::opencl::MemType::IMG) { | |||
| std::vector<size_t> img_size; | |||
| ret = GetImageSize(i, &img_size); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "GetImageSize failed"; | |||
| return ret; | |||
| } | |||
| auto data_ptr = allocator->Malloc(output->Size(), img_size); | |||
| if (data_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| output->set_data(data_ptr); | |||
| } else { | |||
| ret = output->MallocData(allocator); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "MallocData failed"; | |||
| return ret; | |||
| } | |||
| } | |||
| output->set_allocator(allocator); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int OpenCLKernel::PostProcess() { | |||
| for (auto *output : this->out_tensors()) { | |||
| MS_ASSERT(output != nullptr); | |||
| @@ -130,6 +165,45 @@ int OpenCLKernel::PostProcess() { | |||
| return FreeInWorkTensor(); | |||
| } | |||
| int OpenCLKernel::InferShape() { | |||
| if (infer_shape_flag_) { | |||
| return RET_OK; | |||
| } | |||
| if (primitive_ == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(true); | |||
| auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_); | |||
| if (ret != RET_OK) { | |||
| (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(false); | |||
| MS_LOG(ERROR) << "InferShape fail!"; | |||
| return ret; | |||
| } | |||
| infer_shape_flag_ = true; | |||
| return RET_OK; | |||
| } | |||
| int OpenCLKernel::ReSize() { | |||
| if (infer_shape_flag_) { | |||
| return RET_OK; | |||
| } | |||
| auto ret = InferShape(); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| ret = CheckSpecs(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ReSize failed for check kernel specs!"; | |||
| return ret; | |||
| } | |||
| ret = Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ReSize failed for kernel prepare!"; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| std::vector<BaseTuningParameter> OpenCLKernel::GenerateTuningParam() { | |||
| size_t ndim = global_size_.size(); | |||
| std::vector<BaseTuningParameter> tuning_params = {}; | |||
| @@ -156,17 +156,30 @@ struct BaseTuningParameter { | |||
| class OpenCLKernel : public LiteKernel { | |||
| public: | |||
| OpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : LiteKernel(parameter, inputs, outputs, nullptr, nullptr) { | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| ocl_runtime_ = ocl_runtime_wrap_.GetInstance(); | |||
| if (primitive != nullptr) { | |||
| infer_shape_flag_ = primitive->infer_flag(); | |||
| } else { | |||
| bool output_shape_setted = true; | |||
| for (auto output : outputs) { | |||
| if (output->shape().empty() || output->ElementsNum() < 0) { | |||
| output_shape_setted = false; | |||
| break; | |||
| } | |||
| } | |||
| infer_shape_flag_ = output_shape_setted; | |||
| } | |||
| } | |||
| ~OpenCLKernel() override = default; | |||
| int AlignGlobalLocal(const std::vector<size_t> &global, const std::vector<size_t> &local); | |||
| int Prepare() override { return RET_OK; } | |||
| int PreProcess() override { return RET_ERROR; } | |||
| int PreProcess() override; | |||
| int PostProcess() override; | |||
| int ReSize() override { return RET_ERROR; } | |||
| int ReSize() override; | |||
| int Run() override { return RET_ERROR; } | |||
| virtual int CheckSpecs() { return RET_ERROR; } | |||
| @@ -189,6 +202,9 @@ class OpenCLKernel : public LiteKernel { | |||
| double GetProfilingTimeMs(); | |||
| int DequantWeight(); | |||
| void FreeDequantedWeight(); | |||
| virtual int InferShape(); | |||
| bool GetInferShapeFlag() { return infer_shape_flag_; } | |||
| void SetInferShapeFlag(bool flag) { infer_shape_flag_ = flag; } | |||
| protected: | |||
| static std::set<size_t> GenerateLocalByGlobal(size_t global_i); | |||
| @@ -213,6 +229,7 @@ class OpenCLKernel : public LiteKernel { | |||
| cl::Event event_; | |||
| void *restore_quant_data_{nullptr}; | |||
| bool dequant_flag_{false}; | |||
| bool infer_shape_flag_{false}; | |||
| private: | |||
| lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_; | |||
| @@ -223,12 +240,16 @@ kernel::LiteKernel *OpenCLKernelCreator(const std::vector<lite::Tensor *> &input | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| auto *kernel = new (std::nothrow) T(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| auto *kernel = new (std::nothrow) T(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| if (!reinterpret_cast<kernel::OpenCLKernel *>(kernel)->GetInferShapeFlag()) { | |||
| MS_LOG(WARNING) << "kernel don't infer shape yet!"; | |||
| return kernel; | |||
| } | |||
| auto ret = kernel->CheckSpecs(); | |||
| if (ret != mindspore::lite::RET_OK) { | |||
| MS_LOG(ERROR) << "Check " << opParameter->name_ << " specification failed!"; | |||
| @@ -171,6 +171,8 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors, | |||
| parameter = nullptr; | |||
| return RET_ERROR; | |||
| } | |||
| static int index = 0; | |||
| in_convert_op->set_name("ToFormat_" + std::to_string(index)); | |||
| ReplaceOutTensorAndKernelToConvert(in_tensor, in_kernels.at(i), new_tensor, in_convert_op, mem_type); | |||
| @@ -302,16 +304,36 @@ void OpenCLSubGraph::GetInOutNodes() { | |||
| } | |||
| } | |||
| bool OpenCLSubGraph::IsSubGraphInferShapeDone() { | |||
| for (auto node : this->nodes_) { | |||
| auto opencl_kernel = reinterpret_cast<kernel::OpenCLKernel *>(node); | |||
| if (!opencl_kernel->GetInferShapeFlag()) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| int OpenCLSubGraph::Prepare() { | |||
| executor_ = new (std::nothrow) lite::opencl::OpenCLExecutor(); | |||
| if (executor_ == nullptr) { | |||
| MS_LOG(ERROR) << "Create OpenCLExecutor fail"; | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = SubGraphKernel::Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "OpenCL prepare fail"; | |||
| return ret; | |||
| auto ret = RET_OK; | |||
| for (auto node : this->nodes_) { | |||
| if (node == nullptr) { | |||
| MS_LOG(ERROR) << "node in Subgraph is nullptr"; | |||
| return mindspore::lite::RET_NULL_PTR; | |||
| } | |||
| auto opencl_kernel = reinterpret_cast<kernel::OpenCLKernel *>(node); | |||
| if (opencl_kernel->GetInferShapeFlag()) { | |||
| ret = node->Prepare(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "prepare node " << node->name() << " failed"; | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| auto opencl_exec = reinterpret_cast<lite::opencl::OpenCLExecutor *>(executor_); | |||
| // If tuning_mode is DEFAULT, just malloc memory for reuse. | |||
| @@ -341,7 +363,40 @@ void OpenCLSubGraph::UnInit() { | |||
| delete this->executor_; | |||
| } | |||
| int OpenCLSubGraph::ReSize() { return RET_OK; } | |||
| int OpenCLSubGraph::ReSize() { return ReSize(false); } | |||
| int OpenCLSubGraph::ReSize(bool interrupt) { | |||
| for (auto kernel : nodes_) { | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "input kernel is nullptr!"; | |||
| return RET_ERROR; | |||
| } | |||
| auto opencl_kernel = reinterpret_cast<kernel::OpenCLKernel *>(kernel); | |||
| if (kernel->subgraph_type() != kernel::kNotSubGraph) { | |||
| MS_LOG(ERROR) << "all nodes in should be kernel"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<lite::Tensor *> inputs = kernel->in_tensors(); | |||
| std::vector<lite::Tensor *> outputs = kernel->out_tensors(); | |||
| for (auto &output : outputs) { | |||
| output->FreeData(); | |||
| } | |||
| opencl_kernel->SetInferShapeFlag(false); | |||
| } | |||
| for (auto kernel : nodes_) { | |||
| auto opencl_kernel = reinterpret_cast<kernel::OpenCLKernel *>(kernel); | |||
| auto ret = opencl_kernel->ReSize(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(WARNING) << "ReSize " << opencl_kernel->name() << "failed!"; | |||
| if (interrupt) { | |||
| return ret; | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int OpenCLSubGraph::Run() { | |||
| if (executor_ == nullptr) { | |||
| @@ -44,9 +44,11 @@ class OpenCLSubGraph : public SubGraphKernel { | |||
| int Prepare() override; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int ReSize(bool interrupt); | |||
| int Run() override; | |||
| int Run(const KernelCallBack &before, const KernelCallBack &after) override { return this->Run(); }; | |||
| int InsertOpsPass(); | |||
| bool IsSubGraphInferShapeDone(); | |||
| private: | |||
| void UnInit(); | |||
| @@ -58,7 +58,8 @@ const std::set<schema::PrimitiveType> ArithmeticPrimitives = {schema::PrimitiveT | |||
| schema::PrimitiveType_LessEqual, | |||
| schema::PrimitiveType_Greater, | |||
| schema::PrimitiveType_GreaterEqual, | |||
| schema::PrimitiveType_Eltwise}; | |||
| schema::PrimitiveType_Eltwise, | |||
| schema::PrimitiveType_BiasAdd}; | |||
| const std::set<schema::PrimitiveType> ArithmeticSelfPrimitives = { | |||
| schema::PrimitiveType_Abs, schema::PrimitiveType_Ceil, schema::PrimitiveType_Cos, | |||
| @@ -47,31 +47,16 @@ int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor | |||
| } | |||
| } | |||
| auto *op_kernel = reinterpret_cast<kernel::OpenCLKernel *>(kernel); | |||
| auto cur_outputs = kernel->out_tensors(); | |||
| for (auto i = 0; i < cur_outputs.size(); ++i) { | |||
| auto *output = cur_outputs.at(i); | |||
| MS_ASSERT(output); | |||
| if (op_kernel->GetMemType() == lite::opencl::MemType::IMG) { | |||
| std::vector<size_t> img_size; | |||
| ret = op_kernel->GetImageSize(i, &img_size); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "GetImageSize failed"; | |||
| return ret; | |||
| } | |||
| auto data_ptr = allocator_->Malloc(output->Size(), img_size); | |||
| if (data_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| output->set_data(data_ptr); | |||
| ret = kernel->PreProcess(); | |||
| if (RET_OK != ret) { | |||
| if (is_tune) { | |||
| MS_LOG(WARNING) << "PreProcess kernel failed, name: " << kernel->name() << " in tuning"; | |||
| opencl_runtime_ins->SetProfiling(profiling_tmp); | |||
| return RET_OK; | |||
| } else { | |||
| ret = output->MallocData(allocator_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "MallocData failed"; | |||
| return ret; | |||
| } | |||
| MS_LOG(ERROR) << "PreProcess kernel failed, name: " << kernel->name(); | |||
| return ret; | |||
| } | |||
| output->set_allocator(allocator_); | |||
| } | |||
| if (is_tune) { | |||
| ret = op_kernel->Tune(); | |||
| @@ -79,26 +64,20 @@ int OpenCLExecutor::RunOrTune(std::vector<Tensor *> &inputs, std::vector<Tensor | |||
| MS_LOG(ERROR) << "tuning kernel failed, name: " << kernel->name(); | |||
| return ret; | |||
| } | |||
| ret = kernel->PostProcess(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "PostProcess kernel failed, name: " << kernel->name(); | |||
| return ret; | |||
| } | |||
| } else { | |||
| ret = kernel->Run(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); | |||
| return ret; | |||
| } | |||
| ret = kernel->PostProcess(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "PostProcess kernel failed, name: " << kernel->name(); | |||
| return ret; | |||
| } | |||
| if (profiling_tmp) { | |||
| if (profiling_tmp) | |||
| MS_LOG(INFO) << "OpenCl kernel " << kernel->name() << "(" << kernel->type_str() | |||
| << ") execute time is: " << op_kernel->GetProfilingTimeMs() << "ms"; | |||
| } | |||
| } | |||
| ret = kernel->PostProcess(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "PostProcess kernel failed, name: " << kernel->name(); | |||
| return ret; | |||
| } | |||
| if (after != nullptr) { | |||
| if (!after(TensorVectorCast(kernel->in_tensors()), TensorVectorCast(kernel->out_tensors()), callbackParam)) { | |||
| @@ -187,7 +187,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in | |||
| kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())}; | |||
| #if SUPPORT_GPU | |||
| if (context_->IsGpuEnabled()) { | |||
| kernel::KernelKey gpu_desc{kGPU, desc.data_type, desc.type}; | |||
| // support more data type like int32 | |||
| kernel::KernelKey gpu_desc{kGPU, kNumberTypeFloat32, desc.type}; | |||
| if (context_->IsGpuFloat16Enabled()) gpu_desc.data_type = kNumberTypeFloat16; | |||
| auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, gpu_desc); | |||
| if (kernel != nullptr) { | |||
| MS_LOG(DEBUG) << "Get gpu op success: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " << node->name_; | |||
| @@ -1,205 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include "src/common/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/common/file_utils.h" | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h" | |||
| using mindspore::kernel::BiasAddOpenCLKernel; | |||
| using mindspore::kernel::LiteKernel; | |||
| using mindspore::kernel::OpenCLSubGraph; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| namespace mindspore { | |||
| // PrimitiveType_BiasAdd: src/ops/populate/bias_add_populate.cc | |||
| class TestBiasAddOpenCL : public CommonTest {}; | |||
| void LoadDataBiasAdd(void *dst, size_t dst_size, const std::string &file_path) { | |||
| if (file_path.empty()) { | |||
| memset(dst, 0x00, dst_size); | |||
| } else { | |||
| auto src_data = mindspore::lite::ReadFile(file_path.c_str(), &dst_size); | |||
| memcpy(dst, src_data, dst_size); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CompareOutBiasAdd(lite::Tensor *output_tensor, const std::string &standard_answer_file) { | |||
| size_t output_size = output_tensor->ElementsNum(); | |||
| auto output_data = reinterpret_cast<T *>(output_tensor->data_c()); | |||
| auto expect_data = reinterpret_cast<T *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); | |||
| constexpr float atol = 0.0002; | |||
| for (int i = 0; i < output_tensor->ElementsNum(); ++i) { | |||
| if (std::fabs(output_data[i] - expect_data[i]) > atol) { | |||
| printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]); | |||
| printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]); | |||
| printf("error at idx[%d] expect=%f output=%f\n\n\n", i, expect_data[i], output_data[i]); | |||
| return; | |||
| } | |||
| } | |||
| printf("compare success!\n"); | |||
| printf("compare success!\n"); | |||
| printf("compare success!\n\n\n"); | |||
| } | |||
| template <typename T> | |||
| void printf_tensor_BiasAdd(const std::string log, mindspore::lite::Tensor *in_data, int size) { | |||
| MS_LOG(INFO) << log; | |||
| auto input_data = reinterpret_cast<T *>(in_data->data_c()); | |||
| for (int i = 0; i < size; ++i) { | |||
| printf("%f ", input_data[i]); | |||
| } | |||
| printf("\n"); | |||
| MS_LOG(INFO) << "Print tensor done"; | |||
| } | |||
| TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { | |||
| std::string in_file = "/data/local/tmp/in_data.bin"; | |||
| std::string weight_file = "/data/local/tmp/weight_data.bin"; | |||
| std::string standard_answer_file = "/data/local/tmp/biasadd.bin"; | |||
| MS_LOG(INFO) << "BiasAdd Begin test:"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); | |||
| ocl_runtime->Init(); | |||
| auto data_type = kNumberTypeFloat16; | |||
| ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); | |||
| std::vector<int> input_shape = {1, 9}; | |||
| std::vector<int> output_shape = {1, 9}; | |||
| auto tensor_type = lite::Tensor::CONST_TENSOR; | |||
| schema::Format type = schema::Format_NC; | |||
| int weight_shape = 0; | |||
| if (input_shape.size() == 4) { | |||
| weight_shape = input_shape[3]; | |||
| } else { | |||
| weight_shape = input_shape[1]; | |||
| } | |||
| auto *input_tensor = new (std::nothrow) lite::Tensor(data_type, input_shape, type, tensor_type); | |||
| if (input_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new input tensor error!"; | |||
| return; | |||
| } | |||
| auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, type, tensor_type); | |||
| if (output_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new output tensor error!"; | |||
| delete input_tensor; | |||
| return; | |||
| } | |||
| auto *weight_tensor = | |||
| new (std::nothrow) lite::Tensor(data_type, std::vector<int>{weight_shape}, schema::Format_NHWC, tensor_type); | |||
| if (weight_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new weight tensor error!"; | |||
| delete output_tensor; | |||
| delete input_tensor; | |||
| return; | |||
| } | |||
| std::vector<lite::Tensor *> inputs{input_tensor, weight_tensor}; | |||
| std::vector<lite::Tensor *> outputs{output_tensor}; | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| inputs[0]->MallocData(allocator); | |||
| inputs[1]->MallocData(allocator); | |||
| LoadDataBiasAdd(input_tensor->data_c(), input_tensor->Size(), in_file); | |||
| LoadDataBiasAdd(weight_tensor->data_c(), weight_tensor->Size(), weight_file); | |||
| if (ocl_runtime->GetFp16Enable()) { | |||
| printf_tensor_BiasAdd<float16_t>("BiasAdd:FP16--input data", inputs[0], input_tensor->ElementsNum()); | |||
| printf_tensor_BiasAdd<float16_t>("BiasAdd:FP16--weight data", inputs[1], weight_tensor->ElementsNum()); | |||
| } else { | |||
| printf_tensor_BiasAdd<float>("BiasAdd:FP32--input data", inputs[0], input_tensor->ElementsNum()); | |||
| printf_tensor_BiasAdd<float>("BiasAdd:FP32--weight data", inputs[1], weight_tensor->ElementsNum()); | |||
| } | |||
| auto *param = new (std::nothrow) OpParameter(); | |||
| if (param == nullptr) { | |||
| delete input_tensor; | |||
| delete output_tensor; | |||
| delete weight_tensor; | |||
| MS_LOG(ERROR) << "new OpParameter error!"; | |||
| return; | |||
| } | |||
| auto *biasadd_kernel = | |||
| new (std::nothrow) kernel::BiasAddOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (biasadd_kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Create biasadd kernel error."; | |||
| delete input_tensor; | |||
| delete output_tensor; | |||
| delete weight_tensor; | |||
| delete param; | |||
| return; | |||
| } | |||
| auto ret = biasadd_kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "biasadd kernel init error."; | |||
| delete input_tensor; | |||
| delete output_tensor; | |||
| delete weight_tensor; | |||
| delete param; | |||
| delete biasadd_kernel; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "initialize sub_graph"; | |||
| std::vector<kernel::LiteKernel *> kernels{biasadd_kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::OpenCLSubGraph({input_tensor}, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Create sub_graph kernel error."; | |||
| delete input_tensor; | |||
| delete output_tensor; | |||
| delete weight_tensor; | |||
| delete param; | |||
| delete biasadd_kernel; | |||
| return; | |||
| } | |||
| ret = sub_graph->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "sub_graph init error."; | |||
| delete input_tensor; | |||
| delete output_tensor; | |||
| delete weight_tensor; | |||
| delete sub_graph; | |||
| delete param; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Sub graph begin running!"; | |||
| ret = sub_graph->Run(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "sub_graph run error."; | |||
| delete input_tensor; | |||
| delete output_tensor; | |||
| delete weight_tensor; | |||
| delete sub_graph; | |||
| delete param; | |||
| return; | |||
| } | |||
| if (ocl_runtime->GetFp16Enable()) { | |||
| printf_tensor_BiasAdd<float16_t>("BiasAdd:FP16--output data", outputs[0], output_tensor->ElementsNum()); | |||
| CompareOutBiasAdd<float16_t>(output_tensor, standard_answer_file); | |||
| } else { | |||
| printf_tensor_BiasAdd<float>("BiasAdd:FP32--output data", outputs[0], output_tensor->ElementsNum()); | |||
| CompareOutBiasAdd<float>(output_tensor, standard_answer_file); | |||
| } | |||
| delete input_tensor; | |||
| delete weight_tensor; | |||
| delete output_tensor; | |||
| delete sub_graph; | |||
| delete param; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -73,8 +73,8 @@ TEST_F(TestCastSelfOpenCL, Castfp32tofp16) { | |||
| std::vector<lite::Tensor *> inputs{input_tensor}; | |||
| std::vector<lite::Tensor *> outputs{output_tensor}; | |||
| auto *cast_kernel = | |||
| new (std::nothrow) kernel::CastOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| auto *cast_kernel = new (std::nothrow) | |||
| kernel::CastOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs, nullptr, nullptr); | |||
| if (cast_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::CastOpenCLKernel failed "; | |||
| for (auto tensor : inputs) { | |||
| @@ -159,8 +159,8 @@ TEST_F(TestCastSelfOpenCL, Castfp16tofp32) { | |||
| std::vector<lite::Tensor *> inputs{input_tensor}; | |||
| std::vector<lite::Tensor *> outputs{output_tensor}; | |||
| auto *cast_kernel = | |||
| new (std::nothrow) kernel::CastOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| auto *cast_kernel = new (std::nothrow) | |||
| kernel::CastOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs, nullptr, nullptr); | |||
| if (cast_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::CastOpenCLKernel failed "; | |||
| for (auto tensor : inputs) { | |||
| @@ -60,8 +60,8 @@ TEST_F(TestFillOpenCLCI, Fp32testfill) { | |||
| return; | |||
| } | |||
| auto *fill_kernel = | |||
| new (std::nothrow) kernel::FillOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| auto *fill_kernel = new (std::nothrow) | |||
| kernel::FillOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs, nullptr, nullptr); | |||
| if (fill_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::FillOpenCLKernel failed "; | |||
| delete param; | |||
| @@ -116,8 +116,8 @@ TEST_F(TestFillOpenCLCI, Fp32testshape) { | |||
| return; | |||
| } | |||
| auto *fill_kernel = | |||
| new (std::nothrow) kernel::FillOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| auto *fill_kernel = new (std::nothrow) | |||
| kernel::FillOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs, nullptr, nullptr); | |||
| if (fill_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::FillOpenCLKernel failed "; | |||
| delete param; | |||
| @@ -58,7 +58,7 @@ TEST_F(TestToFormatOpenCL, ToFormatNHWC2NCHW) { | |||
| } | |||
| std::vector<lite::Tensor *> inputs{tensor_x}; | |||
| std::vector<lite::Tensor *> outputs{tensor_out}; | |||
| auto arith_kernel_ptr = std::make_unique<kernel::ToFormatOpenCLKernel>(nullptr, inputs, outputs); | |||
| auto arith_kernel_ptr = std::make_unique<kernel::ToFormatOpenCLKernel>(nullptr, inputs, outputs, nullptr, nullptr); | |||
| auto arith_kernel = arith_kernel_ptr.get(); | |||
| if (arith_kernel == nullptr) { | |||
| MS_LOG(ERROR) << "arith_kernel create error."; | |||