|
|
|
@@ -536,7 +536,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct |
|
|
|
} |
|
|
|
|
|
|
|
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value, |
|
|
|
bool depthwise) { |
|
|
|
bool perchanel, bool depthwise) { |
|
|
|
// const vector<int> dims = filter->dims; |
|
|
|
// perlayer |
|
|
|
if (!weight->isa<Parameter>()) { |
|
|
|
@@ -544,9 +544,17 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P |
|
|
|
return RET_PARAM_INVALID; |
|
|
|
} |
|
|
|
auto parameter = std::dynamic_pointer_cast<Parameter>(weight); |
|
|
|
if (parameter == nullptr) { |
|
|
|
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not cast to Parameter"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); |
|
|
|
if (paramValue == nullptr) { |
|
|
|
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num, |
|
|
|
per_channel_, depthwise); |
|
|
|
perchanel, depthwise); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "QuantFilter failed: " << status; |
|
|
|
return status; |
|
|
|
@@ -690,11 +698,29 @@ STATUS PostTrainingQuantizer::QuantNode() { |
|
|
|
auto op_name = cnode->fullname_with_scope(); |
|
|
|
auto op_type = primitiveT_value->GetPrimitiveT()->value.type; |
|
|
|
MS_LOG(INFO) << "OpName: " << op_name; |
|
|
|
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D) { |
|
|
|
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D && |
|
|
|
op_type != PrimitiveType_FullConnection) { |
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); i++) { |
|
|
|
auto input_node = cnode->input(i); |
|
|
|
if (!input_node->isa<mindspore::CNode>()) { |
|
|
|
MS_LOG(WARNING) << "node: " << cnode_name << " input " << i << " not a cnode"; |
|
|
|
MS_LOG(DEBUG) << "node: " << cnode_name << " input " << i << " not a cnode"; |
|
|
|
// get dtype |
|
|
|
auto abstractBase = input_node->abstract(); |
|
|
|
if (abstractBase == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { |
|
|
|
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << input_node->fullname_with_scope(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); |
|
|
|
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) { |
|
|
|
MS_LOG(DEBUG) << "this parameter do quant"; |
|
|
|
DoWeightQuant(input_node, primitiveT_value, false, false); |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "this parameter no need to do quant"; |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node); |
|
|
|
@@ -704,8 +730,15 @@ STATUS PostTrainingQuantizer::QuantNode() { |
|
|
|
<< " PrimitiveTValue is null"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) { |
|
|
|
primitiveT_value->AddInputQuantParam(quant_param); |
|
|
|
if (!input_cnode_primitiveT_value->GetOutputQuantParams().empty()) { |
|
|
|
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) { |
|
|
|
primitiveT_value->AddInputQuantParam(quant_param); |
|
|
|
} |
|
|
|
} else { |
|
|
|
// do input quant |
|
|
|
double scale = input_scale[cnode]; |
|
|
|
int32_t zp = input_zero_point[cnode]; |
|
|
|
DoQuantInput(scale, zp, &input_min_max[cnode], primitiveT_value); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
@@ -715,8 +748,12 @@ STATUS PostTrainingQuantizer::QuantNode() { |
|
|
|
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value); |
|
|
|
// do weight quant |
|
|
|
auto weight = cnode->input(2); |
|
|
|
bool depthwise = op_type == PrimitiveType_DeDepthwiseConv2D; |
|
|
|
DoWeightQuant(weight, primitiveT_value, depthwise); |
|
|
|
bool depthwise = op_type == PrimitiveType_DepthwiseConv2D; |
|
|
|
bool perchannel = per_channel_; |
|
|
|
if (op_type == PrimitiveType_FullConnection) { |
|
|
|
perchannel = false; |
|
|
|
} |
|
|
|
DoWeightQuant(weight, primitiveT_value, perchannel, depthwise); |
|
|
|
// do bias quant |
|
|
|
if (cnode->inputs().size() == 4) { |
|
|
|
auto bias = cnode->input(3); |
|
|
|
|