From: @ddwsky Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -42,6 +42,7 @@ set(LITE_SRC | |||
| if (SUPPORT_GPU) | |||
| set(LITE_SRC | |||
| ${LITE_SRC} | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_kernel.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_subgraph.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_fusion.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc | |||
| @@ -31,14 +31,14 @@ class DequantUtil { | |||
| static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); | |||
| template <typename T> | |||
| static float *DequantData(lite::Tensor *input_tensor) { | |||
| const auto *quant_datas = static_cast<const T *>(input_tensor->MutableData()); | |||
| template <typename ST, typename DT = float> | |||
| static DT *DequantData(lite::Tensor *input_tensor) { | |||
| const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData()); | |||
| if (quant_datas == nullptr) { | |||
| MS_LOG(ERROR) << "Get quant tensor failed."; | |||
| return nullptr; | |||
| } | |||
| auto *dequant_datas = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float))); | |||
| DT *dequant_datas = static_cast<DT *>(malloc(input_tensor->ElementsNum() * sizeof(DT))); | |||
| if (dequant_datas == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc failed."; | |||
| return nullptr; | |||
| @@ -53,8 +53,7 @@ class DequantUtil { | |||
| auto zero_point = param.zeroPoint; | |||
| auto matrix_size = input_tensor->ElementsNum() / per_batch_size; | |||
| for (int64_t j = 0; j < matrix_size; j++) { | |||
| dequant_datas[i * matrix_size + j] = | |||
| static_cast<float>((quant_datas[i * matrix_size + j] - zero_point) * scale); | |||
| dequant_datas[i * matrix_size + j] = static_cast<DT>((quant_datas[i * matrix_size + j] - zero_point) * scale); | |||
| } | |||
| } | |||
| } else if (input_tensor->quant_params().size() != kPerTensor) { | |||
| @@ -78,7 +77,7 @@ class DequantUtil { | |||
| } | |||
| for (size_t j = 0; j < per_channel_size; j++) { | |||
| auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale; | |||
| dequant_datas[per_channel_size * i + j] = static_cast<float>(dequant_data * var_corr + mean_corr); | |||
| dequant_datas[per_channel_size * i + j] = static_cast<DT>(dequant_data * var_corr + mean_corr); | |||
| } | |||
| } | |||
| } else { | |||
| @@ -95,9 +94,9 @@ class DequantUtil { | |||
| free(dequant_datas); | |||
| return nullptr; | |||
| } | |||
| dequant_datas[j] = static_cast<float>(param.clusters[index - INT8_MIN]); | |||
| dequant_datas[j] = static_cast<DT>(param.clusters[index - INT8_MIN]); | |||
| } else { | |||
| dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale); | |||
| dequant_datas[j] = static_cast<DT>((quant_datas[j] - zero_point) * scale); | |||
| } | |||
| } | |||
| } | |||
| @@ -211,6 +211,10 @@ int Conv2DOpenCLKernel::GenerateWinogradFilter() { | |||
| int Conv2DOpenCLKernel::InitFilter() { | |||
| auto allocator = ocl_runtime_->GetAllocator(); | |||
| auto ret = DequantWeight(); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| // allocate memory | |||
| size_t packed_weight_size; | |||
| @@ -236,7 +240,7 @@ int Conv2DOpenCLKernel::InitFilter() { | |||
| ConvertConvWeight4DTo7D<float16_t, float>(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_, | |||
| block_size_.C); | |||
| } | |||
| } else { | |||
| } else if (weight_tensor->data_type() == kNumberTypeFloat32) { | |||
| if (use_fp16_) { | |||
| ConvertConvWeight4DTo7D<float, float16_t>(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_, | |||
| block_size_.C); | |||
| @@ -244,6 +248,15 @@ int Conv2DOpenCLKernel::InitFilter() { | |||
| ConvertConvWeight4DTo7D<float, float>(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_, | |||
| block_size_.C); | |||
| } | |||
| } else { // int8 or int16 | |||
| if (use_fp16_) { | |||
| ConvertConvWeight4DTo7D<float16_t, float16_t>(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_, | |||
| block_size_.C); | |||
| } else { | |||
| ConvertConvWeight4DTo7D<float, float>(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_, | |||
| block_size_.C); | |||
| } | |||
| FreeDequantedWeight(); | |||
| } | |||
| } | |||
| @@ -92,6 +92,10 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { | |||
| MS_LOG(ERROR) << "DepthwiseConv2d don't support non-constant filter yet."; | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = DequantWeight(); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| auto parameter = reinterpret_cast<ConvParameter *>(op_parameter_); | |||
| auto allocator = ocl_runtime_->GetAllocator(); | |||
| bool is_fp16 = ocl_runtime_->GetFp16Enable(); | |||
| @@ -111,9 +115,10 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { | |||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) { | |||
| std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); }; | |||
| PackNCHWToNC4HW4<float, float16_t>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype); | |||
| } else { | |||
| MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type(); | |||
| return mindspore::lite::RET_ERROR; | |||
| } else { // int8 or int16 | |||
| std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; }; | |||
| PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype); | |||
| FreeDequantedWeight(); | |||
| } | |||
| } else { | |||
| packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float)); | |||
| @@ -124,9 +129,10 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { | |||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) { | |||
| std::function<float(float16_t)> to_dtype = [](float16_t x) -> float { return static_cast<float>(x); }; | |||
| PackNCHWToNC4HW4<float16_t, float>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype); | |||
| } else { | |||
| MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type(); | |||
| return mindspore::lite::RET_ERROR; | |||
| } else { // int8 or int16 | |||
| std::function<float(float)> to_dtype = [](float x) -> float { return x; }; | |||
| PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype); | |||
| FreeDequantedWeight(); | |||
| } | |||
| } | |||
| @@ -135,6 +135,10 @@ int FullConnectionOpenCLKernel::InitWeights() { | |||
| } // namespace mindspore::kernel | |||
| int FullConnectionOpenCLKernel::InitFilter() { | |||
| auto ret = DequantWeight(); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| auto allocator = ocl_runtime_->GetAllocator(); | |||
| auto intensor_shape = GpuTensorInfo(in_tensors_[0]); | |||
| int co4 = UP_DIV(CO_, C4NUM); | |||
| @@ -187,6 +191,7 @@ int FullConnectionOpenCLKernel::InitFilter() { | |||
| } | |||
| } | |||
| allocator->UnmapBuffer(padWeight_); | |||
| FreeDequantedWeight(); | |||
| return RET_OK; | |||
| } | |||
| @@ -84,6 +84,10 @@ int MatMulOpenCLKernel::InitWeights() { | |||
| MS_LOG(ERROR) << "Matmul don't support non-constant filter yet."; | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = DequantWeight(); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| auto allocator = ocl_runtime_->GetAllocator(); | |||
| int ci = inShape[3]; | |||
| int ci4 = UP_DIV(ci, C4NUM); | |||
| @@ -143,6 +147,7 @@ int MatMulOpenCLKernel::InitWeights() { | |||
| } | |||
| } | |||
| allocator->UnmapBuffer(padWeight_); | |||
| FreeDequantedWeight(); | |||
| return RET_OK; | |||
| } | |||
| @@ -0,0 +1,235 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| namespace mindspore::kernel { | |||
| int OpenCLKernel::AlignGlobalLocal(const std::vector<size_t> &global, const std::vector<size_t> &local) { | |||
| std::vector<size_t> internal_global_ws = global; | |||
| for (size_t i = 0; i < local.size(); ++i) { | |||
| internal_global_ws.at(i) = UP_ROUND(global.at(i), local.at(i)); | |||
| } | |||
| MS_LOG(DEBUG) << "global size: " << global.size() << ", local size: " << local.size(); | |||
| for (size_t i = 0; i < global.size(); i++) { | |||
| MS_LOG(DEBUG) << "global[" << i << "] = " << global.at(i); | |||
| } | |||
| for (size_t i = 0; i < local.size(); i++) { | |||
| MS_LOG(DEBUG) << "local[" << i << "] = " << local.at(i); | |||
| } | |||
| if (local.empty()) { | |||
| local_range_ = cl::NullRange; | |||
| } | |||
| if (global.size() == 1) { | |||
| global_range_ = cl::NDRange(internal_global_ws.at(0)); | |||
| if (!local.empty()) { | |||
| local_range_ = cl::NDRange(local.at(0)); | |||
| } | |||
| } else if (global.size() == 2) { | |||
| global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1)); | |||
| if (!local.empty()) { | |||
| local_range_ = cl::NDRange(local.at(0), local.at(1)); | |||
| } | |||
| } else if (global.size() == 3) { | |||
| global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1), internal_global_ws.at(2)); | |||
| if (!local.empty()) { | |||
| local_range_ = cl::NDRange(local.at(0), local.at(1), local.at(2)); | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Not supported NDRange!"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int OpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| MS_ASSERT(img_size); | |||
| if (idx >= out_tensors_.size()) { | |||
| return RET_ERROR; | |||
| } | |||
| auto img_info = GpuTensorInfo(out_tensors_[idx]); | |||
| size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT; | |||
| *img_size = {img_info.width, img_info.height, img_dtype}; | |||
| return RET_OK; | |||
| } | |||
| std::vector<BaseTuningParameter> OpenCLKernel::GenerateTuningParam() { | |||
| size_t ndim = global_size_.size(); | |||
| std::vector<BaseTuningParameter> tuning_params = {}; | |||
| if (ndim == 0) { | |||
| MS_LOG(ERROR) << "Generate tuning param failed, global_size_ is null."; | |||
| return tuning_params; | |||
| } | |||
| BaseTuningParameter default_tuning_param = BaseTuningParameter(); | |||
| default_tuning_param.local_size = local_size_; | |||
| tuning_params.push_back(default_tuning_param); | |||
| std::vector<size_t> max_work_items = ocl_runtime_->GetWorkItemSize(); | |||
| size_t max_workgroup_size = ocl_runtime_->GetMaxWorkGroupSize(kernel_); | |||
| const size_t MIN_WORKGROUP_SIZE = 8; | |||
| std::set<size_t> candidate_x = GenerateLocalByGlobal(global_size_[0]); | |||
| std::set<size_t> candidate_y = {1}; | |||
| std::set<size_t> candidate_z = {1}; | |||
| if (ndim > 1) { | |||
| candidate_y = GenerateLocalByGlobal(global_size_[1]); | |||
| } | |||
| if (ndim > 2) { | |||
| candidate_z = GenerateLocalByGlobal(global_size_[2]); | |||
| } | |||
| for (auto x : candidate_x) { | |||
| if (x <= max_work_items[0]) { | |||
| for (auto y : candidate_y) { | |||
| if (y <= max_work_items[1]) { | |||
| for (auto z : candidate_z) { | |||
| auto group_size = x * y * z; | |||
| if (z <= max_work_items[2] && group_size <= max_workgroup_size && group_size >= MIN_WORKGROUP_SIZE) { | |||
| BaseTuningParameter tuning_param = BaseTuningParameter(); | |||
| tuning_param.local_size = {x, y, z}; | |||
| tuning_params.push_back(tuning_param); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return tuning_params; | |||
| } | |||
| int OpenCLKernel::AssignTuningParam(const BaseTuningParameter ¶m) { | |||
| std::vector<size_t> local_size_tmp = param.local_size; | |||
| if (local_size_tmp.size() > global_size_.size()) { | |||
| local_size_tmp = std::vector<size_t>(local_size_tmp.begin(), local_size_tmp.begin() + global_size_.size()); | |||
| } | |||
| AlignGlobalLocal(global_size_, local_size_tmp); | |||
| return RET_OK; | |||
| } | |||
| int OpenCLKernel::Tune() { | |||
| if (!ocl_runtime_->isProfiling()) { | |||
| MS_LOG(WARNING) << "Tuning mode require opencl runtime profiling."; | |||
| return RET_OK; | |||
| } | |||
| lite::opencl::TuningMode mode = ocl_runtime_->GetTuningMode(); | |||
| if (mode == lite::opencl::TuningMode::DEFAULT) { | |||
| return RET_OK; | |||
| } | |||
| static const std::set<int> FAST_MODE_OPS = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, | |||
| schema::PrimitiveType_DeConv2D}; | |||
| if (mode == lite::opencl::TuningMode::FAST && FAST_MODE_OPS.find(op_parameter_->type_) == FAST_MODE_OPS.end()) { | |||
| return RET_OK; | |||
| } | |||
| auto tuning_params = GenerateTuningParam(); | |||
| if (tuning_params.empty()) { | |||
| MS_LOG(WARNING) << "Tuning param size is 0."; | |||
| return RET_OK; | |||
| } | |||
| int index = -1; | |||
| double min_time = MAX_PROFILING_TIME_MILLI_SECOND; | |||
| for (int i = 0; i < tuning_params.size(); i++) { | |||
| AssignTuningParam(tuning_params[i]); | |||
| auto ret = Run(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Tuning " << name() << " failed for tuning param " << tuning_params[i]; | |||
| return ret; | |||
| } | |||
| double current_time = GetProfilingTimeMs(); | |||
| MS_LOG(DEBUG) << "Tuning " << name() << " param (" << tuning_params[i] << ") exectime " << current_time << "ms"; | |||
| if (current_time < min_time) { | |||
| min_time = current_time; | |||
| index = i; | |||
| } | |||
| } | |||
| if (index != -1) { | |||
| MS_LOG(INFO) << "Tuning " << name() << " result: param (" << tuning_params[index] << ") exectime " << min_time | |||
| << "ms"; | |||
| AssignTuningParam(tuning_params[index]); | |||
| } else { | |||
| MS_LOG(WARNING) << "Cannot find suitable param."; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| double OpenCLKernel::GetProfilingTimeMs() { | |||
| if (!ocl_runtime_->isProfiling()) { | |||
| return MAX_PROFILING_TIME_MILLI_SECOND; | |||
| } | |||
| cl_ulong time_start; | |||
| cl_ulong time_end; | |||
| event_.getProfilingInfo(CL_PROFILING_COMMAND_START, &time_start); | |||
| event_.getProfilingInfo(CL_PROFILING_COMMAND_END, &time_end); | |||
| cl_ulong time_ns = time_end - time_start; | |||
| return static_cast<double>(time_ns) * 1e-6; | |||
| } | |||
| std::set<size_t> OpenCLKernel::GenerateLocalByGlobal(size_t global_i) { | |||
| std::set<size_t> local_ = {}; | |||
| int index = 1; | |||
| while (index <= global_i) { | |||
| local_.insert(index); | |||
| index *= 2; | |||
| } | |||
| for (size_t i = 1; i <= 16; i++) { | |||
| if (global_i % i == 0) { | |||
| local_.insert(i); | |||
| } | |||
| } | |||
| return local_; | |||
| } | |||
| int OpenCLKernel::DequantWeight() { | |||
| bool is_fp16 = ocl_runtime_->GetFp16Enable(); | |||
| auto *weight_tensor = in_tensors_.at(kWeightIndex); | |||
| auto *restore_data = weight_tensor->data_c(); | |||
| dequant_flag_ = | |||
| !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr; | |||
| if (dequant_flag_) { | |||
| void *dequant_weight{nullptr}; | |||
| bool set_flag{true}; | |||
| if (is_fp16) { | |||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { | |||
| dequant_weight = kernel::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor); | |||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { | |||
| dequant_weight = kernel::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor); | |||
| } else { | |||
| set_flag = false; | |||
| } | |||
| } else { | |||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { | |||
| dequant_weight = kernel::DequantUtil::DequantData<int8_t, float>(weight_tensor); | |||
| } else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { | |||
| dequant_weight = kernel::DequantUtil::DequantData<int16_t, float>(weight_tensor); | |||
| } else { | |||
| set_flag = false; | |||
| } | |||
| } | |||
| if (set_flag && dequant_weight == nullptr) { | |||
| MS_LOG(ERROR) << "dequant data failed."; | |||
| return RET_ERROR; | |||
| } | |||
| weight_tensor->set_data(dequant_weight); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void OpenCLKernel::FreeDequantedWeight() { | |||
| auto *weight_tensor = in_tensors_.at(kWeightIndex); | |||
| if (dequant_flag_) { | |||
| free(weight_tensor->data_c()); | |||
| } | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -25,6 +25,7 @@ | |||
| #include "src/lite_kernel.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/arm/base/dequant.h" | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| @@ -159,43 +160,7 @@ class OpenCLKernel : public LiteKernel { | |||
| ocl_runtime_ = ocl_runtime_wrap_.GetInstance(); | |||
| } | |||
| ~OpenCLKernel() override = default; | |||
| int AlignGlobalLocal(const std::vector<size_t> &global, const std::vector<size_t> &local) { | |||
| std::vector<size_t> internal_global_ws = global; | |||
| for (size_t i = 0; i < local.size(); ++i) { | |||
| internal_global_ws.at(i) = UP_ROUND(global.at(i), local.at(i)); | |||
| } | |||
| MS_LOG(DEBUG) << "global size: " << global.size() << ", local size: " << local.size(); | |||
| for (size_t i = 0; i < global.size(); i++) { | |||
| MS_LOG(DEBUG) << "global[" << i << "] = " << global.at(i); | |||
| } | |||
| for (size_t i = 0; i < local.size(); i++) { | |||
| MS_LOG(DEBUG) << "local[" << i << "] = " << local.at(i); | |||
| } | |||
| if (local.empty()) { | |||
| local_range_ = cl::NullRange; | |||
| } | |||
| if (global.size() == 1) { | |||
| global_range_ = cl::NDRange(internal_global_ws.at(0)); | |||
| if (!local.empty()) { | |||
| local_range_ = cl::NDRange(local.at(0)); | |||
| } | |||
| } else if (global.size() == 2) { | |||
| global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1)); | |||
| if (!local.empty()) { | |||
| local_range_ = cl::NDRange(local.at(0), local.at(1)); | |||
| } | |||
| } else if (global.size() == 3) { | |||
| global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1), internal_global_ws.at(2)); | |||
| if (!local.empty()) { | |||
| local_range_ = cl::NDRange(local.at(0), local.at(1), local.at(2)); | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Not supported NDRange!"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int AlignGlobalLocal(const std::vector<size_t> &global, const std::vector<size_t> &local); | |||
| int Prepare() override { return RET_OK; } | |||
| int PreProcess() override { return RET_ERROR; } | |||
| @@ -210,135 +175,20 @@ class OpenCLKernel : public LiteKernel { | |||
| virtual int GetLocalSize(size_t idx, const std::vector<size_t> &global_size, std::vector<size_t> *local_size) { | |||
| return RET_ERROR; | |||
| } | |||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| MS_ASSERT(img_size); | |||
| if (idx >= out_tensors_.size()) { | |||
| return RET_ERROR; | |||
| } | |||
| auto img_info = GpuTensorInfo(out_tensors_[idx]); | |||
| size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT; | |||
| *img_size = {img_info.width, img_info.height, img_dtype}; | |||
| return RET_OK; | |||
| } | |||
| virtual std::vector<BaseTuningParameter> GenerateTuningParam(); | |||
| virtual int AssignTuningParam(const BaseTuningParameter ¶m); | |||
| virtual int Tune(); | |||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size); | |||
| lite::opencl::MemType GetMemType() { return out_mem_type_; } | |||
| void SetMemType(lite::opencl::MemType mem_type) { out_mem_type_ = mem_type; } | |||
| OpParameter *GetParameter() { return op_parameter_; } | |||
| double GetProfilingTimeMs(); | |||
| int DequantWeight(); | |||
| void FreeDequantedWeight(); | |||
| virtual std::vector<BaseTuningParameter> GenerateTuningParam() { | |||
| size_t ndim = global_size_.size(); | |||
| std::vector<BaseTuningParameter> tuning_params = {}; | |||
| if (ndim == 0) { | |||
| MS_LOG(ERROR) << "Generate tuning param failed, global_size_ is null."; | |||
| return tuning_params; | |||
| } | |||
| BaseTuningParameter default_tuning_param = BaseTuningParameter(); | |||
| default_tuning_param.local_size = local_size_; | |||
| tuning_params.push_back(default_tuning_param); | |||
| std::vector<size_t> max_work_items = ocl_runtime_->GetWorkItemSize(); | |||
| size_t max_workgroup_size = ocl_runtime_->GetMaxWorkGroupSize(kernel_); | |||
| const size_t MIN_WORKGROUP_SIZE = 8; | |||
| std::set<size_t> candidate_x = GenerateLocalByGlobal(global_size_[0]); | |||
| std::set<size_t> candidate_y = {1}; | |||
| std::set<size_t> candidate_z = {1}; | |||
| if (ndim > 1) { | |||
| candidate_y = GenerateLocalByGlobal(global_size_[1]); | |||
| } | |||
| if (ndim > 2) { | |||
| candidate_z = GenerateLocalByGlobal(global_size_[2]); | |||
| } | |||
| for (auto x : candidate_x) { | |||
| if (x <= max_work_items[0]) { | |||
| for (auto y : candidate_y) { | |||
| if (y <= max_work_items[1]) { | |||
| for (auto z : candidate_z) { | |||
| auto group_size = x * y * z; | |||
| if (z <= max_work_items[2] && group_size <= max_workgroup_size && group_size >= MIN_WORKGROUP_SIZE) { | |||
| BaseTuningParameter tuning_param = BaseTuningParameter(); | |||
| tuning_param.local_size = {x, y, z}; | |||
| tuning_params.push_back(tuning_param); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return tuning_params; | |||
| } | |||
| virtual int AssignTuningParam(const BaseTuningParameter ¶m) { | |||
| std::vector<size_t> local_size_tmp = param.local_size; | |||
| if (local_size_tmp.size() > global_size_.size()) { | |||
| local_size_tmp = std::vector<size_t>(local_size_tmp.begin(), local_size_tmp.begin() + global_size_.size()); | |||
| } | |||
| AlignGlobalLocal(global_size_, local_size_tmp); | |||
| return RET_OK; | |||
| } | |||
| virtual int Tune() { | |||
| if (!ocl_runtime_->isProfiling()) { | |||
| MS_LOG(WARNING) << "Tuning mode require opencl runtime profiling."; | |||
| return RET_OK; | |||
| } | |||
| lite::opencl::TuningMode mode = ocl_runtime_->GetTuningMode(); | |||
| if (mode == lite::opencl::TuningMode::DEFAULT) { | |||
| return RET_OK; | |||
| } | |||
| static const std::set<int> FAST_MODE_OPS = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, | |||
| schema::PrimitiveType_DeConv2D}; | |||
| if (mode == lite::opencl::TuningMode::FAST && FAST_MODE_OPS.find(op_parameter_->type_) == FAST_MODE_OPS.end()) { | |||
| return RET_OK; | |||
| } | |||
| auto key = Key(); | |||
| auto finded = tuned_param_cache_.find(key); | |||
| if (finded != tuned_param_cache_.end()) { | |||
| auto cache_param = finded->second; | |||
| MS_LOG(INFO) << "Tuning " << name() << ", found cached param(" << cache_param << ")"; | |||
| return RET_OK; | |||
| } | |||
| auto tuning_params = GenerateTuningParam(); | |||
| if (tuning_params.empty()) { | |||
| MS_LOG(WARNING) << "Tuning param size is 0."; | |||
| return RET_OK; | |||
| } | |||
| int index = -1; | |||
| double min_time = MAX_PROFILING_TIME_MILLI_SECOND; | |||
| for (int i = 0; i < tuning_params.size(); i++) { | |||
| AssignTuningParam(tuning_params[i]); | |||
| auto ret = Run(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Tuning " << name() << " failed for tuning param " << tuning_params[i]; | |||
| return ret; | |||
| } | |||
| double current_time = GetProfilingTimeMs(); | |||
| MS_LOG(DEBUG) << "Tuning " << name() << " param (" << tuning_params[i] << ") exectime " << current_time << "ms"; | |||
| if (current_time < min_time) { | |||
| min_time = current_time; | |||
| index = i; | |||
| } | |||
| } | |||
| if (index != -1) { | |||
| MS_LOG(INFO) << "Tuning " << name() << " result: param (" << tuning_params[index] << ") exectime " << min_time | |||
| << "ms"; | |||
| AssignTuningParam(tuning_params[index]); | |||
| tuned_param_cache_[key] = tuning_params[index]; | |||
| } else { | |||
| MS_LOG(WARNING) << "Cannot find suitable param."; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| double GetProfilingTimeMs() { | |||
| if (!ocl_runtime_->isProfiling()) { | |||
| return MAX_PROFILING_TIME_MILLI_SECOND; | |||
| } | |||
| cl_ulong time_start; | |||
| cl_ulong time_end; | |||
| event_.getProfilingInfo(CL_PROFILING_COMMAND_START, &time_start); | |||
| event_.getProfilingInfo(CL_PROFILING_COMMAND_END, &time_end); | |||
| cl_ulong time_ns = time_end - time_start; | |||
| return static_cast<double>(time_ns) * 1e-6; | |||
| } | |||
| protected: | |||
| static std::set<size_t> GenerateLocalByGlobal(size_t global_i); | |||
| virtual std::string Key() { | |||
| std::string key = type_str(); | |||
| @@ -358,20 +208,7 @@ class OpenCLKernel : public LiteKernel { | |||
| std::vector<size_t> local_size_; | |||
| cl::Kernel kernel_; | |||
| cl::Event event_; | |||
| static std::set<size_t> GenerateLocalByGlobal(size_t global_i) { | |||
| std::set<size_t> local_ = {}; | |||
| int index = 1; | |||
| while (index <= global_i) { | |||
| local_.insert(index); | |||
| index *= 2; | |||
| } | |||
| for (size_t i = 1; i <= 16; i++) { | |||
| if (global_i % i == 0) { | |||
| local_.insert(i); | |||
| } | |||
| } | |||
| return local_; | |||
| } | |||
| bool dequant_flag_{false}; | |||
| private: | |||
| lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_; | |||
| @@ -82,6 +82,7 @@ if (SUPPORT_GPU) | |||
| set(KERNEL_OP_SRC | |||
| ${KERNEL_OP_SRC} | |||
| ${GPU_KERNEL_OP_SRC} | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/opencl_kernel.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/opencl_subgraph.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/opencl_fusion.cc | |||
| ${LITE_DIR}/src/runtime/kernel/opencl/utils.cc | |||
| @@ -0,0 +1 @@ | |||
| ml_face_openclose.tflite | |||
| @@ -525,12 +525,11 @@ function Run_x86() { | |||
| model_name_len=${#model_name} | |||
| input_params=${line:model_name_len+1} | |||
| input_num=${input_params%%;*} | |||
| input_shape=${input_params##*;} | |||
| input_files='' | |||
| output_file='' | |||
| if [[ -z "$input_files" || $input_files == 1 ]] && [ -e ${ms_models_path}/${model_name}'.ms.bin' ]; then | |||
| input_files=$model_name'.ms.bin' | |||
| elif [[ ! -z "$input_files" && $input_files > 1 ]]; then | |||
| elif [[ ! -z "$input_files" && $input_files -gt 1 ]]; then | |||
| for i in $(seq 1 $input_num) | |||
| do | |||
| input_files=$input_files$model_name'.ms.bin_'$i',' | |||
| @@ -788,12 +787,11 @@ function Run_x86_sse() { | |||
| model_name_len=${#model_name} | |||
| input_params=${line:model_name_len+1} | |||
| input_num=${input_params%%;*} | |||
| input_shape=${input_params##*;} | |||
| input_files='' | |||
| output_file='' | |||
| if [[ -z "$input_files" || $input_files == 1 ]] && [ -e ${ms_models_path}/${model_name}'.ms.bin' ]; then | |||
| input_files=$model_name'.ms.bin' | |||
| elif [[ ! -z "$input_files" && $input_files > 1 ]]; then | |||
| elif [[ ! -z "$input_files" && $input_files -gt 1 ]]; then | |||
| for i in $(seq 1 $input_num) | |||
| do | |||
| input_files=$input_files$model_name'.ms.bin_'$i',' | |||
| @@ -1147,18 +1145,7 @@ function Run_arm64() { | |||
| else | |||
| run_result='arm64_gpu: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||
| fi | |||
| # run benchmark test without clib data | |||
| #echo ${model_name} | |||
| echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> "${run_arm64_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> adb_run_cmd.txt | |||
| adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" | |||
| if [ $? = 0 ]; then | |||
| run_result='arm64_gpu: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||
| else | |||
| run_result='arm64_gpu: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||
| fi | |||
| done < ${models_tflite_gpu_config} | |||
| done < ${models_gpu_fp32_config} | |||
| # Run GPU fp16 converted models: | |||
| while read line; do | |||
| @@ -1176,19 +1163,27 @@ function Run_arm64() { | |||
| else | |||
| run_result='arm64_gpu_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||
| fi | |||
| # run benchmark test without clib data | |||
| #sleep 1 | |||
| done < ${models_gpu_fp16_config} | |||
| # Run GPU weightquant converted models: | |||
| while read line; do | |||
| model_name=${line} | |||
| if [[ $model_name == \#* ]]; then | |||
| continue | |||
| fi | |||
| echo ${model_name} >> "${run_arm64_log_file}" | |||
| echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2 --enableFp16=true' >> "${run_arm64_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2 --enableFp16=true' >> adb_run_cmd.txt | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'_weightquant.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold=5' >> "${run_arm64_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'_weightquant.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold=5' >> adb_run_cmd.txt | |||
| adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" | |||
| if [ $? = 0 ]; then | |||
| run_result='arm64_gpu_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||
| run_result='arm64_gpu_weightquant: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} | |||
| else | |||
| run_result='arm64_gpu_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||
| run_result='arm64_gpu_weightquant: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1 | |||
| fi | |||
| #sleep 1 | |||
| done < ${models_fp16_gpu_config} | |||
| done < ${models_gpu_weightquant_config} | |||
| # Run mindir converted models: | |||
| while read line; do | |||
| @@ -1274,12 +1269,11 @@ function Run_arm64() { | |||
| model_name_len=${#model_name} | |||
| input_params=${line:model_name_len+1} | |||
| input_num=${input_params%%;*} | |||
| input_shape=${input_params##*;} | |||
| input_files='' | |||
| output_file='' | |||
| if [[ -z "$input_files" || $input_files == 1 ]] && [ -e ${ms_models_path}/${model_name}'.ms.bin' ]; then | |||
| input_files=$model_name'.ms.bin' | |||
| elif [[ ! -z "$input_files" && $input_files > 1 ]]; then | |||
| elif [[ ! -z "$input_files" && $input_files -gt 1 ]]; then | |||
| for i in $(seq 1 $input_num) | |||
| do | |||
| input_files=$input_files$model_name'.ms.bin_'$i',' | |||
| @@ -1445,9 +1439,10 @@ models_tflite_fp16_config=${basepath}/models_tflite_fp16.cfg | |||
| models_mindspore_config=${basepath}/models_mindspore.cfg | |||
| models_mindspore_train_config=${basepath}/models_mindspore_train.cfg | |||
| models_mindspore_mixbit_config=${basepath}/models_mindspore_mixbit.cfg | |||
| models_tflite_gpu_config=${basepath}/models_fp32_gpu.cfg | |||
| models_gpu_fp32_config=${basepath}/models_gpu_fp32.cfg | |||
| models_gpu_fp16_config=${basepath}/models_gpu_fp16.cfg | |||
| models_gpu_weightquant_config=${basepath}/models_gpu_weightquant.cfg | |||
| models_mindspore_weightquant_config=${basepath}/models_mindspore_weightquant.cfg | |||
| models_fp16_gpu_config=${basepath}/models_fp16_gpu.cfg | |||
| models_arm32_config=${basepath}/models_arm32.cfg | |||
| models_compatibility_config=${basepath}/models_compatibility.cfg | |||
| models_only_for_process_config=${basepath}/models_with_several_inputs_or_without_outputs.cfg | |||