| @@ -39,10 +39,7 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector | |||||
| STATUS QuantCast::Run(FuncGraphPtr graph) { | STATUS QuantCast::Run(FuncGraphPtr graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| auto cnodes = graph->GetOrderedCnodes(); | auto cnodes = graph->GetOrderedCnodes(); | ||||
| bool first = true; | |||||
| for (auto &cnode : cnodes) { | for (auto &cnode : cnodes) { | ||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | ||||
| auto curnode_quant_type = schema::QuantType_QUANT_NONE; | auto curnode_quant_type = schema::QuantType_QUANT_NONE; | ||||
| @@ -51,34 +48,30 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||||
| } else { | } else { | ||||
| curnode_quant_type = primitive_c->GetQuantType(); | curnode_quant_type = primitive_c->GetQuantType(); | ||||
| } | } | ||||
| if (first) { | |||||
| if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { | |||||
| auto value_node = | |||||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams().front()); | |||||
| std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)}; | |||||
| auto quant_cast_cnode = graph->NewCNode(op_inputs); | |||||
| quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); | |||||
| cnode->set_input(1, quant_cast_cnode); | |||||
| MS_LOG(DEBUG) << "Add quant cast at front. " | |||||
| << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type; | |||||
| } | |||||
| first = false; | |||||
| continue; | |||||
| } | |||||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | for (size_t i = 1; i < cnode->inputs().size(); i++) { | ||||
| auto input_node = cnode->input(i); | auto input_node = cnode->input(i); | ||||
| if (!input_node->isa<CNode>()) { | |||||
| continue; | |||||
| auto is_graph_input = false; | |||||
| if (input_node->isa<Parameter>()) { | |||||
| if (!input_node->cast<ParameterPtr>()->has_default()) { | |||||
| is_graph_input = true; | |||||
| } | |||||
| } | } | ||||
| auto input_cnode = std::dynamic_pointer_cast<CNode>(input_node); | |||||
| auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||||
| if (input_cnode_primitive_c == nullptr) { | |||||
| MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " | |||||
| << " PrimitiveC is null"; | |||||
| if (!input_node->isa<CNode>() && !is_graph_input) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto input_cnode_quant_type = input_cnode_primitive_c->GetQuantType(); | |||||
| auto input_cnode_quant_type = schema::QuantType_QUANT_NONE; | |||||
| std::shared_ptr<PrimitiveC> input_cnode_primitive_c = nullptr; | |||||
| if (!is_graph_input) { | |||||
| auto input_cnode = std::dynamic_pointer_cast<CNode>(input_node); | |||||
| input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||||
| if (input_cnode_primitive_c == nullptr) { | |||||
| MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " | |||||
| << " PrimitiveC is null"; | |||||
| continue; | |||||
| } | |||||
| input_cnode_quant_type = input_cnode_primitive_c->GetQuantType(); | |||||
| } | |||||
| if (curnode_quant_type != input_cnode_quant_type) { | if (curnode_quant_type != input_cnode_quant_type) { | ||||
| ValueNodePtr value_node = nullptr; | ValueNodePtr value_node = nullptr; | ||||
| @@ -94,22 +87,22 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||||
| if (value_node == nullptr) { | if (value_node == nullptr) { | ||||
| MS_LOG(WARNING) << "value_node is null! " | MS_LOG(WARNING) << "value_node is null! " | ||||
| << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " | << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " | ||||
| << " input_" << i << ": " << input_cnode->fullname_with_scope() | |||||
| << " input_" << i << ": " | |||||
| << " quant_type:" << input_cnode_quant_type; | << " quant_type:" << input_cnode_quant_type; | ||||
| continue; | continue; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> op_inputs = {value_node, input_cnode}; | |||||
| std::vector<AnfNodePtr> op_inputs = {value_node, input_node}; | |||||
| auto quant_cast_cnode = graph->NewCNode(op_inputs); | auto quant_cast_cnode = graph->NewCNode(op_inputs); | ||||
| quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" + std::to_string(i)); | quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" + std::to_string(i)); | ||||
| cnode->set_input(i, quant_cast_cnode); | cnode->set_input(i, quant_cast_cnode); | ||||
| MS_LOG(DEBUG) << "Add quant cast. " | MS_LOG(DEBUG) << "Add quant cast. " | ||||
| << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type | << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type | ||||
| << " input_" << i << ": " << input_cnode->fullname_with_scope() | |||||
| << " input_" << i << ": " | |||||
| << " quant_type:" << input_cnode_quant_type; | << " quant_type:" << input_cnode_quant_type; | ||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "No need to add quant cast. " | MS_LOG(DEBUG) << "No need to add quant cast. " | ||||
| << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type | << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type | ||||
| << " input_" << i << ": " << input_cnode->fullname_with_scope() | |||||
| << " input_" << i << ": " | |||||
| << " quant_type:" << input_cnode_quant_type; | << " quant_type:" << input_cnode_quant_type; | ||||
| } | } | ||||
| } | } | ||||