Browse Source

!9629 support weight quant for opencl

From: @ddwsky
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e5d48a2786
13 changed files with 315 additions and 217 deletions
  1. +1
    -0
      mindspore/lite/src/CMakeLists.txt
  2. +8
    -9
      mindspore/lite/src/runtime/kernel/arm/base/dequant.h
  3. +14
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc
  4. +12
    -6
      mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc
  5. +5
    -0
      mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc
  6. +5
    -0
      mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc
  7. +235
    -0
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc
  8. +12
    -175
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h
  9. +1
    -0
      mindspore/lite/test/CMakeLists.txt
  10. +0
    -0
      mindspore/lite/test/models_gpu_fp16.cfg
  11. +0
    -0
      mindspore/lite/test/models_gpu_fp32.cfg
  12. +1
    -0
      mindspore/lite/test/models_gpu_weightquant.cfg
  13. +21
    -26
      mindspore/lite/test/run_benchmark_nets.sh

+ 1
- 0
mindspore/lite/src/CMakeLists.txt View File

@@ -42,6 +42,7 @@ set(LITE_SRC
if (SUPPORT_GPU)
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_subgraph.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/opencl_fusion.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc


+ 8
- 9
mindspore/lite/src/runtime/kernel/arm/base/dequant.h View File

@@ -31,14 +31,14 @@ class DequantUtil {

static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);

template <typename T>
static float *DequantData(lite::Tensor *input_tensor) {
const auto *quant_datas = static_cast<const T *>(input_tensor->MutableData());
template <typename ST, typename DT = float>
static DT *DequantData(lite::Tensor *input_tensor) {
const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData());
if (quant_datas == nullptr) {
MS_LOG(ERROR) << "Get quant tensor failed.";
return nullptr;
}
auto *dequant_datas = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float)));
DT *dequant_datas = static_cast<DT *>(malloc(input_tensor->ElementsNum() * sizeof(DT)));
if (dequant_datas == nullptr) {
MS_LOG(ERROR) << "Malloc failed.";
return nullptr;
@@ -53,8 +53,7 @@ class DequantUtil {
auto zero_point = param.zeroPoint;
auto matrix_size = input_tensor->ElementsNum() / per_batch_size;
for (int64_t j = 0; j < matrix_size; j++) {
dequant_datas[i * matrix_size + j] =
static_cast<float>((quant_datas[i * matrix_size + j] - zero_point) * scale);
dequant_datas[i * matrix_size + j] = static_cast<DT>((quant_datas[i * matrix_size + j] - zero_point) * scale);
}
}
} else if (input_tensor->quant_params().size() != kPerTensor) {
@@ -78,7 +77,7 @@ class DequantUtil {
}
for (size_t j = 0; j < per_channel_size; j++) {
auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale;
dequant_datas[per_channel_size * i + j] = static_cast<float>(dequant_data * var_corr + mean_corr);
dequant_datas[per_channel_size * i + j] = static_cast<DT>(dequant_data * var_corr + mean_corr);
}
}
} else {
@@ -95,9 +94,9 @@ class DequantUtil {
free(dequant_datas);
return nullptr;
}
dequant_datas[j] = static_cast<float>(param.clusters[index - INT8_MIN]);
dequant_datas[j] = static_cast<DT>(param.clusters[index - INT8_MIN]);
} else {
dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale);
dequant_datas[j] = static_cast<DT>((quant_datas[j] - zero_point) * scale);
}
}
}


+ 14
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc View File

@@ -211,6 +211,10 @@ int Conv2DOpenCLKernel::GenerateWinogradFilter() {

int Conv2DOpenCLKernel::InitFilter() {
auto allocator = ocl_runtime_->GetAllocator();
auto ret = DequantWeight();
if (ret != RET_OK) {
return ret;
}

// allocate memory
size_t packed_weight_size;
@@ -236,7 +240,7 @@ int Conv2DOpenCLKernel::InitFilter() {
ConvertConvWeight4DTo7D<float16_t, float>(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_,
block_size_.C);
}
} else {
} else if (weight_tensor->data_type() == kNumberTypeFloat32) {
if (use_fp16_) {
ConvertConvWeight4DTo7D<float, float16_t>(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_,
block_size_.C);
@@ -244,6 +248,15 @@ int Conv2DOpenCLKernel::InitFilter() {
ConvertConvWeight4DTo7D<float, float>(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_,
block_size_.C);
}
} else { // int8 or int16
if (use_fp16_) {
ConvertConvWeight4DTo7D<float16_t, float16_t>(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_,
block_size_.C);
} else {
ConvertConvWeight4DTo7D<float, float>(weight_tensor->data_c(), packed_weight_, CO_, KH_, KW_, CI_,
block_size_.C);
}
FreeDequantedWeight();
}
}



+ 12
- 6
mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc View File

@@ -92,6 +92,10 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
MS_LOG(ERROR) << "DepthwiseConv2d don't support non-constant filter yet.";
return RET_ERROR;
}
auto ret = DequantWeight();
if (ret != RET_OK) {
return ret;
}
auto parameter = reinterpret_cast<ConvParameter *>(op_parameter_);
auto allocator = ocl_runtime_->GetAllocator();
bool is_fp16 = ocl_runtime_->GetFp16Enable();
@@ -111,9 +115,10 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) {
std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); };
PackNCHWToNC4HW4<float, float16_t>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
} else {
MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type();
return mindspore::lite::RET_ERROR;
} else { // int8 or int16
std::function<int16_t(int16_t)> to_dtype = [](int16_t x) -> int16_t { return x; };
PackNCHWToNC4HW4<int16_t, int16_t>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
FreeDequantedWeight();
}
} else {
packed_weight_ = allocator->Malloc(pack_weight_size * sizeof(float));
@@ -124,9 +129,10 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) {
std::function<float(float16_t)> to_dtype = [](float16_t x) -> float { return static_cast<float>(x); };
PackNCHWToNC4HW4<float16_t, float>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
} else {
MS_LOG(ERROR) << "Only support float16/float32, actual data type " << in_tensors_.at(kWeightIndex)->data_type();
return mindspore::lite::RET_ERROR;
} else { // int8 or int16
std::function<float(float)> to_dtype = [](float x) -> float { return x; };
PackNCHWToNC4HW4<float, float>(origin_weight, packed_weight_, 1, plane, out_tensors_[0]->Channel(), to_dtype);
FreeDequantedWeight();
}
}



+ 5
- 0
mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc View File

@@ -135,6 +135,10 @@ int FullConnectionOpenCLKernel::InitWeights() {
} // namespace mindspore::kernel

int FullConnectionOpenCLKernel::InitFilter() {
auto ret = DequantWeight();
if (ret != RET_OK) {
return ret;
}
auto allocator = ocl_runtime_->GetAllocator();
auto intensor_shape = GpuTensorInfo(in_tensors_[0]);
int co4 = UP_DIV(CO_, C4NUM);
@@ -187,6 +191,7 @@ int FullConnectionOpenCLKernel::InitFilter() {
}
}
allocator->UnmapBuffer(padWeight_);
FreeDequantedWeight();
return RET_OK;
}



+ 5
- 0
mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc View File

@@ -84,6 +84,10 @@ int MatMulOpenCLKernel::InitWeights() {
MS_LOG(ERROR) << "Matmul don't support non-constant filter yet.";
return RET_ERROR;
}
auto ret = DequantWeight();
if (ret != RET_OK) {
return ret;
}
auto allocator = ocl_runtime_->GetAllocator();
int ci = inShape[3];
int ci4 = UP_DIV(ci, C4NUM);
@@ -143,6 +147,7 @@ int MatMulOpenCLKernel::InitWeights() {
}
}
allocator->UnmapBuffer(padWeight_);
FreeDequantedWeight();
return RET_OK;
}



+ 235
- 0
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc View File

@@ -0,0 +1,235 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;

namespace mindspore::kernel {

int OpenCLKernel::AlignGlobalLocal(const std::vector<size_t> &global, const std::vector<size_t> &local) {
std::vector<size_t> internal_global_ws = global;
for (size_t i = 0; i < local.size(); ++i) {
internal_global_ws.at(i) = UP_ROUND(global.at(i), local.at(i));
}

MS_LOG(DEBUG) << "global size: " << global.size() << ", local size: " << local.size();
for (size_t i = 0; i < global.size(); i++) {
MS_LOG(DEBUG) << "global[" << i << "] = " << global.at(i);
}
for (size_t i = 0; i < local.size(); i++) {
MS_LOG(DEBUG) << "local[" << i << "] = " << local.at(i);
}
if (local.empty()) {
local_range_ = cl::NullRange;
}
if (global.size() == 1) {
global_range_ = cl::NDRange(internal_global_ws.at(0));
if (!local.empty()) {
local_range_ = cl::NDRange(local.at(0));
}
} else if (global.size() == 2) {
global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1));
if (!local.empty()) {
local_range_ = cl::NDRange(local.at(0), local.at(1));
}
} else if (global.size() == 3) {
global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1), internal_global_ws.at(2));
if (!local.empty()) {
local_range_ = cl::NDRange(local.at(0), local.at(1), local.at(2));
}
} else {
MS_LOG(ERROR) << "Not supported NDRange!";
return RET_ERROR;
}
return RET_OK;
}

int OpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
MS_ASSERT(img_size);
if (idx >= out_tensors_.size()) {
return RET_ERROR;
}
auto img_info = GpuTensorInfo(out_tensors_[idx]);
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
*img_size = {img_info.width, img_info.height, img_dtype};
return RET_OK;
}

std::vector<BaseTuningParameter> OpenCLKernel::GenerateTuningParam() {
size_t ndim = global_size_.size();
std::vector<BaseTuningParameter> tuning_params = {};
if (ndim == 0) {
MS_LOG(ERROR) << "Generate tuning param failed, global_size_ is null.";
return tuning_params;
}
BaseTuningParameter default_tuning_param = BaseTuningParameter();
default_tuning_param.local_size = local_size_;
tuning_params.push_back(default_tuning_param);
std::vector<size_t> max_work_items = ocl_runtime_->GetWorkItemSize();
size_t max_workgroup_size = ocl_runtime_->GetMaxWorkGroupSize(kernel_);
const size_t MIN_WORKGROUP_SIZE = 8;
std::set<size_t> candidate_x = GenerateLocalByGlobal(global_size_[0]);
std::set<size_t> candidate_y = {1};
std::set<size_t> candidate_z = {1};
if (ndim > 1) {
candidate_y = GenerateLocalByGlobal(global_size_[1]);
}
if (ndim > 2) {
candidate_z = GenerateLocalByGlobal(global_size_[2]);
}
for (auto x : candidate_x) {
if (x <= max_work_items[0]) {
for (auto y : candidate_y) {
if (y <= max_work_items[1]) {
for (auto z : candidate_z) {
auto group_size = x * y * z;
if (z <= max_work_items[2] && group_size <= max_workgroup_size && group_size >= MIN_WORKGROUP_SIZE) {
BaseTuningParameter tuning_param = BaseTuningParameter();
tuning_param.local_size = {x, y, z};
tuning_params.push_back(tuning_param);
}
}
}
}
}
}
return tuning_params;
}

int OpenCLKernel::AssignTuningParam(const BaseTuningParameter &param) {
std::vector<size_t> local_size_tmp = param.local_size;
if (local_size_tmp.size() > global_size_.size()) {
local_size_tmp = std::vector<size_t>(local_size_tmp.begin(), local_size_tmp.begin() + global_size_.size());
}
AlignGlobalLocal(global_size_, local_size_tmp);
return RET_OK;
}

int OpenCLKernel::Tune() {
if (!ocl_runtime_->isProfiling()) {
MS_LOG(WARNING) << "Tuning mode require opencl runtime profiling.";
return RET_OK;
}
lite::opencl::TuningMode mode = ocl_runtime_->GetTuningMode();
if (mode == lite::opencl::TuningMode::DEFAULT) {
return RET_OK;
}
static const std::set<int> FAST_MODE_OPS = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_DeConv2D};
if (mode == lite::opencl::TuningMode::FAST && FAST_MODE_OPS.find(op_parameter_->type_) == FAST_MODE_OPS.end()) {
return RET_OK;
}
auto tuning_params = GenerateTuningParam();
if (tuning_params.empty()) {
MS_LOG(WARNING) << "Tuning param size is 0.";
return RET_OK;
}
int index = -1;
double min_time = MAX_PROFILING_TIME_MILLI_SECOND;
for (int i = 0; i < tuning_params.size(); i++) {
AssignTuningParam(tuning_params[i]);
auto ret = Run();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Tuning " << name() << " failed for tuning param " << tuning_params[i];
return ret;
}
double current_time = GetProfilingTimeMs();
MS_LOG(DEBUG) << "Tuning " << name() << " param (" << tuning_params[i] << ") exectime " << current_time << "ms";
if (current_time < min_time) {
min_time = current_time;
index = i;
}
}
if (index != -1) {
MS_LOG(INFO) << "Tuning " << name() << " result: param (" << tuning_params[index] << ") exectime " << min_time
<< "ms";
AssignTuningParam(tuning_params[index]);
} else {
MS_LOG(WARNING) << "Cannot find suitable param.";
}
return RET_OK;
}

double OpenCLKernel::GetProfilingTimeMs() {
if (!ocl_runtime_->isProfiling()) {
return MAX_PROFILING_TIME_MILLI_SECOND;
}
cl_ulong time_start;
cl_ulong time_end;
event_.getProfilingInfo(CL_PROFILING_COMMAND_START, &time_start);
event_.getProfilingInfo(CL_PROFILING_COMMAND_END, &time_end);
cl_ulong time_ns = time_end - time_start;
return static_cast<double>(time_ns) * 1e-6;
}

std::set<size_t> OpenCLKernel::GenerateLocalByGlobal(size_t global_i) {
std::set<size_t> local_ = {};
int index = 1;
while (index <= global_i) {
local_.insert(index);
index *= 2;
}
for (size_t i = 1; i <= 16; i++) {
if (global_i % i == 0) {
local_.insert(i);
}
}
return local_;
}
int OpenCLKernel::DequantWeight() {
bool is_fp16 = ocl_runtime_->GetFp16Enable();
auto *weight_tensor = in_tensors_.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
dequant_flag_ =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag_) {
void *dequant_weight{nullptr};
bool set_flag{true};
if (is_fp16) {
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) {
dequant_weight = kernel::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor);
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) {
dequant_weight = kernel::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor);
} else {
set_flag = false;
}
} else {
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) {
dequant_weight = kernel::DequantUtil::DequantData<int8_t, float>(weight_tensor);
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) {
dequant_weight = kernel::DequantUtil::DequantData<int16_t, float>(weight_tensor);
} else {
set_flag = false;
}
}
if (set_flag && dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data failed.";
return RET_ERROR;
}
weight_tensor->set_data(dequant_weight);
}
return RET_OK;
}
void OpenCLKernel::FreeDequantedWeight() {
auto *weight_tensor = in_tensors_.at(kWeightIndex);
if (dequant_flag_) {
free(weight_tensor->data_c());
}
}
} // namespace mindspore::kernel

+ 12
- 175
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h View File

@@ -25,6 +25,7 @@
#include "src/lite_kernel.h"
#include "include/errorcode.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/base/dequant.h"

using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
@@ -159,43 +160,7 @@ class OpenCLKernel : public LiteKernel {
ocl_runtime_ = ocl_runtime_wrap_.GetInstance();
}
~OpenCLKernel() override = default;
int AlignGlobalLocal(const std::vector<size_t> &global, const std::vector<size_t> &local) {
std::vector<size_t> internal_global_ws = global;
for (size_t i = 0; i < local.size(); ++i) {
internal_global_ws.at(i) = UP_ROUND(global.at(i), local.at(i));
}

MS_LOG(DEBUG) << "global size: " << global.size() << ", local size: " << local.size();
for (size_t i = 0; i < global.size(); i++) {
MS_LOG(DEBUG) << "global[" << i << "] = " << global.at(i);
}
for (size_t i = 0; i < local.size(); i++) {
MS_LOG(DEBUG) << "local[" << i << "] = " << local.at(i);
}
if (local.empty()) {
local_range_ = cl::NullRange;
}
if (global.size() == 1) {
global_range_ = cl::NDRange(internal_global_ws.at(0));
if (!local.empty()) {
local_range_ = cl::NDRange(local.at(0));
}
} else if (global.size() == 2) {
global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1));
if (!local.empty()) {
local_range_ = cl::NDRange(local.at(0), local.at(1));
}
} else if (global.size() == 3) {
global_range_ = cl::NDRange(internal_global_ws.at(0), internal_global_ws.at(1), internal_global_ws.at(2));
if (!local.empty()) {
local_range_ = cl::NDRange(local.at(0), local.at(1), local.at(2));
}
} else {
MS_LOG(ERROR) << "Not supported NDRange!";
return RET_ERROR;
}
return RET_OK;
}
int AlignGlobalLocal(const std::vector<size_t> &global, const std::vector<size_t> &local);

int Prepare() override { return RET_OK; }
int PreProcess() override { return RET_ERROR; }
@@ -210,135 +175,20 @@ class OpenCLKernel : public LiteKernel {
virtual int GetLocalSize(size_t idx, const std::vector<size_t> &global_size, std::vector<size_t> *local_size) {
return RET_ERROR;
}
int GetImageSize(size_t idx, std::vector<size_t> *img_size) {
MS_ASSERT(img_size);
if (idx >= out_tensors_.size()) {
return RET_ERROR;
}
auto img_info = GpuTensorInfo(out_tensors_[idx]);
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
*img_size = {img_info.width, img_info.height, img_dtype};
return RET_OK;
}
virtual std::vector<BaseTuningParameter> GenerateTuningParam();
virtual int AssignTuningParam(const BaseTuningParameter &param);
virtual int Tune();

int GetImageSize(size_t idx, std::vector<size_t> *img_size);
lite::opencl::MemType GetMemType() { return out_mem_type_; }
void SetMemType(lite::opencl::MemType mem_type) { out_mem_type_ = mem_type; }
OpParameter *GetParameter() { return op_parameter_; }
double GetProfilingTimeMs();
int DequantWeight();
void FreeDequantedWeight();

virtual std::vector<BaseTuningParameter> GenerateTuningParam() {
size_t ndim = global_size_.size();
std::vector<BaseTuningParameter> tuning_params = {};
if (ndim == 0) {
MS_LOG(ERROR) << "Generate tuning param failed, global_size_ is null.";
return tuning_params;
}
BaseTuningParameter default_tuning_param = BaseTuningParameter();
default_tuning_param.local_size = local_size_;
tuning_params.push_back(default_tuning_param);
std::vector<size_t> max_work_items = ocl_runtime_->GetWorkItemSize();
size_t max_workgroup_size = ocl_runtime_->GetMaxWorkGroupSize(kernel_);
const size_t MIN_WORKGROUP_SIZE = 8;
std::set<size_t> candidate_x = GenerateLocalByGlobal(global_size_[0]);
std::set<size_t> candidate_y = {1};
std::set<size_t> candidate_z = {1};
if (ndim > 1) {
candidate_y = GenerateLocalByGlobal(global_size_[1]);
}
if (ndim > 2) {
candidate_z = GenerateLocalByGlobal(global_size_[2]);
}
for (auto x : candidate_x) {
if (x <= max_work_items[0]) {
for (auto y : candidate_y) {
if (y <= max_work_items[1]) {
for (auto z : candidate_z) {
auto group_size = x * y * z;
if (z <= max_work_items[2] && group_size <= max_workgroup_size && group_size >= MIN_WORKGROUP_SIZE) {
BaseTuningParameter tuning_param = BaseTuningParameter();
tuning_param.local_size = {x, y, z};
tuning_params.push_back(tuning_param);
}
}
}
}
}
}
return tuning_params;
}

virtual int AssignTuningParam(const BaseTuningParameter &param) {
std::vector<size_t> local_size_tmp = param.local_size;
if (local_size_tmp.size() > global_size_.size()) {
local_size_tmp = std::vector<size_t>(local_size_tmp.begin(), local_size_tmp.begin() + global_size_.size());
}
AlignGlobalLocal(global_size_, local_size_tmp);
return RET_OK;
}

virtual int Tune() {
if (!ocl_runtime_->isProfiling()) {
MS_LOG(WARNING) << "Tuning mode require opencl runtime profiling.";
return RET_OK;
}
lite::opencl::TuningMode mode = ocl_runtime_->GetTuningMode();
if (mode == lite::opencl::TuningMode::DEFAULT) {
return RET_OK;
}
static const std::set<int> FAST_MODE_OPS = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_DeConv2D};
if (mode == lite::opencl::TuningMode::FAST && FAST_MODE_OPS.find(op_parameter_->type_) == FAST_MODE_OPS.end()) {
return RET_OK;
}
auto key = Key();
auto finded = tuned_param_cache_.find(key);
if (finded != tuned_param_cache_.end()) {
auto cache_param = finded->second;
MS_LOG(INFO) << "Tuning " << name() << ", found cached param(" << cache_param << ")";
return RET_OK;
}
auto tuning_params = GenerateTuningParam();
if (tuning_params.empty()) {
MS_LOG(WARNING) << "Tuning param size is 0.";
return RET_OK;
}
int index = -1;
double min_time = MAX_PROFILING_TIME_MILLI_SECOND;
for (int i = 0; i < tuning_params.size(); i++) {
AssignTuningParam(tuning_params[i]);
auto ret = Run();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Tuning " << name() << " failed for tuning param " << tuning_params[i];
return ret;
}
double current_time = GetProfilingTimeMs();
MS_LOG(DEBUG) << "Tuning " << name() << " param (" << tuning_params[i] << ") exectime " << current_time << "ms";
if (current_time < min_time) {
min_time = current_time;
index = i;
}
}
if (index != -1) {
MS_LOG(INFO) << "Tuning " << name() << " result: param (" << tuning_params[index] << ") exectime " << min_time
<< "ms";
AssignTuningParam(tuning_params[index]);
tuned_param_cache_[key] = tuning_params[index];
} else {
MS_LOG(WARNING) << "Cannot find suitable param.";
}
return RET_OK;
}

double GetProfilingTimeMs() {
if (!ocl_runtime_->isProfiling()) {
return MAX_PROFILING_TIME_MILLI_SECOND;
}
cl_ulong time_start;
cl_ulong time_end;
event_.getProfilingInfo(CL_PROFILING_COMMAND_START, &time_start);
event_.getProfilingInfo(CL_PROFILING_COMMAND_END, &time_end);
cl_ulong time_ns = time_end - time_start;
return static_cast<double>(time_ns) * 1e-6;
}
protected:
static std::set<size_t> GenerateLocalByGlobal(size_t global_i);

virtual std::string Key() {
std::string key = type_str();
@@ -358,20 +208,7 @@ class OpenCLKernel : public LiteKernel {
std::vector<size_t> local_size_;
cl::Kernel kernel_;
cl::Event event_;
static std::set<size_t> GenerateLocalByGlobal(size_t global_i) {
std::set<size_t> local_ = {};
int index = 1;
while (index <= global_i) {
local_.insert(index);
index *= 2;
}
for (size_t i = 1; i <= 16; i++) {
if (global_i % i == 0) {
local_.insert(i);
}
}
return local_;
}
bool dequant_flag_{false};

private:
lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_;


+ 1
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -82,6 +82,7 @@ if (SUPPORT_GPU)
set(KERNEL_OP_SRC
${KERNEL_OP_SRC}
${GPU_KERNEL_OP_SRC}
${LITE_DIR}/src/runtime/kernel/opencl/opencl_kernel.cc
${LITE_DIR}/src/runtime/kernel/opencl/opencl_subgraph.cc
${LITE_DIR}/src/runtime/kernel/opencl/opencl_fusion.cc
${LITE_DIR}/src/runtime/kernel/opencl/utils.cc


mindspore/lite/test/models_fp16_gpu.cfg → mindspore/lite/test/models_gpu_fp16.cfg View File


mindspore/lite/test/models_fp32_gpu.cfg → mindspore/lite/test/models_gpu_fp32.cfg View File


+ 1
- 0
mindspore/lite/test/models_gpu_weightquant.cfg View File

@@ -0,0 +1 @@
ml_face_openclose.tflite

+ 21
- 26
mindspore/lite/test/run_benchmark_nets.sh View File

@@ -525,12 +525,11 @@ function Run_x86() {
model_name_len=${#model_name}
input_params=${line:model_name_len+1}
input_num=${input_params%%;*}
input_shape=${input_params##*;}
input_files=''
output_file=''
if [[ -z "$input_files" || $input_files == 1 ]] && [ -e ${ms_models_path}/${model_name}'.ms.bin' ]; then
input_files=$model_name'.ms.bin'
elif [[ ! -z "$input_files" && $input_files > 1 ]]; then
elif [[ ! -z "$input_files" && $input_files -gt 1 ]]; then
for i in $(seq 1 $input_num)
do
input_files=$input_files$model_name'.ms.bin_'$i','
@@ -788,12 +787,11 @@ function Run_x86_sse() {
model_name_len=${#model_name}
input_params=${line:model_name_len+1}
input_num=${input_params%%;*}
input_shape=${input_params##*;}
input_files=''
output_file=''
if [[ -z "$input_files" || $input_files == 1 ]] && [ -e ${ms_models_path}/${model_name}'.ms.bin' ]; then
input_files=$model_name'.ms.bin'
elif [[ ! -z "$input_files" && $input_files > 1 ]]; then
elif [[ ! -z "$input_files" && $input_files -gt 1 ]]; then
for i in $(seq 1 $input_num)
do
input_files=$input_files$model_name'.ms.bin_'$i','
@@ -1147,18 +1145,7 @@ function Run_arm64() {
else
run_result='arm64_gpu: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
# run benchmark test without clib data
#echo ${model_name}
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> "${run_arm64_log_file}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2' >> adb_run_cmd.txt
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}"
if [ $? = 0 ]; then
run_result='arm64_gpu: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='arm64_gpu: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
done < ${models_tflite_gpu_config}
done < ${models_gpu_fp32_config}

# Run GPU fp16 converted models:
while read line; do
@@ -1176,19 +1163,27 @@ function Run_arm64() {
else
run_result='arm64_gpu_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
# run benchmark test without clib data
#sleep 1
done < ${models_gpu_fp16_config}

# Run GPU weightquant converted models:
while read line; do
model_name=${line}
if [[ $model_name == \#* ]]; then
continue
fi
echo ${model_name} >> "${run_arm64_log_file}"
echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2 --enableFp16=true' >> "${run_arm64_log_file}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'.ms --warmUpLoopCount=1 --loopCount=2 --enableFp16=true' >> adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'_weightquant.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold=5' >> "${run_arm64_log_file}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --device=GPU --modelFile='${model_name}'_weightquant.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --enableFp16=true --accuracyThreshold=5' >> adb_run_cmd.txt
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}"
if [ $? = 0 ]; then
run_result='arm64_gpu_fp16: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
run_result='arm64_gpu_weightquant: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file}
else
run_result='arm64_gpu_fp16: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
run_result='arm64_gpu_weightquant: '${model_name}' failed'; echo ${run_result} >> ${run_benchmark_result_file}; return 1
fi
#sleep 1
done < ${models_fp16_gpu_config}
done < ${models_gpu_weightquant_config}

# Run mindir converted models:
while read line; do
@@ -1274,12 +1269,11 @@ function Run_arm64() {
model_name_len=${#model_name}
input_params=${line:model_name_len+1}
input_num=${input_params%%;*}
input_shape=${input_params##*;}
input_files=''
output_file=''
if [[ -z "$input_files" || $input_files == 1 ]] && [ -e ${ms_models_path}/${model_name}'.ms.bin' ]; then
input_files=$model_name'.ms.bin'
elif [[ ! -z "$input_files" && $input_files > 1 ]]; then
elif [[ ! -z "$input_files" && $input_files -gt 1 ]]; then
for i in $(seq 1 $input_num)
do
input_files=$input_files$model_name'.ms.bin_'$i','
@@ -1445,9 +1439,10 @@ models_tflite_fp16_config=${basepath}/models_tflite_fp16.cfg
models_mindspore_config=${basepath}/models_mindspore.cfg
models_mindspore_train_config=${basepath}/models_mindspore_train.cfg
models_mindspore_mixbit_config=${basepath}/models_mindspore_mixbit.cfg
models_tflite_gpu_config=${basepath}/models_fp32_gpu.cfg
models_gpu_fp32_config=${basepath}/models_gpu_fp32.cfg
models_gpu_fp16_config=${basepath}/models_gpu_fp16.cfg
models_gpu_weightquant_config=${basepath}/models_gpu_weightquant.cfg
models_mindspore_weightquant_config=${basepath}/models_mindspore_weightquant.cfg
models_fp16_gpu_config=${basepath}/models_fp16_gpu.cfg
models_arm32_config=${basepath}/models_arm32.cfg
models_compatibility_config=${basepath}/models_compatibility.cfg
models_only_for_process_config=${basepath}/models_with_several_inputs_or_without_outputs.cfg


Loading…
Cancel
Save