Browse Source

static check

tags/v1.1.0
jianghui58 5 years ago
parent
commit
b22901d84c
12 changed files with 180 additions and 106 deletions
  1. +2
    -5
      mindspore/lite/tools/converter/quantizer/bitpacking.h
  2. +2
    -8
      mindspore/lite/tools/converter/quantizer/calc_quant_param.cc
  3. +2
    -5
      mindspore/lite/tools/converter/quantizer/calc_quant_param.h
  4. +111
    -37
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc
  5. +3
    -7
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.h
  6. +1
    -3
      mindspore/lite/tools/converter/quantizer/quant_cast.cc
  7. +1
    -4
      mindspore/lite/tools/converter/quantizer/quant_cast.h
  8. +44
    -13
      mindspore/lite/tools/converter/quantizer/quantize_util.cc
  9. +7
    -9
      mindspore/lite/tools/converter/quantizer/quantize_util.h
  10. +0
    -1
      mindspore/lite/tools/converter/quantizer/quantizer.h
  11. +5
    -8
      mindspore/lite/tools/converter/quantizer/weight_quantizer.cc
  12. +2
    -6
      mindspore/lite/tools/converter/quantizer/weight_quantizer.h

+ 2
- 5
mindspore/lite/tools/converter/quantizer/bitpacking.h View File

@@ -22,8 +22,7 @@
#include <vector>
#include <cassert>

namespace mindspore {
namespace lite {
namespace mindspore::lite {
class BitPack {
public:
~BitPack() = default;
@@ -68,7 +67,5 @@ class BitPack {
}
}
};
} // namespace lite
} // namespace mindspore

} // namespace mindspore::lite
#endif

+ 2
- 8
mindspore/lite/tools/converter/quantizer/calc_quant_param.cc View File

@@ -80,9 +80,6 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
inputParamDone++;
continue;
}
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));

MS_ASSERT(tensor != nullptr);
if (!tensor->data.empty() && !IsContain(graph->inputIndex, node.inputIndex.at(i))) {
auto status = ComputeConstQuantParam((*tensor), quantParam.get());
if (status != RET_OK) {
@@ -104,10 +101,7 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
outputParamDone++;
continue;
}

if (!tensor->data.empty()) {
MS_ASSERT(false);
}
MS_ASSERT(tensor->data.empty());
}
return RET_OK;
}
@@ -487,7 +481,7 @@ class CalcActivation : public QuantParamCalcer {
}
}
};
QuantParamCalcRegister::~QuantParamCalcRegister() {}
QuantParamCalcRegister::~QuantParamCalcRegister() = default;

QuantParamCalcRegister::QuantParamCalcRegister() {
bool hasError = false;


+ 2
- 5
mindspore/lite/tools/converter/quantizer/calc_quant_param.h View File

@@ -22,8 +22,7 @@
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace lite {
namespace mindspore::lite {
static constexpr int CONVLUTION_INPUT_NUM = 3;

class QuantParamCalcer {
@@ -77,7 +76,5 @@ class QuantParamCalcRegister {
QuantParamCalcRegister();
std::unordered_map<schema::PrimitiveType, std::shared_ptr<QuantParamCalcer>> _registerMap;
};
} // namespace lite
} // namespace mindspore

} // namespace mindspore::lite
#endif

+ 111
- 37
mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc View File

@@ -43,9 +43,7 @@
using std::string;
using std::vector;

namespace mindspore {
namespace lite {
namespace quant {
namespace mindspore::lite::quant {
STATUS DivergInfo::RecordMaxValue(const std::vector<float> &datas) {
for (float data : datas) {
max = std::max(data, max);
@@ -71,6 +69,7 @@ STATUS DivergInfo::RecordMaxValueArray(const std::vector<float> &datas) {

void DivergInfo::UpdateInterval() {
auto max_value = std::max(fabs(this->max), fabs(this->min));
MS_ASSERT(bin_num != 0);
this->interval = max_value / static_cast<float>(bin_num);
}

@@ -79,6 +78,10 @@ STATUS DivergInfo::UpdateHistogram(const std::vector<float> &data) {
if (value == 0) {
continue;
}
if (this->interval == 0) {
MS_LOG(ERROR) << "divisor 'interval' cannot be 0.";
return RET_ERROR;
}
int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1);
this->histogram[bin_index]++;
}
@@ -117,9 +120,7 @@ STATUS DivergInfo::ComputeThreshold() {
std::vector<float> expanded_histogram(i, 0);
reference_histogram[i - 1] += after_threshold_sum;
after_threshold_sum -= this->histogram[i];

const float bin_interval = static_cast<float>(i) / static_cast<float>(quant_bint_nums);

// merge i bins to target bins
for (int j = 0; j < quant_bint_nums; ++j) {
const float start = j * bin_interval;
@@ -224,7 +225,7 @@ std::pair<CNodePtr, float> DivergInfo::GetScale() {
MS_ASSERT(quant_max - quant_min != 0);
float scale = (max_value - min_value) / (quant_max - quant_min);
this->scale_tmp = scale;
MS_ASSERT(scale != 0);
MS_ASSERT(fabs(scale) <= 0.0f);
return std::make_pair(this->cnode, scale);
}

@@ -237,8 +238,8 @@ std::pair<CNodePtr, int32_t> DivergInfo::GetZeropoint() {
} else {
MS_LOG(WARNING) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
}

if (this->method_x == kMethodOutlier) {
MS_ASSERT(fabs(scale_tmp) <= 0.0f);
zero_point = std::round(quant_max - percent_result.second / scale_tmp);
}
return std::make_pair(this->cnode, zero_point);
@@ -246,6 +247,7 @@ std::pair<CNodePtr, int32_t> DivergInfo::GetZeropoint() {

std::unordered_map<CNodePtr, float> Calibrator::GetScale(
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
MS_ASSERT(diverg_info != nullptr);
std::unordered_map<CNodePtr, float> result;
for (auto &iter : *diverg_info) {
DivergInfo *info = iter.second.get();
@@ -257,6 +259,7 @@ std::unordered_map<CNodePtr, float> Calibrator::GetScale(

std::unordered_map<CNodePtr, int32_t> Calibrator::GetZeropoint(
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
MS_ASSERT(diverg_info != nullptr);
std::unordered_map<CNodePtr, int32_t> result;
for (auto &iter : *diverg_info) {
DivergInfo *info = iter.second.get();
@@ -268,6 +271,7 @@ std::unordered_map<CNodePtr, int32_t> Calibrator::GetZeropoint(

std::map<CNodePtr, MaxMin> Calibrator::GetMinMax(
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
MS_ASSERT(diverg_info != nullptr);
std::map<CNodePtr, MaxMin> result;
for (auto &iter : *diverg_info) {
DivergInfo *info = iter.second.get();
@@ -314,7 +318,6 @@ STATUS Calibrator::ComputeThreshold() {
auto &input_infos = kv.second;
for (size_t i = 0; i < input_infos.size(); i++) {
auto cnode = input_infos[i]->cnode;

bool already_computed = false;
auto input = cnode->input(i + 1);
if (input->isa<mindspore::CNode>()) {
@@ -346,6 +349,7 @@ STATUS Calibrator::ComputeThreshold() {

STATUS Calibrator::UpdateDivergInverval(
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info) {
MS_ASSERT(diverg_info != nullptr);
for (auto &kv : *diverg_info) {
for (auto &info : kv.second) {
info->UpdateInterval();
@@ -355,6 +359,7 @@ STATUS Calibrator::UpdateDivergInverval(
}

STATUS Calibrator::UpdateDataFrequency(const vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
MS_ASSERT(diverg_info != nullptr);
diverg_info->UpdateHistogram(data);
return RET_OK;
}
@@ -391,15 +396,30 @@ void Calibrator::AddImage(const string &file, size_t index) {
}
}

STATUS Calibrator::GenerateInputData(int input_index, int image_index, mindspore::tensor::MSTensor *tensor) const {
STATUS Calibrator::GenerateInputData(size_t input_index, size_t image_index,
mindspore::tensor::MSTensor *tensor) const {
MS_ASSERT(tensor != nullptr);
if (input_index >= images_.size()) {
MS_LOG(ERROR) << "images_ size: " << images_.size() << " but input_index: " << input_index;
return RET_ERROR;
}
if (image_index >= images_[input_index].size()) {
MS_LOG(ERROR) << "images_[input_index] size: " << images_[input_index].size()
<< " but image_index: " << image_index;
return RET_ERROR;
}
string path = images_[input_index][image_index];
MS_LOG(INFO) << "read image: " << path;
size_t size;
char *bin_buf = ReadFile(path.c_str(), &size);
if (bin_buf == nullptr) {
MS_LOG(ERROR) << "ReadFile return nullptr";
return RET_NULL_PTR;
}
auto data = tensor->MutableData();
if (data == nullptr) {
MS_LOG(ERROR) << "Get tensor MutableData return nullptr";
return RET_ERROR;
return RET_NULL_PTR;
}
if (size != tensor->Size()) {
MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size
@@ -543,6 +563,7 @@ Calibrator::Calibrator(string path, size_t bit_num, int quant_max, int quant_min
PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type,
bool per_channel)
: Quantizer(std::move(graph)) {
MS_ASSERT(graph != nullptr);
this->per_channel_ = per_channel;
this->bit_num = bit_num;
this->target_type_ = target_type;
@@ -564,6 +585,8 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in

STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min,
const std::shared_ptr<PrimitiveC> &lite_primitive) const {
MS_ASSERT(max_min != nullptr);
MS_ASSERT(lite_primitive != nullptr);
schema::QuantParamT quant_param;
quant_param.scale = scale;
quant_param.zeroPoint = zeropoint;
@@ -579,6 +602,8 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, stru

STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min,
const std::shared_ptr<PrimitiveC> &lite_primitive) const {
MS_ASSERT(max_min != nullptr);
MS_ASSERT(lite_primitive != nullptr);
schema::QuantParamT quant_param;
quant_param.scale = scale;
quant_param.zeroPoint = zeropoint;
@@ -594,6 +619,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct

STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c,
bool perchanel) const {
MS_ASSERT(weight != nullptr);
MS_ASSERT(lite_primitive != nullptr);
// perlayer
if (!weight->isa<Parameter>()) {
MS_LOG(ERROR) << "not a parameter";
@@ -602,12 +629,12 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::share
auto parameter = std::dynamic_pointer_cast<Parameter>(weight);
if (parameter == nullptr) {
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not cast to Parameter";
return RET_ERROR;
return RET_NULL_PTR;
}
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
if (paramValue == nullptr) {
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
return RET_ERROR;
return RET_NULL_PTR;
}
auto status = QuantFilter<int8_t>(paramValue, std::move(primitive_c), QuantType_PostTraining, quant_max, quant_min,
bit_num, perchanel);
@@ -619,13 +646,17 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::share
auto abstractBase = parameter->abstract();
if (abstractBase == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << parameter->name();
return RET_ERROR;
return RET_NULL_PTR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << parameter->name();
return RET_ERROR;
}
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
if (abstractTensor == nullptr || abstractTensor->element() == nullptr) {
MS_LOG(ERROR) << "abstractTensor is nullptr, " << parameter->name();
return RET_NULL_PTR;
}
abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt8));
return RET_OK;
}
@@ -635,11 +666,11 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const std::sha
MS_LOG(ERROR) << "null pointer!";
return RET_NULL_PTR;
}

auto bias_parameter_ptr = std::dynamic_pointer_cast<Parameter>(bias);
MS_ASSERT(bias_parameter_ptr != nullptr);
auto bias_default_param = bias_parameter_ptr->default_param();
auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param);
MS_ASSERT(bias_parameter_ptr != nullptr);
auto active_weight_quant_params = primitive_c->GetInputQuantParams();
if (active_weight_quant_params.size() != 2) {
MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size();
@@ -694,6 +725,10 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const std::sha
if (bias_scales.size() == shape_size) {
for (size_t i = 0; i < shape_size; i++) {
bias_scale_tmp = bias_scales[i];
if (fabs(bias_scale_tmp) <= 0.0f) {
MS_LOG(ERROR) << "divisor 'bias_scale_tmp' cannot be 0.";
return RET_ERROR;
}
if (std::abs(raw_datas[i] / bias_scale_tmp) >= quanted_bias_abs_limit) {
MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << active_weight_quant_params[1][i].scale
<< " is too small, need to update";
@@ -719,6 +754,10 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const std::sha
max_raw_data = std::abs(raw_datas[i]);
}
}
if (fabs(bias_scale_tmp) <= 0.0f) {
MS_LOG(ERROR) << "divisor 'bias_scale_tmp' cannot be 0.";
return RET_ERROR;
}
if (std::abs(max_raw_data / bias_scale_tmp) >= quanted_bias_abs_limit) {
MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << active_weight_quant_params[1][0].scale
<< " is too small, need to update";
@@ -759,6 +798,10 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const std::sha
return RET_ERROR;
}
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
if (abstractTensor == nullptr || abstractTensor->element() == nullptr) {
MS_LOG(ERROR) << "abstractTensor is nullptr" << bias_parameter_ptr->name();
return RET_NULL_PTR;
}
abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
return RET_OK;
}
@@ -770,16 +813,13 @@ STATUS PostTrainingQuantizer::QuantNode() {
auto cnodes = funcGraph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto op_name = cnode->fullname_with_scope();
if (this->calibrator_->GetInputDivergInfo()->find(op_name) == this->calibrator_->GetInputDivergInfo()->end()) {
MS_LOG(INFO) << op_name << " can not do quant";
continue;
}
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
continue;
}
if (inputs_diverg_info->find(op_name) == inputs_diverg_info->end()) {
MS_LOG(INFO) << op_name << " can not do quant";
primitive_c->SetQuantType(schema::QuantType_QUANT_NONE);
continue;
}
@@ -794,9 +834,10 @@ STATUS PostTrainingQuantizer::QuantNode() {
continue;
}
size_t index = GetValue<int>(index_value_node->value());

auto input_node = cnode->input(1);
MS_ASSERT(input_node != nullptr);
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
MS_ASSERT(input_cnode != nullptr);
auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (input_cnode_primitive_c == nullptr) {
MS_LOG(WARNING) << "input_cnode_primitive_c is null";
@@ -818,6 +859,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
op_type != PrimitiveType_FullConnection && op_type != PrimitiveType_LayerNorm) {
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i);
MS_ASSERT(input_node != nullptr);
bool is_graph_input = false;
if (input_node->isa<Parameter>()) {
if (!input_node->cast<ParameterPtr>()->has_default()) {
@@ -866,6 +908,10 @@ STATUS PostTrainingQuantizer::QuantNode() {
return RET_ERROR;
}
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
if (abstractTensor == nullptr || abstractTensor->element() == nullptr) {
MS_LOG(ERROR) << "abstractTensor is nullptr, " << input_node->fullname_with_scope();
return RET_NULL_PTR;
}
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
MS_LOG(DEBUG) << "this parameter do quant";
DoWeightQuant(input_node, primitive_c, false);
@@ -950,6 +996,10 @@ STATUS PostTrainingQuantizer::PreProcess() {
auto cnodes = funcGraph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
AnfNodePtr anf = std::dynamic_pointer_cast<AnfNode>(cnode);
if (anf == nullptr) {
MS_LOG(ERROR) << " cnode is null";
return RET_NULL_PTR;
}
if (strategy.CanOpPostQuantized(anf)) {
calibrator_->AddQuantizedOp(cnode);
}
@@ -965,11 +1015,13 @@ STATUS PostTrainingQuantizer::PreProcess() {

STATUS PostTrainingQuantizer::CheckFp32TensorVec(const std::string &node_name,
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) const {
MS_ASSERT(tensor_vec != nullptr);
if (tensor_vec.empty()) {
MS_LOG(ERROR) << "node: " << node_name << " input tensors is 0";
return RET_ERROR;
}
auto *tensor = tensor_vec[0];
MS_ASSERT(tensor != nullptr);
if (tensor->data_type() != kNumberTypeFloat32) {
MS_LOG(WARNING) << "node: " << node_name << " will not quantize"
<< " tensor data_type: " << tensor->data_type();
@@ -1021,7 +1073,9 @@ STATUS PostTrainingQuantizer::DoInference() {
}
for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
auto tensor = beforeInputs[i];
MS_ASSERT(tensor != nullptr);
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
MS_ASSERT(tensor_data != nullptr);
size_t elem_count = tensor->ElementsNum();
vector<float> data(tensor_data, tensor_data + elem_count);
this->calibrator_->RecordMaxValue(data, (*diverg_info_map)[callParam.node_name][i]);
@@ -1091,13 +1145,13 @@ STATUS PostTrainingQuantizer::Int8Inference() {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
auto tensor = beforeInputs[0];
MS_ASSERT(tensor != nullptr);
auto lite_tensor = dynamic_cast<mindspore::lite::Tensor *>(tensor);
MS_ASSERT(lite_tensor != nullptr);
if (tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
return false;
}

// do quantization: activation is always per layer quantized
std::vector<int8_t> quant_datas;
auto quant_params = lite_tensor->GetQuantParams();
@@ -1138,13 +1192,13 @@ STATUS PostTrainingQuantizer::Int8Inference() {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
auto tensor = afterOutputs[0];
MS_ASSERT(tensor != nullptr);
auto lite_tensor = dynamic_cast<mindspore::lite::Tensor *>(tensor);
MS_ASSERT(lite_tensor != nullptr);
if (tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
return false;
}

const int8_t *tensor_data = static_cast<int8_t *>(tensor->MutableData());
size_t elem_count = tensor->ElementsNum();
auto shapes = tensor->shape();
@@ -1180,6 +1234,10 @@ STATUS PostTrainingQuantizer::Int8Inference() {
auto float_data = scale * (tensor_data[index] - zp);
sum += float_data;
}
if (one_filter_size == 0) {
MS_LOG(ERROR) << "divisor 'one_filter_size' cannot be 0.";
return RET_ERROR;
}
sum = sum / one_filter_size;
dequant_op_output_ch_mean[i] = sum;
}
@@ -1231,6 +1289,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
return false;
}
auto tensor = beforeInputs[0];
MS_ASSERT(tensor != nullptr);
size_t elem_count = tensor->ElementsNum();
std::vector<float> fp32_op_input(elem_count);
auto ret =
@@ -1254,6 +1313,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
return false;
}
auto tensor = afterOutputs[0];
MS_ASSERT(tensor != nullptr);
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
size_t elem_count = tensor->ElementsNum();
auto shapes = tensor->shape();
@@ -1279,6 +1339,10 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
}
sum += tensor_data[index];
}
if (one_filter_size == 0) {
MS_LOG(ERROR) << "divisor 'one_filter_size' cannot be 0.";
return false;
}
sum = sum / one_filter_size;
fp32_op_output_ch_mean[i] = sum;
}
@@ -1301,6 +1365,10 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "int8 inference failed!";
return RET_ERROR;
}
if (calibrator_->GetBatchNum() == 0) {
MS_LOG(ERROR) << "divisor 'batch_count' cannot be 0.";
return RET_ERROR;
}
for (auto &key_value : op_bias_diff_map) {
std::for_each(key_value.second.begin(), key_value.second.end(),
[this](float &data) { data = data / calibrator_->GetBatchNum(); });
@@ -1338,6 +1406,10 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {

for (int i = 0; i < bias_param->tensor_shape_size(); i++) {
auto scale = bias_quant_params[i].scale;
if (fabs(scale) <= 0.0f) {
MS_LOG(ERROR) << "divisor 'scale' cannot be 0.";
return RET_ERROR;
}
double after_correct = std::round(bias_diff[i] / scale) + bias_datas[i];
const constexpr int32_t corrected_bias_abs_limit = 0.6 * INT32_MAX;
if (after_correct > corrected_bias_abs_limit) {
@@ -1359,6 +1431,10 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << op_name << " add bias input";
// need to add bias input
auto parameter = func_graph->add_parameter();
if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr.";
return RET_NULL_PTR;
}
ShapeVector shape;
shape.push_back(bias_diff.size());
auto type_ptr = TypeIdToType(kNumberTypeFloat32);
@@ -1380,11 +1456,10 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "new char[] failed";
return RET_MEMORY_FAILED;
}
ret = ::memcpy_s(tensor_data, size * sizeof(char), bias_diff.data(), size * sizeof(char));
ret = memcpy_s(tensor_data, size * sizeof(char), bias_diff.data(), size * sizeof(char));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s error: " << ret;
free(tensor_data);
tensor_data = nullptr;
delete[] tensor_data;
return false;
}
param_value->set_tensor_addr(tensor_data);
@@ -1398,8 +1473,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
auto conv2d = primitive_c->GetPrimitiveT()->value.AsConv2D();
if (conv2d == nullptr) {
MS_LOG(ERROR) << "conv2d is null";
free(tensor_data);
tensor_data = nullptr;
delete[] tensor_data;
return RET_ERROR;
}
conv2d->hasBias = true;
@@ -1407,14 +1481,12 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
auto depthwise_conv2d = primitive_c->GetPrimitiveT()->value.AsDepthwiseConv2D();
if (depthwise_conv2d == nullptr) {
MS_LOG(ERROR) << "conv2d is null";
free(tensor_data);
tensor_data = nullptr;
delete[] tensor_data;
return RET_ERROR;
}
depthwise_conv2d->hasBias = true;
}
free(tensor_data);
tensor_data = nullptr;
delete[] tensor_data;
} else {
MS_LOG(ERROR) << "unexpected input_quant_params size: " << input_quant_params.size();
continue;
@@ -1455,7 +1527,9 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
}
for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
auto tensor = beforeInputs[i];
MS_ASSERT(tensor != nullptr);
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
MS_ASSERT(tensor_data != nullptr);
size_t elem_count = tensor->ElementsNum();
vector<float> data(tensor_data, tensor_data + elem_count);
this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[callParam.node_name][i]);
@@ -1476,6 +1550,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
int output_i = 0;
for (const auto &tensor : after_outputs) {
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
MS_ASSERT(tensor_data != nullptr);
size_t elem_count = tensor->ElementsNum();
vector<float> data(tensor_data, tensor_data + elem_count);
this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[call_param.node_name][output_i]);
@@ -1637,6 +1712,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
}

bool PostTrainingQuantizer::OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
MS_ASSERT(data != nullptr);
std::lock_guard<std::mutex> lg(mutex_op_input);
if (type == STORE) {
if (fp32_op_input_map.find(op_name) != fp32_op_input_map.end()) {
@@ -1661,6 +1737,7 @@ bool PostTrainingQuantizer::OpInputDataHandle(OperationType type, const string &

bool PostTrainingQuantizer::OpOutputChMeanDataHandle(OperationType type, const string &op_name,
std::vector<float> *data) {
MS_ASSERT(data != nullptr);
std::lock_guard<std::mutex> lg(mutex_op_output);
if (type == STORE) {
if (fp32_op_output_ch_mean_map.find(op_name) != fp32_op_output_ch_mean_map.end()) {
@@ -1682,7 +1759,4 @@ bool PostTrainingQuantizer::OpOutputChMeanDataHandle(OperationType type, const s
}
return false;
}

} // namespace quant
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite::quant

+ 3
- 7
mindspore/lite/tools/converter/quantizer/post_training_quantizer.h View File

@@ -29,9 +29,7 @@
#include "tools/converter/converter.h"
#include "include/ms_tensor.h"

namespace mindspore {
namespace lite {
namespace quant {
namespace mindspore::lite::quant {
class Calibrator;

struct MaxMin {
@@ -178,7 +176,7 @@ class Calibrator {

STATUS CollectImages();

STATUS GenerateInputData(int input_index, int image_index, mindspore::tensor::MSTensor *tensor) const;
STATUS GenerateInputData(size_t input_index, size_t image_index, mindspore::tensor::MSTensor *tensor) const;

size_t GetBatchNum() const { return config_param_.batch_count; }

@@ -230,7 +228,5 @@ class Calibrator {

void AddImage(const std::string &file, size_t index);
};
} // namespace quant
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H

+ 1
- 3
mindspore/lite/tools/converter/quantizer/quant_cast.cc View File

@@ -20,7 +20,6 @@
#include "src/ops/primitive_c.h"

namespace mindspore::lite::quant {

ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) {
std::unique_ptr<schema::PrimitiveT> primitive = std::make_unique<schema::PrimitiveT>();
schema::QuantDTypeCastT quant_dtype_cast;
@@ -37,7 +36,7 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector
return NewValueNode(primTValue);
}

STATUS QuantCast::Run(FuncGraphPtr graph) {
STATUS QuantCast::Run(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto cnodes = graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
@@ -109,5 +108,4 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
}
return RET_OK;
}

} // namespace mindspore::lite::quant

+ 1
- 4
mindspore/lite/tools/converter/quantizer/quant_cast.h View File

@@ -23,18 +23,15 @@
#include "mindspore/core/ir/func_graph.h"

namespace mindspore::lite::quant {

class QuantCast {
public:
QuantCast() = default;
~QuantCast() = default;
STATUS Run(FuncGraphPtr graph);
STATUS Run(const FuncGraphPtr &graph);
void SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; }

private:
TypeId inputDataDType = kNumberTypeFloat32;
};

} // namespace mindspore::lite::quant

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__QUANT_CAST_H

+ 44
- 13
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -30,9 +30,7 @@
using std::string;
using std::vector;

namespace mindspore {
namespace lite {
namespace quant {
namespace mindspore::lite::quant {
const std::vector<schema::PrimitiveType> QuantStrategy::conv_types = {
schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D};
@@ -42,6 +40,7 @@ QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThr
: mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {}

bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
MS_ASSERT(node != nullptr);
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
@@ -53,7 +52,6 @@ bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
if (node->size() < 3) {
return false;
}

auto inputNode = node->input(2);
if (!inputNode->isa<Parameter>()) {
return false;
@@ -63,7 +61,6 @@ bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
if (abstract_base == nullptr) {
return false;
}

if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
return false;
@@ -81,11 +78,11 @@ bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0];
return false;
}

return true;
}

bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
MS_ASSERT(node != nullptr);
if (!node->isa<CNode>()) {
return false;
}
@@ -121,6 +118,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
}

bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
MS_ASSERT(node != nullptr);
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
@@ -210,7 +208,15 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl

auto quantMinFloat = static_cast<double>(quant_min);
auto quantMaxFloat = static_cast<double>(quant_max);
if (fabs(quantMaxFloat - quantMinFloat) <= 0.0f) {
MS_LOG(ERROR) << "divisor cannot be 0";
return RET_ERROR;
}
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
if (fabs(scale) <= 0.0f) {
MS_LOG(ERROR) << "divisor 'scale' cannot be 0";
return RET_ERROR;
}
const double zeroPointFromMin = quantMinFloat - mMin / scale;
int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin));

@@ -262,7 +268,15 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
const int8_t quantMax = std::numeric_limits<int8_t>::max();
auto quantMinFloat = static_cast<double>(quantMin);
auto quantMaxFloat = static_cast<double>(quantMax);
if (fabs(quantMaxFloat - quantMinFloat) <= 0.0f) {
MS_LOG(ERROR) << "divisor cannot be 0";
return RET_ERROR;
}
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
if (fabs(scale) <= 0.0f) {
MS_LOG(ERROR) << "divisor 'scale' cannot be 0";
return RET_ERROR;
}
const double zeroPointFromMin = quantMinFloat - mMin / scale;
const double zeroPointFromMax = quantMaxFloat - mMax / scale;
const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale);
@@ -300,8 +314,16 @@ static bool SearchLowerBound(const std::vector<float> &data, const size_t &index
if (max_tmp - data.at(index) < delta) {
return false;
}
if (fabs(max_tmp - *min_tmp) <= 0.0f || fabs(length - *min_idx) <= 0.0f) {
MS_LOG(ERROR) << "divisor cannot be 0";
return false;
}
float range_ratio = (data.at(index) - *min_tmp) / (max_tmp - *min_tmp);
float index_ratio = static_cast<float>(index - *min_idx) / (length - *min_idx);
if (fabs(index_ratio) <= 0.0f) {
MS_LOG(ERROR) << "divisor cannot be 0";
return false;
}
if (index_ratio > 0 && range_ratio / index_ratio > ratio) {
*min_idx = index;
*min_tmp = data.at(index);
@@ -315,8 +337,16 @@ static bool SearchUpperBound(const std::vector<float> &data, const size_t &index
if (data.at(index) - min_tmp < delta) {
return false;
}
if (fabs(*max_tmp - min_tmp) <= 0.0f || fabs(length - *max_idx) <= 0.0f) {
MS_LOG(ERROR) << "divisor cannot be 0";
return false;
}
float range_ratio = (*max_tmp - data.at(index)) / (*max_tmp - min_tmp);
float index_ratio = static_cast<float>(index - *max_idx) / (length - *max_idx);
if (fabs(index_ratio) <= 0.0f) {
MS_LOG(ERROR) << "divisor cannot be 0";
return false;
}
if (index_ratio > 0 && range_ratio / index_ratio > ratio) {
*max_idx = index;
*max_tmp = data.at(index);
@@ -328,7 +358,7 @@ static float CalPercentile(const std::vector<float> &datas, const int &outlier_p
const int size = datas.size();
float val = outlier_percent / 100.0 * size;
int index = std::ceil(val);
float result = 0.0;
float result;
if (index - val > 0) {
result = datas.at(index - 1);
} else {
@@ -374,6 +404,7 @@ static std::vector<float> InitClusters(float *data, size_t elem_count, size_t k)
return clusters;
}
// init cluster
MS_ASSERT(k != 1);
float ratio = static_cast<float>(data_unique.size()) / (k - 1);
std::sort(data_unique.begin(), data_unique.end());
for (size_t i = 0; i < k; i++) {
@@ -388,6 +419,8 @@ static std::vector<float> InitClusters(float *data, size_t elem_count, size_t k)
}

std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam) {
MS_ASSERT(data != nullptr);
MS_ASSERT(quantParam != nullptr);
std::vector<float> clusters = InitClusters(data, elem_count, k);
std::vector<int8_t> clusters_index{};
double error{0};
@@ -412,7 +445,7 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
clusters_data[index].emplace_back(data[i]);
}
for (size_t j = 0; j < clusters.size(); j++) {
if (clusters_data[j].size() > 0) {
if (!clusters_data[j].empty()) {
clusters[j] = std::accumulate(clusters_data[j].begin(), clusters_data[j].end(), 0.0) / clusters_data[j].size();
}
}
@@ -421,7 +454,7 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
error_cur += pow(data[j] - clusters[clusters_index[j]], 2);
}
error_cur = pow(error_cur / elem_count, 0.5);
if (std::abs((error_cur - error) / error_cur) < 1e-6) {
if (std::abs((error_cur - error) / error_cur) <= 0.0f) {
break;
}
error = error_cur;
@@ -430,7 +463,7 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
return clusters_index;
}

schema::PrimitiveType NodePrimitiveType(CNodePtr cnode) {
schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode) {
if (cnode == nullptr) {
MS_LOG(ERROR) << "cnode is null";
return schema::PrimitiveType_NONE;
@@ -442,6 +475,4 @@ schema::PrimitiveType NodePrimitiveType(CNodePtr cnode) {
}
return (schema::PrimitiveType)primitive_c->Type();
}
} // namespace quant
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite::quant

+ 7
- 9
mindspore/lite/tools/converter/quantizer/quantize_util.h View File

@@ -36,9 +36,7 @@
#include "abstract/dshape.h"
#include "tools/converter/quantizer/bitpacking.h"

namespace mindspore {
namespace lite {
namespace quant {
namespace mindspore::lite::quant {
static constexpr size_t UINT8_QUANTIZATION = 8;
static constexpr size_t WEIGHT_INDEX = 1;

@@ -96,7 +94,7 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
}

return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
double tmp = 0.0f;
double tmp;
if (originData > maxLimit) {
tmp = maxLimit;
} else if (originData < minLimit) {
@@ -130,8 +128,10 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
}();
}
template <typename T>
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<PrimitiveC> &primitive_c, QuantType quantType,
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool k_means = false) {
MS_ASSERT(weight != nullptr);
MS_ASSERT(primitive_c != nullptr);
auto dims = weight->tensor_shape();
auto op_type = (schema::PrimitiveType)primitive_c->Type();
if (per_channel) {
@@ -320,8 +320,6 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
return RET_OK;
}

schema::PrimitiveType NodePrimitiveType(CNodePtr cnode);
} // namespace quant
} // namespace lite
} // namespace mindspore
schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode);
} // namespace mindspore::lite::quant
#endif

+ 0
- 1
mindspore/lite/tools/converter/quantizer/quantizer.h View File

@@ -77,5 +77,4 @@ class FbQuantizer {
schema::MetaGraphT *graph = nullptr;
};
} // namespace mindspore::lite::quant

#endif

+ 5
- 8
mindspore/lite/tools/converter/quantizer/weight_quantizer.cc View File

@@ -24,9 +24,7 @@
using std::string;
using std::vector;

namespace mindspore {
namespace lite {
namespace quant {
namespace mindspore::lite::quant {
bool WeightQuantizer::IsPosNum(const std::string &str) {
for (size_t i = 0; i < str.size(); i++) {
if (str.at(i) < '0' || str.at(i) > '9') {
@@ -65,7 +63,7 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize,
auto quantSize = static_cast<size_t>(std::stoull(weightSize));
this->bitNum = static_cast<size_t>(std::stoull(bitNum));
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold));
mStrategy = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
quant_max = (1 << (unsigned int)(this->bitNum - 1)) - 1;
quant_min = -(1 << (unsigned int)(this->bitNum - 1));
// parse type_id
@@ -217,7 +215,8 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
}

STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
auto ret = RET_OK;
MS_ASSERT(funcGraph != nullptr);
STATUS ret;
auto cnodes = funcGraph->GetOrderedCnodes();
ret = DoConvQuantize(cnodes);
if (ret != RET_OK) {
@@ -231,6 +230,4 @@ STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
}
return ret;
}
} // namespace quant
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite::quant

+ 2
- 6
mindspore/lite/tools/converter/quantizer/weight_quantizer.h View File

@@ -28,9 +28,7 @@
#include "base/base.h"
#include "abstract/dshape.h"

namespace mindspore {
namespace lite {
namespace quant {
namespace mindspore::lite::quant {
class WeightQuantizer : public Quantizer {
public:
WeightQuantizer(FuncGraphPtr graph, const std::string &weightSize, const std::string &covWeightChannelThreshold,
@@ -51,7 +49,5 @@ class WeightQuantizer : public Quantizer {
std::unique_ptr<QuantStrategy> mStrategy;
size_t bitNum;
};
} // namespace quant
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite::quant
#endif

Loading…
Cancel
Save