|
|
|
@@ -50,6 +50,32 @@ AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) { |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph) { |
|
|
|
AnfNodePtrList result; |
|
|
|
for (auto node : nodes) { |
|
|
|
if (node->abstract()->isa<abstract::AbstractTuple>()) { |
|
|
|
auto node_abstract = node->abstract()->cast<abstract::AbstractTuplePtr>()->elements(); |
|
|
|
auto num = node_abstract.size(); |
|
|
|
for (size_t i = 0; i < num; i++) { |
|
|
|
auto idx_val = SizeToLong(i); |
|
|
|
|
|
|
|
auto idx = NewValueNode(idx_val); |
|
|
|
MS_EXCEPTION_IF_NULL(idx); |
|
|
|
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val)); |
|
|
|
|
|
|
|
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); |
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem); |
|
|
|
tuple_getitem->set_fullname_with_scope(node->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i)); |
|
|
|
tuple_getitem->set_abstract(node_abstract[i]); |
|
|
|
tuple_getitem->set_kernel_info(std::make_shared<device::KernelInfo>()); |
|
|
|
result.push_back(tuple_getitem); |
|
|
|
} |
|
|
|
} else { |
|
|
|
result.push_back(node); |
|
|
|
} |
|
|
|
} |
|
|
|
return result; |
|
|
|
} |
|
|
|
bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) { |
|
|
|
auto todos = GetUpdateStateList(func_graph); |
|
|
|
bool changed = false; |
|
|
|
@@ -58,6 +84,8 @@ bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (cnode->size() <= kUpdateStateRealInput) continue; |
|
|
|
auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput); |
|
|
|
// extend inputs of update if which have multiple outputs |
|
|
|
inputs = ExtendInputsOfUpdate(inputs, func_graph); |
|
|
|
if (inputs.size() + 2 != cnode->size() || inputs[0] != cnode->input(2)) { |
|
|
|
AnfNodePtrList node_inputs = {cnode->input(0), cnode->input(1)}; |
|
|
|
node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end()); |
|
|
|
@@ -81,7 +109,7 @@ bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) { |
|
|
|
for (auto node : todos) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (cnode->size() <= kUpdateStateRealInput) continue; |
|
|
|
if (cnode->size() <= kUpdateStateRealInput + 1) continue; |
|
|
|
AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput); |
|
|
|
AbstractBasePtrList abs_list; |
|
|
|
std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list), |
|
|
|
|