|
|
|
@@ -147,21 +147,18 @@ void ConcatGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> * |
|
|
|
local->push_back(z); |
|
|
|
} |
|
|
|
int ConcatOpenCLKernel::Run() { |
|
|
|
MS_LOG(DEBUG) << this->Name() << " Running!"; |
|
|
|
auto param = reinterpret_cast<ConcatParameter *>(this->opParameter); |
|
|
|
if (param->axis_ == 0) { |
|
|
|
return Run_axis0(); |
|
|
|
} |
|
|
|
|
|
|
|
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); |
|
|
|
MS_LOG(INFO) << " judge the numbers of input vector"; |
|
|
|
auto input0_shape = inputs_[0]->shape(); |
|
|
|
auto input1_shape = inputs_[1]->shape(); |
|
|
|
auto input2_shape = inputs_[2]->shape(); |
|
|
|
auto output_shape = outputs_[0]->shape(); |
|
|
|
|
|
|
|
cl_int2 input0_shape2_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4)}; // change |
|
|
|
cl_int3 input0_shape3_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4), |
|
|
|
DivideRoundUp(input2_shape[3], 4)}; |
|
|
|
cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], DivideRoundUp(output_shape[3], 4)}; |
|
|
|
|
|
|
|
uint32_t OH = output_shape[0] * output_shape[1]; // N*H |
|
|
|
@@ -173,14 +170,15 @@ int ConcatOpenCLKernel::Run() { |
|
|
|
|
|
|
|
int arg_cn = 0; |
|
|
|
if (inputs_.size() == 2) { |
|
|
|
MS_LOG(INFO) << " SetKernelArg"; |
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); |
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data()); |
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data()); |
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape2_); |
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); |
|
|
|
} else if (inputs_.size() == 3) { |
|
|
|
MS_LOG(INFO) << " SetKernelArg"; |
|
|
|
auto input2_shape = inputs_[2]->shape(); |
|
|
|
cl_int3 input0_shape3_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4), |
|
|
|
DivideRoundUp(input2_shape[3], 4)}; |
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); |
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data()); |
|
|
|
ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data()); |
|
|
|
|