diff --git a/mindspore/lite/src/dequant.cc b/mindspore/lite/src/dequant.cc index 1c37aa340d..1591281a05 100644 --- a/mindspore/lite/src/dequant.cc +++ b/mindspore/lite/src/dequant.cc @@ -18,9 +18,10 @@ #include #include "src/dequant.h" #include "src/huffman_decode.h" +#include "src/ops/matmul.h" namespace mindspore::lite { -float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { +float *DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first) { MS_ASSERT(input_tensor != nullptr); if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type(); @@ -31,9 +32,9 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { return nullptr; } if (input_tensor->data_type() == kNumberTypeInt16) { - return DequantData(input_tensor); + return DequantData(input_tensor, channel_first); } else { - return DequantData(input_tensor); + return DequantData(input_tensor, channel_first); } } @@ -65,19 +66,35 @@ int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_in return RET_OK; } -std::map> DequantUtil::DequantTensor(const std::vector &in_tensors, +std::map> DequantUtil::DequantTensor(const mindspore::lite::PrimitiveC *primitive, + const std::vector &in_tensors, TypeId data_type, bool need_restore) { std::map> tensor_origin_data; if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) { + auto input_i = 0; for (auto weight_tensor : in_tensors) { MS_ASSERT(weight_tensor != nullptr); + input_i++; + auto channel_first = true; + if ((schema::PrimitiveType)primitive->Type() == schema::PrimitiveType_MatMul && + weight_tensor->shape().size() == 2) { + auto param = reinterpret_cast(const_cast(primitive)); + if (input_i == 1) { + channel_first = !param->GetTransposeA(); + } else if (input_i == 2) { + channel_first = param->GetTransposeB(); + } else { + MS_LOG(WARNING) << "unexpected input_i"; + } + } + auto *restore_data = weight_tensor->data_c(); auto restore_type = weight_tensor->data_type(); bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr && (restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16); if (dequant_flag) { - auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor); + auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor, channel_first); if (dequant_weight == nullptr) { MS_LOG(ERROR) << "dequant data is nullptr."; return tensor_origin_data; diff --git a/mindspore/lite/src/dequant.h b/mindspore/lite/src/dequant.h index 5191a769bd..a45aa03620 100644 --- a/mindspore/lite/src/dequant.h +++ b/mindspore/lite/src/dequant.h @@ -29,17 +29,18 @@ namespace mindspore::lite { class DequantUtil { public: - static float *DequantWeight(lite::Tensor *input_tensor); + static float *DequantWeight(lite::Tensor *input_tensor, bool); static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); - static std::map> DequantTensor(const std::vector &in_tensors, + static std::map> DequantTensor(const mindspore::lite::PrimitiveC *primitive, + const std::vector &in_tensors, TypeId data_type, bool need_restore = true); static void RestoreTensorData(const std::map> &tensor_origin_data_map); template - static DT *DequantData(lite::Tensor *input_tensor) { + static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) { const auto *quant_datas = static_cast(input_tensor->MutableData()); if (quant_datas == nullptr) { MS_LOG(ERROR) << "Get quant tensor failed."; @@ -65,6 +66,13 @@ class DequantUtil { } } else if (input_tensor->quant_params().size() != kPerTensor) { auto channels = static_cast(input_tensor->Batch()); + if (!channel_first) { + if (input_tensor->shape().size() != 2) { + MS_LOG(ERROR) << "unexpected shape size: " << input_tensor->shape().size(); + return nullptr; + } + channels = input_tensor->shape()[1]; + } if (input_tensor->quant_params().size() != channels) { MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->quant_params().size() << channels; free(dequant_datas); @@ -83,8 +91,12 @@ class DequantUtil { var_corr = 1; } for (size_t j = 0; j < per_channel_size; j++) { - auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale; - dequant_datas[per_channel_size * i + j] = static_cast
(dequant_data * var_corr + mean_corr); + auto index = per_channel_size * i + j; + if (!channel_first) { + index = channels * j + i; + } + auto dequant_data = (quant_datas[index] - zero_point) * scale; + dequant_datas[index] = static_cast
(dequant_data * var_corr + mean_corr); } } } else { diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index af7e445f54..6d0d979873 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -223,7 +223,8 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in if (mindspore::lite::IsSupportFloat16() && ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; - auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type, need_restore); + auto tensor_origin_data_map = + DequantUtil::DequantTensor(primitive, in_tensors, fp16_cpu_desc.data_type, need_restore); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); DequantUtil::RestoreTensorData(tensor_origin_data_map); @@ -237,7 +238,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; desc.data_type = kNumberTypeFloat32; } - auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type, need_restore); + auto tensor_origin_data_map = DequantUtil::DequantTensor(primitive, in_tensors, desc.data_type, need_restore); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); DequantUtil::RestoreTensorData(tensor_origin_data_map); if (kernel != nullptr) { diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 1fa132aa88..2e5dae9871 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -358,15 +358,15 @@ static bool SearchUpperBound(const std::vector &data, const size_t &index return true; } -static float CalPercentile(const std::vector &datas, const int &outlier_percent) { - const int size = datas.size(); +static float CalPercentile(const std::vector &data, const int &outlier_percent) { + const int size = data.size(); float val = outlier_percent / 100.0 * size; int index = std::ceil(val); float result; if (index - val > 0) { - result = datas.at(index - 1); + result = data.at(index - 1); } else { - result = (datas.at(index - 1) + datas.at(index)) / 2; + result = (data.at(index - 1) + data.at(index)) / 2; } return result; } @@ -522,11 +522,78 @@ std::vector> DataToVectors(const string &str) { return result; } -STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) { - if (post_quant_config == nullptr) { - MS_LOG(ERROR) << "post_quant_config is null."; - return RET_PARAM_INVALID; +void ParseInputShape(PostQuantConfig *post_quant_config, std::string raw_shape) { + MS_ASSERT(post_quant_config != nullptr); + auto ind = raw_shape.find('/'); + while (ind != std::string::npos) { + auto shape = raw_shape.substr(0, ind); + Trim(&shape); + post_quant_config->input_shapes.push_back(DataToVectors(shape)); + raw_shape = raw_shape.substr(ind + 1); + Trim(&raw_shape); + ind = raw_shape.find('/'); + } + if (!raw_shape.empty()) { + post_quant_config->input_shapes.push_back(DataToVectors(raw_shape)); + } +} + +void ParseImagePath(PostQuantConfig *post_quant_config, std::string raw_image_paths) { + MS_ASSERT(post_quant_config != nullptr); + auto ind = raw_image_paths.find(','); + while (ind != std::string::npos) { + auto image_path = raw_image_paths.substr(0, ind); + Trim(&image_path); + post_quant_config->image_paths.push_back(image_path); + raw_image_paths = raw_image_paths.substr(ind + 1); + Trim(&raw_image_paths); + ind = raw_image_paths.find(','); + } + post_quant_config->image_paths.push_back(raw_image_paths); +} + +void ParseBatchCount(PostQuantConfig *post_quant_config, std::string value) { + MS_ASSERT(post_quant_config != nullptr); + post_quant_config->batch_count = std::stoul(value); +} + +void ParseThreadNum(PostQuantConfig *post_quant_config, std::string value) { + MS_ASSERT(post_quant_config != nullptr); + post_quant_config->thread_num = std::stoul(value); +} + +void ParseMethodX(PostQuantConfig *post_quant_config, const std::string &value) { + MS_ASSERT(post_quant_config != nullptr); + if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) { + MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; + } else { + post_quant_config->method_x = value; } +} + +void ParseMixed(PostQuantConfig *post_quant_config, std::string value) { + MS_ASSERT(post_quant_config != nullptr); + std::for_each(value.begin(), value.end(), ::tolower); + if (value == "true") { + post_quant_config->mixed = true; + } +} + +void ParseMeanErrorThreshold(PostQuantConfig *post_quant_config, std::string value) { + MS_ASSERT(post_quant_config != nullptr); + post_quant_config->mean_error_threshold = std::stof(value); +} + +void ParseBiasCorrection(PostQuantConfig *post_quant_config, std::string value) { + MS_ASSERT(post_quant_config != nullptr); + std::for_each(value.begin(), value.end(), ::tolower); + if (value == "true") { + post_quant_config->bias_correction = true; + } +} + +STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) { + MS_ASSERT(post_quant_config != nullptr); if (config_file.empty() || config_file.length() > PATH_MAX) { MS_LOG(ERROR) << "invalid config path!"; @@ -552,6 +619,26 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf MS_LOG(ERROR) << "config file open failed: " << config_file; return RET_PARAM_INVALID; } + + std::string INPUT_SHAPES = "input_shapes"; + std::string IMAGE_PATH = "image_path"; + std::string BATCH_COUNT = "batch_count"; + std::string THREAD_NUM = "thread_num"; + std::string METHOD_X = "method_x"; + std::string MIXED = "mixed"; + std::string MEAN_ERROR_THRESHOLD = "mean_error_threshold"; + std::string BIAS_CORRECTION = "bias_correction"; + + std::map> value_parser; + value_parser[INPUT_SHAPES] = ParseInputShape; + value_parser[IMAGE_PATH] = ParseImagePath; + value_parser[BATCH_COUNT] = ParseBatchCount; + value_parser[THREAD_NUM] = ParseThreadNum; + value_parser[METHOD_X] = ParseMethodX; + value_parser[MIXED] = ParseMixed; + value_parser[MEAN_ERROR_THRESHOLD] = ParseMeanErrorThreshold; + value_parser[BIAS_CORRECTION] = ParseBiasCorrection; + std::string line; while (std::getline(fs, line)) { Trim(&line); @@ -567,54 +654,9 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf auto value = line.substr(index + 1); Trim(&key); Trim(&value); - if (key == "image_path") { - auto &raw_image_paths = value; - auto ind = raw_image_paths.find(','); - while (ind != std::string::npos) { - auto image_path = raw_image_paths.substr(0, ind); - Trim(&image_path); - post_quant_config->image_paths.push_back(image_path); - raw_image_paths = raw_image_paths.substr(ind + 1); - Trim(&raw_image_paths); - ind = raw_image_paths.find(','); - } - post_quant_config->image_paths.push_back(raw_image_paths); - } else if (key == "batch_count") { - post_quant_config->batch_count = std::stoul(value); - } else if (key == "thread_num") { - post_quant_config->thread_num = std::stoul(value); - } else if (key == "method_x") { - if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) { - MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; - } else { - post_quant_config->method_x = value; - } - } else if (key == "bias_correction") { - std::for_each(value.begin(), value.end(), ::tolower); - if (value == "true") { - post_quant_config->bias_correction = true; - } - } else if (key == "mixed") { - std::for_each(value.begin(), value.end(), ::tolower); - if (value == "true") { - post_quant_config->mixed = true; - } - } else if (key == "mean_error_threshold") { - post_quant_config->mean_error_threshold = std::stof(value); - } else if (key == "input_shapes") { - auto &raw_shape = value; - auto ind = raw_shape.find('/'); - while (ind != std::string::npos) { - auto shape = raw_shape.substr(0, ind); - Trim(&shape); - post_quant_config->input_shapes.push_back(DataToVectors(shape)); - raw_shape = raw_shape.substr(ind + 1); - Trim(&raw_shape); - ind = raw_shape.find('/'); - } - if (!raw_shape.empty()) { - post_quant_config->input_shapes.push_back(DataToVectors(raw_shape)); - } + auto it = value_parser.find(key); + if (it != value_parser.end()) { + it->second(post_quant_config, value); } else { MS_LOG(WARNING) << "unsupported parameter: " << key; } @@ -881,4 +923,24 @@ STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int return RET_OK; } +void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas, + bool channel_at_first, float *desired_max, float *desired_min) { + float min = FLT_MAX; + float max = -FLT_MAX; + // find min and max + for (int j = 0; j < one_filter_size; j++) { + auto index = j + i * one_filter_size; + if (!channel_at_first) { + index = j * channels + i; + } + if (index >= elem_count) { + MS_LOG(ERROR) << "over flow!"; + } + min = std::min(min, raw_datas[index]); + max = std::max(max, raw_datas[index]); + } + *desired_max = max; + *desired_min = min; +} + } // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index aec6f7cd26..21f7417598 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -107,6 +107,9 @@ std::vector KMeans(float *data, size_t elem_count, size_t k, size_t epoc STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int new_size); +void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas, + bool channel_at_first, float *desired_max, float *desired_min); + template T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { MS_ASSERT(quantParam != nullptr); @@ -163,11 +166,19 @@ template STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant_type, std::vector *quant_params, const int &quant_max, const int &quant_min, const size_t &bit_num, const bool &k_means, std::vector *quant_datas, - std::vector *dequant_datas) { + std::vector *dequant_datas, bool channel_at_first = true) { auto dims = weight->tensor_shape(); size_t elem_count = weight->tensor_shape_size(); auto *raw_datas = static_cast(weight->tensor_addr()); auto channels = dims[0]; + if (!channel_at_first) { + if (dims.size() != 2) { + MS_LOG(ERROR) << "unexpected dims size: " << dims.size(); + channel_at_first = true; + } else { + channels = dims[1]; + } + } if (channels == 0) { MS_LOG(ERROR) << "channels is zero"; return RET_ERROR; @@ -181,16 +192,7 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant for (int i = 0; i < channels; i++) { float min = FLT_MAX; float max = -FLT_MAX; - // find min and max - for (size_t j = 0; j < one_filter_size; j++) { - auto index = j + i * one_filter_size; - if (index >= elem_count) { - MS_LOG(ERROR) << "over flow!"; - return RET_ERROR; - } - min = std::min(min, raw_datas[index]); - max = std::max(max, raw_datas[index]); - } + GetMaxMinPerchannel(channels, one_filter_size, i, elem_count, raw_datas, channel_at_first, &max, &min); schema::QuantParamT quant_param; STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num); if (status != RET_OK) { @@ -202,10 +204,10 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant double average_raw = 0; for (uint32_t j = 0; j < one_filter_size; j++) { auto index = j + i * one_filter_size; - if (index >= elem_count) { - MS_LOG(ERROR) << "over flow!"; - return RET_ERROR; + if (!channel_at_first) { + index = j * channels + i; } + MS_ASSERT(index < elem_count); float raw_data = raw_datas[index]; auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); (*quant_datas)[index] = quant_data; @@ -226,10 +228,10 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant double variance_raw = 0; for (uint32_t j = 0; j < one_filter_size; j++) { auto index = j + i * one_filter_size; - if (index >= elem_count) { - MS_LOG(ERROR) << "over flow!"; - return RET_ERROR; + if (!channel_at_first) { + index = j * channels + i; } + MS_ASSERT(index < elem_count); variance_dequant += std::pow((*dequant_datas)[index] - average_dequant, 2); variance_raw += std::pow(raw_datas[index] - average_raw, 2); } @@ -339,20 +341,26 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr quant_params; size_t elem_count = weight->tensor_shape_size(); - auto *raw_datas = static_cast(weight->tensor_addr()); - if (raw_datas == nullptr) { + auto *raw_data = static_cast(weight->tensor_addr()); + if (raw_data == nullptr) { MS_LOG(ERROR) << "rawDatas is nullptr"; return RET_ERROR; } - std::vector quant_datas(elem_count); + std::vector quant_data(elem_count); std::vector dequant_datas(elem_count); int ret = RET_OK; if (per_channel) { - // notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC + bool channel_at_first = true; + auto op_type = (schema::PrimitiveType)primitive_c->Type(); + if (op_type == schema::PrimitiveType_MatMul && weight->tensor_shape().size() == 2) { + auto matmul_op = primitive_c->primitiveT()->value.AsMatMul(); + MS_ASSERT(matmul_op != nullptr); + channel_at_first = !(index == 1 && !matmul_op->transposeB); + } // channel at first - ret = DoPerChannelQuant(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas, - &dequant_datas); + ret = DoPerChannelQuant(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data, + &dequant_datas, channel_at_first); if (ret == RET_CONTINUE) { return ret; } else if (ret != RET_OK) { @@ -360,7 +368,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas); + ret = DoPerLayerQuant(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data); if (ret != RET_OK) { MS_LOG(ERROR) << "Do per layer quant failed."; return ret; @@ -376,7 +384,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptrsize(); i++) { auto inputNode = cnode->input(i); if (inputNode->isa()) { @@ -146,6 +147,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { param_value = nullptr; continue; } else { + index = i; break; } } @@ -169,11 +171,11 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { auto status = RET_ERROR; if (type_id_ == kNumberTypeInt8) { - status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); + status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + true, index - 1); } else if (type_id_ == kNumberTypeInt16) { - status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); + status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + true, index - 1); } if (status == RET_CONTINUE) { return RET_OK;