Merge pull request !4139 from pengyongrong/concat_debug_prtags/v0.7.0-beta
| @@ -1,54 +1,44 @@ | |||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||
| __kernel void Concat(__global float *input0, __global float *input1, __global float *output, const int4 input_shape0, | |||
| const int4 input_shape1, const int4 output_shape, const int axis) { | |||
| uint oh = get_global_id(0); | |||
| uint ow = get_global_id(1); | |||
| uint oc = get_global_id(2); | |||
| uint index_output; | |||
| uint input_idx; | |||
| if ((oh >= output_shape.y || oh < 0) || (ow >= output_shape.z || ow < 0) || (oc >= output_shape.w || oc < 0)) { | |||
| return; | |||
| #define FLT4 float4 | |||
| __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; | |||
| __kernel void Concat(__write_only image2d_t output_image2d, __read_only image2d_t input0_image2d, | |||
| __read_only image2d_t input1_image2d, int2 shared_int0, int4 shared_out) { | |||
| int X = get_global_id(0); // H | |||
| int Y = get_global_id(1); // W | |||
| int S = 0; | |||
| if (X >= shared_out.y || Y >= shared_out.z) return; | |||
| for (int i = 0; i < shared_int0.x; i++) { | |||
| FLT4 result0 = read_imagef(input0_image2d, smp_none, (int2)((Y)*shared_int0.x + (i), (X))); | |||
| write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result0); | |||
| S++; | |||
| } | |||
| if (axis == 3) { | |||
| index_output = oh * output_shape.z * output_shape.w + ow * output_shape.w + oc; | |||
| if (oc < input_shape0.w) { | |||
| input_idx = (input_shape0.z * oh + ow) * input_shape0.w + oc; | |||
| output[index_output] = input0[input_idx]; | |||
| } else if ((input_shape0.w <= oc) && oc < (input_shape0.w + input_shape1.w)) { | |||
| input_idx = (input_shape1.z * oh + ow) * input_shape1.w + (oc - input_shape0.w); | |||
| output[index_output] = input1[input_idx]; | |||
| } else { | |||
| output[index_output] = 0; | |||
| } | |||
| for (int i = 0; i < shared_int0.y; i++) { | |||
| FLT4 result1 = read_imagef(input1_image2d, smp_none, (int2)((Y)*shared_int0.y + (i), (X))); | |||
| write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result1); | |||
| S++; | |||
| } | |||
| } | |||
| __kernel void Concat3input(__global float *input0, __global float *input1, __global float *input2, | |||
| __global float *output, const int4 input_shape0, const int4 input_shape1, | |||
| const int4 input_shape2, const int4 output_shape, const int axis) { | |||
| uint oh = get_global_id(0); | |||
| uint ow = get_global_id(1); | |||
| uint oc = get_global_id(2); | |||
| uint index_output; | |||
| uint input_idx; | |||
| if ((oh >= output_shape.y || oh < 0) || (ow >= output_shape.z || ow < 0) || (oc >= output_shape.w || oc < 0)) { | |||
| return; | |||
| __kernel void Concat3input(__write_only image2d_t output_image2d, __read_only image2d_t input0_image2d, | |||
| __read_only image2d_t input1_image2d, __read_only image2d_t input2_image2d, int3 shared_int0, | |||
| int4 shared_out) { | |||
| int X = get_global_id(0); // H | |||
| int Y = get_global_id(1); // W | |||
| int S = 0; | |||
| if (X >= shared_out.y || Y >= shared_out.z) return; | |||
| for (int i = 0; i < shared_int0.x; i++) { | |||
| FLT4 result0 = read_imagef(input0_image2d, smp_none, (int2)((Y)*shared_int0.x + (i), (X))); | |||
| write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result0); | |||
| S++; | |||
| } | |||
| index_output = oh * output_shape.z * output_shape.w + ow * output_shape.w + oc; | |||
| if (oc < (input_shape0.w + input_shape1.w)) { | |||
| if (oc < input_shape0.w) { | |||
| input_idx = (input_shape0.z * oh + ow) * input_shape0.w + oc; | |||
| output[index_output] = input0[input_idx]; | |||
| } else { | |||
| input_idx = (input_shape1.z * oh + ow) * input_shape1.w + (oc - input_shape0.w); | |||
| output[index_output] = input1[input_idx]; | |||
| } | |||
| } else { | |||
| if ((input_shape0.w + input_shape1.w + input_shape2.w) <= oc) { | |||
| output[index_output] = 0; | |||
| } else { | |||
| input_idx = (input_shape2.z * oh + ow) * input_shape2.w + (oc - input_shape0.w - input_shape1.w); | |||
| output[index_output] = input2[input_idx]; | |||
| } | |||
| for (int i = 0; i < shared_int0.y; i++) { | |||
| FLT4 result1 = read_imagef(input1_image2d, smp_none, (int2)((Y)*shared_int0.y + (i), (X))); | |||
| write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result1); | |||
| S++; | |||
| } | |||
| for (int i = 0; i < shared_int0.z; i++) { | |||
| FLT4 result2 = read_imagef(input2_image2d, smp_none, (int2)((Y)*shared_int0.z + (i), (X))); | |||
| write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result2); | |||
| S++; | |||
| } | |||
| } | |||
| @@ -13,6 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cstring> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <set> | |||
| @@ -27,6 +28,26 @@ using mindspore::schema::PrimitiveType_Concat; | |||
| namespace mindspore::kernel { | |||
| int ConcatOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); | |||
| size_t im_dst_x, im_dst_y; | |||
| if (inputs_[0]->GetFormat() == schema::Format_NHWC4) { | |||
| im_dst_x = outputs_[0]->Width() * CO4; | |||
| im_dst_y = outputs_[0]->Height(); | |||
| } else { | |||
| im_dst_y = outputs_[0]->Height() * CO4; | |||
| im_dst_x = outputs_[0]->Width(); | |||
| } | |||
| #ifdef ENABLE_FP16 | |||
| size_t img_dtype = CL_HALF_FLOAT; | |||
| #else | |||
| size_t img_dtype = CL_FLOAT; | |||
| #endif | |||
| img_size->clear(); | |||
| std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype}; | |||
| *img_size = vec; | |||
| return 1; | |||
| } | |||
| int ConcatOpenCLKernel::Init() { | |||
| if (inputs_[0]->shape().size() != 4) { | |||
| MS_LOG(ERROR) << "only support dim=4"; | |||
| @@ -132,72 +153,45 @@ int ConcatOpenCLKernel::Run() { | |||
| } | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| std::vector<size_t> local; | |||
| std::vector<size_t> global; | |||
| MS_LOG(INFO) << " judge the numbers of input vector"; | |||
| auto input0_shape = inputs_[0]->shape(); | |||
| auto input1_shape = inputs_[1]->shape(); | |||
| auto input2_shape = inputs_[2]->shape(); | |||
| auto output_shape = outputs_[0]->shape(); | |||
| cl_int2 input0_shape2_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4)}; // change | |||
| cl_int3 input0_shape3_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4), | |||
| DivideRoundUp(input2_shape[3], 4)}; | |||
| cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], DivideRoundUp(output_shape[3], 4)}; | |||
| uint32_t OH = output_shape[0] * output_shape[1]; // N*H | |||
| uint32_t OW = output_shape[2]; | |||
| std::vector<size_t> local = {1, 1}; | |||
| std::vector<size_t> global = {OH, OW}; | |||
| // ConcatGetWorkGroup(global, &local, 512); | |||
| int arg_cn = 0; | |||
| if (inputs_.size() == 2) { | |||
| auto input0_shape = inputs_[0]->shape(); | |||
| auto input1_shape = inputs_[1]->shape(); | |||
| auto output_shape = outputs_[0]->shape(); | |||
| cl_int4 input0_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], input0_shape[3]}; | |||
| cl_int4 input1_shape_ = {input1_shape[0], input1_shape[1], input1_shape[2], input1_shape[3]}; | |||
| cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], output_shape[3]}; | |||
| uint32_t OH = output_shape[0] * output_shape[1]; // N*H | |||
| uint32_t OW = output_shape[2]; | |||
| uint32_t OC = output_shape[3]; | |||
| global = {OH, OW, OC}; // HWC | |||
| ConcatGetWorkGroup(global, &local, 384); | |||
| std::cout << "local size=:" << std::endl; | |||
| for (int i = 0; i < local.size(); i++) { | |||
| std::cout << local[i] << " "; | |||
| } | |||
| std::cout << std::endl; | |||
| int arg_cn = 0; | |||
| MS_LOG(INFO) << " SetKernelArg"; | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, input1_shape_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape2_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_); | |||
| } | |||
| if (inputs_.size() == 3) { | |||
| auto input0_shape = inputs_[0]->shape(); | |||
| auto input1_shape = inputs_[1]->shape(); | |||
| auto input2_shape = inputs_[2]->shape(); | |||
| auto output_shape = outputs_[0]->shape(); | |||
| cl_int4 input0_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], input0_shape[3]}; | |||
| cl_int4 input1_shape_ = {input1_shape[0], input1_shape[1], input1_shape[2], input1_shape[3]}; | |||
| cl_int4 input2_shape_ = {input2_shape[0], input2_shape[1], input2_shape[2], input2_shape[3]}; | |||
| cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], output_shape[3]}; | |||
| uint32_t OH = output_shape[0] * output_shape[1]; // N*H | |||
| uint32_t OW = output_shape[2]; | |||
| uint32_t OC = output_shape[3]; | |||
| global = {OH, OW, OC}; // HWC | |||
| ConcatGetWorkGroup(global, &local, 384); | |||
| std::cout << "local size=:" << std::endl; | |||
| for (int i = 0; i < local.size(); i++) { | |||
| std::cout << local[i] << " "; | |||
| } | |||
| std::cout << std::endl; | |||
| int arg_cn = 0; | |||
| } else if (inputs_.size() == 3) { | |||
| MS_LOG(INFO) << " SetKernelArg"; | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[2]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, input1_shape_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, input2_shape_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape3_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_); | |||
| } | |||
| ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||
| return 0; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_Concat_H_ | |||
| #define MINDSPORE_LITE_SRC_BACKEND_OPENCL_Concat_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONCAT_H_ | |||
| #define MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONCAT_H_ | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| @@ -25,11 +25,11 @@ | |||
| namespace mindspore::kernel { | |||
| class ConcatOpenCLKernel : public LiteKernel { | |||
| class ConcatOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : LiteKernel(parameter, inputs, outputs) {} | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| ~ConcatOpenCLKernel() override{}; | |||
| @@ -40,6 +40,7 @@ class ConcatOpenCLKernel : public LiteKernel { | |||
| int Run_axis0(); | |||
| int Run() override; | |||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size) override; | |||
| private: | |||
| cl::Kernel kernel_; | |||
| @@ -69,8 +69,8 @@ void *OpenCLAllocator::Malloc(size_t size) { | |||
| host_ptr = clSVMAlloc((*ocl_runtime->Context())(), flags, size, 0); | |||
| } else { | |||
| cl_int ret = CL_SUCCESS; | |||
| cl::Buffer *buffer = | |||
| new cl::Buffer(*ocl_runtime->Context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, size, NULL, &ret); | |||
| cl::Buffer *buffer = new cl::Buffer(*ocl_runtime->Context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, | |||
| size, NULL, &ret); | |||
| if (ret != CL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Create OpenCL buffer failed! (ERROR CODE: " << ret << ")"; | |||
| UnLock(); | |||
| @@ -125,8 +125,8 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t>& img_size) | |||
| cl_int ret = CL_SUCCESS; | |||
| // CL_HALF_FLOAT, CL_FLOAT | |||
| cl::ImageFormat image_format(CL_RGBA, img_size[2]); | |||
| cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, | |||
| image_format, img_size[0], img_size[1], 0, nullptr, &ret); | |||
| cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, image_format, | |||
| img_size[0], img_size[1], 0, nullptr, &ret); | |||
| if (ret != CL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; | |||
| UnLock(); | |||
| @@ -164,20 +164,26 @@ void *OpenCLAllocator::CreateImageFromHost(void *data, size_t size, const std::v | |||
| auto iter = free_list_.lower_bound(size); | |||
| if (iter != free_list_.end() && (iter->second->size_ >= size) && (iter->second->size_ < (size << shift_factor_))) { | |||
| auto mem_buf = iter->second; | |||
| free_list_.erase(iter); | |||
| allocated_list_[mem_buf->host_ptr_] = mem_buf; | |||
| UnLock(); | |||
| MS_LOG(DEBUG) << "Malloc Image2D from free list. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ | |||
| << ", device addr: " << mem_buf->device_ptr_; | |||
| return mem_buf->host_ptr_; | |||
| bool is_match{mem_buf->img_size.size() == img_size.size()}; | |||
| for (int i = 0; i < img_size.size() && is_match; ++i) { | |||
| is_match = img_size[i] == mem_buf->img_size[i]; | |||
| } | |||
| if (is_match) { | |||
| free_list_.erase(iter); | |||
| allocated_list_[mem_buf->host_ptr_] = mem_buf; | |||
| UnLock(); | |||
| MS_LOG(DEBUG) << "Malloc Image2D from free list. size: " << mem_buf->size_ | |||
| << ", host addr: " << mem_buf->host_ptr_ << ", device addr: " << mem_buf->device_ptr_; | |||
| return mem_buf->host_ptr_; | |||
| } | |||
| } | |||
| void *host_ptr = nullptr; | |||
| void *device_ptr = nullptr; | |||
| cl_int ret = CL_SUCCESS; | |||
| // CL_HALF_FLOAT, CL_FLOAT | |||
| cl::ImageFormat image_format(CL_RGBA, img_size[2]); | |||
| cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, | |||
| img_size[0], img_size[1], 0, data, &ret); | |||
| cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, | |||
| image_format, img_size[0], img_size[1], 0, data, &ret); | |||
| if (ret != CL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; | |||
| UnLock(); | |||
| @@ -372,4 +378,3 @@ int OpenCLAllocator::GetImageSize(void *host_ptr, std::vector<size_t>* img_size) | |||
| } | |||
| } // namespace mindspore::lite::opencl | |||
| @@ -21,7 +21,6 @@ | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h" | |||
| int DivideRoundUp(int n, int div) { | |||
| int q = n / div; | |||
| return n % div == 0 ? q : q + 1; | |||
| @@ -77,15 +76,26 @@ void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *i | |||
| postion = i * output_shape[1] * output_shape[2] * output_shape[3] + j * output_shape[2] * output_shape[3] + | |||
| k * output_shape[3]; | |||
| for (int w = 0; w < output_shape[3]; w++) { | |||
| if (w < input_shape0[3] + input_shape1[3]) { | |||
| output[postion++] = (w < input_shape0[3]) ? input0[index0++] : input1[index1++]; | |||
| if (w < input_shape0[3]) { | |||
| int align = DivideRoundUp(input_shape0[3], 4) * 4; | |||
| index0 = i * input_shape0[1] * input_shape0[2] * align + j * input_shape0[2] * align + k * align + w; | |||
| output[postion++] = input0[index0]; | |||
| } else if (w >= input_shape0[3] && w < (input_shape0[3] + input_shape1[3])) { | |||
| int align = DivideRoundUp(input_shape1[3], 4) * 4; | |||
| index1 = i * input_shape1[1] * input_shape1[2] * align + j * input_shape1[2] * align + k * align + w - | |||
| input_shape0[3]; | |||
| output[postion++] = input1[index1]; | |||
| } else if ((input_shape0[3] + input_shape1[3]) <= w && | |||
| w < (input_shape0[3] + input_shape1[3] + input_shape2[3])) { | |||
| output[postion++] = input2[index2++]; | |||
| int align = DivideRoundUp(input_shape2[3], 4) * 4; | |||
| index2 = i * input_shape2[1] * input_shape2[2] * align + j * input_shape2[2] * align + k * align + w - | |||
| input_shape0[3] - input_shape1[3]; | |||
| output[postion++] = input2[index2]; | |||
| } else { | |||
| for (int ind = input_shape0[3] + input_shape1[3]; ind < output_shape[3]; ind++) { | |||
| for (int ind = input_shape0[3] + input_shape1[3] + input_shape2[3]; ind < output_shape[3]; ind++) { | |||
| output[postion++] = 0; | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| @@ -96,18 +106,31 @@ void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *i | |||
| namespace mindspore { | |||
| class TestConcatOpenCL : public mindspore::Common { | |||
| public: | |||
| TestConcatOpenCL(){} | |||
| TestConcatOpenCL() {} | |||
| }; | |||
| template <typename T> | |||
| void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bound) { | |||
| for (size_t i = 0; i < size; i++) { | |||
| T abs = fabs(output_data[i] - correct_data[i]); | |||
| // printf("i=%d %.3f %.3f\n", i, output_data[i], correct_data[i]); | |||
| ASSERT_LE(abs, err_bound); | |||
| } | |||
| } | |||
| TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { | |||
| MS_LOG(INFO) << "begin test"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| MS_LOG(INFO) << "init tensors"; | |||
| constexpr int INPUT_NUM = 3; | |||
| std::array<std::vector<int>, INPUT_NUM> input_shapes = { | |||
| std::vector<int>{1, 240, 240, 16}, std::vector<int>{1, 240, 240, 16}, std::vector<int>{1, 240, 240, 64}}; | |||
| std::vector<int> output_shape = {1, 240, 240, 96}; | |||
| constexpr int INPUT_NUM = 2; | |||
| // std::array<std::vector<int>, INPUT_NUM> input_shapes = { | |||
| // std::vector<int>{1, 120, 120, 16}, std::vector<int>{1, 120, 120, 16},std::vector<int>{1, 120, 120, 96}}; | |||
| std::array<std::vector<int>, INPUT_NUM> input_shapes = {std::vector<int>{1, 32, 512, 48}, | |||
| std::vector<int>{1, 32, 512, 48}}; | |||
| std::vector<int> output_shape = {1, 32, 512, 96}; | |||
| output_shape[3] = DivideRoundUp(output_shape[3], 4) * 4; | |||
| auto data_type = kNumberTypeFloat32; | |||
| auto tensor_type = schema::NodeType_ValueNode; | |||
| @@ -118,32 +141,30 @@ TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { | |||
| auto *output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); | |||
| std::vector<lite::tensor::Tensor *> outputs{output_tensor}; | |||
| std::cout << "input_shapes size=: " << input_shapes.size() << std::endl; | |||
| MS_LOG(INFO) << "initialize tensors"; | |||
| std::cout << "initialize tensors"; | |||
| auto param = new ConcatParameter(); | |||
| param->axis_ = 3; | |||
| auto *concat_kernel = new kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| concat_kernel->Init(); | |||
| MS_LOG(INFO) << "initialize sub_graph"; | |||
| std::vector<kernel::LiteKernel *> kernels{concat_kernel}; | |||
| auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| // to do allocate memory for inputs and outputs | |||
| for (auto &input_tensor : inputs) { | |||
| input_tensor->MallocData(allocator); | |||
| } | |||
| sub_graph->Init(); | |||
| unsigned int seed = 123; | |||
| MS_LOG(INFO) << "initialize input data"; | |||
| srand(time(NULL)); | |||
| for (auto &input_tensor : inputs) { | |||
| auto input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||
| static unsigned int seed = 123; | |||
| for (int i = 0; i < input_tensor->ElementsNum(); ++i) { | |||
| input_data[i] = static_cast<float>(rand_r(&seed) % 10 + 1); | |||
| } | |||
| printf("\n"); | |||
| } | |||
| MS_LOG(INFO) << "==================output data================"; | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data()); | |||
| printf("\n"); | |||
| // compute the result for CPU | |||
| auto *input_data0 = reinterpret_cast<float *>(inputs[0]->Data()); | |||
| auto *input_data1 = reinterpret_cast<float *>(inputs[1]->Data()); | |||
| std::vector<float> output_data_cpu(output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3]); | |||
| @@ -156,8 +177,10 @@ TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { | |||
| ConcatComputeByCPU_3input_dim4_axis3(input_data0, input_data1, input_data2, output_data_cpu.data(), input_shapes[0], | |||
| input_shapes[1], input_shapes[2], output_shape, param->axis_); | |||
| } | |||
| printf("\n"); | |||
| CompareOutputData(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001); | |||
| MS_LOG(INFO) << "Testconcat passed"; | |||
| std::cout << "==================output data================" << std::endl; | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data()); | |||
| CompareOutputData1(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001); | |||
| } | |||
| } // namespace mindspore | |||