| @@ -10,15 +10,7 @@ __kernel void Concat2input_NHWC4(__read_only image2d_t input0, __read_only image | |||||
| if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { | if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { | ||||
| return; | return; | ||||
| } | } | ||||
| if (axis == 0) { | |||||
| if (X < input_shape0.x * input_shape0.y) { | |||||
| FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); | |||||
| WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); | |||||
| } else { | |||||
| FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.x * input_shape0.y))); | |||||
| WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); | |||||
| } | |||||
| } else if (axis == 1) { | |||||
| if (axis == 1) { | |||||
| if (X < input_shape0.y) { | if (X < input_shape0.y) { | ||||
| FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); | FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); | ||||
| WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); | WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); | ||||
| @@ -54,21 +46,7 @@ __kernel void Concat3input_NHWC4(__read_only image2d_t input0, __read_only image | |||||
| if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { | if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { | ||||
| return; | return; | ||||
| } | } | ||||
| if (axis == 0) { | |||||
| if (X < input_shape0.x * input_shape0.y) { | |||||
| FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); | |||||
| WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0); | |||||
| } else if (X < (input_shape0.x * input_shape0.y + input_shape1.x * input_shape1.y)) { | |||||
| FLT4 result1 = | |||||
| READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.x * input_shape0.y))); | |||||
| WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1); | |||||
| } else { | |||||
| FLT4 result2 = READ_IMAGE( | |||||
| input2, smp_none, | |||||
| (int2)((Y)*input_shape2.w + Z, (X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y))); | |||||
| WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2); | |||||
| } | |||||
| } else if (axis == 1) { | |||||
| if (axis == 1) { | |||||
| if (X < input_shape0.y) { | if (X < input_shape0.y) { | ||||
| FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); | FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); | ||||
| WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0); | WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0); | ||||
| @@ -121,18 +99,7 @@ __kernel void Concat2input_NC4HW4(__read_only image2d_t input0, __read_only imag | |||||
| } | } | ||||
| int in_postion_x; | int in_postion_x; | ||||
| int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; | int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; | ||||
| if (axis == 0) { | |||||
| if (X < (input_shape0.x * input_shape0.y)) { | |||||
| in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; | |||||
| FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); | |||||
| WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); | |||||
| } else { | |||||
| in_postion_x = ((X - input_shape0.x * input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + | |||||
| Z * input_shape1.y + ((X - input_shape0.x * input_shape0.y) % input_shape1.y); | |||||
| FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); | |||||
| WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); | |||||
| } | |||||
| } else if (axis == 1) { | |||||
| if (axis == 1) { | |||||
| if (X < input_shape0.y) { | if (X < input_shape0.y) { | ||||
| in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; | in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; | ||||
| FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); | FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); | ||||
| @@ -181,25 +148,7 @@ __kernel void Concat3input_NC4HW4(__read_only image2d_t input0, __read_only imag | |||||
| } | } | ||||
| int in_postion_x; | int in_postion_x; | ||||
| int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; | int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; | ||||
| if (axis == 0) { | |||||
| if (X < (input_shape0.x * input_shape0.y)) { | |||||
| in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; | |||||
| FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); | |||||
| WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); | |||||
| } else if (X < (input_shape0.x * input_shape0.y + input_shape1.x * input_shape1.y)) { | |||||
| in_postion_x = ((X - input_shape0.x * input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + | |||||
| Z * input_shape1.y + ((X - input_shape0.x * input_shape0.y) % input_shape1.y); | |||||
| FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); | |||||
| WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); | |||||
| } else { | |||||
| in_postion_x = ((X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y) / input_shape2.y) * | |||||
| input_shape2.w * input_shape2.y + | |||||
| Z * input_shape2.y + | |||||
| (X - input_shape0.x * input_shape0.y - input_shape1.x * input_shape1.y) % input_shape2.y; | |||||
| FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); | |||||
| WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); | |||||
| } | |||||
| } else if (axis == 1) { | |||||
| if (axis == 1) { | |||||
| if (X < input_shape0.y) { | if (X < input_shape0.y) { | ||||
| in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; | in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; | ||||
| FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); | FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); | ||||
| @@ -59,10 +59,10 @@ __kernel void to_format_NHWC_to_NC4HW4_IMG_float(__global float4 *src_data, __wr | |||||
| int X = get_global_id(0); | int X = get_global_id(0); | ||||
| int Y = get_global_id(1); | int Y = get_global_id(1); | ||||
| int Z = get_global_id(2); | int Z = get_global_id(2); | ||||
| if (X >= size.x || Y >= size.y || Z >= size.z) { | |||||
| if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) { | |||||
| return; | return; | ||||
| } | } | ||||
| int offset = (X * shape.z + Y) * shape.w + Z * 4; | |||||
| int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4; | |||||
| __global float *src_addr = (__global float *)src_data; | __global float *src_addr = (__global float *)src_data; | ||||
| src_addr += offset; | src_addr += offset; | ||||
| FLT4 data = (FLT4)(0.f); | FLT4 data = (FLT4)(0.f); | ||||
| @@ -79,17 +79,18 @@ __kernel void to_format_NHWC_to_NC4HW4_IMG_float(__global float4 *src_data, __wr | |||||
| data.z = (FLT)src_addr[2]; | data.z = (FLT)src_addr[2]; | ||||
| } | } | ||||
| } | } | ||||
| WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), data); | |||||
| int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y; | |||||
| WRITE_IMAGE(dst_data, (int2)(Y, pos_ix), data); | |||||
| } | } | ||||
| __kernel void to_format_NHWC_to_NC4HW4_IMG_half(__global half4 *src_data, __write_only image2d_t dst_data, int4 size, | __kernel void to_format_NHWC_to_NC4HW4_IMG_half(__global half4 *src_data, __write_only image2d_t dst_data, int4 size, | ||||
| int4 shape) { | int4 shape) { | ||||
| int X = get_global_id(0); | int X = get_global_id(0); | ||||
| int Y = get_global_id(1); | int Y = get_global_id(1); | ||||
| int Z = get_global_id(2); | int Z = get_global_id(2); | ||||
| if (X >= size.x || Y >= size.y || Z >= size.z) { | |||||
| if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) { | |||||
| return; | return; | ||||
| } | } | ||||
| int offset = (X * shape.z + Y) * shape.w + Z * 4; | |||||
| int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4; | |||||
| __global half *src_addr = (__global half *)src_data; | __global half *src_addr = (__global half *)src_data; | ||||
| src_addr += offset; | src_addr += offset; | ||||
| FLT4 data = (FLT4)(0.f); | FLT4 data = (FLT4)(0.f); | ||||
| @@ -106,7 +107,8 @@ __kernel void to_format_NHWC_to_NC4HW4_IMG_half(__global half4 *src_data, __writ | |||||
| data.z = (FLT)src_addr[2]; | data.z = (FLT)src_addr[2]; | ||||
| } | } | ||||
| } | } | ||||
| WRITE_IMAGE(dst_data, (int2)(Y, Z * size.x + X), data); | |||||
| int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y; | |||||
| WRITE_IMAGE(dst_data, (int2)(Y, pos_ix), data); | |||||
| } | } | ||||
| __kernel void to_format_NHWC4_to_NHWC4_IMG_float(__global float4 *src_data, __write_only image2d_t dst_data, int4 size, | __kernel void to_format_NHWC4_to_NHWC4_IMG_float(__global float4 *src_data, __write_only image2d_t dst_data, int4 size, | ||||
| int4 shape) { | int4 shape) { | ||||
| @@ -227,11 +229,12 @@ __kernel void to_format_NC4HW4_to_NHWC_BUF_float(__read_only image2d_t src_data, | |||||
| int X = get_global_id(0); | int X = get_global_id(0); | ||||
| int Y = get_global_id(1); | int Y = get_global_id(1); | ||||
| int Z = get_global_id(2); | int Z = get_global_id(2); | ||||
| if (X >= size.x || Y >= size.y || Z >= size.z) { | |||||
| if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) { | |||||
| return; | return; | ||||
| } | } | ||||
| float4 data = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X))); | |||||
| int offset = (X * shape.z + Y) * shape.w + Z * 4; | |||||
| int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y; | |||||
| float4 data = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y, pos_ix))); | |||||
| int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4; | |||||
| __global float *dst_addr = (__global float *)dst_data; | __global float *dst_addr = (__global float *)dst_data; | ||||
| dst_addr += offset; | dst_addr += offset; | ||||
| if ((Z + 1) * 4 <= shape.w) { | if ((Z + 1) * 4 <= shape.w) { | ||||
| @@ -253,11 +256,12 @@ __kernel void to_format_NC4HW4_to_NHWC_BUF_half(__read_only image2d_t src_data, | |||||
| int X = get_global_id(0); | int X = get_global_id(0); | ||||
| int Y = get_global_id(1); | int Y = get_global_id(1); | ||||
| int Z = get_global_id(2); | int Z = get_global_id(2); | ||||
| if (X >= size.x || Y >= size.y || Z >= size.z) { | |||||
| if (X >= size.x || Y >= size.y || Z >= size.z || shape.y == 0) { | |||||
| return; | return; | ||||
| } | } | ||||
| half4 data = convert_half4(READ_IMAGE(src_data, smp_zero, (int2)(Y, Z * size.x + X))); | |||||
| int offset = (X * shape.z + Y) * shape.w + Z * 4; | |||||
| int pos_ix = (X / shape.y) * size.z * shape.y + Z * shape.y + X % shape.y; | |||||
| half4 data = convert_half4(READ_IMAGE(src_data, smp_zero, (int2)(Y, pos_ix))); | |||||
| int offset = (X / shape.y) * shape.y * shape.z * shape.w + ((X % shape.y) * shape.z + Y) * shape.w + Z * 4; | |||||
| __global half *dst_addr = (__global half *)dst_data; | __global half *dst_addr = (__global half *)dst_data; | ||||
| dst_addr += offset; | dst_addr += offset; | ||||
| if ((Z + 1) * 4 <= shape.w) { | if ((Z + 1) * 4 <= shape.w) { | ||||
| @@ -49,6 +49,26 @@ int ConcatOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) | |||||
| *img_size = vec; | *img_size = vec; | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConcatOpenCLKernel::RunAxis0() { | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| auto allocator_ = ocl_runtime->GetAllocator(); | |||||
| std::vector<size_t> img_size; | |||||
| auto dst_data = out_tensors_[0]->MutableData(); | |||||
| auto dst_origin = cl::array<cl::size_type, 3U>{0, 0, 0}; | |||||
| cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data)); | |||||
| for (int i = 0; i < in_tensors_.size(); i++) { | |||||
| auto src_data = in_tensors_[i]->MutableData(); | |||||
| allocator_->GetImageSize(src_data, &img_size); | |||||
| auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0}; | |||||
| auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1}; | |||||
| cl::Image2D *input_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data)); | |||||
| ocl_runtime->GetDefaultCommandQueue()->enqueueCopyImage(*input_image, *out_image, src_origin, dst_origin, region); | |||||
| dst_origin[1] += region[1]; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConcatOpenCLKernel::Init() { | int ConcatOpenCLKernel::Init() { | ||||
| if (in_tensors_[0]->shape().size() != 4) { | if (in_tensors_[0]->shape().size() != 4) { | ||||
| MS_LOG(ERROR) << " only support dim = 4 "; | MS_LOG(ERROR) << " only support dim = 4 "; | ||||
| @@ -98,6 +118,19 @@ int ConcatOpenCLKernel::Init() { | |||||
| int ConcatOpenCLKernel::ReSize() { return RET_OK; } | int ConcatOpenCLKernel::ReSize() { return RET_OK; } | ||||
| int ConcatOpenCLKernel::GetSumShape(std::vector<int> *sum_shape, std::vector<int> *in_shape) { | |||||
| std::vector<int> temp_sum = {0, 0, 0, 0}; | |||||
| for (int i = 0; i < in_tensors_.size(); ++i) { | |||||
| auto temp = in_tensors_[i]->shape(); | |||||
| for (int j = 0; j < temp.size(); ++j) { | |||||
| in_shape->push_back(temp[j]); | |||||
| temp_sum.at(j) += temp[j]; | |||||
| sum_shape->push_back(temp_sum.at(j)); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConcatGetBiggestDividerWithPriority(int number, int max_divider) { | int ConcatGetBiggestDividerWithPriority(int number, int max_divider) { | ||||
| if (number % 8 == 0 && max_divider >= 8) { | if (number % 8 == 0 && max_divider >= 8) { | ||||
| return number / 8; | return number / 8; | ||||
| @@ -133,6 +166,9 @@ void ConcatGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> * | |||||
| int ConcatOpenCLKernel::Run() { | int ConcatOpenCLKernel::Run() { | ||||
| MS_LOG(DEBUG) << this->name() << " Running! "; | MS_LOG(DEBUG) << this->name() << " Running! "; | ||||
| auto param = reinterpret_cast<ConcatParameter *>(this->op_parameter_); | auto param = reinterpret_cast<ConcatParameter *>(this->op_parameter_); | ||||
| if (param->axis_ == 0) { | |||||
| return RunAxis0(); | |||||
| } | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | ||||
| auto input1_shape = in_tensors_[0]->shape(); | auto input1_shape = in_tensors_[0]->shape(); | ||||
| @@ -151,6 +187,7 @@ int ConcatOpenCLKernel::Run() { | |||||
| std::vector<size_t> local = {1, 1, 1}; // init local | std::vector<size_t> local = {1, 1, 1}; // init local | ||||
| std::vector<size_t> global = {OH, OW, OC}; | std::vector<size_t> global = {OH, OW, OC}; | ||||
| ConcatGetWorkGroup(global, &local, max_global[0]); | ConcatGetWorkGroup(global, &local, max_global[0]); | ||||
| GetSumShape(&sum_shape, &in_shape); | |||||
| int arg_cn = 0; | int arg_cn = 0; | ||||
| if (in_tensors_.size() == 2) { | if (in_tensors_.size() == 2) { | ||||
| @@ -38,10 +38,17 @@ class ConcatOpenCLKernel : public OpenCLKernel { | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int RunAxis0(); | |||||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size) override; | int GetImageSize(size_t idx, std::vector<size_t> *img_size) override; | ||||
| int GetSumShape(std::vector<int> *sum_shape, std::vector<int> *in_shape); | |||||
| private: | private: | ||||
| cl::Kernel kernel_; | cl::Kernel kernel_; | ||||
| std::vector<int> sum_shape; | |||||
| std::vector<int> in_shape; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -62,19 +62,19 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { | |||||
| constexpr int INPUT_NUM = 2; | constexpr int INPUT_NUM = 2; | ||||
| std::array<std::vector<int>, INPUT_NUM> input_shapes = {std::vector<int>{1, 19, 19, 96}, | std::array<std::vector<int>, INPUT_NUM> input_shapes = {std::vector<int>{1, 19, 19, 96}, | ||||
| std::vector<int>{1, 19, 19, 96}}; | std::vector<int>{1, 19, 19, 96}}; | ||||
| std::vector<int> output_shape = {2, 19, 19, 96}; | |||||
| std::vector<int> output_shape = {1, 19, 19, 192}; | |||||
| auto data_type = kNumberTypeFloat16; | auto data_type = kNumberTypeFloat16; | ||||
| auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); | auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); | ||||
| std::vector<lite::Tensor *> inputs; | std::vector<lite::Tensor *> inputs; | ||||
| for (auto &shape : input_shapes) { | for (auto &shape : input_shapes) { | ||||
| auto input_temp = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC4, tensor_type); | |||||
| auto input_temp = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type); | |||||
| inputs.push_back(input_temp); | inputs.push_back(input_temp); | ||||
| if (input_temp == nullptr) { | if (input_temp == nullptr) { | ||||
| MS_LOG(INFO) << " new input_tensor failed "; | MS_LOG(INFO) << " new input_tensor failed "; | ||||
| return; | return; | ||||
| } | } | ||||
| } | } | ||||
| auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type); | |||||
| auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); | |||||
| if (output_tensor == nullptr) { | if (output_tensor == nullptr) { | ||||
| MS_LOG(INFO) << " new output_tensor failed "; | MS_LOG(INFO) << " new output_tensor failed "; | ||||
| for (auto tensor : inputs) { | for (auto tensor : inputs) { | ||||
| @@ -97,7 +97,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| param->axis_ = 0; | |||||
| param->axis_ = 3; | |||||
| auto *concat_kernel = | auto *concat_kernel = | ||||
| new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | ||||
| if (concat_kernel == nullptr) { | if (concat_kernel == nullptr) { | ||||
| @@ -111,6 +111,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { | |||||
| delete param; | delete param; | ||||
| return; | return; | ||||
| } | } | ||||
| concat_kernel->SetFormatType(schema::Format_NC4HW4); | |||||
| concat_kernel->Init(); | concat_kernel->Init(); | ||||
| // to do allocate memory for inputs and outputs | // to do allocate memory for inputs and outputs | ||||
| for (auto &input_tensor : inputs) { | for (auto &input_tensor : inputs) { | ||||
| @@ -229,8 +230,9 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) { | |||||
| delete param; | delete param; | ||||
| return; | return; | ||||
| } | } | ||||
| concat_kernel->SetFormatType(schema::Format_NC4HW4); | |||||
| concat_kernel->Init(); | concat_kernel->Init(); | ||||
| // to do allocate memory for inputs and outputs | |||||
| // to do allocate memory for inputs | |||||
| for (auto &input_tensor : inputs) { | for (auto &input_tensor : inputs) { | ||||
| input_tensor->MallocData(allocator); | input_tensor->MallocData(allocator); | ||||
| } | } | ||||