Browse Source

extend inputs of update if which have multiple outputs

pull/15943/head
wenfangpei 5 years ago
parent
commit
62cc2990a6
2 changed files with 30 additions and 1 deletions
  1. +29
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc
  2. +1
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h

+ 29
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.cc View File

@@ -50,6 +50,32 @@ AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
return result;
}

AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph) {
AnfNodePtrList result;
for (auto node : nodes) {
if (node->abstract()->isa<abstract::AbstractTuple>()) {
auto node_abstract = node->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
auto num = node_abstract.size();
for (size_t i = 0; i < num; i++) {
auto idx_val = SizeToLong(i);

auto idx = NewValueNode(idx_val);
MS_EXCEPTION_IF_NULL(idx);
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));

auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
MS_EXCEPTION_IF_NULL(tuple_getitem);
tuple_getitem->set_fullname_with_scope(node->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i));
tuple_getitem->set_abstract(node_abstract[i]);
tuple_getitem->set_kernel_info(std::make_shared<device::KernelInfo>());
result.push_back(tuple_getitem);
}
} else {
result.push_back(node);
}
}
return result;
}
bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
auto todos = GetUpdateStateList(func_graph);
bool changed = false;
@@ -58,6 +84,8 @@ bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= kUpdateStateRealInput) continue;
auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
// extend inputs of update if which have multiple outputs
inputs = ExtendInputsOfUpdate(inputs, func_graph);
if (inputs.size() + 2 != cnode->size() || inputs[0] != cnode->input(2)) {
AnfNodePtrList node_inputs = {cnode->input(0), cnode->input(1)};
node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end());
@@ -81,7 +109,7 @@ bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) {
for (auto node : todos) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= kUpdateStateRealInput) continue;
if (cnode->size() <= kUpdateStateRealInput + 1) continue;
AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
AbstractBasePtrList abs_list;
std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list),


+ 1
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/update_state_formatter.h View File

@@ -39,6 +39,7 @@ class SpreadUpdateState : public Pass {
public:
SpreadUpdateState() : Pass("spread_update_state") {}
~SpreadUpdateState() override = default;
AnfNodePtrList ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph);
bool Run(const FuncGraphPtr &func_graph) override;
};



Loading…
Cancel
Save