Browse Source

fix matmul quantizationt

tags/v1.2.0-rc1
xutianchun 4 years ago
parent
commit
72daa10df6
6 changed files with 199 additions and 97 deletions
  1. +22
    -5
      mindspore/lite/src/dequant.cc
  2. +17
    -5
      mindspore/lite/src/dequant.h
  3. +3
    -2
      mindspore/lite/src/scheduler.cc
  4. +118
    -56
      mindspore/lite/tools/converter/quantizer/quantize_util.cc
  5. +33
    -25
      mindspore/lite/tools/converter/quantizer/quantize_util.h
  6. +6
    -4
      mindspore/lite/tools/converter/quantizer/weight_quantizer.cc

+ 22
- 5
mindspore/lite/src/dequant.cc View File

@@ -18,9 +18,10 @@
#include <memory> #include <memory>
#include "src/dequant.h" #include "src/dequant.h"
#include "src/huffman_decode.h" #include "src/huffman_decode.h"
#include "src/ops/matmul.h"


namespace mindspore::lite { namespace mindspore::lite {
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first) {
MS_ASSERT(input_tensor != nullptr); MS_ASSERT(input_tensor != nullptr);
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type(); MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type();
@@ -31,9 +32,9 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
return nullptr; return nullptr;
} }
if (input_tensor->data_type() == kNumberTypeInt16) { if (input_tensor->data_type() == kNumberTypeInt16) {
return DequantData<int16_t>(input_tensor);
return DequantData<int16_t>(input_tensor, channel_first);
} else { } else {
return DequantData<int8_t>(input_tensor);
return DequantData<int8_t>(input_tensor, channel_first);
} }
} }


@@ -65,19 +66,35 @@ int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_in
return RET_OK; return RET_OK;
} }


std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors,
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const mindspore::lite::PrimitiveC *primitive,
const std::vector<Tensor *> &in_tensors,
TypeId data_type, bool need_restore) { TypeId data_type, bool need_restore) {
std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data; std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data;
if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) { if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) {
auto input_i = 0;
for (auto weight_tensor : in_tensors) { for (auto weight_tensor : in_tensors) {
MS_ASSERT(weight_tensor != nullptr); MS_ASSERT(weight_tensor != nullptr);
input_i++;
auto channel_first = true;
if ((schema::PrimitiveType)primitive->Type() == schema::PrimitiveType_MatMul &&
weight_tensor->shape().size() == 2) {
auto param = reinterpret_cast<mindspore::lite::MatMul *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
if (input_i == 1) {
channel_first = !param->GetTransposeA();
} else if (input_i == 2) {
channel_first = param->GetTransposeB();
} else {
MS_LOG(WARNING) << "unexpected input_i";
}
}

auto *restore_data = weight_tensor->data_c(); auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type(); auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
restore_data != nullptr && restore_data != nullptr &&
(restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16); (restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16);
if (dequant_flag) { if (dequant_flag) {
auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor);
auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor, channel_first);
if (dequant_weight == nullptr) { if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr."; MS_LOG(ERROR) << "dequant data is nullptr.";
return tensor_origin_data; return tensor_origin_data;


+ 17
- 5
mindspore/lite/src/dequant.h View File

@@ -29,17 +29,18 @@
namespace mindspore::lite { namespace mindspore::lite {
class DequantUtil { class DequantUtil {
public: public:
static float *DequantWeight(lite::Tensor *input_tensor);
static float *DequantWeight(lite::Tensor *input_tensor, bool);


static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);


static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors,
static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const mindspore::lite::PrimitiveC *primitive,
const std::vector<Tensor *> &in_tensors,
TypeId data_type, bool need_restore = true); TypeId data_type, bool need_restore = true);


static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map); static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map);


template <typename ST, typename DT = float> template <typename ST, typename DT = float>
static DT *DequantData(lite::Tensor *input_tensor) {
static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) {
const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData()); const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData());
if (quant_datas == nullptr) { if (quant_datas == nullptr) {
MS_LOG(ERROR) << "Get quant tensor failed."; MS_LOG(ERROR) << "Get quant tensor failed.";
@@ -65,6 +66,13 @@ class DequantUtil {
} }
} else if (input_tensor->quant_params().size() != kPerTensor) { } else if (input_tensor->quant_params().size() != kPerTensor) {
auto channels = static_cast<size_t>(input_tensor->Batch()); auto channels = static_cast<size_t>(input_tensor->Batch());
if (!channel_first) {
if (input_tensor->shape().size() != 2) {
MS_LOG(ERROR) << "unexpected shape size: " << input_tensor->shape().size();
return nullptr;
}
channels = input_tensor->shape()[1];
}
if (input_tensor->quant_params().size() != channels) { if (input_tensor->quant_params().size() != channels) {
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->quant_params().size() << channels; MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->quant_params().size() << channels;
free(dequant_datas); free(dequant_datas);
@@ -83,8 +91,12 @@ class DequantUtil {
var_corr = 1; var_corr = 1;
} }
for (size_t j = 0; j < per_channel_size; j++) { for (size_t j = 0; j < per_channel_size; j++) {
auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale;
dequant_datas[per_channel_size * i + j] = static_cast<DT>(dequant_data * var_corr + mean_corr);
auto index = per_channel_size * i + j;
if (!channel_first) {
index = channels * j + i;
}
auto dequant_data = (quant_datas[index] - zero_point) * scale;
dequant_datas[index] = static_cast<DT>(dequant_data * var_corr + mean_corr);
} }
} }
} else { } else {


+ 3
- 2
mindspore/lite/src/scheduler.cc View File

@@ -223,7 +223,8 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
if (mindspore::lite::IsSupportFloat16() && if (mindspore::lite::IsSupportFloat16() &&
((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) {
kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type};
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type, need_restore);
auto tensor_origin_data_map =
DequantUtil::DequantTensor(primitive, in_tensors, fp16_cpu_desc.data_type, need_restore);
auto *kernel = auto *kernel =
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc);
DequantUtil::RestoreTensorData(tensor_origin_data_map); DequantUtil::RestoreTensorData(tensor_origin_data_map);
@@ -237,7 +238,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
desc.data_type = kNumberTypeFloat32; desc.data_type = kNumberTypeFloat32;
} }
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type, need_restore);
auto tensor_origin_data_map = DequantUtil::DequantTensor(primitive, in_tensors, desc.data_type, need_restore);
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
DequantUtil::RestoreTensorData(tensor_origin_data_map); DequantUtil::RestoreTensorData(tensor_origin_data_map);
if (kernel != nullptr) { if (kernel != nullptr) {


+ 118
- 56
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -358,15 +358,15 @@ static bool SearchUpperBound(const std::vector<float> &data, const size_t &index
return true; return true;
} }


static float CalPercentile(const std::vector<float> &datas, const int &outlier_percent) {
const int size = datas.size();
static float CalPercentile(const std::vector<float> &data, const int &outlier_percent) {
const int size = data.size();
float val = outlier_percent / 100.0 * size; float val = outlier_percent / 100.0 * size;
int index = std::ceil(val); int index = std::ceil(val);
float result; float result;
if (index - val > 0) { if (index - val > 0) {
result = datas.at(index - 1);
result = data.at(index - 1);
} else { } else {
result = (datas.at(index - 1) + datas.at(index)) / 2;
result = (data.at(index - 1) + data.at(index)) / 2;
} }
return result; return result;
} }
@@ -522,11 +522,78 @@ std::vector<std::vector<int>> DataToVectors(const string &str) {
return result; return result;
} }


STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) {
if (post_quant_config == nullptr) {
MS_LOG(ERROR) << "post_quant_config is null.";
return RET_PARAM_INVALID;
void ParseInputShape(PostQuantConfig *post_quant_config, std::string raw_shape) {
MS_ASSERT(post_quant_config != nullptr);
auto ind = raw_shape.find('/');
while (ind != std::string::npos) {
auto shape = raw_shape.substr(0, ind);
Trim(&shape);
post_quant_config->input_shapes.push_back(DataToVectors(shape));
raw_shape = raw_shape.substr(ind + 1);
Trim(&raw_shape);
ind = raw_shape.find('/');
}
if (!raw_shape.empty()) {
post_quant_config->input_shapes.push_back(DataToVectors(raw_shape));
}
}

void ParseImagePath(PostQuantConfig *post_quant_config, std::string raw_image_paths) {
MS_ASSERT(post_quant_config != nullptr);
auto ind = raw_image_paths.find(',');
while (ind != std::string::npos) {
auto image_path = raw_image_paths.substr(0, ind);
Trim(&image_path);
post_quant_config->image_paths.push_back(image_path);
raw_image_paths = raw_image_paths.substr(ind + 1);
Trim(&raw_image_paths);
ind = raw_image_paths.find(',');
}
post_quant_config->image_paths.push_back(raw_image_paths);
}

void ParseBatchCount(PostQuantConfig *post_quant_config, std::string value) {
MS_ASSERT(post_quant_config != nullptr);
post_quant_config->batch_count = std::stoul(value);
}

void ParseThreadNum(PostQuantConfig *post_quant_config, std::string value) {
MS_ASSERT(post_quant_config != nullptr);
post_quant_config->thread_num = std::stoul(value);
}

void ParseMethodX(PostQuantConfig *post_quant_config, const std::string &value) {
MS_ASSERT(post_quant_config != nullptr);
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
} else {
post_quant_config->method_x = value;
} }
}

void ParseMixed(PostQuantConfig *post_quant_config, std::string value) {
MS_ASSERT(post_quant_config != nullptr);
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
post_quant_config->mixed = true;
}
}

void ParseMeanErrorThreshold(PostQuantConfig *post_quant_config, std::string value) {
MS_ASSERT(post_quant_config != nullptr);
post_quant_config->mean_error_threshold = std::stof(value);
}

void ParseBiasCorrection(PostQuantConfig *post_quant_config, std::string value) {
MS_ASSERT(post_quant_config != nullptr);
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
post_quant_config->bias_correction = true;
}
}

STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) {
MS_ASSERT(post_quant_config != nullptr);


if (config_file.empty() || config_file.length() > PATH_MAX) { if (config_file.empty() || config_file.length() > PATH_MAX) {
MS_LOG(ERROR) << "invalid config path!"; MS_LOG(ERROR) << "invalid config path!";
@@ -552,6 +619,26 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf
MS_LOG(ERROR) << "config file open failed: " << config_file; MS_LOG(ERROR) << "config file open failed: " << config_file;
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }

std::string INPUT_SHAPES = "input_shapes";
std::string IMAGE_PATH = "image_path";
std::string BATCH_COUNT = "batch_count";
std::string THREAD_NUM = "thread_num";
std::string METHOD_X = "method_x";
std::string MIXED = "mixed";
std::string MEAN_ERROR_THRESHOLD = "mean_error_threshold";
std::string BIAS_CORRECTION = "bias_correction";

std::map<std::string, std::function<void(PostQuantConfig *, std::string)>> value_parser;
value_parser[INPUT_SHAPES] = ParseInputShape;
value_parser[IMAGE_PATH] = ParseImagePath;
value_parser[BATCH_COUNT] = ParseBatchCount;
value_parser[THREAD_NUM] = ParseThreadNum;
value_parser[METHOD_X] = ParseMethodX;
value_parser[MIXED] = ParseMixed;
value_parser[MEAN_ERROR_THRESHOLD] = ParseMeanErrorThreshold;
value_parser[BIAS_CORRECTION] = ParseBiasCorrection;

std::string line; std::string line;
while (std::getline(fs, line)) { while (std::getline(fs, line)) {
Trim(&line); Trim(&line);
@@ -567,54 +654,9 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf
auto value = line.substr(index + 1); auto value = line.substr(index + 1);
Trim(&key); Trim(&key);
Trim(&value); Trim(&value);
if (key == "image_path") {
auto &raw_image_paths = value;
auto ind = raw_image_paths.find(',');
while (ind != std::string::npos) {
auto image_path = raw_image_paths.substr(0, ind);
Trim(&image_path);
post_quant_config->image_paths.push_back(image_path);
raw_image_paths = raw_image_paths.substr(ind + 1);
Trim(&raw_image_paths);
ind = raw_image_paths.find(',');
}
post_quant_config->image_paths.push_back(raw_image_paths);
} else if (key == "batch_count") {
post_quant_config->batch_count = std::stoul(value);
} else if (key == "thread_num") {
post_quant_config->thread_num = std::stoul(value);
} else if (key == "method_x") {
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
} else {
post_quant_config->method_x = value;
}
} else if (key == "bias_correction") {
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
post_quant_config->bias_correction = true;
}
} else if (key == "mixed") {
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
post_quant_config->mixed = true;
}
} else if (key == "mean_error_threshold") {
post_quant_config->mean_error_threshold = std::stof(value);
} else if (key == "input_shapes") {
auto &raw_shape = value;
auto ind = raw_shape.find('/');
while (ind != std::string::npos) {
auto shape = raw_shape.substr(0, ind);
Trim(&shape);
post_quant_config->input_shapes.push_back(DataToVectors(shape));
raw_shape = raw_shape.substr(ind + 1);
Trim(&raw_shape);
ind = raw_shape.find('/');
}
if (!raw_shape.empty()) {
post_quant_config->input_shapes.push_back(DataToVectors(raw_shape));
}
auto it = value_parser.find(key);
if (it != value_parser.end()) {
it->second(post_quant_config, value);
} else { } else {
MS_LOG(WARNING) << "unsupported parameter: " << key; MS_LOG(WARNING) << "unsupported parameter: " << key;
} }
@@ -881,4 +923,24 @@ STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int
return RET_OK; return RET_OK;
} }


void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas,
bool channel_at_first, float *desired_max, float *desired_min) {
float min = FLT_MAX;
float max = -FLT_MAX;
// find min and max
for (int j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size;
if (!channel_at_first) {
index = j * channels + i;
}
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
}
min = std::min(min, raw_datas[index]);
max = std::max(max, raw_datas[index]);
}
*desired_max = max;
*desired_min = min;
}

} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant

+ 33
- 25
mindspore/lite/tools/converter/quantizer/quantize_util.h View File

@@ -107,6 +107,9 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc


STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int new_size); STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int new_size);


void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas,
bool channel_at_first, float *desired_max, float *desired_min);

template <typename T> template <typename T>
T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
MS_ASSERT(quantParam != nullptr); MS_ASSERT(quantParam != nullptr);
@@ -163,11 +166,19 @@ template <typename T>
STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant_type, STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant_type,
std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min, std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min,
const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas, const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas,
std::vector<float> *dequant_datas) {
std::vector<float> *dequant_datas, bool channel_at_first = true) {
auto dims = weight->tensor_shape(); auto dims = weight->tensor_shape();
size_t elem_count = weight->tensor_shape_size(); size_t elem_count = weight->tensor_shape_size();
auto *raw_datas = static_cast<float *>(weight->tensor_addr()); auto *raw_datas = static_cast<float *>(weight->tensor_addr());
auto channels = dims[0]; auto channels = dims[0];
if (!channel_at_first) {
if (dims.size() != 2) {
MS_LOG(ERROR) << "unexpected dims size: " << dims.size();
channel_at_first = true;
} else {
channels = dims[1];
}
}
if (channels == 0) { if (channels == 0) {
MS_LOG(ERROR) << "channels is zero"; MS_LOG(ERROR) << "channels is zero";
return RET_ERROR; return RET_ERROR;
@@ -181,16 +192,7 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant
for (int i = 0; i < channels; i++) { for (int i = 0; i < channels; i++) {
float min = FLT_MAX; float min = FLT_MAX;
float max = -FLT_MAX; float max = -FLT_MAX;
// find min and max
for (size_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
min = std::min(min, raw_datas[index]);
max = std::max(max, raw_datas[index]);
}
GetMaxMinPerchannel(channels, one_filter_size, i, elem_count, raw_datas, channel_at_first, &max, &min);
schema::QuantParamT quant_param; schema::QuantParamT quant_param;
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num); STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num);
if (status != RET_OK) { if (status != RET_OK) {
@@ -202,10 +204,10 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant
double average_raw = 0; double average_raw = 0;
for (uint32_t j = 0; j < one_filter_size; j++) { for (uint32_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size; auto index = j + i * one_filter_size;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
if (!channel_at_first) {
index = j * channels + i;
} }
MS_ASSERT(index < elem_count);
float raw_data = raw_datas[index]; float raw_data = raw_datas[index];
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
(*quant_datas)[index] = quant_data; (*quant_datas)[index] = quant_data;
@@ -226,10 +228,10 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant
double variance_raw = 0; double variance_raw = 0;
for (uint32_t j = 0; j < one_filter_size; j++) { for (uint32_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size; auto index = j + i * one_filter_size;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
if (!channel_at_first) {
index = j * channels + i;
} }
MS_ASSERT(index < elem_count);
variance_dequant += std::pow((*dequant_datas)[index] - average_dequant, 2); variance_dequant += std::pow((*dequant_datas)[index] - average_dequant, 2);
variance_raw += std::pow(raw_datas[index] - average_raw, 2); variance_raw += std::pow(raw_datas[index] - average_raw, 2);
} }
@@ -339,20 +341,26 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit


std::vector<schema::QuantParamT> quant_params; std::vector<schema::QuantParamT> quant_params;
size_t elem_count = weight->tensor_shape_size(); size_t elem_count = weight->tensor_shape_size();
auto *raw_datas = static_cast<float *>(weight->tensor_addr());
if (raw_datas == nullptr) {
auto *raw_data = static_cast<float *>(weight->tensor_addr());
if (raw_data == nullptr) {
MS_LOG(ERROR) << "rawDatas is nullptr"; MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR; return RET_ERROR;
} }


std::vector<T> quant_datas(elem_count);
std::vector<T> quant_data(elem_count);
std::vector<float> dequant_datas(elem_count); std::vector<float> dequant_datas(elem_count);
int ret = RET_OK; int ret = RET_OK;
if (per_channel) { if (per_channel) {
// notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC
bool channel_at_first = true;
auto op_type = (schema::PrimitiveType)primitive_c->Type();
if (op_type == schema::PrimitiveType_MatMul && weight->tensor_shape().size() == 2) {
auto matmul_op = primitive_c->primitiveT()->value.AsMatMul();
MS_ASSERT(matmul_op != nullptr);
channel_at_first = !(index == 1 && !matmul_op->transposeB);
}
// channel at first // channel at first
ret = DoPerChannelQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas,
&dequant_datas);
ret = DoPerChannelQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data,
&dequant_datas, channel_at_first);
if (ret == RET_CONTINUE) { if (ret == RET_CONTINUE) {
return ret; return ret;
} else if (ret != RET_OK) { } else if (ret != RET_OK) {
@@ -360,7 +368,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit
return ret; return ret;
} }
} else { } else {
ret = DoPerLayerQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas);
ret = DoPerLayerQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Do per layer quant failed."; MS_LOG(ERROR) << "Do per layer quant failed.";
return ret; return ret;
@@ -376,7 +384,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit
} }
#else #else
// do bit pack // do bit pack
ret = DoBitPack(weight, bit_num, quant_datas);
ret = DoBitPack(weight, bit_num, quant_data);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Do bit pack failed."; MS_LOG(ERROR) << "Do bit pack failed.";
return ret; return ret;


+ 6
- 4
mindspore/lite/tools/converter/quantizer/weight_quantizer.cc View File

@@ -127,6 +127,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
auto already_quant = false; auto already_quant = false;
ParamValueLitePtr param_value = nullptr; ParamValueLitePtr param_value = nullptr;
ParameterPtr param_node = nullptr; ParameterPtr param_node = nullptr;
int index = 0;
for (size_t i = 1; i < cnode->size(); i++) { for (size_t i = 1; i < cnode->size(); i++) {
auto inputNode = cnode->input(i); auto inputNode = cnode->input(i);
if (inputNode->isa<Parameter>()) { if (inputNode->isa<Parameter>()) {
@@ -146,6 +147,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
param_value = nullptr; param_value = nullptr;
continue; continue;
} else { } else {
index = i;
break; break;
} }
} }
@@ -169,11 +171,11 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {


auto status = RET_ERROR; auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) { if (type_id_ == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
true, index - 1);
} else if (type_id_ == kNumberTypeInt16) { } else if (type_id_ == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
true, index - 1);
} }
if (status == RET_CONTINUE) { if (status == RET_CONTINUE) {
return RET_OK; return RET_OK;


Loading…
Cancel
Save