Merge pull request !4086 from wandongdong/mastertags/v0.7.0-beta
| @@ -20,6 +20,7 @@ | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/fp32/arithmetic.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| namespace mindspore::kernel { | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/arm/base/concat_base.h" | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/runtime/kernel/arm/opclib/conv_parameter.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "src/ir/tensor.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/arm/opclib/conv_parameter.h" | |||
| @@ -17,10 +17,12 @@ | |||
| #include "src/runtime/kernel/opencl/kernel/depthwise_conv2d.h" | |||
| #include <string> | |||
| #include <set> | |||
| #include <utility> | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise.h" | |||
| #include "src/runtime/kernel/arm/opclib/pack.h" | |||
| #include "include/errorcode.h" | |||
| #ifndef PROGRAM_WITH_IL | |||
| @@ -29,9 +31,12 @@ | |||
| #endif | |||
| using mindspore::schema::PrimitiveType_DepthwiseConv2D; | |||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::schema::PrimitiveType_DepthwiseConv2D; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| namespace mindspore::kernel { | |||
| @@ -72,8 +77,8 @@ int DepthwiseConv2dOpenCLKernel::Init() { | |||
| ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||
| #endif | |||
| this->InitBuffer(); | |||
| MS_LOG(DEBUG) << kernel_name << " Init Done!"; | |||
| return 0; | |||
| MS_LOG(DEBUG) << kernel_name << " Init Done! mem type=" << static_cast<int>(mem_type_); | |||
| return RET_OK; | |||
| } | |||
| int DepthwiseConv2dOpenCLKernel::InitBuffer() { | |||
| @@ -109,10 +114,46 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() { | |||
| } else { | |||
| MS_ASSERT(inputs_.size() == kInputSize1); | |||
| } | |||
| return 0; | |||
| return RET_OK; | |||
| } | |||
| int DepthwiseConv2dOpenCLKernel::ReSize() { return 0; } | |||
| int DepthwiseConv2dOpenCLKernel::ReSize() { | |||
| return RET_OK; | |||
| } | |||
| int DepthwiseConv2dOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t>* img_size) { | |||
| size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); | |||
| size_t im_dst_x, im_dst_y; | |||
| if (inputs_[0]->GetFormat() == schema::Format_NHWC4) { | |||
| im_dst_x = outputs_[0]->Width() * CO4; | |||
| im_dst_y = outputs_[0]->Height(); | |||
| } else { | |||
| im_dst_y = outputs_[0]->Height() * CO4; | |||
| im_dst_x = outputs_[0]->Width(); | |||
| } | |||
| #ifdef ENABLE_FP16 | |||
| size_t img_dtype = CL_HALF_FLOAT; | |||
| #else | |||
| size_t img_dtype = CL_FLOAT; | |||
| #endif | |||
| img_size->clear(); | |||
| std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype}; | |||
| *img_size = vec; | |||
| return RET_OK; | |||
| } | |||
| int DepthwiseConv2dOpenCLKernel::GetGlobalSize(size_t idx, std::vector<size_t>* global_size) { | |||
| size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); | |||
| std::vector <size_t> global = {(size_t) outputs_[0]->Width(), (size_t) outputs_[0]->Height(), CO4}; | |||
| *global_size = std::move(global); | |||
| return RET_OK; | |||
| } | |||
| int DepthwiseConv2dOpenCLKernel::GetLocalSize(size_t idx, const std::vector<size_t>& global_size, | |||
| std::vector<size_t>* local_size) { | |||
| size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); | |||
| std::vector <size_t> local = {1, 1, CO4}; | |||
| *local_size = std::move(local); | |||
| return RET_OK; | |||
| } | |||
| int DepthwiseConv2dOpenCLKernel::Run() { | |||
| MS_LOG(DEBUG) << this->Name() << " Running!"; | |||
| @@ -120,8 +161,9 @@ int DepthwiseConv2dOpenCLKernel::Run() { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); | |||
| size_t CI4 = UP_DIV(inputs_[0]->Channel(), C4NUM); | |||
| std::vector<size_t> global = {(size_t)outputs_[0]->Width(), (size_t)outputs_[0]->Height(), CO4}; | |||
| std::vector<size_t> local = {1, 1, CO4}; | |||
| std::vector <size_t> global = {(size_t) outputs_[0]->Width(), (size_t) outputs_[0]->Height(), CO4}; | |||
| std::vector <size_t> local; | |||
| GetLocalSize(0, global, &local); | |||
| float relu_clip1 = 6.0; | |||
| cl_int2 kernel_size = {parameter->kernel_h_, parameter->kernel_w_}; | |||
| @@ -141,53 +183,10 @@ int DepthwiseConv2dOpenCLKernel::Run() { | |||
| ocl_runtime->SetKernelArg(kernel_, 8, dilation); | |||
| ocl_runtime->SetKernelArg(kernel_, 9, src_size); | |||
| ocl_runtime->SetKernelArg(kernel_, 10, dst_size); | |||
| if (mem_type_ == MEM_TYPE::BUF) { | |||
| ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, 4, outputs_[0]->Data()); | |||
| ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||
| } else { | |||
| cl::ImageFormat image_format; | |||
| { | |||
| image_format.image_channel_order = CL_RGBA; | |||
| image_format.image_channel_data_type = CL_FLOAT; | |||
| } | |||
| cl_int in_error_code; | |||
| size_t im_src_x, im_src_y; | |||
| size_t im_dst_x, im_dst_y; | |||
| if (inputs_[0]->GetFormat() == schema::Format_NHWC4) { | |||
| im_src_x = inputs_[0]->Width() * CI4; | |||
| im_src_y = inputs_[0]->Height(); | |||
| im_dst_x = outputs_[0]->Width() * CO4; | |||
| im_dst_y = outputs_[0]->Height(); | |||
| } else { | |||
| im_src_y = inputs_[0]->Height() * CI4; | |||
| im_src_x = inputs_[0]->Width(); | |||
| im_dst_y = outputs_[0]->Height() * CO4; | |||
| im_dst_x = outputs_[0]->Width(); | |||
| } | |||
| cl::Image2D in_mem(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, im_src_x, | |||
| im_src_y, 0, inputs_[0]->Data(), &in_error_code); | |||
| cl_int out_error_code; | |||
| cl::Image2D out_mem(*ocl_runtime->Context(), CL_MEM_WRITE_ONLY, image_format, im_dst_x, im_dst_y, 0, nullptr, | |||
| &out_error_code); | |||
| if (in_error_code != CL_SUCCESS) { | |||
| MS_LOG(DEBUG) << "in Image2D Failed, error=" << in_error_code; | |||
| return 1; | |||
| } | |||
| if (out_error_code != CL_SUCCESS) { | |||
| MS_LOG(DEBUG) << "out Image2D Failed, error= " << out_error_code; | |||
| return 1; | |||
| } | |||
| auto origin = cl::array<cl::size_type, 3U>{0, 0, 0}; | |||
| auto region = cl::array<cl::size_type, 3U>{im_dst_x, im_dst_y, 1}; | |||
| ocl_runtime->SetKernelArg(kernel_, 0, in_mem); | |||
| ocl_runtime->SetKernelArg(kernel_, 4, out_mem); | |||
| ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||
| ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(out_mem, CL_TRUE, origin, region, 0, 0, | |||
| outputs_[0]->Data()); | |||
| } | |||
| return 0; | |||
| ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, 4, outputs_[0]->Data()); | |||
| ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *OpenCLDepthwiseConv2dKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| @@ -18,17 +18,17 @@ | |||
| #define MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/runtime/kernel/arm/opclib/conv_parameter.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| namespace mindspore::kernel { | |||
| class DepthwiseConv2dOpenCLKernel : public LiteKernel { | |||
| class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| explicit DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : LiteKernel(parameter, inputs, outputs), | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs), | |||
| packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {} | |||
| ~DepthwiseConv2dOpenCLKernel() override {}; | |||
| @@ -41,13 +41,18 @@ class DepthwiseConv2dOpenCLKernel : public LiteKernel { | |||
| int InitBuffer(); | |||
| int GetImageSize(size_t idx, std::vector<size_t>* img_size) override; | |||
| int GetGlobalSize(size_t idx, std::vector<size_t>* global_size) override; | |||
| int GetLocalSize(size_t idx, const std::vector<size_t>& global_size, | |||
| std::vector<size_t>* local_size) override; | |||
| private: | |||
| FLOAT_t *packed_weight_; | |||
| FLOAT_t *bias_data_; | |||
| cl::Kernel kernel_; | |||
| enum class MEM_TYPE { | |||
| BUF, IMG | |||
| } mem_type_{MEM_TYPE::BUF}; | |||
| } mem_type_{MEM_TYPE::IMG}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/runtime/kernel/arm/opclib/conv_parameter.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/runtime/kernel/arm/opclib/fp32/pooling.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/runtime/kernel/arm/opclib/fp32/softmax.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_OPENCL_KERNEL_H_ | |||
| #define MINDSPORE_LITE_SRC_OPENCL_KERNEL_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| namespace mindspore::kernel { | |||
| class OpenCLKernel : public LiteKernel { | |||
| public: | |||
| explicit OpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : LiteKernel(parameter, inputs, outputs) {} | |||
| virtual int Init() { return -1; } | |||
| virtual int Prepare() { return -1; } | |||
| virtual int InferShape() { return -1; } | |||
| virtual int ReSize() { return -1; } | |||
| virtual int Run() { return -1; } | |||
| virtual int GetImageSize(size_t idx, std::vector<size_t>* img_size) { return -1; } | |||
| virtual int GetGlobalSize(size_t idx, std::vector<size_t>* global_size) { return -1; } | |||
| virtual int GetLocalSize(size_t idx, const std::vector<size_t>& global_size, | |||
| std::vector<size_t>* local_size) { return -1; } | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_OPENCL_KERNEL_H_ | |||
| @@ -32,9 +32,10 @@ int SubGraphOpenCLKernel::Init() { | |||
| } | |||
| // Map buffer for write, it is not necessary for fine-grained | |||
| for (auto &tensor : inputs_) { | |||
| void *data = allocator_->MapBuffer(tensor->Data(), CL_MAP_WRITE, nullptr, true); | |||
| void *data = tensor->Data(); | |||
| // It is required with coarse-grained SVM | |||
| if (data != nullptr) { | |||
| data = allocator_->MapBuffer(data, CL_MAP_WRITE, nullptr, true); | |||
| tensor->SetData(data); | |||
| } else { | |||
| MS_LOG(ERROR) << "OpenCL kernel must use GPU buffer pointer, " | |||
| @@ -18,7 +18,7 @@ | |||
| #define MINDSPORE_LITE_SRC_BACKEND_OPENCL_SUBGRAPH_OPENCL_KENEL_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/runtime/opencl/opencl_allocator.h" | |||
| namespace mindspore::kernel { | |||
| @@ -21,6 +21,7 @@ | |||
| #include <vector> | |||
| #include "CL/cl2.hpp" | |||
| #include "utils/log_adapter.h" | |||
| #include "src/runtime/kernel/arm/opclib/op_base.h" | |||
| namespace mindspore::kernel { | |||
| @@ -81,7 +82,6 @@ std::vector<size_t> GetLocalSize(const std::vector<size_t> &global, int max_size | |||
| std::string CLErrorCode(cl_int error_code); | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_ | |||
| @@ -18,6 +18,7 @@ | |||
| #include <utility> | |||
| #include "utils/log_adapter.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore::lite::opencl { | |||
| @@ -61,7 +62,7 @@ void *OpenCLAllocator::Malloc(size_t size) { | |||
| auto svm_capabilities = ocl_runtime->GetSVMCapabilities(); | |||
| void *host_ptr = nullptr; | |||
| void *device_ptr = nullptr; | |||
| if (svm_capabilities) { | |||
| if (svm_capabilities && svm_on_) { | |||
| cl_svm_mem_flags flags = (svm_capabilities & CL_DEVICE_SVM_FINE_GRAIN_BUFFER) ? CL_MEM_SVM_FINE_GRAIN_BUFFER : 0; | |||
| flags |= (svm_capabilities & CL_DEVICE_SVM_ATOMICS) ? CL_MEM_SVM_ATOMICS : 0; | |||
| flags = flags | CL_MEM_READ_WRITE; | |||
| @@ -69,7 +70,7 @@ void *OpenCLAllocator::Malloc(size_t size) { | |||
| } else { | |||
| cl_int ret = CL_SUCCESS; | |||
| cl::Buffer *buffer = | |||
| new cl::Buffer(*ocl_runtime->Context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, size, NULL, &ret); | |||
| new cl::Buffer(*ocl_runtime->Context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, size, NULL, &ret); | |||
| if (ret != CL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Create OpenCL buffer failed! (ERROR CODE: " << ret << ")"; | |||
| UnLock(); | |||
| @@ -77,7 +78,13 @@ void *OpenCLAllocator::Malloc(size_t size) { | |||
| } | |||
| device_ptr = static_cast<void *>(buffer); | |||
| host_ptr = ocl_runtime->MapBuffer(*buffer, CL_MAP_READ | CL_MAP_WRITE, size); | |||
| ocl_runtime->UnmapBuffer(*buffer, host_ptr); | |||
| if (host_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << device_ptr << ", host_ptr=" << host_ptr; | |||
| UnLock(); | |||
| return nullptr; | |||
| } | |||
| cl::Memory *mem = buffer; | |||
| ocl_runtime->UnmapBuffer(*mem, host_ptr); | |||
| } | |||
| std::unique_ptr<MemBuf> mem_buf = std::make_unique<MemBuf>(); | |||
| mem_buf->size_ = size; | |||
| @@ -90,6 +97,113 @@ void *OpenCLAllocator::Malloc(size_t size) { | |||
| return host_ptr; | |||
| } | |||
| void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t>& img_size) { | |||
| if (size > MAX_MALLOC_SIZE) { | |||
| MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; | |||
| return nullptr; | |||
| } | |||
| auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); | |||
| Lock(); | |||
| auto iter = free_list_.lower_bound(size); | |||
| if (iter != free_list_.end() && (iter->second->size_ >= size) && (iter->second->size_ < (size << shift_factor_))) { | |||
| auto mem_buf = iter->second; | |||
| bool is_match{mem_buf->img_size.size() == img_size.size()}; | |||
| for (int i = 0; i < img_size.size() && is_match; ++i) { | |||
| is_match = img_size[i] == mem_buf->img_size[i]; | |||
| } | |||
| if (is_match) { | |||
| free_list_.erase(iter); | |||
| allocated_list_[mem_buf->host_ptr_] = mem_buf; | |||
| UnLock(); | |||
| MS_LOG(DEBUG) << "Malloc Image2D from free list. size: " << mem_buf->size_ | |||
| << ", host addr: " << mem_buf->host_ptr_ << ", device addr: " << mem_buf->device_ptr_; | |||
| return mem_buf->host_ptr_; | |||
| } | |||
| } | |||
| void *host_ptr = nullptr; | |||
| void *device_ptr = nullptr; | |||
| cl_int ret = CL_SUCCESS; | |||
| // CL_HALF_FLOAT, CL_FLOAT | |||
| cl::ImageFormat image_format(CL_RGBA, img_size[2]); | |||
| cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, | |||
| image_format, img_size[0], img_size[1], 0, nullptr, &ret); | |||
| if (ret != CL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; | |||
| UnLock(); | |||
| return nullptr; | |||
| } | |||
| device_ptr = static_cast<void *>(buffer); | |||
| std::vector<size_t> region{img_size[0], img_size[1], 1}; | |||
| host_ptr = ocl_runtime->MapBuffer(*buffer, 0, CL_MAP_READ | CL_MAP_WRITE, region); | |||
| if (host_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << device_ptr << ", host_ptr=" << host_ptr; | |||
| UnLock(); | |||
| return nullptr; | |||
| } | |||
| cl::Memory *mem = buffer; | |||
| ocl_runtime->UnmapBuffer(*mem, host_ptr); | |||
| std::unique_ptr<MemBuf> mem_buf = std::make_unique<MemBuf>(); | |||
| mem_buf->size_ = size; | |||
| mem_buf->device_ptr_ = device_ptr; | |||
| mem_buf->host_ptr_ = host_ptr; | |||
| mem_buf->img_size = img_size; | |||
| MS_LOG(DEBUG) << "Malloc a new Image2D. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ | |||
| << ", device addr: " << mem_buf->device_ptr_; | |||
| allocated_list_[host_ptr] = mem_buf.release(); | |||
| UnLock(); | |||
| return host_ptr; | |||
| } | |||
| void *OpenCLAllocator::CreateImageFromHost(void *data, size_t size, const std::vector<size_t>& img_size) { | |||
| if (size > MAX_MALLOC_SIZE) { | |||
| MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; | |||
| return nullptr; | |||
| } | |||
| auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); | |||
| Lock(); | |||
| auto iter = free_list_.lower_bound(size); | |||
| if (iter != free_list_.end() && (iter->second->size_ >= size) && (iter->second->size_ < (size << shift_factor_))) { | |||
| auto mem_buf = iter->second; | |||
| free_list_.erase(iter); | |||
| allocated_list_[mem_buf->host_ptr_] = mem_buf; | |||
| UnLock(); | |||
| MS_LOG(DEBUG) << "Malloc Image2D from free list. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ | |||
| << ", device addr: " << mem_buf->device_ptr_; | |||
| return mem_buf->host_ptr_; | |||
| } | |||
| void *host_ptr = nullptr; | |||
| void *device_ptr = nullptr; | |||
| cl_int ret = CL_SUCCESS; | |||
| // CL_HALF_FLOAT, CL_FLOAT | |||
| cl::ImageFormat image_format(CL_RGBA, img_size[2]); | |||
| cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, | |||
| img_size[0], img_size[1], 0, data, &ret); | |||
| if (ret != CL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; | |||
| UnLock(); | |||
| return nullptr; | |||
| } | |||
| device_ptr = static_cast<void *>(buffer); | |||
| std::vector<size_t> region{img_size[0], img_size[1], 1}; | |||
| host_ptr = ocl_runtime->MapBuffer(*buffer, 0, CL_MAP_READ | CL_MAP_WRITE, region); | |||
| if (host_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << device_ptr << ", host_ptr=" << host_ptr; | |||
| UnLock(); | |||
| return nullptr; | |||
| } | |||
| cl::Memory *mem = buffer; | |||
| ocl_runtime->UnmapBuffer(*mem, host_ptr); | |||
| std::unique_ptr<MemBuf> mem_buf = std::make_unique<MemBuf>(); | |||
| mem_buf->size_ = size; | |||
| mem_buf->device_ptr_ = device_ptr; | |||
| mem_buf->host_ptr_ = host_ptr; | |||
| mem_buf->img_size = img_size; | |||
| MS_LOG(DEBUG) << "Malloc a new Image2D. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ | |||
| << ", device addr: " << mem_buf->device_ptr_; | |||
| allocated_list_[host_ptr] = mem_buf.release(); | |||
| UnLock(); | |||
| return host_ptr; | |||
| } | |||
| void OpenCLAllocator::Free(void *buf) { | |||
| if (buf == nullptr) { | |||
| return; | |||
| @@ -163,7 +277,7 @@ void OpenCLAllocator::Clear() { | |||
| void *OpenCLAllocator::MapBuffer(void *host_ptr, int flags, void *command_queue, bool sync) { | |||
| auto ocl_runtime = opencl::OpenCLRuntime::GetInstance(); | |||
| auto svm_capabilities = ocl_runtime->GetSVMCapabilities(); | |||
| if (svm_capabilities) { | |||
| if (svm_capabilities && svm_on_) { | |||
| if (!(svm_capabilities & CL_DEVICE_SVM_FINE_GRAIN_BUFFER)) { | |||
| auto it = allocated_list_.find(host_ptr); | |||
| if (it == allocated_list_.end()) { | |||
| @@ -178,11 +292,25 @@ void *OpenCLAllocator::MapBuffer(void *host_ptr, int flags, void *command_queue, | |||
| auto it = allocated_list_.find(host_ptr); | |||
| if (it == allocated_list_.end()) { | |||
| MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << host_ptr; | |||
| UnLock(); | |||
| return nullptr; | |||
| } | |||
| MemBuf *mem_buf = it->second; | |||
| cl::Buffer *buffer = static_cast<cl::Buffer *>(mem_buf->device_ptr_); | |||
| void *new_host_ptr = ocl_runtime->MapBuffer(*buffer, flags, mem_buf->size_, nullptr, sync); | |||
| void *new_host_ptr{nullptr}; | |||
| if (mem_buf->img_size.empty()) { | |||
| cl::Buffer *buffer = static_cast<cl::Buffer *>(mem_buf->device_ptr_); | |||
| new_host_ptr = ocl_runtime->MapBuffer(*buffer, flags, mem_buf->size_, nullptr, sync); | |||
| } else { | |||
| cl::ImageFormat image_format(CL_RGBA, mem_buf->img_size[2]); | |||
| std::vector<size_t> region{mem_buf->img_size[0], mem_buf->img_size[1], 1}; | |||
| cl::Image2D *buffer = static_cast<cl::Image2D *>(mem_buf->device_ptr_); | |||
| new_host_ptr = ocl_runtime->MapBuffer(*buffer, 0, CL_MAP_READ | CL_MAP_WRITE, region); | |||
| } | |||
| if (new_host_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << mem_buf->device_ptr_ << ", host_ptr=" << host_ptr; | |||
| UnLock(); | |||
| return nullptr; | |||
| } | |||
| mem_buf->host_ptr_ = new_host_ptr; | |||
| allocated_list_.erase(it); | |||
| allocated_list_[new_host_ptr] = mem_buf; | |||
| @@ -208,5 +336,40 @@ int OpenCLAllocator::UnmapBuffer(void *host_ptr, void *command_queue) { | |||
| return ocl_runtime->UnmapBuffer(*buffer, it->second->host_ptr_, static_cast<cl::CommandQueue *>(command_queue)); | |||
| } | |||
| MEM_TYPE OpenCLAllocator::GetMemType(void *host_ptr) { | |||
| MEM_TYPE mem_type{MEM_TYPE::BUF}; | |||
| Lock(); | |||
| auto it = allocated_list_.find(host_ptr); | |||
| if (it == allocated_list_.end()) { | |||
| MS_LOG(ERROR) << "Can not found buffer :" << host_ptr; | |||
| UnLock(); | |||
| return mem_type; | |||
| } | |||
| MemBuf *mem_buf = it->second; | |||
| if (mem_buf->img_size.empty()) { | |||
| mem_type = MEM_TYPE::BUF; | |||
| } else { | |||
| mem_type = MEM_TYPE::IMG; | |||
| } | |||
| UnLock(); | |||
| return mem_type; | |||
| } | |||
| int OpenCLAllocator::GetImageSize(void *host_ptr, std::vector<size_t>* img_size) { | |||
| Lock(); | |||
| auto it = allocated_list_.find(host_ptr); | |||
| if (it == allocated_list_.end()) { | |||
| MS_LOG(ERROR) << "Can not found buffer :" << host_ptr; | |||
| UnLock(); | |||
| return RET_OK; | |||
| } | |||
| MemBuf *mem_buf = it->second; | |||
| if (!mem_buf->img_size.empty()) { | |||
| *img_size = mem_buf->img_size; | |||
| } | |||
| UnLock(); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite::opencl | |||
| @@ -39,18 +39,27 @@ struct OpenclMemory { | |||
| OpenCLMemoryType mem_type{MS_HOST_BUFFER | MS_CL_BUFFER}; | |||
| }; | |||
| enum class MEM_TYPE : char { | |||
| BUF, IMG | |||
| }; | |||
| class OpenCLAllocator : public Allocator { | |||
| public: | |||
| OpenCLAllocator(); | |||
| ~OpenCLAllocator() override; | |||
| void SetContext(const AllocatorContext &ctx) override; | |||
| void *Malloc(size_t size) override; | |||
| void *Malloc(size_t size, const std::vector<size_t>& img_size); | |||
| void *CreateImageFromHost(void *host_ptr, size_t size, const std::vector<size_t>& img_size); | |||
| void Free(void *ptr) override; | |||
| size_t GetTotalSize() override; | |||
| void Clear() override; | |||
| void *GetDeviceBuffer(void *buffer); | |||
| void *MapBuffer(void *host_ptr, int flags, void *command_queue = nullptr, bool sync = true); | |||
| int UnmapBuffer(void *host_ptr, void *command_queue = nullptr); | |||
| MEM_TYPE GetMemType(void *host_ptr); | |||
| int GetImageSize(void *host_ptr, std::vector<size_t>* img_size); | |||
| private: | |||
| void Lock(); | |||
| @@ -59,6 +68,7 @@ class OpenCLAllocator : public Allocator { | |||
| size_t size_; | |||
| void *device_ptr_; | |||
| void *host_ptr_; | |||
| std::vector<size_t> img_size; | |||
| }; | |||
| std::mutex lock; | |||
| @@ -68,6 +78,7 @@ class OpenCLAllocator : public Allocator { | |||
| // 6 is empirical value | |||
| int shift_factor_ = 6; | |||
| bool lock_flag_ = false; | |||
| bool svm_on_{false}; | |||
| }; | |||
| } // namespace mindspore::lite::opencl | |||
| @@ -15,9 +15,10 @@ | |||
| */ | |||
| #include "src/runtime/opencl/opencl_executor.h" | |||
| #include "src/runtime/kernel/opencl/utils.h" | |||
| #include "src/runtime/kernel/arm/opclib/pack.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/common/ms_tensor_utils.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore::lite::opencl { | |||
| int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tensor::Tensor *> &outputs, | |||
| @@ -29,23 +30,32 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso | |||
| MS_LOG(ERROR) << "Graph input tensor is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| if (inTensor->GetFormat() != schema::Format_NHWC4 && inTensor->GetFormat() != schema::Format_NC4HW4) { | |||
| if (inTensor->GetFormat() != schema::Format_NHWC) { | |||
| MS_LOG(ERROR) << "Model input should be NHWC, actual is " << schema::EnumNameFormat(inTensor->GetFormat()); | |||
| return RET_ERROR; | |||
| } else { | |||
| TransformTensorLayout(inTensor, schema::Format_NHWC4); | |||
| // TransformTensorLayout(inTensor, schema::Format_NC4HW4); | |||
| } | |||
| if (inTensor->GetFormat() != schema::Format_NHWC4 && inTensor->GetFormat() != schema::Format_NC4HW4 && | |||
| inTensor->GetFormat() != schema::Format_NHWC) { | |||
| MS_LOG(ERROR) << "input should be NHWC/NHWC4/NC4HW4, actual is " << schema::EnumNameFormat(inTensor->GetFormat()); | |||
| return RET_ERROR; | |||
| } else { | |||
| TransformTensorLayout(inTensor, inTensor->GetFormat(), schema::Format_NHWC4, true); | |||
| // TransformTensorLayout(inTensor, inTensor->GetFormat(), schema::Format_NC4HW4, true); | |||
| } | |||
| } | |||
| kernel::LiteKernelUtil::InitTensorRefCount(kernels); | |||
| OpenCLAllocator* op_allocator = reinterpret_cast<OpenCLAllocator*>(allocator); | |||
| for (auto *kernel : kernels) { | |||
| MS_ASSERT(nullptr != kernel); | |||
| kernel::OpenCLKernel *op_kernel = reinterpret_cast<kernel::OpenCLKernel*>(kernel); | |||
| auto &outputs = kernel->GetOutputs(); | |||
| for (auto *output : outputs) { | |||
| for (auto i = 0; i < outputs.size(); ++i) { | |||
| auto *output = outputs.at(i); | |||
| MS_ASSERT(nullptr != output); | |||
| output->MallocData(); | |||
| if (is_image2d_out_) { | |||
| std::vector<size_t> img_size; | |||
| op_kernel->GetImageSize(i, &img_size); | |||
| auto data_ptr = op_allocator->Malloc(output->Size(), img_size); | |||
| output->SetData(data_ptr); | |||
| } else { | |||
| output->MallocData(allocator); | |||
| } | |||
| } | |||
| session::CallBackParam callbackParam; | |||
| callbackParam.name_callback_param = kernel->Name(); | |||
| @@ -81,21 +91,22 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso | |||
| return RET_ERROR; | |||
| } | |||
| if (outTensor->GetFormat() != schema::Format_NHWC) { | |||
| MS_LOG(ERROR) << "Model output tensor should be NHWC"; | |||
| TransformTensorLayout(outTensor, outTensor->GetFormat(), schema::Format_NHWC, false); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format dst_format) { | |||
| int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format src_format, | |||
| schema::Format dst_format, bool trans_dir) { | |||
| MS_ASSERT(nullptr != tensor); | |||
| MS_ASSERT(4 == tensor->shape().size()); | |||
| auto data_type = tensor->data_type(); | |||
| switch (data_type) { | |||
| case kNumberTypeInt8: | |||
| return TransformTensorLayoutUint8(tensor, dst_format); | |||
| return TransformTensorLayoutUint8(tensor, src_format, dst_format, trans_dir); | |||
| case kNumberTypeFloat32: | |||
| return TransformTensorLayoutFp32(tensor, dst_format); | |||
| return TransformTensorLayoutFp32(tensor, src_format, dst_format, trans_dir); | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| << schema::EnumNameFormat(dst_format); | |||
| @@ -104,21 +115,103 @@ int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format | |||
| return RET_OK; | |||
| } | |||
| int OpenCLExecutor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format dst_format) { | |||
| int OpenCLExecutor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format src_format, | |||
| schema::Format dst_format, bool trans_dir) { | |||
| MS_ASSERT(nullptr != tensor); | |||
| MS_ASSERT(nullptr != allocator_); | |||
| MS_ASSERT(4 == tensor->shape().size()); | |||
| if (trans_dir) { | |||
| if (is_image2d_out_) { | |||
| return TransformTensorLayoutToImage(tensor, src_format, dst_format); | |||
| } else { | |||
| return TransformTensorLayoutToBuffer(tensor, src_format, dst_format); | |||
| } | |||
| } else { | |||
| if (is_image2d_out_) { | |||
| return TransformTensorLayoutFromImage(tensor, src_format, dst_format); | |||
| } else { | |||
| return TransformTensorLayoutToBuffer(tensor, src_format, dst_format); | |||
| } | |||
| } | |||
| } | |||
| int OpenCLExecutor::TransformTensorLayoutToBuffer(tensor::Tensor *tensor, schema::Format src_format, | |||
| schema::Format dst_format) { | |||
| if (dst_format == schema::Format_NHWC4) { | |||
| auto *src_data = tensor->Data(); | |||
| auto *dst_data = allocator_->Malloc(tensor->Size()); | |||
| if (dst_data == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc data failed"; | |||
| return RET_ERROR; | |||
| size_t C4 = UP_DIV(tensor->Channel(), C4NUM); | |||
| std::vector <size_t> img_size{tensor->Width() * C4, (size_t) tensor->Height(), CL_FLOAT}; | |||
| if (src_format == schema::Format_NHWC) { | |||
| auto *dst_data = allocator_->Malloc(tensor->Size(), img_size); | |||
| if (dst_data == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| dst_data = reinterpret_cast<FLOAT_t *>(allocator_->MapBuffer(dst_data, CL_MAP_WRITE, nullptr, true)); | |||
| PackNHWCToNHWC4Fp32(src_data, dst_data, tensor->Batch(), tensor->Height() * tensor->Width(), tensor->Channel()); | |||
| tensor->SetData(dst_data); | |||
| allocator_->Free(src_data); | |||
| allocator_->UnmapBuffer(dst_data); | |||
| } | |||
| dst_data = reinterpret_cast<FLOAT_t *>(allocator_->MapBuffer(dst_data, CL_MAP_WRITE, nullptr, true)); | |||
| PackNHWCToNHWC4Fp32(src_data, dst_data, tensor->Batch(), tensor->Height() * tensor->Width(), tensor->Channel()); | |||
| tensor->SetData(dst_data); | |||
| tensor->SetFormat(dst_format); | |||
| return RET_OK; | |||
| } else if (dst_format == schema::Format_NHWC) { | |||
| // TODO(wandongdong): add support !! | |||
| return RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| << schema::EnumNameFormat(dst_format) << " in float32"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| int OpenCLExecutor::TransformTensorLayoutToImage(tensor::Tensor *tensor, schema::Format src_format, | |||
| schema::Format dst_format) { | |||
| if (dst_format == schema::Format_NHWC4) { | |||
| // convert to nhwc4 | |||
| auto *src_data = tensor->Data(); | |||
| auto *dst_data{src_data}; | |||
| if (src_format == schema::Format_NHWC) { | |||
| dst_data = allocator_->Malloc(tensor->Size()); | |||
| if (dst_data == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| dst_data = reinterpret_cast<FLOAT_t *>(allocator_->MapBuffer(dst_data, CL_MAP_WRITE, nullptr, true)); | |||
| PackNHWCToNHWC4Fp32(src_data, dst_data, tensor->Batch(), tensor->Height() * tensor->Width(), tensor->Channel()); | |||
| tensor->SetData(dst_data); | |||
| allocator_->Free(src_data); | |||
| allocator_->UnmapBuffer(dst_data); | |||
| } | |||
| // copy to image2d | |||
| src_data = dst_data; | |||
| size_t C4 = UP_DIV(tensor->Channel(), C4NUM); | |||
| std::vector<size_t> img_size{tensor->Width() * C4, (size_t)tensor->Height(), CL_FLOAT}; | |||
| dst_data = allocator_->CreateImageFromHost(src_data, tensor->Size(), img_size); | |||
| tensor->SetData(dst_data); | |||
| allocator_->Free(src_data); | |||
| tensor->SetFormat(schema::Format_NHWC4); | |||
| return RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| << schema::EnumNameFormat(dst_format) << " in float32"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| int OpenCLExecutor::TransformTensorLayoutFromImage(tensor::Tensor *tensor, schema::Format src_format, | |||
| schema::Format dst_format) { | |||
| if (dst_format == schema::Format_NHWC) { | |||
| auto src_data = tensor->Data(); | |||
| auto dst_data = allocator_->Malloc(tensor->Size()); | |||
| cl::Image2D *out_mem = reinterpret_cast<cl::Image2D *>(allocator_->GetDeviceBuffer(src_data)); | |||
| std::vector<size_t> img_size; | |||
| allocator_->GetImageSize(src_data, &img_size); | |||
| auto origin = cl::array < cl::size_type, 3U > {0, 0, 0}; | |||
| auto region = cl::array < cl::size_type, 3U > {img_size[0], img_size[1], 1}; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(*out_mem, CL_TRUE, origin, region, 0, 0, dst_data); | |||
| tensor->SetData(dst_data); | |||
| allocator_->Free(src_data); | |||
| return RET_OK; | |||
| } else { | |||
| @@ -128,7 +221,8 @@ int OpenCLExecutor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Fo | |||
| } | |||
| } | |||
| int OpenCLExecutor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format dst_format) { | |||
| int OpenCLExecutor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format src_format, | |||
| schema::Format dst_format, bool is_image) { | |||
| MS_ASSERT(nullptr != tensor); | |||
| MS_ASSERT(4 == tensor->shape().size()); | |||
| // auto src_format = tensor->GetFormat(); | |||
| @@ -20,7 +20,7 @@ | |||
| #include <vector> | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/allocator.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/executor.h" | |||
| #include "include/lite_session.h" | |||
| @@ -38,15 +38,25 @@ class OpenCLExecutor : Executor { | |||
| const session::KernelCallBack &before = nullptr, const session::KernelCallBack &after = nullptr); | |||
| protected: | |||
| int TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format dst_format); | |||
| int TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format, | |||
| bool trans_dir = false); | |||
| int TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format dst_format); | |||
| int TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format, | |||
| bool trans_dir = false); | |||
| int TransformTensorLayout(tensor::Tensor *tensor, schema::Format dst_format); | |||
| int TransformTensorLayout(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format, | |||
| bool trans_dir = false); | |||
| int TransformTensorLayoutToBuffer(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format); | |||
| int TransformTensorLayoutToImage(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format); | |||
| int TransformTensorLayoutFromImage(tensor::Tensor *tensor, schema::Format src_format, schema::Format dst_format); | |||
| protected: | |||
| Context *context = nullptr; | |||
| OpenCLAllocator *allocator_; | |||
| bool is_image2d_out_{true}; | |||
| }; | |||
| } // namespace mindspore::lite::opencl | |||
| @@ -124,8 +124,13 @@ int OpenCLRuntime::Init() { | |||
| const std::string device_name = device_->getInfo<CL_DEVICE_NAME>(); | |||
| const std::string device_version = device_->getInfo<CL_DEVICE_VERSION>(); | |||
| const std::string opencl_version = device_->getInfo<CL_DEVICE_OPENCL_C_VERSION>(); | |||
| cl_uint align; | |||
| size_t ret; | |||
| clGetDeviceInfo((*device_)(), CL_DEVICE_IMAGE_PITCH_ALIGNMENT, sizeof(cl_uint), &align, &ret); | |||
| MS_LOG(INFO) << "Device name:\t" << device_name; | |||
| MS_LOG(INFO) << "Opencl version:\t" << device_version; | |||
| MS_LOG(INFO) << "Image alignment:\t" << align; | |||
| MS_LOG(INFO) << "Image ret:\t" << ret; | |||
| MS_LOG(INFO) << "Highest OpenCL c version:\t" << opencl_version; | |||
| MS_LOG(INFO) << "Max work item size:\t" | |||
| << max_work_item_sizes_[0] << " : " | |||
| @@ -133,7 +138,6 @@ int OpenCLRuntime::Init() { | |||
| << max_work_item_sizes_[2]; | |||
| gpu_info_ = ParseGpuInfo(device_name, device_version); | |||
| cl_int err; | |||
| #if defined(SHARING_MEM_WITH_OPENGL) && (CL_HPP_TARGET_OPENCL_VERSION >= 120) | |||
| // create context from glcontext | |||
| @@ -164,6 +168,7 @@ int OpenCLRuntime::Init() { | |||
| support_fp16_ = CL_SUCCESS == success && fp_config > 0; | |||
| err = device_->getInfo(CL_DEVICE_SVM_CAPABILITIES, &svm_capabilities_); | |||
| svm_capabilities_ = 0; | |||
| if (err != CL_SUCCESS || svm_capabilities_ == 0) { | |||
| svm_capabilities_ = 0; | |||
| MS_LOG(INFO) << "SVM capalibilties: " | |||
| @@ -535,7 +540,19 @@ int OpenCLRuntime::MapBuffer(void *host_ptr, int flags, size_t size, cl::Command | |||
| return command_queue->enqueueMapSVM(host_ptr, sync, flags, size); | |||
| } | |||
| int OpenCLRuntime::UnmapBuffer(const cl::Buffer buffer, void *host_ptr, cl::CommandQueue *command_queue) const { | |||
| void *OpenCLRuntime::MapBuffer(const cl::Image2D buffer, bool sync, int flags, | |||
| const std::vector<size_t>& region, cl::CommandQueue *command_queue) const { | |||
| if (command_queue == nullptr) { | |||
| command_queue = default_command_queue_.get(); | |||
| } | |||
| cl::size_type row_pitch; | |||
| cl::size_type slice_pitch; | |||
| cl::array<cl::size_type, 3> origin_{0, 0, 0}; | |||
| cl::array<cl::size_type, 3> region_{region[0], region[1], region[2]}; | |||
| return command_queue->enqueueMapImage(buffer, sync, flags, origin_, region_, &row_pitch, &slice_pitch); | |||
| } | |||
| int OpenCLRuntime::UnmapBuffer(const cl::Memory buffer, void *host_ptr, cl::CommandQueue *command_queue) const { | |||
| if (command_queue == nullptr) { | |||
| command_queue = default_command_queue_.get(); | |||
| } | |||
| @@ -75,9 +75,16 @@ class OpenCLRuntime { | |||
| MS_LOG(DEBUG) << "Set kernel arg[" << index << "] SVM pointer " << value; | |||
| return clSetKernelArgSVMPointer(kernel, index, value); | |||
| } else { | |||
| cl::Buffer *buffer = reinterpret_cast<cl::Buffer *>(allocator_->GetDeviceBuffer(value)); | |||
| MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << value; | |||
| return clSetKernelArg(kernel, index, sizeof((*buffer)()), &(*buffer)()); | |||
| MEM_TYPE mem_type = allocator_->GetMemType(value); | |||
| if (mem_type == MEM_TYPE::BUF) { | |||
| cl::Buffer *buffer = reinterpret_cast<cl::Buffer *>(allocator_->GetDeviceBuffer(value)); | |||
| MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Buffer " << value; | |||
| return clSetKernelArg(kernel, index, sizeof((*buffer)()), &(*buffer)()); | |||
| } else { | |||
| cl::Image2D *buffer = reinterpret_cast<cl::Image2D *>(allocator_->GetDeviceBuffer(value)); | |||
| MS_LOG(DEBUG) << "Set kernel arg[" << index << "] OpenCL Image2D " << value; | |||
| return clSetKernelArg(kernel, index, sizeof((*buffer)()), &(*buffer)()); | |||
| } | |||
| } | |||
| } | |||
| @@ -107,9 +114,11 @@ class OpenCLRuntime { | |||
| bool sync = false) const; | |||
| void *MapBuffer(const cl::Buffer buffer, int map_flags, size_t size, cl::CommandQueue *command_queue = nullptr, | |||
| bool sync = false) const; | |||
| void *MapBuffer(const cl::Image2D buffer, bool sync, int flags, | |||
| const std::vector<size_t>& region, cl::CommandQueue *command_queue = nullptr) const; | |||
| int MapBuffer(void *host_ptr, int map_flags, size_t size, cl::CommandQueue *command_queue = nullptr, | |||
| bool sync = false) const; | |||
| int UnmapBuffer(const cl::Buffer buffer, void *host_ptr, cl::CommandQueue *command_queue = nullptr) const; | |||
| int UnmapBuffer(const cl::Memory buffer, void *host_ptr, cl::CommandQueue *command_queue = nullptr) const; | |||
| int UnmapBuffer(void *host_ptr, cl::CommandQueue *command_queue = nullptr) const; | |||
| bool SyncCommandQueue(cl::CommandQueue *command_queue = nullptr); | |||
| @@ -35,6 +35,8 @@ | |||
| a = nullptr; \ | |||
| } | |||
| bool IMAGE2D_OPEN = true; | |||
| namespace mindspore { | |||
| class TestConvolutionDwOpenCL : public mindspore::Common { | |||
| public: | |||
| @@ -95,6 +97,18 @@ void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t * | |||
| std::vector<kernel::LiteKernel *> kernels{pKernel}; | |||
| std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | |||
| size_t C4 = UP_DIV(inputs[0]->Channel(), C4NUM); | |||
| // if (IMAGE2D_OPEN && format == schema::Format_NHWC4) { | |||
| // std::vector<size_t> img_size{inputs[0]->Width() * C4, (size_t)inputs[0]->Height(), CL_FLOAT}; | |||
| // auto in_data = allocator->Malloc(inputs[0]->Size(), img_size); | |||
| // inputs[0]->SetData(in_data); | |||
| // } else if (IMAGE2D_OPEN && format == schema::Format_NC4HW4) { | |||
| // std::vector<size_t> img_size{(size_t)inputs[0]->Width(), inputs[0]->Height() * C4, CL_FLOAT}; | |||
| // auto in_data = allocator->Malloc(inputs[0]->Size(), img_size); | |||
| // inputs[0]->SetData(in_data); | |||
| // } else { | |||
| inputs[0]->MallocData(allocator); | |||
| // } | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||
| pGraph->Init(); | |||
| @@ -103,9 +117,9 @@ void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t * | |||
| pGraph->Run(); | |||
| if (is_compare) { | |||
| float* packed_output = reinterpret_cast<float *>(outputs[0]->Data()); | |||
| float *packed_correct_data = new float[packed_output_size]; | |||
| memset(packed_correct_data, 0, packed_output_size * sizeof(float)); | |||
| float_t* packed_output = reinterpret_cast<float *>(outputs[0]->Data()); | |||
| float_t *packed_correct_data = new float_t[packed_output_size]; | |||
| memset(packed_correct_data, 0, packed_output_size * sizeof(float_t)); | |||
| if (format == schema::Format_NC4HW4) { | |||
| PackNHWCToNC4HW4Fp32(gnd_data, packed_correct_data, conv_param->output_batch_, | |||
| conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_); | |||
| @@ -128,7 +142,7 @@ void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t * | |||
| std::cout << std::endl; | |||
| printf("==================output data=================\n"); | |||
| std::cout << std::endl; | |||
| for (int i = 0; i < packed_output_size; i++) { | |||
| for (int i = 0; i < 80/*packed_output_size*/; i++) { | |||
| std::cout << packed_output[i] << ", "; | |||
| } | |||
| std::cout << std::endl; | |||
| @@ -142,13 +156,13 @@ void DepthWiseTestMain(ConvParameter *conv_param, float_t *input_data, float_t * | |||
| SAFE_DELETE_ARRAY(packed_correct_data) | |||
| } | |||
| inputs[1]->SetData(nullptr); | |||
| inputs[2]->SetData(nullptr); | |||
| SAFE_DELETE_ARRAY(packed_input); | |||
| for (auto tensor : inputs) { | |||
| tensor->SetData(nullptr); | |||
| SAFE_DELETE_PTR(tensor) | |||
| } | |||
| for (auto tensor : outputs) { | |||
| tensor->SetData(nullptr); | |||
| SAFE_DELETE_PTR(tensor) | |||
| } | |||
| SAFE_DELETE_PTR(pKernel) | |||
| @@ -477,6 +491,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { | |||
| std::vector<kernel::LiteKernel *> kernels{pKernel}; | |||
| std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | |||
| inputs[0]->MallocData(); | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||
| pGraph->Init(); | |||
| @@ -516,12 +531,12 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) { | |||
| // compare | |||
| Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||
| inputs[1]->SetData(nullptr); | |||
| inputs[2]->SetData(nullptr); | |||
| for (auto tensor : inputs) { | |||
| tensor->SetData(nullptr); | |||
| SAFE_DELETE_PTR(tensor) | |||
| } | |||
| for (auto tensor : outputs) { | |||
| tensor->SetData(nullptr); | |||
| SAFE_DELETE_PTR(tensor) | |||
| } | |||
| SAFE_DELETE_PTR(pKernel) | |||
| @@ -640,6 +655,7 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { | |||
| std::vector<kernel::LiteKernel *> kernels{pKernel}; | |||
| std::vector<lite::tensor::Tensor *> inputs_{tensor_a}; | |||
| inputs[0]->MallocData(); | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs_, outputs, kernels, kernels, kernels); | |||
| pGraph->Init(); | |||
| @@ -687,14 +703,14 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) { | |||
| // compare | |||
| Common::CompareOutputData(packed_output, packed_correct_data, packed_output_size, 0.00001); | |||
| inputs[1]->SetData(nullptr); | |||
| inputs[2]->SetData(nullptr); | |||
| SAFE_DELETE_ARRAY(packed_input); | |||
| SAFE_DELETE_ARRAY(packed_correct_data) | |||
| for (auto tensor : inputs) { | |||
| tensor->SetData(nullptr); | |||
| SAFE_DELETE_PTR(tensor) | |||
| } | |||
| for (auto tensor : outputs) { | |||
| tensor->SetData(nullptr); | |||
| SAFE_DELETE_PTR(tensor) | |||
| } | |||
| SAFE_DELETE_PTR(pKernel) | |||
| @@ -742,35 +758,27 @@ TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { | |||
| }; | |||
| // nhwc | |||
| float_t *input_data = new float_t[96*112*112]{ | |||
| 0.5488135 , 0.3834415 , 0.77815676, 0.9446689 , 0.6120957 , | |||
| 0.71518934, 0.79172504, 0.87001216, 0.5218483 , 0.616934 , | |||
| 0.60276335, 0.5288949 , 0.9786183 , 0.41466194, 0.94374806, | |||
| 0.5448832 , 0.56804454, 0.7991586 , 0.2645556 , 0.6818203 , | |||
| 0.4236548 , 0.92559665, 0.46147937, 0.7742337 , 0.3595079 , | |||
| 0.6458941 , 0.07103606, 0.7805292 , 0.45615032, 0.43703195, | |||
| 0.4375872 , 0.0871293 , 0.11827443, 0.56843394, 0.6976312 , | |||
| 0.891773 , 0.0202184 , 0.639921 , 0.0187898 , 0.06022547, | |||
| 0.96366274, 0.83261985, 0.14335328, 0.6176355 , 0.6667667 }; | |||
| size_t in_size = 96*112*112; | |||
| float_t *input_data = new float_t[in_size]; | |||
| memset(input_data, 0, in_size); | |||
| for (auto i = 0; i < in_size; ++i) { | |||
| input_data[i] = 1; | |||
| } | |||
| // co h w ci | |||
| float_t *weight_data = new float_t[576*3*3]{ | |||
| 0.67063785, 0.21038257, 0.12892629, | |||
| 0.31542835, 0.36371076, 0.57019675, | |||
| 0.43860152, 0.9883738 , 0.10204481, | |||
| 0.20887676, 0.16130951, 0.6531083 , | |||
| 0.2532916 , 0.46631077, 0.2444256 , | |||
| 0.15896958, 0.11037514, 0.6563296 , | |||
| 0.13818295, 0.19658236, 0.36872518, | |||
| 0.82099324, 0.09710128, 0.8379449 , | |||
| 0.09609841, 0.97645944, 0.4686512 , | |||
| 0.9767611 , 0.6048455 , 0.7392636 , | |||
| 0.03918779, 0.28280696, 0.12019656, | |||
| 0.2961402 , 0.11872772, 0.31798318, | |||
| 0.41426298, 0.06414749, 0.6924721 , | |||
| 0.56660146, 0.2653895 , 0.5232481 , | |||
| 0.09394051, 0.5759465 , 0.9292962 }; | |||
| size_t wt_size = 576*3*3; | |||
| float_t *weight_data = new float_t[wt_size]; | |||
| memset(weight_data, 0, wt_size); | |||
| for (auto i = 0; i < wt_size; ++i) { | |||
| weight_data[i] = 1; | |||
| } | |||
| size_t out_size = 96*112*112; | |||
| float_t *gnd_data = new float_t[out_size]; | |||
| memset(gnd_data, 0, out_size); | |||
| // for (auto i = 0; i < in_size; ++i) { | |||
| // gnd_data[i] = 1; | |||
| // } | |||
| for (size_t i = 0; i < src_shape.size(); ++i) { | |||
| const int MAX_RUN_TIMES = 10; | |||
| const int MAX_RUN_TIMES = 1; | |||
| for (int j = 0; j < MAX_RUN_TIMES; ++j) { | |||
| printf("========profiling depthwise, in shape(%d,%d,%d,%d), out shape(%d,%d,%d,%d), iter%d========\n", | |||
| src_shape[i][0], src_shape[i][1], src_shape[i][2], src_shape[i][3], | |||
| @@ -794,8 +802,8 @@ TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { | |||
| conv_param->dilation_h_ = 1; | |||
| conv_param->dilation_w_ = 1; | |||
| } | |||
| DepthWiseTestMain(conv_param, input_data, weight_data, nullptr, schema::Format_NC4HW4, false); | |||
| // DepthWiseTestMain(conv_param, input_data, weight_data, nullptr, schema::Format_NHWC4, false); | |||
| // DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4, false); | |||
| DepthWiseTestMain(conv_param, input_data, weight_data, nullptr, schema::Format_NHWC4, false); | |||
| } | |||
| } | |||
| SAFE_DELETE_ARRAY(input_data); | |||
| @@ -803,4 +811,54 @@ TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) { | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| TEST_F(TestConvolutionDwOpenCL, Buffer2Image) { | |||
| std::vector<int> src_shape{1, 96, 64, 64}; | |||
| std::vector<int> dst_shape{1, 96, 32, 32}; | |||
| std::vector<int> filter_shape{96, 3, 3, 1}; | |||
| // nhwc | |||
| size_t in_size = 96*112*112; | |||
| float_t *input_data = new float_t[in_size]; | |||
| memset(input_data, 0, in_size); | |||
| for (auto i = 0; i < in_size; ++i) { | |||
| input_data[i] = 1; | |||
| } | |||
| // co h w ci | |||
| size_t wt_size = 576*3*3; | |||
| float_t *weight_data = new float_t[wt_size]; | |||
| memset(weight_data, 0, wt_size); | |||
| for (auto i = 0; i < wt_size; ++i) { | |||
| weight_data[i] = 1; | |||
| } | |||
| size_t out_size = 96*112*112; | |||
| float_t *gnd_data = new float_t[out_size]; | |||
| memset(gnd_data, 0, out_size); | |||
| // for (auto i = 0; i < in_size; ++i) { | |||
| // gnd_data[i] = 1; | |||
| // } | |||
| ConvParameter *conv_param = new ConvParameter(); | |||
| { | |||
| conv_param->input_batch_ = 1; | |||
| conv_param->input_h_ = src_shape[2]; | |||
| conv_param->input_w_ = src_shape[3]; | |||
| conv_param->input_channel_ = src_shape[1]; | |||
| conv_param->output_batch_ = 1; | |||
| conv_param->output_h_ = dst_shape[2]; | |||
| conv_param->output_w_ = dst_shape[3]; | |||
| conv_param->output_channel_ = dst_shape[1]; | |||
| conv_param->kernel_h_ = filter_shape[1]; | |||
| conv_param->kernel_w_ = filter_shape[2]; | |||
| conv_param->stride_h_ = conv_param->output_h_/conv_param->input_h_; | |||
| conv_param->stride_w_ = conv_param->output_w_/conv_param->input_w_; | |||
| conv_param->pad_h_ = (conv_param->kernel_h_-1)/2; | |||
| conv_param->pad_w_ = (conv_param->kernel_w_-1)/2; | |||
| conv_param->dilation_h_ = 1; | |||
| conv_param->dilation_w_ = 1; | |||
| } | |||
| // DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NC4HW4, true); | |||
| DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4, true); | |||
| SAFE_DELETE_ARRAY(input_data); | |||
| SAFE_DELETE_ARRAY(weight_data); | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| } // namespace mindspore | |||