Merge pull request !4135 from chenzhongming/new_mastertags/v0.7.0-beta
| @@ -1,5 +1,5 @@ | |||
| __kernel void MaxPooling2d(__global float4 *input, __global float4 *output, const int4 input_shape, | |||
| const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { | |||
| __kernel void MaxPooling2d_BUF(__global float4 *input, __global float4 *output, const int4 input_shape, | |||
| const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { | |||
| // axis to dst tensor coordinate | |||
| int X = get_global_id(0); | |||
| int Y = get_global_id(1); | |||
| @@ -31,38 +31,37 @@ __kernel void MaxPooling2d(__global float4 *input, __global float4 *output, cons | |||
| output[(output_shape.y * X + Y) * output_shape.w + Z] = maximum; | |||
| } | |||
| // __constant sampler_t sample_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; | |||
| __constant sampler_t sample_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; | |||
| //__kernel void MaxPooling2dImage2d(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, | |||
| // const int4 output_shape, const int2 stride, const int2 kernel_size, | |||
| // const int2 padding) { | |||
| // // axis to dst tensor coordinate | |||
| // int X = get_global_id(0); | |||
| // int Y = get_global_id(1); | |||
| // int Z = get_global_id(2); | |||
| // | |||
| // // boundary check | |||
| // if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { | |||
| // return; | |||
| // } | |||
| // | |||
| // float4 maximum = (float4)(-10000.0f); | |||
| // int xs = X * stride.x + padding.x; | |||
| // int ys = Y * stride.y + padding.y; | |||
| // | |||
| // for (int ky = 0; ky < kernel_size.y; ++ky) { | |||
| // int y_c = ys + ky; | |||
| // if (y_c < 0 || y_c >= input_shape.y) { | |||
| // continue; | |||
| // } | |||
| // for (int kx = 0; kx < kernel_size.x; ++kx) { | |||
| // int x_c = xs + kx; | |||
| // if (x_c < 0 || x_c >= input_shape.x) { | |||
| // continue; | |||
| // } | |||
| // float4 src = read_imagef(input, sample_none, (int2)(x_c, y_c * input_shape.w + Z)); | |||
| // maximum = max(src, maximum); | |||
| // } | |||
| // } | |||
| // write_imagef(output, (int2)(X, Y * output_shape.w + Z), maximum); | |||
| //} | |||
| __kernel void MaxPooling2d_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, | |||
| const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { | |||
| // axis to dst tensor coordinate | |||
| int X = get_global_id(0); | |||
| int Y = get_global_id(1); | |||
| int Z = get_global_id(2); | |||
| // boundary check | |||
| if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { | |||
| return; | |||
| } | |||
| float4 maximum = (float4)(-10000.0f); | |||
| int xs = X * stride.x + padding.x; | |||
| int ys = Y * stride.y + padding.y; | |||
| for (int kx = 0; kx < kernel_size.x; ++kx) { | |||
| int x_c = xs + kx; | |||
| if (x_c < 0 || x_c >= input_shape.x) { | |||
| continue; | |||
| } | |||
| for (int ky = 0; ky < kernel_size.y; ++ky) { | |||
| int y_c = ys + ky; | |||
| if (y_c < 0 || y_c >= input_shape.y) { | |||
| continue; | |||
| } | |||
| float4 src = read_imagef(input, sample_none, (int2)(x_c, y_c * input_shape.w + Z)); | |||
| maximum = max(src, maximum); | |||
| } | |||
| } | |||
| write_imagef(output, (int2)(X, Y * output_shape.w + Z), maximum); | |||
| } | |||
| @@ -13,13 +13,14 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/opencl/kernel/arithmetic.h" | |||
| #include <set> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/opencl/utils.h" | |||
| #include "src/runtime/kernel/opencl/kernel/arithmetic.h" | |||
| #ifndef PROGRAM_WITH_IL | |||
| #include "src/runtime/kernel/opencl/cl/fp32/arithmetic_buffer.cl.inc" | |||
| #include "src/runtime/kernel/opencl/cl/fp32/arithmetic_image2d.cl.inc" | |||
| @@ -41,8 +42,8 @@ std::vector<size_t> ArithmeticOpenCLKernel::InitGlobalSize() const { | |||
| void ArithmeticOpenCLKernel::Image2dGetWorkGroupSize() { | |||
| global_size_ = InitGlobalSize(); | |||
| int max_work_group_size = runtime_->GetKernelMaxWorkGroupSize(kernel_(), (*runtime_->Device())()); | |||
| local_size_ = GetLocalSize(global_size_, max_work_group_size); | |||
| global_size_ = GetGlobalSize(local_size_, global_size_); | |||
| local_size_ = GetCommonLocalSize(global_size_, max_work_group_size); | |||
| global_size_ = GetCommonGlobalSize(local_size_, global_size_); | |||
| } | |||
| void ArithmeticOpenCLKernel::BufferGetWorkGroupSize() { | |||
| @@ -31,12 +31,11 @@ | |||
| #endif | |||
| using mindspore::schema::PrimitiveType_DepthwiseConv2D; | |||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_DepthwiseConv2D; | |||
| namespace mindspore::kernel { | |||
| @@ -117,11 +116,9 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() { | |||
| return RET_OK; | |||
| } | |||
| int DepthwiseConv2dOpenCLKernel::ReSize() { | |||
| return RET_OK; | |||
| } | |||
| int DepthwiseConv2dOpenCLKernel::ReSize() { return RET_OK; } | |||
| int DepthwiseConv2dOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t>* img_size) { | |||
| int DepthwiseConv2dOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); | |||
| size_t im_dst_x, im_dst_y; | |||
| if (inputs_[0]->GetFormat() == schema::Format_NHWC4) { | |||
| @@ -141,16 +138,18 @@ int DepthwiseConv2dOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t>* i | |||
| *img_size = vec; | |||
| return RET_OK; | |||
| } | |||
| int DepthwiseConv2dOpenCLKernel::GetGlobalSize(size_t idx, std::vector<size_t>* global_size) { | |||
| int DepthwiseConv2dOpenCLKernel::GetGlobalSize(size_t idx, std::vector<size_t> *global_size) { | |||
| size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); | |||
| std::vector <size_t> global = {(size_t) outputs_[0]->Width(), (size_t) outputs_[0]->Height(), CO4}; | |||
| std::vector<size_t> global = {(size_t)outputs_[0]->Width(), (size_t)outputs_[0]->Height(), CO4}; | |||
| *global_size = std::move(global); | |||
| return RET_OK; | |||
| } | |||
| int DepthwiseConv2dOpenCLKernel::GetLocalSize(size_t idx, const std::vector<size_t>& global_size, | |||
| std::vector<size_t>* local_size) { | |||
| int DepthwiseConv2dOpenCLKernel::GetLocalSize(size_t idx, const std::vector<size_t> &global_size, | |||
| std::vector<size_t> *local_size) { | |||
| size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); | |||
| std::vector <size_t> local = {1, 1, CO4}; | |||
| std::vector<size_t> local = {1, 1, CO4}; | |||
| *local_size = std::move(local); | |||
| return RET_OK; | |||
| } | |||
| @@ -161,8 +160,8 @@ int DepthwiseConv2dOpenCLKernel::Run() { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); | |||
| size_t CI4 = UP_DIV(inputs_[0]->Channel(), C4NUM); | |||
| std::vector <size_t> global = {(size_t) outputs_[0]->Width(), (size_t) outputs_[0]->Height(), CO4}; | |||
| std::vector <size_t> local; | |||
| std::vector<size_t> global = {(size_t)outputs_[0]->Width(), (size_t)outputs_[0]->Height(), CO4}; | |||
| std::vector<size_t> local; | |||
| GetLocalSize(0, global, &local); | |||
| float relu_clip1 = 6.0; | |||
| @@ -28,11 +28,10 @@ namespace mindspore::kernel { | |||
| class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| explicit DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs), | |||
| packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {} | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs), packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {} | |||
| ~DepthwiseConv2dOpenCLKernel() override {}; | |||
| ~DepthwiseConv2dOpenCLKernel() override{}; | |||
| int Init() override; | |||
| @@ -42,20 +41,16 @@ class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { | |||
| int InitBuffer(); | |||
| int GetImageSize(size_t idx, std::vector<size_t>* img_size) override; | |||
| int GetGlobalSize(size_t idx, std::vector<size_t>* global_size) override; | |||
| int GetLocalSize(size_t idx, const std::vector<size_t>& global_size, | |||
| std::vector<size_t>* local_size) override; | |||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size) override; | |||
| int GetGlobalSize(size_t idx, std::vector<size_t> *global_size) override; | |||
| int GetLocalSize(size_t idx, const std::vector<size_t> &global_size, std::vector<size_t> *local_size) override; | |||
| private: | |||
| FLOAT_t *packed_weight_; | |||
| FLOAT_t *bias_data_; | |||
| cl::Kernel kernel_; | |||
| enum class MEM_TYPE { | |||
| BUF, IMG | |||
| } mem_type_{MEM_TYPE::IMG}; | |||
| enum class MEM_TYPE { BUF, IMG } mem_type_{MEM_TYPE::IMG}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_ | |||
| @@ -64,12 +64,18 @@ int PoolingOpenCLKernel::Init() { | |||
| #ifdef PROGRAM_WITH_IL | |||
| ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); | |||
| #else | |||
| if (mem_type_ == MEM_TYPE::BUF) { | |||
| kernel_name += "_BUF"; | |||
| } else { | |||
| kernel_name += "_IMG"; | |||
| } | |||
| std::set<std::string> build_options; | |||
| ocl_runtime->LoadSource(program_name, source); | |||
| ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||
| #endif | |||
| outputs_[0]->SetFormat(schema::Format_NHWC4); | |||
| MS_LOG(DEBUG) << kernel_name << " Init Done!"; | |||
| return RET_OK; | |||
| } | |||
| @@ -81,8 +87,30 @@ std::vector<size_t> PoolingOpenCLKernel::InitGlobalSize() const { | |||
| return global; | |||
| } | |||
| int PoolingOpenCLKernel::InitBuffer() { return 0; } | |||
| int PoolingOpenCLKernel::ReSize() { return 0; } | |||
| int PoolingOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); | |||
| size_t im_dst_x, im_dst_y; | |||
| if (inputs_[0]->GetFormat() == schema::Format_NHWC4) { | |||
| im_dst_x = outputs_[0]->Height(); | |||
| im_dst_y = outputs_[0]->Width() * CO4; | |||
| } else { | |||
| im_dst_y = outputs_[0]->Width(); | |||
| im_dst_x = outputs_[0]->Height() * CO4; | |||
| } | |||
| #ifdef ENABLE_FP16 | |||
| size_t img_dtype = CL_HALF_FLOAT; | |||
| #else | |||
| size_t img_dtype = CL_FLOAT; | |||
| #endif | |||
| img_size->clear(); | |||
| std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype}; | |||
| *img_size = vec; | |||
| return RET_OK; | |||
| } | |||
| int PoolingOpenCLKernel::InitBuffer() { return RET_OK; } | |||
| int PoolingOpenCLKernel::ReSize() { return RET_OK; } | |||
| int PoolingOpenCLKernel::Run() { | |||
| MS_LOG(DEBUG) << this->Name() << " Running!"; | |||
| @@ -110,12 +138,11 @@ int PoolingOpenCLKernel::Run() { | |||
| std::vector<size_t> local_size; | |||
| std::vector<size_t> global_size = InitGlobalSize(); | |||
| int max_work_group_size = ocl_runtime->GetKernelMaxWorkGroupSize(kernel_(), (*ocl_runtime->Device())()); | |||
| local_size = GetLocalSize(global_size, max_work_group_size); | |||
| global_size = GetGlobalSize(local_size, global_size); | |||
| local_size = GetCommonLocalSize(global_size, max_work_group_size); | |||
| global_size = GetCommonGlobalSize(local_size, global_size); | |||
| // run opengl kernel | |||
| ocl_runtime->RunKernel(kernel_, global_size, local_size, nullptr); | |||
| return RET_OK; | |||
| } | |||
| @@ -25,11 +25,11 @@ | |||
| namespace mindspore::kernel { | |||
| class PoolingOpenCLKernel : public LiteKernel { | |||
| class PoolingOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| explicit PoolingOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : LiteKernel(parameter, inputs, outputs) { | |||
| : OpenCLKernel(parameter, inputs, outputs) { | |||
| parameter_ = reinterpret_cast<PoolingParameter *>(parameter); | |||
| } | |||
| ~PoolingOpenCLKernel() override{}; | |||
| @@ -38,10 +38,11 @@ class PoolingOpenCLKernel : public LiteKernel { | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int InitBuffer(); | |||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size) override; | |||
| private: | |||
| std::vector<size_t> InitGlobalSize() const; | |||
| enum class MEM_TYPE { BUF, IMG } mem_type_{MEM_TYPE::IMG}; | |||
| PoolingParameter *parameter_; | |||
| cl::Kernel kernel_; | |||
| }; | |||
| @@ -49,4 +50,3 @@ class PoolingOpenCLKernel : public LiteKernel { | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_POOLING_H_ | |||
| @@ -22,7 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| std::vector<size_t> GetGlobalSize(const std::vector<size_t> &local, const std::vector<size_t> &global) { | |||
| std::vector<size_t> GetCommonGlobalSize(const std::vector<size_t> &local, const std::vector<size_t> &global) { | |||
| std::vector<size_t> result(3, 1); | |||
| for (int i = 0; i < 3; ++i) { | |||
| result[i] = AlignByN(global[i], local[i]); | |||
| @@ -30,7 +30,7 @@ std::vector<size_t> GetGlobalSize(const std::vector<size_t> &local, const std::v | |||
| return result; | |||
| } | |||
| std::vector<size_t> GetLocalSize(const std::vector<size_t> &global, int max_size) { | |||
| std::vector<size_t> GetCommonLocalSize(const std::vector<size_t> &global, int max_size) { | |||
| size_t wg_z = GetBiggestDividerWithPriority(global[2], 8); | |||
| size_t wg_xy_size = max_size / wg_z; | |||
| size_t wg_x = std::min(DivideRoundUp(global[0], 2), wg_xy_size); | |||
| @@ -75,10 +75,10 @@ T AlignByN(T number, N n) { | |||
| } | |||
| // GetGlobalSize | |||
| std::vector<size_t> GetGlobalSize(const std::vector<size_t> &local, const std::vector<size_t> &global); | |||
| std::vector<size_t> GetCommonGlobalSize(const std::vector<size_t> &local, const std::vector<size_t> &global); | |||
| // GetLocalSize | |||
| std::vector<size_t> GetLocalSize(const std::vector<size_t> &global, int max_size); | |||
| std::vector<size_t> GetCommonLocalSize(const std::vector<size_t> &global, int max_size); | |||
| std::string CLErrorCode(cl_int error_code); | |||
| @@ -43,33 +43,42 @@ TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) { | |||
| MS_LOG(INFO) << "ocl runtime"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| MS_LOG(INFO) << "PoolingParameter"; | |||
| auto param = new PoolingParameter; | |||
| InitParameter(param); | |||
| // define tensor | |||
| MS_LOG(INFO) << "define tensor"; | |||
| MS_LOG(INFO) << "define tensor1"; | |||
| std::vector<int> input_shape = {1, 16, 256, 192}; | |||
| std::vector<int> output_shape = {1, 8, 128, 192}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| auto tensorType = schema::NodeType_ValueNode; | |||
| MS_LOG(INFO) << "define tensor2"; | |||
| auto input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensorType); | |||
| auto output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensorType); | |||
| MS_LOG(INFO) << "define input"; | |||
| std::vector<lite::tensor::Tensor *> inputs{input_tensor}; | |||
| std::vector<lite::tensor::Tensor *> outputs{output_tensor}; | |||
| // run | |||
| MS_LOG(INFO) << "pooling_kernel"; | |||
| auto *pooling_kernel = new kernel::PoolingOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| MS_LOG(INFO) << "pooling_kernel init"; | |||
| pooling_kernel->Init(); | |||
| std::vector<kernel::LiteKernel *> kernels{pooling_kernel}; | |||
| inputs[0]->MallocData(allocator); | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| MS_LOG(INFO) << "pGraph init"; | |||
| pGraph->Init(); | |||
| // load data | |||
| MS_LOG(INFO) << "load data"; | |||
| MS_LOG(INFO) << "load data1"; | |||
| std::string input_file = "maxpool_in.bin"; | |||
| std::string expect_file = "maxpool_out.bin"; | |||
| MS_LOG(INFO) << "load data2"; | |||
| LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file); | |||
| auto *input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||
| printf("input[0:10]:"); | |||
| @@ -81,6 +90,7 @@ TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) { | |||
| pGraph->Run(); | |||
| MS_LOG(INFO) << "compare result"; | |||
| std::cout << "compare result" << std::endl; | |||
| CompareOutput(output_tensor, expect_file); | |||
| } | |||
| @@ -24,9 +24,14 @@ namespace mindspore { | |||
| void LoadTestData(void *dst, size_t dst_size, const std::string &file_path) { | |||
| if (file_path.empty()) { | |||
| memset(dst, dst_size, dst_size); | |||
| memset(dst, 0x00, dst_size); | |||
| } else { | |||
| memcpy(dst, reinterpret_cast<const void *>(dst_size), dst_size); | |||
| auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); | |||
| if (src_data != nullptr) { | |||
| memcpy(dst, src_data, dst_size); | |||
| } else { | |||
| MS_LOG(ERROR) << "read file empty."; | |||
| } | |||
| } | |||
| } | |||