| @@ -39,10 +39,7 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector | |||
| STATUS QuantCast::Run(FuncGraphPtr graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto cnodes = graph->GetOrderedCnodes(); | |||
| bool first = true; | |||
| for (auto &cnode : cnodes) { | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| auto curnode_quant_type = schema::QuantType_QUANT_NONE; | |||
| @@ -51,34 +48,30 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||
| } else { | |||
| curnode_quant_type = primitive_c->GetQuantType(); | |||
| } | |||
| if (first) { | |||
| if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { | |||
| auto value_node = | |||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams().front()); | |||
| std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)}; | |||
| auto quant_cast_cnode = graph->NewCNode(op_inputs); | |||
| quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); | |||
| cnode->set_input(1, quant_cast_cnode); | |||
| MS_LOG(DEBUG) << "Add quant cast at front. " | |||
| << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type; | |||
| } | |||
| first = false; | |||
| continue; | |||
| } | |||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | |||
| auto input_node = cnode->input(i); | |||
| if (!input_node->isa<CNode>()) { | |||
| continue; | |||
| auto is_graph_input = false; | |||
| if (input_node->isa<Parameter>()) { | |||
| if (!input_node->cast<ParameterPtr>()->has_default()) { | |||
| is_graph_input = true; | |||
| } | |||
| } | |||
| auto input_cnode = std::dynamic_pointer_cast<CNode>(input_node); | |||
| auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||
| if (input_cnode_primitive_c == nullptr) { | |||
| MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " | |||
| << " PrimitiveC is null"; | |||
| if (!input_node->isa<CNode>() && !is_graph_input) { | |||
| continue; | |||
| } | |||
| auto input_cnode_quant_type = input_cnode_primitive_c->GetQuantType(); | |||
| auto input_cnode_quant_type = schema::QuantType_QUANT_NONE; | |||
| std::shared_ptr<PrimitiveC> input_cnode_primitive_c = nullptr; | |||
| if (!is_graph_input) { | |||
| auto input_cnode = std::dynamic_pointer_cast<CNode>(input_node); | |||
| input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||
| if (input_cnode_primitive_c == nullptr) { | |||
| MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " | |||
| << " PrimitiveC is null"; | |||
| continue; | |||
| } | |||
| input_cnode_quant_type = input_cnode_primitive_c->GetQuantType(); | |||
| } | |||
| if (curnode_quant_type != input_cnode_quant_type) { | |||
| ValueNodePtr value_node = nullptr; | |||
| @@ -94,22 +87,22 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||
| if (value_node == nullptr) { | |||
| MS_LOG(WARNING) << "value_node is null! " | |||
| << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " | |||
| << " input_" << i << ": " << input_cnode->fullname_with_scope() | |||
| << " input_" << i << ": " | |||
| << " quant_type:" << input_cnode_quant_type; | |||
| continue; | |||
| } | |||
| std::vector<AnfNodePtr> op_inputs = {value_node, input_cnode}; | |||
| std::vector<AnfNodePtr> op_inputs = {value_node, input_node}; | |||
| auto quant_cast_cnode = graph->NewCNode(op_inputs); | |||
| quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" + std::to_string(i)); | |||
| cnode->set_input(i, quant_cast_cnode); | |||
| MS_LOG(DEBUG) << "Add quant cast. " | |||
| << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type | |||
| << " input_" << i << ": " << input_cnode->fullname_with_scope() | |||
| << " input_" << i << ": " | |||
| << " quant_type:" << input_cnode_quant_type; | |||
| } else { | |||
| MS_LOG(DEBUG) << "No need to add quant cast. " | |||
| << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type | |||
| << " input_" << i << ": " << input_cnode->fullname_with_scope() | |||
| << " input_" << i << ": " | |||
| << " quant_type:" << input_cnode_quant_type; | |||
| } | |||
| } | |||