diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.cc b/mindspore/lite/tools/converter/quantizer/quant_cast.cc index d6775146c0..a6c0085fae 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_cast.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.cc @@ -39,10 +39,7 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector STATUS QuantCast::Run(FuncGraphPtr graph) { MS_ASSERT(graph != nullptr); - auto cnodes = graph->GetOrderedCnodes(); - bool first = true; - for (auto &cnode : cnodes) { auto primitive_c = GetValueNode>(cnode->input(0)); auto curnode_quant_type = schema::QuantType_QUANT_NONE; @@ -51,34 +48,30 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { } else { curnode_quant_type = primitive_c->GetQuantType(); } - if (first) { - if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { - auto value_node = - NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams().front()); - std::vector op_inputs = {value_node, cnode->input(1)}; - auto quant_cast_cnode = graph->NewCNode(op_inputs); - quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); - cnode->set_input(1, quant_cast_cnode); - MS_LOG(DEBUG) << "Add quant cast at front. " - << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type; - } - first = false; - continue; - } for (size_t i = 1; i < cnode->inputs().size(); i++) { auto input_node = cnode->input(i); - if (!input_node->isa()) { - continue; + auto is_graph_input = false; + if (input_node->isa()) { + if (!input_node->cast()->has_default()) { + is_graph_input = true; + } } - auto input_cnode = std::dynamic_pointer_cast(input_node); - auto input_cnode_primitive_c = GetValueNode>(input_cnode->input(0)); - if (input_cnode_primitive_c == nullptr) { - MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " - << " PrimitiveC is null"; + if (!input_node->isa() && !is_graph_input) { continue; } - auto input_cnode_quant_type = input_cnode_primitive_c->GetQuantType(); + auto input_cnode_quant_type = schema::QuantType_QUANT_NONE; + std::shared_ptr input_cnode_primitive_c = nullptr; + if (!is_graph_input) { + auto input_cnode = std::dynamic_pointer_cast(input_node); + input_cnode_primitive_c = GetValueNode>(input_cnode->input(0)); + if (input_cnode_primitive_c == nullptr) { + MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " + << " PrimitiveC is null"; + continue; + } + input_cnode_quant_type = input_cnode_primitive_c->GetQuantType(); + } if (curnode_quant_type != input_cnode_quant_type) { ValueNodePtr value_node = nullptr; @@ -94,22 +87,22 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { if (value_node == nullptr) { MS_LOG(WARNING) << "value_node is null! " << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " - << " input_" << i << ": " << input_cnode->fullname_with_scope() + << " input_" << i << ": " << " quant_type:" << input_cnode_quant_type; continue; } - std::vector op_inputs = {value_node, input_cnode}; + std::vector op_inputs = {value_node, input_node}; auto quant_cast_cnode = graph->NewCNode(op_inputs); quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" + std::to_string(i)); cnode->set_input(i, quant_cast_cnode); MS_LOG(DEBUG) << "Add quant cast. " << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type - << " input_" << i << ": " << input_cnode->fullname_with_scope() + << " input_" << i << ": " << " quant_type:" << input_cnode_quant_type; } else { MS_LOG(DEBUG) << "No need to add quant cast. " << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type - << " input_" << i << ": " << input_cnode->fullname_with_scope() + << " input_" << i << ": " << " quant_type:" << input_cnode_quant_type; } }