|
|
|
@@ -148,6 +148,23 @@ WeightPtr Util::MakeWeightPtr(const std::shared_ptr<std::vector<float>> &data, b |
|
|
|
return weight_ptr; |
|
|
|
} |
|
|
|
|
|
|
|
std::string Util::GetPrimitiveName(const CNodePtr &cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
if (inputs.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Inputs of node " << cnode->fullname_with_scope() << " is empty."; |
|
|
|
return ""; |
|
|
|
} |
|
|
|
auto fn = inputs[0]; |
|
|
|
if (!IsValueNode<Primitive>(fn)) { |
|
|
|
return ""; |
|
|
|
} |
|
|
|
|
|
|
|
auto node_prim = GetValueNode<PrimitivePtr>(fn); |
|
|
|
MS_EXCEPTION_IF_NULL(node_prim); |
|
|
|
return node_prim->name(); |
|
|
|
} |
|
|
|
|
|
|
|
void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name, |
|
|
|
const std::string &fused_cnode_name) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
@@ -158,7 +175,7 @@ void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_nam |
|
|
|
std::vector<int64_t> indices; |
|
|
|
for (const AnfNodePtr &node : node_list) { |
|
|
|
if (node != nullptr && node->isa<CNode>()) { |
|
|
|
if (AnfAlgo::GetCNodeName(node) == cnode_name) { |
|
|
|
if (GetPrimitiveName(node->cast<CNodePtr>()) == cnode_name) { |
|
|
|
single_nodes.push_back(node); |
|
|
|
|
|
|
|
auto weight_name_value_node = |
|
|
|
|