From 68c7ba09d95fc08c0a13ca8eb475e1a93d8a363e Mon Sep 17 00:00:00 2001 From: wandongdong Date: Mon, 7 Dec 2020 19:10:42 -0800 Subject: [PATCH] support weight quant for opencl --- mindspore/lite/src/CMakeLists.txt | 1 + .../src/runtime/kernel/arm/base/dequant.h | 17 +- .../runtime/kernel/opencl/kernel/conv2d.cc | 15 +- .../kernel/opencl/kernel/depthwise_conv2d.cc | 18 +- .../kernel/opencl/kernel/fullconnection.cc | 5 + .../runtime/kernel/opencl/kernel/matmul.cc | 5 + .../runtime/kernel/opencl/opencl_kernel.cc | 235 ++++++++++++++++++ .../src/runtime/kernel/opencl/opencl_kernel.h | 187 +------------- mindspore/lite/test/CMakeLists.txt | 1 + ...odels_fp16_gpu.cfg => models_gpu_fp16.cfg} | 0 ...odels_fp32_gpu.cfg => models_gpu_fp32.cfg} | 0 .../lite/test/models_gpu_weightquant.cfg | 1 + mindspore/lite/test/run_benchmark_nets.sh | 47 ++-- 13 files changed, 315 insertions(+), 217 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc rename mindspore/lite/test/{models_fp16_gpu.cfg => models_gpu_fp16.cfg} (100%) rename mindspore/lite/test/{models_fp32_gpu.cfg => models_gpu_fp32.cfg} (100%) create mode 100644 mindspore/lite/test/models_gpu_weightquant.cfg diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index b4cbd79109..d81cb4be87 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -41,6 +41,7 @@ set(LITE_SRC if (SUPPORT_GPU) set(LITE_SRC ${LITE_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_subgraph.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_fusion.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc diff --git a/mindspore/lite/src/runtime/kernel/arm/base/dequant.h b/mindspore/lite/src/runtime/kernel/arm/base/dequant.h index 6aaa8f2ca4..641f5b66bd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/dequant.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/dequant.h @@ -31,14 +31,14 @@ class DequantUtil { static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); - template - static float *DequantData(lite::Tensor *input_tensor) { - const auto *quant_datas = static_cast(input_tensor->MutableData()); + template + static DT *DequantData(lite::Tensor *input_tensor) { + const auto *quant_datas = static_cast(input_tensor->MutableData()); if (quant_datas == nullptr) { MS_LOG(ERROR) << "Get quant tensor failed."; return nullptr; } - auto *dequant_datas = static_cast(malloc(input_tensor->ElementsNum() * sizeof(float))); + DT *dequant_datas = static_cast
(malloc(input_tensor->ElementsNum() * sizeof(DT))); if (dequant_datas == nullptr) { MS_LOG(ERROR) << "Malloc failed."; return nullptr; @@ -53,8 +53,7 @@ class DequantUtil { auto zero_point = param.zeroPoint; auto matrix_size = input_tensor->ElementsNum() / per_batch_size; for (int64_t j = 0; j < matrix_size; j++) { - dequant_datas[i * matrix_size + j] = - static_cast((quant_datas[i * matrix_size + j] - zero_point) * scale); + dequant_datas[i * matrix_size + j] = static_cast
((quant_datas[i * matrix_size + j] - zero_point) * scale); } } } else if (input_tensor->quant_params().size() != kPerTensor) { @@ -78,7 +77,7 @@ class DequantUtil { } for (size_t j = 0; j < per_channel_size; j++) { auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale; - dequant_datas[per_channel_size * i + j] = static_cast(dequant_data * var_corr + mean_corr); + dequant_datas[per_channel_size * i + j] = static_cast
(dequant_data * var_corr + mean_corr); } } } else { @@ -95,9 +94,9 @@ class DequantUtil { free(dequant_datas); return nullptr; } - dequant_datas[j] = static_cast(param.clusters[index - INT8_MIN]); + dequant_datas[j] = static_cast
(param.clusters[index - INT8_MIN]); } else { - dequant_datas[j] = static_cast((quant_datas[j] - zero_point) * scale); + dequant_datas[j] = static_cast
((quant_datas[j] - zero_point) * scale); } } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc index 453ce0035c..d603edb0e5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc @@ -200,6 +200,10 @@ int Conv2DOpenCLKernel::GenerateWinogradFilter() { int Conv2DOpenCLKernel::InitFilter() { auto allocator = ocl_runtime_->GetAllocator(); + auto ret = DequantWeight(); + if (ret != RET_OK) { + return ret; + } // allocate memory size_t packed_weight_size; @@ -225,7 +229,7 @@ int Conv2DOpenCLKernel::InitFilter() { ConvertConvWeight4DTo7D(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_, block_size_.C); } - } else { + } else if (weight_tensor->data_type() == kNumberTypeFloat32) { if (use_fp16_) { ConvertConvWeight4DTo7D(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_, block_size_.C); @@ -233,6 +237,15 @@ int Conv2DOpenCLKernel::InitFilter() { ConvertConvWeight4DTo7D(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_, block_size_.C); } + } else { // int8 or int16 + if (use_fp16_) { + ConvertConvWeight4DTo7D(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_, + block_size_.C); + } else { + ConvertConvWeight4DTo7D(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_, + block_size_.C); + } + FreeDequantedWeight(); } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc index 90e33983d9..e142ca65dc 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc @@ -92,6 +92,10 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { MS_LOG(ERROR) << "DepthwiseConv2d don't support non-constant filter yet."; return RET_ERROR; } + auto ret = DequantWeight(); + if (ret != RET_OK) { + return ret; + } auto parameter = reinterpret_cast(op_parameter_); auto allocator = ocl_runtime_->GetAllocator(); bool is_fp16 = ocl_runtime_->GetFp16Enable(); @@ -111,9 +115,10 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) { std::function to_dtype = [](float x) -> float16_t { return static_cast(x); }; PackNCHWToNC4HW4(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype); - } else { - MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type(); - return mindspore::lite::RET_ERROR; + } else { // int8 or int16 + std::function to_dtype = [](int16_t x) -> int16_t { return x; }; + PackNCHWToNC4HW4(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype); + FreeDequantedWeight(); } } else { packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float)); @@ -124,9 +129,10 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) { std::function to_dtype = [](float16_t x) -> float { return static_cast(x); }; PackNCHWToNC4HW4(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype); - } else { - MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type(); - return mindspore::lite::RET_ERROR; + } else { // int8 or int16 + std::function to_dtype = [](float x) -> float { return x; }; + PackNCHWToNC4HW4(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype); + FreeDequantedWeight(); } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc index 58b0078da1..949be78e37 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc @@ -121,6 +121,10 @@ int FullConnectionOpenCLKernel::InitWeights() { } // namespace mindspore::kernel int FullConnectionOpenCLKernel::InitFilter() { + auto ret = DequantWeight(); + if (ret != RET_OK) { + return ret; + } auto allocator = ocl_runtime_->GetAllocator(); auto intensor_shape = GpuTensorInfo(in_tensors_[0]); int co4 = UP_DIV(CO_, C4NUM); @@ -173,6 +177,7 @@ int FullConnectionOpenCLKernel::InitFilter() { } } allocator->UnmapBuffer(padWeight_); + FreeDequantedWeight(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc index 1e9dcf936e..8894eac2a9 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -84,6 +84,10 @@ int MatMulOpenCLKernel::InitWeights() { MS_LOG(ERROR) << "Matmul don't support non-constant filter yet."; return RET_ERROR; } + auto ret = DequantWeight(); + if (ret != RET_OK) { + return ret; + } auto allocator = ocl_runtime_->GetAllocator(); int ci = inShape[3]; int ci4 = UP_DIV(ci, C4NUM); @@ -143,6 +147,7 @@ int MatMulOpenCLKernel::InitWeights() { } } allocator->UnmapBuffer(padWeight_); + FreeDequantedWeight(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc new file mode 100644 index 0000000000..84c93fe2bc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc @@ -0,0 +1,235 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "src/runtime/kernel/arm/base/dequant.h" + +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { + +int OpenCLKernel::AlignGlobalLocal(const std::vector &global, const std::vector &local) { + std::vector internal_global_ws = global; + for (size_t i = 0; i < local.size(); ++i) { + internal_global_ws.at(i) = UP_ROUND(global.at(i), local.at(i)); + } + + MS_LOG(DEBUG) << "global size: " << global.size() << ", local size: " << local.size(); + for (size_t i = 0; i < global.size(); i++) { + MS_LOG(DEBUG) << "global[" << i << "] = " << global.at(i); + } + for (size_t i = 0; i < local.size(); i++) { + MS_LOG(DEBUG) << "local[" << i << "] = " << local.at(i); + } + if (local.empty()) { + local_range_ = cl::NullRange; + } + if (global.size() == 1) { + global_range_ = cl::NDRange(internal_global_ws.at(0)); + if (!local.empty()) { + local_range_ = cl::NDRange(local.at(0)); + } + } else if (global.size() == 2) { + global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1)); + if (!local.empty()) { + local_range_ = cl::NDRange(local.at(0), local.at(1)); + } + } else if (global.size() == 3) { + global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1), internal_global_ws.at(2)); + if (!local.empty()) { + local_range_ = cl::NDRange(local.at(0), local.at(1), local.at(2)); + } + } else { + MS_LOG(ERROR) << "Not supported NDRange!"; + return RET_ERROR; + } + return RET_OK; +} + +int OpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { + MS_ASSERT(img_size); + if (idx >= out_tensors_.size()) { + return RET_ERROR; + } + auto img_info = GpuTensorInfo(out_tensors_[idx]); + size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT; + *img_size = {img_info.width, img_info.height, img_dtype}; + return RET_OK; +} + +std::vector OpenCLKernel::GenerateTuningParam() { + size_t ndim = global_size_.size(); + std::vector tuning_params = {}; + if (ndim == 0) { + MS_LOG(ERROR) << "Generate tuning param failed, global_size_ is null."; + return tuning_params; + } + BaseTuningParameter default_tuning_param = BaseTuningParameter(); + default_tuning_param.local_size = local_size_; + tuning_params.push_back(default_tuning_param); + std::vector max_work_items = ocl_runtime_->GetWorkItemSize(); + size_t max_workgroup_size = ocl_runtime_->GetMaxWorkGroupSize(kernel_); + const size_t MIN_WORKGROUP_SIZE = 8; + std::set candidate_x = GenerateLocalByGlobal(global_size_[0]); + std::set candidate_y = {1}; + std::set candidate_z = {1}; + if (ndim > 1) { + candidate_y = GenerateLocalByGlobal(global_size_[1]); + } + if (ndim > 2) { + candidate_z = GenerateLocalByGlobal(global_size_[2]); + } + for (auto x : candidate_x) { + if (x <= max_work_items[0]) { + for (auto y : candidate_y) { + if (y <= max_work_items[1]) { + for (auto z : candidate_z) { + auto group_size = x * y * z; + if (z <= max_work_items[2] && group_size <= max_workgroup_size && group_size >= MIN_WORKGROUP_SIZE) { + BaseTuningParameter tuning_param = BaseTuningParameter(); + tuning_param.local_size = {x, y, z}; + tuning_params.push_back(tuning_param); + } + } + } + } + } + } + return tuning_params; +} + +int OpenCLKernel::AssignTuningParam(const BaseTuningParameter ¶m) { + std::vector local_size_tmp = param.local_size; + if (local_size_tmp.size() > global_size_.size()) { + local_size_tmp = std::vector(local_size_tmp.begin(), local_size_tmp.begin() + global_size_.size()); + } + AlignGlobalLocal(global_size_, local_size_tmp); + return RET_OK; +} + +int OpenCLKernel::Tune() { + if (!ocl_runtime_->isProfiling()) { + MS_LOG(WARNING) << "Tuning mode require opencl runtime profiling."; + return RET_OK; + } + lite::opencl::TuningMode mode = ocl_runtime_->GetTuningMode(); + if (mode == lite::opencl::TuningMode::DEFAULT) { + return RET_OK; + } + static const std::set FAST_MODE_OPS = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, + schema::PrimitiveType_DeConv2D}; + if (mode == lite::opencl::TuningMode::FAST && FAST_MODE_OPS.find(op_parameter_->type_) == FAST_MODE_OPS.end()) { + return RET_OK; + } + auto tuning_params = GenerateTuningParam(); + if (tuning_params.empty()) { + MS_LOG(WARNING) << "Tuning param size is 0."; + return RET_OK; + } + int index = -1; + double min_time = MAX_PROFILING_TIME_MILLI_SECOND; + for (int i = 0; i < tuning_params.size(); i++) { + AssignTuningParam(tuning_params[i]); + auto ret = Run(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Tuning " << name() << " failed for tuning param " << tuning_params[i]; + return ret; + } + double current_time = GetProfilingTimeMs(); + MS_LOG(DEBUG) << "Tuning " << name() << " param (" << tuning_params[i] << ") exectime " << current_time << "ms"; + if (current_time < min_time) { + min_time = current_time; + index = i; + } + } + if (index != -1) { + MS_LOG(INFO) << "Tuning " << name() << " result: param (" << tuning_params[index] << ") exectime " << min_time + << "ms"; + AssignTuningParam(tuning_params[index]); + } else { + MS_LOG(WARNING) << "Cannot find suitable param."; + } + return RET_OK; +} + +double OpenCLKernel::GetProfilingTimeMs() { + if (!ocl_runtime_->isProfiling()) { + return MAX_PROFILING_TIME_MILLI_SECOND; + } + cl_ulong time_start; + cl_ulong time_end; + event_.getProfilingInfo(CL_PROFILING_COMMAND_START, &time_start); + event_.getProfilingInfo(CL_PROFILING_COMMAND_END, &time_end); + cl_ulong time_ns = time_end - time_start; + return static_cast(time_ns) * 1e-6; +} + +std::set OpenCLKernel::GenerateLocalByGlobal(size_t global_i) { + std::set local_ = {}; + int index = 1; + while (index <= global_i) { + local_.insert(index); + index *= 2; + } + for (size_t i = 1; i <= 16; i++) { + if (global_i % i == 0) { + local_.insert(i); + } + } + return local_; +} +int OpenCLKernel::DequantWeight() { + bool is_fp16 = ocl_runtime_->GetFp16Enable(); + auto *weight_tensor = in_tensors_.at(kWeightIndex); + auto *restore_data = weight_tensor->data_c(); + dequant_flag_ = + !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; + if (dequant_flag_) { + void *dequant_weight{nullptr}; + bool set_flag{true}; + if (is_fp16) { + if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { + dequant_weight = kernel::DequantUtil::DequantData(weight_tensor); + } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { + dequant_weight = kernel::DequantUtil::DequantData(weight_tensor); + } else { + set_flag = false; + } + } else { + if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { + dequant_weight = kernel::DequantUtil::DequantData(weight_tensor); + } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { + dequant_weight = kernel::DequantUtil::DequantData(weight_tensor); + } else { + set_flag = false; + } + } + if (set_flag && dequant_weight == nullptr) { + MS_LOG(ERROR) << "dequant data failed."; + return RET_ERROR; + } + weight_tensor->set_data(dequant_weight); + } + return RET_OK; +} +void OpenCLKernel::FreeDequantedWeight() { + auto *weight_tensor = in_tensors_.at(kWeightIndex); + if (dequant_flag_) { + free(weight_tensor->data_c()); + } +} +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h index 93db7888c4..a09647fc9d 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h @@ -25,6 +25,7 @@ #include "src/lite_kernel.h" #include "include/errorcode.h" #include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/arm/base/dequant.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; @@ -159,43 +160,7 @@ class OpenCLKernel : public LiteKernel { ocl_runtime_ = ocl_runtime_wrap_.GetInstance(); } ~OpenCLKernel() override = default; - int AlignGlobalLocal(const std::vector &global, const std::vector &local) { - std::vector internal_global_ws = global; - for (size_t i = 0; i < local.size(); ++i) { - internal_global_ws.at(i) = UP_ROUND(global.at(i), local.at(i)); - } - - MS_LOG(DEBUG) << "global size: " << global.size() << ", local size: " << local.size(); - for (size_t i = 0; i < global.size(); i++) { - MS_LOG(DEBUG) << "global[" << i << "] = " << global.at(i); - } - for (size_t i = 0; i < local.size(); i++) { - MS_LOG(DEBUG) << "local[" << i << "] = " << local.at(i); - } - if (local.empty()) { - local_range_ = cl::NullRange; - } - if (global.size() == 1) { - global_range_ = cl::NDRange(internal_global_ws.at(0)); - if (!local.empty()) { - local_range_ = cl::NDRange(local.at(0)); - } - } else if (global.size() == 2) { - global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1)); - if (!local.empty()) { - local_range_ = cl::NDRange(local.at(0), local.at(1)); - } - } else if (global.size() == 3) { - global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1), internal_global_ws.at(2)); - if (!local.empty()) { - local_range_ = cl::NDRange(local.at(0), local.at(1), local.at(2)); - } - } else { - MS_LOG(ERROR) << "Not supported NDRange!"; - return RET_ERROR; - } - return RET_OK; - } + int AlignGlobalLocal(const std::vector &global, const std::vector &local); int Prepare() override { return RET_OK; } int PreProcess() override { return RET_ERROR; } @@ -210,135 +175,20 @@ class OpenCLKernel : public LiteKernel { virtual int GetLocalSize(size_t idx, const std::vector &global_size, std::vector *local_size) { return RET_ERROR; } - int GetImageSize(size_t idx, std::vector *img_size) { - MS_ASSERT(img_size); - if (idx >= out_tensors_.size()) { - return RET_ERROR; - } - auto img_info = GpuTensorInfo(out_tensors_[idx]); - size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT; - *img_size = {img_info.width, img_info.height, img_dtype}; - return RET_OK; - } + virtual std::vector GenerateTuningParam(); + virtual int AssignTuningParam(const BaseTuningParameter ¶m); + virtual int Tune(); + int GetImageSize(size_t idx, std::vector *img_size); lite::opencl::MemType GetMemType() { return out_mem_type_; } void SetMemType(lite::opencl::MemType mem_type) { out_mem_type_ = mem_type; } OpParameter *GetParameter() { return op_parameter_; } + double GetProfilingTimeMs(); + int DequantWeight(); + void FreeDequantedWeight(); - virtual std::vector GenerateTuningParam() { - size_t ndim = global_size_.size(); - std::vector tuning_params = {}; - if (ndim == 0) { - MS_LOG(ERROR) << "Generate tuning param failed, global_size_ is null."; - return tuning_params; - } - BaseTuningParameter default_tuning_param = BaseTuningParameter(); - default_tuning_param.local_size = local_size_; - tuning_params.push_back(default_tuning_param); - std::vector max_work_items = ocl_runtime_->GetWorkItemSize(); - size_t max_workgroup_size = ocl_runtime_->GetMaxWorkGroupSize(kernel_); - const size_t MIN_WORKGROUP_SIZE = 8; - std::set candidate_x = GenerateLocalByGlobal(global_size_[0]); - std::set candidate_y = {1}; - std::set candidate_z = {1}; - if (ndim > 1) { - candidate_y = GenerateLocalByGlobal(global_size_[1]); - } - if (ndim > 2) { - candidate_z = GenerateLocalByGlobal(global_size_[2]); - } - for (auto x : candidate_x) { - if (x <= max_work_items[0]) { - for (auto y : candidate_y) { - if (y <= max_work_items[1]) { - for (auto z : candidate_z) { - auto group_size = x * y * z; - if (z <= max_work_items[2] && group_size <= max_workgroup_size && group_size >= MIN_WORKGROUP_SIZE) { - BaseTuningParameter tuning_param = BaseTuningParameter(); - tuning_param.local_size = {x, y, z}; - tuning_params.push_back(tuning_param); - } - } - } - } - } - } - return tuning_params; - } - - virtual int AssignTuningParam(const BaseTuningParameter ¶m) { - std::vector local_size_tmp = param.local_size; - if (local_size_tmp.size() > global_size_.size()) { - local_size_tmp = std::vector(local_size_tmp.begin(), local_size_tmp.begin() + global_size_.size()); - } - AlignGlobalLocal(global_size_, local_size_tmp); - return RET_OK; - } - - virtual int Tune() { - if (!ocl_runtime_->isProfiling()) { - MS_LOG(WARNING) << "Tuning mode require opencl runtime profiling."; - return RET_OK; - } - lite::opencl::TuningMode mode = ocl_runtime_->GetTuningMode(); - if (mode == lite::opencl::TuningMode::DEFAULT) { - return RET_OK; - } - static const std::set FAST_MODE_OPS = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, - schema::PrimitiveType_DeConv2D}; - if (mode == lite::opencl::TuningMode::FAST && FAST_MODE_OPS.find(op_parameter_->type_) == FAST_MODE_OPS.end()) { - return RET_OK; - } - auto key = Key(); - auto finded = tuned_param_cache_.find(key); - if (finded != tuned_param_cache_.end()) { - auto cache_param = finded->second; - MS_LOG(INFO) << "Tuning " << name() << ", found cached param(" << cache_param << ")"; - return RET_OK; - } - auto tuning_params = GenerateTuningParam(); - if (tuning_params.empty()) { - MS_LOG(WARNING) << "Tuning param size is 0."; - return RET_OK; - } - int index = -1; - double min_time = MAX_PROFILING_TIME_MILLI_SECOND; - for (int i = 0; i < tuning_params.size(); i++) { - AssignTuningParam(tuning_params[i]); - auto ret = Run(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Tuning " << name() << " failed for tuning param " << tuning_params[i]; - return ret; - } - double current_time = GetProfilingTimeMs(); - MS_LOG(DEBUG) << "Tuning " << name() << " param (" << tuning_params[i] << ") exectime " << current_time << "ms"; - if (current_time < min_time) { - min_time = current_time; - index = i; - } - } - if (index != -1) { - MS_LOG(INFO) << "Tuning " << name() << " result: param (" << tuning_params[index] << ") exectime " << min_time - << "ms"; - AssignTuningParam(tuning_params[index]); - tuned_param_cache_[key] = tuning_params[index]; - } else { - MS_LOG(WARNING) << "Cannot find suitable param."; - } - return RET_OK; - } - - double GetProfilingTimeMs() { - if (!ocl_runtime_->isProfiling()) { - return MAX_PROFILING_TIME_MILLI_SECOND; - } - cl_ulong time_start; - cl_ulong time_end; - event_.getProfilingInfo(CL_PROFILING_COMMAND_START, &time_start); - event_.getProfilingInfo(CL_PROFILING_COMMAND_END, &time_end); - cl_ulong time_ns = time_end - time_start; - return static_cast(time_ns) * 1e-6; - } + protected: + static std::set GenerateLocalByGlobal(size_t global_i); virtual std::string Key() { std::string key = type_str(); @@ -358,20 +208,7 @@ class OpenCLKernel : public LiteKernel { std::vector local_size_; cl::Kernel kernel_; cl::Event event_; - static std::set GenerateLocalByGlobal(size_t global_i) { - std::set local_ = {}; - int index = 1; - while (index <= global_i) { - local_.insert(index); - index *= 2; - } - for (size_t i = 1; i <= 16; i++) { - if (global_i % i == 0) { - local_.insert(i); - } - } - return local_; - } + bool dequant_flag_{false}; private: lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 4472c49d0a..097a8806de 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -83,6 +83,7 @@ if (SUPPORT_GPU) set(KERNEL_OP_SRC ${KERNEL_OP_SRC} ${GPU_KERNEL_OP_SRC} + ${LITE_DIR}/src/runtime/kernel/opencl/opencl_kernel.cc ${LITE_DIR}/src/runtime/kernel/opencl/opencl_subgraph.cc ${LITE_DIR}/src/runtime/kernel/opencl/opencl_fusion.cc ${LITE_DIR}/src/runtime/kernel/opencl/utils.cc diff --git a/mindspore/lite/test/models_fp16_gpu.cfg b/mindspore/lite/test/models_gpu_fp16.cfg similarity index 100% rename from mindspore/lite/test/models_fp16_gpu.cfg rename to mindspore/lite/test/models_gpu_fp16.cfg diff --git a/mindspore/lite/test/models_fp32_gpu.cfg b/mindspore/lite/test/models_gpu_fp32.cfg similarity index 100% rename from mindspore/lite/test/models_fp32_gpu.cfg rename to mindspore/lite/test/models_gpu_fp32.cfg diff --git a/mindspore/lite/test/models_gpu_weightquant.cfg b/mindspore/lite/test/models_gpu_weightquant.cfg new file mode 100644 index 0000000000..54ae30da47 --- /dev/null +++ b/mindspore/lite/test/models_gpu_weightquant.cfg @@ -0,0 +1 @@ +ml_face_openclose.tflite \ No newline at end of file diff --git a/mindspore/lite/test/run_benchmark_nets.sh b/mindspore/lite/test/run_benchmark_nets.sh index a794f4af55..ffec24df51 100644 --- a/mindspore/lite/test/run_benchmark_nets.sh +++ b/mindspore/lite/test/run_benchmark_nets.sh @@ -525,12 +525,11 @@ function Run_x86() { model_name_len=${#model_name} input_params=${line:model_name_len+1} input_num=${input_params%%;*} - input_shape=${input_params##*;} input_files='' output_file='' if [[ -z "$input_files" || $input_files == 1 ]] && [ -e ${ms_models_path}/${model_name}'.ms.bin' ]; then input_files=$model_name'.ms.bin' - elif [[ ! -z "$input_files" && $input_files > 1 ]]; then + elif [[ ! -z "$input_files" && $input_files -gt 1 ]]; then for i in $(seq 1 $input_num) do input_files=$input_files$model_name'.ms.bin_'$i',' @@ -788,12 +787,11 @@ function Run_x86_sse() { model_name_len=${#model_name} input_params=${line:model_name_len+1} input_num=${input_params%%;*} - input_shape=${input_params##*;} input_files='' output_file='' if [[ -z "$input_files" || $input_files == 1 ]] && [ -e ${ms_models_path}/${model_name}'.ms.bin' ]; then input_files=$model_name'.ms.bin' - elif [[ ! -z "$input_files" && $input_files > 1 ]]; then + elif [[ ! -z "$input_files" && $input_files -gt 1 ]]; then for i in $(seq 1 $input_num) do input_files=$input_files$model_name'.ms.bin_'$i',' @@ -1147,18 +1145,7 @@ function Run_arm64() { else run_result='arm64_gpu: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 fi - # run benchmark test without clib data - #echo ${model_name} - echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> "${run_arm64_log_file}" - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> adb_run_cmd.txt - adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" - if [ $? = 0 ]; then - run_result='arm64_gpu: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} - else - run_result='arm64_gpu: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 - fi - done < ${models_tflite_gpu_config} + done < ${models_gpu_fp32_config} # Run GPU fp16 converted models: while read line; do @@ -1176,19 +1163,27 @@ function Run_arm64() { else run_result='arm64_gpu_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 fi - # run benchmark test without clib data + #sleep 1 + done < ${models_gpu_fp16_config} + + # Run GPU weightquant converted models: + while read line; do + model_name=${line} + if [[ $model_name == \#* ]]; then + continue + fi echo ${model_name} >> "${run_arm64_log_file}" echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2 --enableFp16=true' >> "${run_arm64_log_file}" - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2 --enableFp16=true' >> adb_run_cmd.txt + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'_weightquant.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold=5' >> "${run_arm64_log_file}" + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'_weightquant.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold=5' >> adb_run_cmd.txt adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" if [ $? = 0 ]; then - run_result='arm64_gpu_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} + run_result='arm64_gpu_weightquant: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} else - run_result='arm64_gpu_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 + run_result='arm64_gpu_weightquant: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 fi #sleep 1 - done < ${models_fp16_gpu_config} + done < ${models_gpu_weightquant_config} # Run mindir converted models: while read line; do @@ -1274,12 +1269,11 @@ function Run_arm64() { model_name_len=${#model_name} input_params=${line:model_name_len+1} input_num=${input_params%%;*} - input_shape=${input_params##*;} input_files='' output_file='' if [[ -z "$input_files" || $input_files == 1 ]] && [ -e ${ms_models_path}/${model_name}'.ms.bin' ]; then input_files=$model_name'.ms.bin' - elif [[ ! -z "$input_files" && $input_files > 1 ]]; then + elif [[ ! -z "$input_files" && $input_files -gt 1 ]]; then for i in $(seq 1 $input_num) do input_files=$input_files$model_name'.ms.bin_'$i',' @@ -1445,9 +1439,10 @@ models_tflite_fp16_config=${basepath}/models_tflite_fp16.cfg models_mindspore_config=${basepath}/models_mindspore.cfg models_mindspore_train_config=${basepath}/models_mindspore_train.cfg models_mindspore_mixbit_config=${basepath}/models_mindspore_mixbit.cfg -models_tflite_gpu_config=${basepath}/models_fp32_gpu.cfg +models_gpu_fp32_config=${basepath}/models_gpu_fp32.cfg +models_gpu_fp16_config=${basepath}/models_gpu_fp16.cfg +models_gpu_weightquant_config=${basepath}/models_gpu_weightquant.cfg models_mindspore_weightquant_config=${basepath}/models_mindspore_weightquant.cfg -models_fp16_gpu_config=${basepath}/models_fp16_gpu.cfg models_arm32_config=${basepath}/models_arm32.cfg models_compatibility_config=${basepath}/models_compatibility.cfg models_only_for_process_config=${basepath}/models_with_several_inputs_or_without_outputs.cfg