|
|
|
@@ -123,33 +123,45 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me |
|
|
|
MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty"; |
|
|
|
} |
|
|
|
// output |
|
|
|
auto output_index = dst_node->outputIndex[0]; |
|
|
|
auto tensor_output = meta_graph->allTensors[output_index].get(); |
|
|
|
|
|
|
|
auto output_quant_params = primitive->output_quant_params(); |
|
|
|
if (output_quant_params.empty()) { |
|
|
|
if (node_type != schema::PrimitiveType_QuantDTypeCast) { |
|
|
|
MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty"; |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (auto output_quant_param : output_quant_params[0]) { |
|
|
|
if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) { |
|
|
|
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = |
|
|
|
std::make_unique<schema::QuantParamT>(output_quant_param); |
|
|
|
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale |
|
|
|
<< " zp: " << output_quant_param_ptr->zeroPoint; |
|
|
|
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr)); |
|
|
|
if (dst_node->outputIndex.size() != output_quant_params.size()) { |
|
|
|
MS_LOG(INFO) << "node: " << dst_node->name << " output has " << output_quant_params.size() |
|
|
|
<< " quant_params; but only " << dst_node->outputIndex.size() << " output"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
int output_idx = 0; |
|
|
|
for (const auto &output_quant_param : output_quant_params) { |
|
|
|
auto output_tensor = meta_graph->allTensors[dst_node->outputIndex[output_idx]].get(); |
|
|
|
output_idx++; |
|
|
|
for (const auto &channel_quant_param : output_quant_param) { |
|
|
|
if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) { |
|
|
|
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = |
|
|
|
std::make_unique<schema::QuantParamT>(channel_quant_param); |
|
|
|
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale |
|
|
|
<< " zp: " << output_quant_param_ptr->zeroPoint; |
|
|
|
output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
auto first_output_index = dst_node->outputIndex[0]; |
|
|
|
auto first_tensor_output = meta_graph->allTensors[first_output_index].get(); |
|
|
|
if (dst_node->quantType == schema::QuantType_PostTraining) { |
|
|
|
if (node_type != schema::PrimitiveType_QuantDTypeCast) { |
|
|
|
tensor_output->dataType = kNumberTypeInt8; |
|
|
|
first_tensor_output->dataType = kNumberTypeInt8; |
|
|
|
} else { |
|
|
|
MS_ASSERT(utils::isa<std::shared_ptr<QuantDTypeCast>>(primitive)); |
|
|
|
auto primc = utils::cast<std::shared_ptr<QuantDTypeCast>>(primitive); |
|
|
|
MS_ASSERT(primc != nullptr); |
|
|
|
if (primc->GetDstT() != kNumberTypeFloat32) { |
|
|
|
tensor_output->dataType = kNumberTypeInt8; |
|
|
|
first_tensor_output->dataType = kNumberTypeInt8; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|