|
|
|
@@ -170,51 +170,59 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsAllGraphInValueSequence(const std::vector<ValuePtr> &value_vec) { |
|
|
|
for (auto &elem : value_vec) { |
|
|
|
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) { |
|
|
|
const auto &vec = GetValue<std::vector<ValuePtr>>(elem); |
|
|
|
auto is_graph = IsAllGraphInValueSequence(vec); |
|
|
|
if (!is_graph) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} else if (!elem->isa<FuncGraph>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, |
|
|
|
const std::vector<ValuePtr> &value_vec) { |
|
|
|
std::vector<AnfNodePtr> nodes; |
|
|
|
nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); |
|
|
|
for (auto &elem : value_vec) { |
|
|
|
AnfNodePtr node = nullptr; |
|
|
|
if (elem->isa<ValueTuple>() || elem->isa<ValueList>()) { |
|
|
|
const auto &vec = GetValue<std::vector<ValuePtr>>(elem); |
|
|
|
node = TransformToMakeTupleNodes(manager, func_graph, vec); |
|
|
|
} else if (elem->isa<FuncGraph>()) { |
|
|
|
FuncGraphPtr new_fg = elem->cast<FuncGraphPtr>(); |
|
|
|
manager->AddFuncGraph(new_fg); |
|
|
|
node = NewValueNode(new_fg); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString(); |
|
|
|
} |
|
|
|
nodes.emplace_back(node); |
|
|
|
} |
|
|
|
auto cnode = func_graph->NewCNode(nodes); |
|
|
|
return cnode; |
|
|
|
} |
|
|
|
|
|
|
|
// transform the ValueTuple or ValueList of graph node to make tuple of const graph node |
|
|
|
bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, |
|
|
|
bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, |
|
|
|
const ValueNodePtr &value_node, AnfNodePtr *const transformed) { |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
const auto &value_vec = GetValue<std::vector<ValuePtr>>(value_node->value()); |
|
|
|
bool has_graph_in_list = false; |
|
|
|
for (auto &elemv : value_vec) { |
|
|
|
MS_EXCEPTION_IF_NULL(elemv); |
|
|
|
if (elemv->isa<FuncGraph>()) { |
|
|
|
FuncGraphPtr new_fg = elemv->cast<FuncGraphPtr>(); |
|
|
|
manager->AddFuncGraph(new_fg); |
|
|
|
has_graph_in_list = true; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (has_graph_in_list) { |
|
|
|
MS_LOG(EXCEPTION) << "List has graph in it, but not all is graph"; |
|
|
|
} |
|
|
|
if (!IsAllGraphInValueSequence(value_vec)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// The celllist or ordered_cell will be parsed as valuetuple of const graph in it, |
|
|
|
// So if has graph in list, try to replace the node with make tuple of graph value node. |
|
|
|
if (has_graph_in_list) { |
|
|
|
// change the vector of graph to be make_list of graph value node |
|
|
|
std::vector<AnfNodePtr> list_vec; |
|
|
|
auto make_list_op = NewValueNode(prim::kPrimMakeTuple); |
|
|
|
list_vec.emplace_back(make_list_op); |
|
|
|
(void)std::transform(std::begin(value_vec), std::end(value_vec), std::back_inserter(list_vec), |
|
|
|
[](const ValuePtr &value) { return NewValueNode(value); }); |
|
|
|
FuncGraphPtr cnode_graph = nullptr; |
|
|
|
auto users = manager->node_users()[node]; |
|
|
|
for (auto &use : users) { |
|
|
|
auto use_node = use.first; |
|
|
|
MS_EXCEPTION_IF_NULL(use_node); |
|
|
|
if (use_node->isa<CNode>()) { |
|
|
|
cnode_graph = use_node->func_graph(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (cnode_graph) { |
|
|
|
CNodePtr list_app = cnode_graph->NewCNode(list_vec); |
|
|
|
// replace the ret ptr to be make_list of graph value node |
|
|
|
*transformed = list_app; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Can not find apply for node use when replacing node of vector of graph"; |
|
|
|
} |
|
|
|
} |
|
|
|
// we do this because the graphmanger won't investigate the graph inside valuetuple, |
|
|
|
// change the vector of graph to be make_tuple of graph value node |
|
|
|
auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); |
|
|
|
// replace the ret ptr to be make tuple of graph value node |
|
|
|
*transformed = node_tuple_graphs; |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -245,7 +253,8 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr |
|
|
|
|
|
|
|
// if the constant node is constant of vector of graph ,add graph to manager |
|
|
|
if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) { |
|
|
|
(void)TransformVectorGraphValueNode(manager, node, resolved_node->cast<ValueNodePtr>(), &resolved_node); |
|
|
|
(void)TransformVectorGraphValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(), |
|
|
|
&resolved_node); |
|
|
|
} |
|
|
|
|
|
|
|
TraceManager::EndTrace(); |
|
|
|
|