| @@ -168,8 +168,9 @@ int ConcatOpenCLKernel::ConvertWeightToTensor() { | |||||
| if (in_tensor->IsConst()) { | if (in_tensor->IsConst()) { | ||||
| std::vector<char> weight(in_shape.Image2DSize, 0); | std::vector<char> weight(in_shape.Image2DSize, 0); | ||||
| bool src_is_fp16 = in_tensor->data_type() == kNumberTypeFloat16; | bool src_is_fp16 = in_tensor->data_type() == kNumberTypeFloat16; | ||||
| PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, fp16_enable, in_shape); | |||||
| size_t dtype = fp16_enable ? CL_HALF_FLOAT : CL_FLOAT; | |||||
| PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, | |||||
| fp16_enable && in_tensor->data_type() != kNumberTypeInt32, in_shape); | |||||
| size_t dtype = fp16_enable && in_tensor->data_type() != kNumberTypeInt32 ? CL_HALF_FLOAT : CL_FLOAT; | |||||
| ImageSize img_size{in_shape.width, in_shape.height, dtype}; | ImageSize img_size{in_shape.width, in_shape.height, dtype}; | ||||
| auto weight_ptr_ = allocator->Malloc(img_size, weight.data()); | auto weight_ptr_ = allocator->Malloc(img_size, weight.data()); | ||||
| weight_ptrs_.push_back(weight_ptr_); | weight_ptrs_.push_back(weight_ptr_); | ||||
| @@ -206,7 +207,7 @@ int ConcatOpenCLKernel::Prepare() { | |||||
| std::string source = concat_source; | std::string source = concat_source; | ||||
| std::string program_name = "Concat"; | std::string program_name = "Concat"; | ||||
| ocl_runtime_->LoadSource(program_name, source); | ocl_runtime_->LoadSource(program_name, source); | ||||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name); | |||||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, {}, out_tensors_[0]->data_type()); | |||||
| MS_LOG(DEBUG) << kernel_name << " Init Done!"; | MS_LOG(DEBUG) << kernel_name << " Init Done!"; | ||||
| SetConstArgs(); | SetConstArgs(); | ||||
| SetGlobalLocal(); | SetGlobalLocal(); | ||||
| @@ -42,10 +42,6 @@ int Conv2dTransposeOpenCLKernel::CheckSpecs() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_); | ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_); | ||||
| if (param->pad_l_ != param->pad_r_ || param->pad_u_ != param->pad_d_) { | |||||
| MS_LOG(ERROR) << "only support symmetric padding"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6) { | if (param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6) { | ||||
| MS_LOG(ERROR) << "Unsupported activation type " << param->act_type_; | MS_LOG(ERROR) << "Unsupported activation type " << param->act_type_; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -102,8 +98,8 @@ void Conv2dTransposeOpenCLKernel::SetConstArgs() { | |||||
| int co = out_tensors_[0]->shape()[3]; | int co = out_tensors_[0]->shape()[3]; | ||||
| int kh = param->kernel_h_; | int kh = param->kernel_h_; | ||||
| int kw = param->kernel_w_; | int kw = param->kernel_w_; | ||||
| int pad_h = param->pad_l_; | |||||
| int pad_w = param->pad_u_; | |||||
| int pad_h = param->pad_u_; | |||||
| int pad_w = param->pad_l_; | |||||
| int stride_h = param->stride_h_; | int stride_h = param->stride_h_; | ||||
| int stride_w = param->stride_w_; | int stride_w = param->stride_w_; | ||||
| int oh = out_tensors_[0]->shape()[1]; | int oh = out_tensors_[0]->shape()[1]; | ||||
| @@ -155,7 +155,7 @@ class FusionEltwiseOpenCLKernel : public OpenCLKernel { | |||||
| ~FusionEltwiseOpenCLKernel() override { | ~FusionEltwiseOpenCLKernel() override { | ||||
| if (op_parameter_ != nullptr) { | if (op_parameter_ != nullptr) { | ||||
| delete op_parameter_; | |||||
| delete reinterpret_cast<FusionEltwiseParameter *>(op_parameter_); | |||||
| op_parameter_ = nullptr; | op_parameter_ = nullptr; | ||||
| } | } | ||||
| } | } | ||||
| @@ -107,7 +107,7 @@ int GatherOpenCLKernel::Prepare() { | |||||
| #else | #else | ||||
| std::string program_name = "gather"; | std::string program_name = "gather"; | ||||
| ocl_runtime_->LoadSource(program_name, gather_source); | ocl_runtime_->LoadSource(program_name, gather_source); | ||||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name); | |||||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, {}, out_tensors_[0]->data_type()); | |||||
| #endif | #endif | ||||
| if (in_tensors_.at(1)->IsConst()) { | if (in_tensors_.at(1)->IsConst()) { | ||||
| intensor1_is_tensor = false; | intensor1_is_tensor = false; | ||||
| @@ -79,7 +79,7 @@ int ReshapeOpenCLKernel::Prepare() { | |||||
| std::string source = reshape_source; | std::string source = reshape_source; | ||||
| std::string program_name = "reshape"; | std::string program_name = "reshape"; | ||||
| ocl_runtime_->LoadSource(program_name, source); | ocl_runtime_->LoadSource(program_name, source); | ||||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name); | |||||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, {}, out_tensors_[0]->data_type()); | |||||
| #endif | #endif | ||||
| SetGlobalLocal(); | SetGlobalLocal(); | ||||
| @@ -172,7 +172,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors, | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| static int index = 0; | static int index = 0; | ||||
| in_convert_op->set_name("ToFormat_" + std::to_string(index)); | |||||
| in_convert_op->set_name("ToFormat_" + std::to_string(index++)); | |||||
| ReplaceOutTensorAndKernelToConvert(in_tensor, in_kernels.at(i), new_tensor, in_convert_op, mem_type); | ReplaceOutTensorAndKernelToConvert(in_tensor, in_kernels.at(i), new_tensor, in_convert_op, mem_type); | ||||
| @@ -318,6 +318,14 @@ bool OpenCLSubGraph::IsSubGraphInferShapeDone() { | |||||
| } | } | ||||
| int OpenCLSubGraph::Prepare() { | int OpenCLSubGraph::Prepare() { | ||||
| for (const auto tensor : in_tensors_) { | |||||
| MS_ASSERT(tensor); | |||||
| tensor->set_allocator(allocator_); | |||||
| } | |||||
| for (const auto tensor : out_tensors_) { | |||||
| MS_ASSERT(tensor); | |||||
| tensor->set_allocator(allocator_); | |||||
| } | |||||
| executor_ = new (std::nothrow) lite::opencl::OpenCLExecutor(); | executor_ = new (std::nothrow) lite::opencl::OpenCLExecutor(); | ||||
| if (executor_ == nullptr) { | if (executor_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Create OpenCLExecutor fail"; | MS_LOG(ERROR) << "Create OpenCLExecutor fail"; | ||||
| @@ -363,9 +363,9 @@ bool OpenCLRuntime::SetFp16Enable(bool enable) { | |||||
| } | } | ||||
| int OpenCLRuntime::BuildKernel(cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name, | int OpenCLRuntime::BuildKernel(cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name, | ||||
| const std::vector<std::string> &build_options_ext) { | |||||
| const std::vector<std::string> &build_options_ext, TypeId data_type) { | |||||
| std::string build_option = default_build_option_; | std::string build_option = default_build_option_; | ||||
| if (fp16_enable_) { | |||||
| if (fp16_enable_ && data_type != kNumberTypeInt32) { | |||||
| build_option += | build_option += | ||||
| " -DFP16_ENABLE=1 -DFLT=half -DFLT4=half4 -DFLT16=half16 -DAS_FLT4=as_half4 -DAS_UINT4=as_ushort4 -DUINT4=ushort4" | " -DFP16_ENABLE=1 -DFLT=half -DFLT4=half4 -DFLT16=half16 -DAS_FLT4=as_half4 -DAS_UINT4=as_ushort4 -DUINT4=ushort4" | ||||
| " -DWRITE_IMAGE=write_imageh -DREAD_IMAGE=read_imageh -DTO_FLT=convert_half -DTO_FLT4=convert_half4"; | " -DWRITE_IMAGE=write_imageh -DREAD_IMAGE=read_imageh -DTO_FLT=convert_half -DTO_FLT4=convert_half4"; | ||||
| @@ -25,6 +25,7 @@ j* you may not use this file except in compliance with the License. | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <type_traits> | #include <type_traits> | ||||
| #include "dtype/type_id.h" | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "src/runtime/opencl/opencl_wrapper.h" | #include "src/runtime/opencl/opencl_wrapper.h" | ||||
| #include "src/runtime/opencl/opencl_allocator.h" | #include "src/runtime/opencl/opencl_allocator.h" | ||||
| @@ -118,7 +119,7 @@ class OpenCLRuntime { | |||||
| std::vector<unsigned char> GetProgramBinary(const cl::Program &program); | std::vector<unsigned char> GetProgramBinary(const cl::Program &program); | ||||
| bool LoadSource(const std::string &program_name, const std::string &source); | bool LoadSource(const std::string &program_name, const std::string &source); | ||||
| int BuildKernel(cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name, | int BuildKernel(cl::Kernel &kernel, const std::string &program_name, const std::string &kernel_name, | ||||
| const std::vector<std::string> &build_options_ext = {}); | |||||
| const std::vector<std::string> &build_options_ext = {}, TypeId data_type = kNumberTypeFloat32); | |||||
| int RunKernel(const cl::Kernel &kernel, const cl::NDRange &global, const cl::NDRange &local, | int RunKernel(const cl::Kernel &kernel, const cl::NDRange &global, const cl::NDRange &local, | ||||
| cl::CommandQueue *command_queue = nullptr, cl::Event *event = nullptr); | cl::CommandQueue *command_queue = nullptr, cl::Event *event = nullptr); | ||||
| int ReadOrWriteImage(void *buffer, void *data, bool is_read); | int ReadOrWriteImage(void *buffer, void *data, bool is_read); | ||||
| @@ -22,7 +22,7 @@ class TestOpenCL_Conv2dTranspose : public CommonTest {}; | |||||
| namespace { | namespace { | ||||
| // PrimitiveType_DeConv2D: src/ops/populate/deconv2d_populate.cc | // PrimitiveType_DeConv2D: src/ops/populate/deconv2d_populate.cc | ||||
| OpParameter *CreateParameter(int n, int h, int w, int ci, int co, int kh, int kw, int pad, | |||||
| OpParameter *CreateParameter(int n, int h, int w, int ci, int co, int kh, int kw, std::vector<int> pad, int oh, int ow, | |||||
| std::vector<int> *input_shape, std::vector<int> *weight_shape, | std::vector<int> *input_shape, std::vector<int> *weight_shape, | ||||
| std::vector<int> *bias_shape, std::vector<int> *output_shape) { | std::vector<int> *bias_shape, std::vector<int> *output_shape) { | ||||
| auto *param = test::CreateParameter<ConvParameter>(schema::PrimitiveType_DeConv2D); | auto *param = test::CreateParameter<ConvParameter>(schema::PrimitiveType_DeConv2D); | ||||
| @@ -30,16 +30,15 @@ OpParameter *CreateParameter(int n, int h, int w, int ci, int co, int kh, int kw | |||||
| param->kernel_w_ = kw; | param->kernel_w_ = kw; | ||||
| param->stride_h_ = 2; | param->stride_h_ = 2; | ||||
| param->stride_w_ = 2; | param->stride_w_ = 2; | ||||
| param->pad_u_ = pad; | |||||
| param->pad_d_ = pad; | |||||
| param->pad_l_ = pad; | |||||
| param->pad_r_ = pad; | |||||
| MS_ASSERT(pad.size() == 4); | |||||
| param->pad_u_ = pad[0]; | |||||
| param->pad_d_ = pad[1]; | |||||
| param->pad_l_ = pad[2]; | |||||
| param->pad_r_ = pad[3]; | |||||
| param->dilation_h_ = 1; | param->dilation_h_ = 1; | ||||
| param->dilation_w_ = 1; | param->dilation_w_ = 1; | ||||
| param->act_type_ = ActType_No; | param->act_type_ = ActType_No; | ||||
| int oh = 2 * h - 1 + 2 * (kh - 1 - pad) - kh + 1; | |||||
| int ow = 2 * w - 1 + 2 * (kw - 1 - pad) - kw + 1; | |||||
| *input_shape = {n, h, w, ci}; | *input_shape = {n, h, w, ci}; | ||||
| *weight_shape = {co, kh, kw, ci}; | *weight_shape = {co, kh, kw, ci}; | ||||
| *bias_shape = {co}; | *bias_shape = {co}; | ||||
| @@ -52,11 +51,13 @@ TEST_F(TestOpenCL_Conv2dTranspose, test0) { | |||||
| int n = 1; | int n = 1; | ||||
| int h = 2; | int h = 2; | ||||
| int w = 2; | int w = 2; | ||||
| int oh = 4; | |||||
| int ow = 4; | |||||
| int ci = 2; | int ci = 2; | ||||
| int co = 1; | int co = 1; | ||||
| int kh = 2; | int kh = 2; | ||||
| int kw = 2; | int kw = 2; | ||||
| int pad = 0; | |||||
| std::vector<int> pad = {0, 0, 0, 0}; | |||||
| float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7}; | float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7}; | ||||
| float weight_data[] = {1, 2, 3, 4, 5, 6, 7, 8}; | float weight_data[] = {1, 2, 3, 4, 5, 6, 7, 8}; | ||||
| float bias_data[] = {0.5}; | float bias_data[] = {0.5}; | ||||
| @@ -65,7 +66,36 @@ TEST_F(TestOpenCL_Conv2dTranspose, test0) { | |||||
| for (auto fp16_enable : {false, true}) { | for (auto fp16_enable : {false, true}) { | ||||
| std::vector<int> input_shape, weight_shape, bias_shape, output_shape; | std::vector<int> input_shape, weight_shape, bias_shape, output_shape; | ||||
| auto *param = | auto *param = | ||||
| CreateParameter(n, h, w, ci, co, kh, kw, pad, &input_shape, &weight_shape, &bias_shape, &output_shape); | |||||
| CreateParameter(n, h, w, ci, co, kh, kw, pad, oh, ow, &input_shape, &weight_shape, &bias_shape, &output_shape); | |||||
| TestMain({{input_shape, input_data, VAR}, | |||||
| {weight_shape, weight_data, CONST_TENSOR}, | |||||
| {bias_shape, bias_data, CONST_TENSOR}}, | |||||
| {output_shape, output_data}, param, fp16_enable); | |||||
| } | |||||
| } | |||||
| TEST_F(TestOpenCL_Conv2dTranspose, test1) { | |||||
| int n = 1; | |||||
| int h = 3; | |||||
| int w = 3; | |||||
| int oh = 6; | |||||
| int ow = 6; | |||||
| int ci = 2; | |||||
| int co = 1; | |||||
| int kh = 2; | |||||
| int kw = 2; | |||||
| std::vector<int> pad = {0, 1, 0, 1}; | |||||
| float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}; | |||||
| float weight_data[] = {0, 2, 4, 6, 1, 3, 5, 7}; | |||||
| float bias_data[] = {0.5}; | |||||
| float output_data[] = {1.5, 3.5, 3.5, 13.5, 5.5, 23.5, 5.5, 7.5, 23.5, 33.5, 41.5, 59.5, | |||||
| 7.5, 33.5, 9.5, 43.5, 11.5, 53.5, 59.5, 85.5, 77.5, 111.5, 95.5, 137.5, | |||||
| 13.5, 63.5, 15.5, 73.5, 17.5, 83.5, 113.5, 163.5, 131.5, 189.5, 149.5, 215.5}; | |||||
| for (auto fp16_enable : {false, true}) { | |||||
| std::vector<int> input_shape, weight_shape, bias_shape, output_shape; | |||||
| auto *param = | |||||
| CreateParameter(n, h, w, ci, co, kh, kw, pad, oh, ow, &input_shape, &weight_shape, &bias_shape, &output_shape); | |||||
| TestMain({{input_shape, input_data, VAR}, | TestMain({{input_shape, input_data, VAR}, | ||||
| {weight_shape, weight_data, CONST_TENSOR}, | {weight_shape, weight_data, CONST_TENSOR}, | ||||
| {bias_shape, bias_data, CONST_TENSOR}}, | {bias_shape, bias_data, CONST_TENSOR}}, | ||||