Browse Source

!7631 fix post training quant with multi-output op

Merge pull request !7631 from xutianchun/mulit
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
5532de75ef
5 changed files with 185 additions and 82 deletions
  1. +135
    -58
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc
  2. +8
    -7
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.h
  3. +1
    -1
      mindspore/lite/tools/converter/quantizer/quant_cast.cc
  4. +39
    -16
      mindspore/lite/tools/converter/quantizer/quantize_util.cc
  5. +2
    -0
      mindspore/lite/tools/converter/quantizer/quantize_util.h

+ 135
- 58
mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc View File

@@ -289,24 +289,22 @@ std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *Calibrator::GetInp
return &this->input_diverg_info_;
}

std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *Calibrator::GetOutputDivergInfo() {
return &this->output_diverg_info_;
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetOutputDivergInfo() {
return &this->outputs_diverg_info_;
}

STATUS Calibrator::RecordMaxValue(const std::string &op_name, const vector<float> &data,
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
auto got = (*diverg_info).find(op_name);
if (got != (*diverg_info).end()) {
((*got).second)->RecordMaxValue(data);
((*got).second)->RecordMaxValueArray(data);
}
STATUS Calibrator::RecordMaxValue(const vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
diverg_info->RecordMaxValue(data);
diverg_info->RecordMaxValueArray(data);
return RET_OK;
}

STATUS Calibrator::ComputeThreshold() {
for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) {
DivergInfo *info = iter->second.get();
info->ComputeThreshold();
for (auto &kv : this->outputs_diverg_info_) {
auto &outputs_diverg_info = kv.second;
for (auto &diverg_info : outputs_diverg_info) {
diverg_info->ComputeThreshold();
}
}
// node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as
for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) {
@@ -317,14 +315,21 @@ STATUS Calibrator::ComputeThreshold() {
auto input = cnode->input(1);
if (input->isa<mindspore::CNode>()) {
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input);
for (const auto &output_diverg_info : output_diverg_info_) {
auto output_diverg_cnode = output_diverg_info.second->cnode;
if (output_diverg_cnode == input_cnode) {
*info = *(output_diverg_info.second);
info->cnode = cnode;
already_computed = true;
for (const auto &outputs_diverg_info : outputs_diverg_info_) {
if (already_computed) {
break;
}
for (const auto &output_diverg_info : outputs_diverg_info.second) {
auto output_diverg_cnode = output_diverg_info->cnode;
if (output_diverg_cnode == input_cnode) {
if (NodePrimitiveType(input_cnode) != schema::PrimitiveType_TupleGetItem) {
*info = *output_diverg_info;
info->cnode = cnode;
already_computed = true;
break;
}
}
}
}
}
if (!already_computed) {
@@ -334,6 +339,16 @@ STATUS Calibrator::ComputeThreshold() {
return RET_OK;
}

STATUS Calibrator::UpdateOutputDivergInverval(
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info) {
for (auto &kv : *diverg_info) {
for (auto &info : kv.second) {
info->UpdateInterval();
}
}
return RET_OK;
}

STATUS Calibrator::UpdateDivergInverval(std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
for (auto iter = (*diverg_info).begin(); iter != (*diverg_info).end(); iter++) {
DivergInfo *info = iter->second.get();
@@ -342,12 +357,8 @@ STATUS Calibrator::UpdateDivergInverval(std::unordered_map<std::string, std::uni
return RET_OK;
}

STATUS Calibrator::UpdateDataFrequency(const std::string &op_name, const vector<float> &data,
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
auto got = (*diverg_info).find(op_name);
if (got != (*diverg_info).end()) {
((*got).second)->UpdateHistogram(data);
}
STATUS Calibrator::UpdateDataFrequency(const vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
diverg_info->UpdateHistogram(data);
return RET_OK;
}

@@ -362,8 +373,8 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) {
std::unique_ptr<DivergInfo> output_diverg = std::unique_ptr<DivergInfo>(
new DivergInfo(node, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, config_param_.method_x));

input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg)));
output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg)));
input_diverg_info_.insert(std::make_pair(node_name, std::move(input_diverg)));
outputs_diverg_info_[node_name].push_back(std::move(output_diverg));
return RET_OK;
}

@@ -377,7 +388,6 @@ void Calibrator::AddImage(const string &file, size_t index) {
return stat(file.c_str(), &buf) == 0;
};
if (exist(file)) {
MS_LOG(INFO) << "load image: " << file;
this->images_[index].push_back(file);
} else {
MS_LOG(WARNING) << "invalid image file path: " << file;
@@ -411,6 +421,7 @@ STATUS Calibrator::GenerateInputData(int input_index, int image_index, mindspore
STATUS Calibrator::CollectImages() {
this->images_.resize(config_param_.image_paths.size());
auto input_i = 0;
bool multi_input = config_param_.image_paths.size() > 1;
for (const auto &image_path : config_param_.image_paths) {
DIR *root = opendir(image_path.c_str());
if (root == nullptr) {
@@ -420,13 +431,14 @@ STATUS Calibrator::CollectImages() {
struct dirent *image_dir = readdir(root);
size_t count = 0;
while (image_dir != nullptr) {
if (image_dir->d_name[0] != '.') {
const std::string file_name = image_path + "/" + image_dir->d_name;
if (config_param_.batch_count == 0) {
this->AddImage(file_name, input_i);
string file_name(image_dir->d_name);
if (file_name != "." && file_name != "..") {
const std::string file_path = image_path + "/" + file_name;
if (multi_input || config_param_.batch_count == 0) {
this->AddImage(file_path, input_i);
count++;
} else if (count < config_param_.batch_count) {
this->AddImage(file_name, input_i);
this->AddImage(file_path, input_i);
count++;
} else {
break;
@@ -434,6 +446,10 @@ STATUS Calibrator::CollectImages() {
}
image_dir = readdir(root);
}
std::sort(images_[input_i].begin(), images_[input_i].end());
if (config_param_.batch_count != 0 && config_param_.batch_count < images_[input_i].size()) {
images_[input_i].resize(config_param_.batch_count);
}
closedir(root);
input_i++;
}
@@ -487,6 +503,7 @@ STATUS Calibrator::ReadConfig() {
config_param_.image_paths.push_back(image_path);
raw_image_paths = raw_image_paths.substr(ind + 1);
Trim(&raw_image_paths);
ind = raw_image_paths.find(',');
}
config_param_.image_paths.push_back(raw_image_paths);
} else if (key == "batch_count") {
@@ -550,7 +567,7 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in
STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min,
std::shared_ptr<PrimitiveC> lite_primitive) {
if (!lite_primitive->GetInputQuantParams().empty()) {
return RET_OK;
MS_LOG(DEBUG) << "input quant params not empty"; // multi-input op: like concat
}
schema::QuantParamT quant_param;
quant_param.scale = scale;
@@ -567,7 +584,7 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M
STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min,
std::shared_ptr<PrimitiveC> lite_primitive) {
if (!lite_primitive->GetOutputQuantParams().empty()) {
return RET_OK;
MS_LOG(DEBUG) << "output quant params not empty"; // multi-output op: like split
}
schema::QuantParamT quant_param;
quant_param.scale = scale;
@@ -737,9 +754,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
auto input_scale = this->calibrator_->GetScale(this->calibrator_->GetInputDivergInfo());
auto input_zero_point = this->calibrator_->GetZeropoint(this->calibrator_->GetInputDivergInfo());

auto output_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetOutputDivergInfo());
auto output_scale = this->calibrator_->GetScale(this->calibrator_->GetOutputDivergInfo());
auto output_zeropoint = this->calibrator_->GetZeropoint(this->calibrator_->GetOutputDivergInfo());
auto outputs_diverg_info = calibrator_->GetOutputDivergInfo();

auto cnodes = funcGraph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
@@ -760,8 +775,35 @@ STATUS PostTrainingQuantizer::QuantNode() {

auto op_type = (schema::PrimitiveType)primitive_c->Type();
MS_LOG(DEBUG) << "OpName: " << op_name;
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
op_type != PrimitiveType_FullConnection) {
if (op_type == PrimitiveType_TupleGetItem) {
auto index_node = cnode->input(2);
auto index_value_node = std::dynamic_pointer_cast<mindspore::ValueNode>(index_node);
if (index_value_node == nullptr) {
MS_LOG(WARNING) << "index value node is null";
continue;
}
size_t index = GetValue<int>(index_value_node->value());

auto input_node = cnode->input(1);
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (input_cnode_primitive_c == nullptr) {
MS_LOG(WARNING) << "input_cnode_primitive_c is null";
continue;
}
if (input_cnode_primitive_c->GetOutputQuantParams().size() > index) {
auto quant_param = input_cnode_primitive_c->GetOutputQuantParams()[index];
primitive_c->AddInputQuantParam(quant_param);
primitive_c->AddOutputQuantParam(quant_param);
} else {
MS_LOG(WARNING) << "this TupleGetItem node's input node: " << input_cnode->fullname_with_scope()
<< "'s output quant_params size: " << input_cnode_primitive_c->GetOutputQuantParams().size()
<< ", but index: " << index;
}
primitive_c->SetQuantType(schema::QuantType_PostTraining);
continue;
} else if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
op_type != PrimitiveType_FullConnection) {
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i);
if (!input_node->isa<mindspore::CNode>()) {
@@ -821,18 +863,25 @@ STATUS PostTrainingQuantizer::QuantNode() {
DoBiasQuant(bias, primitive_c);
}
}
// do output quant
double OutputScale = output_scale[cnode];
int32_t OutputZeropoint = output_zeropoint[cnode];
DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitive_c);
primitive_c->SetQuantType(schema::QuantType_PostTraining);
// do output quant, there may multi-output
auto &infos = (*outputs_diverg_info)[op_name];
for (auto &info : infos) {
auto output_scale = info->GetScale().second;
auto output_zp = info->GetZeropoint().second;
struct MaxMin output_min_max {};
output_min_max.max = info->max;
output_min_max.min = info->min;

DoQuantOutput(output_scale, output_zp, &output_min_max, primitive_c);
primitive_c->SetQuantType(schema::QuantType_PostTraining);
}
}
return RET_OK;
}

STATUS PostTrainingQuantizer::UpdateDivergInverval() {
this->calibrator_->UpdateDivergInverval(this->calibrator_->GetInputDivergInfo());
this->calibrator_->UpdateDivergInverval(this->calibrator_->GetOutputDivergInfo());
this->calibrator_->UpdateOutputDivergInverval(this->calibrator_->GetOutputDivergInfo());
return RET_OK;
}

@@ -869,7 +918,6 @@ STATUS PostTrainingQuantizer::PreProcess() {
for (auto &cnode : cnodes) {
AnfNodePtr anf = std::dynamic_pointer_cast<AnfNode>(cnode);
if (strategy.CanOpPostQuantized(anf)) {
MS_LOG(INFO) << "node: " << cnode->fullname_with_scope() << " will be quantized";
calibrator_->AddQuantizedOp(cnode);
}
}
@@ -917,6 +965,10 @@ STATUS PostTrainingQuantizer::DoInference() {
KernelCallBack beforeCallBack = [&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
const CallBackParam &callParam) -> bool {
auto diverg_info_map = calibrator_->GetInputDivergInfo();
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
return true;
}
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
return false;
}
@@ -924,21 +976,35 @@ STATUS PostTrainingQuantizer::DoInference() {
const float *tData = static_cast<const float *>(tensor->MutableData());
size_t elem_count = tensor->ElementsNum();
vector<float> data(tData, tData + elem_count);
this->calibrator_->RecordMaxValue(callParam.node_name, data, this->calibrator_->GetInputDivergInfo());
this->calibrator_->RecordMaxValue(data, (*diverg_info_map)[callParam.node_name]);
return true;
};
// func
KernelCallBack afterCallBack = [&](const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
const CallBackParam &callParam) -> bool {
auto diverg_info_map = calibrator_->GetOutputDivergInfo();
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
return true;
}
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
return false;
}
auto tensor = afterOutputs[0];
const float *tensor_data = static_cast<const float *>(tensor->MutableData());
size_t elem_count = tensor->ElementsNum();
vector<float> data(tensor_data, tensor_data + elem_count);
this->calibrator_->RecordMaxValue(callParam.node_name, data, this->calibrator_->GetOutputDivergInfo());
if (afterOutputs.size() > 1) {
auto output_diverg = std::make_unique<DivergInfo>();
*output_diverg = *((*diverg_info_map)[callParam.node_name][0]);
for (size_t i = 1; i < afterOutputs.size(); i++) {
(*diverg_info_map)[callParam.node_name].push_back(std::move(output_diverg));
}
}
size_t output_i = 0;
for (const auto &tensor : afterOutputs) {
const float *tensor_data = static_cast<const float *>(tensor->MutableData());
size_t elem_count = tensor->ElementsNum();
vector<float> data(tensor_data, tensor_data + elem_count);
this->calibrator_->RecordMaxValue(data, (*diverg_info_map)[callParam.node_name][output_i]);
output_i++;
}
return true;
};
auto status = fp32_session_->RunGraph(beforeCallBack, afterCallBack);
@@ -1317,6 +1383,10 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
KernelCallBack beforeCallBack = [&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
const CallBackParam &callParam) {
auto diverg_info_map = calibrator_->GetInputDivergInfo();
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
return true;
}
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
return false;
}
@@ -1324,21 +1394,28 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
const float *tensor_data = static_cast<const float *>(tensor->MutableData());
size_t shape_size = tensor->ElementsNum();
vector<float> data(tensor_data, tensor_data + shape_size);
this->calibrator_->UpdateDataFrequency(callParam.node_name, data, this->calibrator_->GetInputDivergInfo());
this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[callParam.node_name]);
return true;
};

KernelCallBack afterCallBack = [&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
const CallBackParam &call_param) {
auto diverg_info_map = calibrator_->GetOutputDivergInfo();
if (diverg_info_map->find(call_param.node_name) == diverg_info_map->end()) {
return true;
}
if (PostTrainingQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) {
return false;
}
auto tensor = after_outputs[0];
const float *tenosr_data = static_cast<const float *>(tensor->MutableData());
size_t shape_size = tensor->ElementsNum();
vector<float> data(tenosr_data, tenosr_data + shape_size);
this->calibrator_->UpdateDataFrequency(call_param.node_name, data, this->calibrator_->GetOutputDivergInfo());
int output_i = 0;
for (const auto &tensor : after_outputs) {
const float *tensor_data = static_cast<const float *>(tensor->MutableData());
size_t elem_count = tensor->ElementsNum();
vector<float> data(tensor_data, tensor_data + elem_count);
this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[call_param.node_name][output_i]);
output_i++;
}
return true;
};
auto status = fp32_session_->RunGraph(beforeCallBack, afterCallBack);


+ 8
- 7
mindspore/lite/tools/converter/quantizer/post_training_quantizer.h View File

@@ -132,7 +132,7 @@ struct DivergInfo {
std::vector<float> max_datas;
std::pair<float, float> percent_result{0.0, 0.0};
float scale_tmp = 0;
DivergInfo() = default;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) {
this->method_x = method_x;
this->cnode = cnode;
@@ -187,13 +187,14 @@ class Calibrator {

STATUS AddQuantizedOp(CNodePtr node);

STATUS RecordMaxValue(const std::string &op_name, const std::vector<float> &data,
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
STATUS RecordMaxValue(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);

STATUS UpdateDivergInverval(std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);

STATUS UpdateDataFrequency(const std::string &op_name, const std::vector<float> &data,
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
STATUS UpdateOutputDivergInverval(
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info);

STATUS UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
void Dump();

STATUS ComputeThreshold();
@@ -208,7 +209,7 @@ class Calibrator {

std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *GetInputDivergInfo();

std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *GetOutputDivergInfo();
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo();

private:
std::vector<std::vector<std::string>> images_; // multi_input, echo input has multi input data
@@ -219,7 +220,7 @@ class Calibrator {

std::unordered_map<std::string, std::unique_ptr<DivergInfo>> input_diverg_info_;

std::unordered_map<std::string, std::unique_ptr<DivergInfo>> output_diverg_info_;
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_;

size_t bit_num_;
int quant_max_;


+ 1
- 1
mindspore/lite/tools/converter/quantizer/quant_cast.cc View File

@@ -85,7 +85,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
if (curnode_quant_type == schema::QuantType_PostTraining &&
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
value_node =
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams().front());
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams()[i - 1]);
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
input_cnode_quant_type == schema::QuantType_PostTraining) {
value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,


+ 39
- 16
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -90,22 +90,32 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
return false;
}
auto cnode = std::dynamic_pointer_cast<CNode>(node);

auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
return false;
}

auto type = (schema::PrimitiveType)primitive_c->Type();
MS_LOG(INFO) << "Primitive type: " << type;
static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
/*schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,*/
schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMul,
schema::PrimitiveType_Activation};
return IsContain(uint8OpList, type);
auto type = NodePrimitiveType(cnode);
static const std::vector<schema::PrimitiveType> int8OpList = {
schema::PrimitiveType_Nchw2Nhwc,
schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add,
schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat,
schema::PrimitiveType_Split,
schema::PrimitiveType_TupleGetItem,
schema::PrimitiveType_Reshape,
schema::PrimitiveType_FullConnection,
schema::PrimitiveType_MatMul,
schema::PrimitiveType_Crop,
schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_Activation,
schema::PrimitiveType_TupleGetItem,
};
bool contain = IsContain(int8OpList, type);
if (!contain) {
MS_LOG(INFO) << "not quant, " << cnode->fullname_with_scope()
<< " of type: " << schema::EnumNamePrimitiveType(type);
}
return contain;
}

bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
@@ -431,6 +441,19 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
quantParam->clusters = clusters;
return clusters_index;
}

schema::PrimitiveType NodePrimitiveType(CNodePtr cnode) {
if (cnode == nullptr) {
MS_LOG(ERROR) << "cnode is null";
return schema::PrimitiveType_NONE;
}
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is null";
return schema::PrimitiveType_NONE;
}
return (schema::PrimitiveType)primitive_c->Type();
}
} // namespace quant
} // namespace lite
} // namespace mindspore

+ 2
- 0
mindspore/lite/tools/converter/quantizer/quantize_util.h View File

@@ -287,6 +287,8 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
}

STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION);

schema::PrimitiveType NodePrimitiveType(CNodePtr cnode);
} // namespace quant
} // namespace lite
} // namespace mindspore


Loading…
Cancel
Save