Browse Source

Support the params passed between subgraphs as tuple type

tags/v1.2.0-rc1
liangzelang 4 years ago
parent
commit
d99d745c0c
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      mindspore/ccsrc/backend/session/ascend_auto_monad.cc

+ 5
- 5
mindspore/ccsrc/backend/session/ascend_auto_monad.cc View File

@@ -673,7 +673,7 @@ class AscendAutoMonadConverter {
// No assign for single monad argument, return it. // No assign for single monad argument, return it.
return value; return value;
} }
return Assign(paras.front(), value, true);
return AssignAll(paras.front(), value, true);
} }
// Multi arguments. // Multi arguments.
AnfNodePtrList tuple_inputs; AnfNodePtrList tuple_inputs;
@@ -691,7 +691,7 @@ class AscendAutoMonadConverter {
if (target == value) { if (target == value) {
continue; continue;
} }
tuple_inputs.emplace_back(Assign(target, value, true));
tuple_inputs.emplace_back(AssignAll(target, value, true));
} }
return kernel_graph_->NewCNode(tuple_inputs); return kernel_graph_->NewCNode(tuple_inputs);
} }
@@ -721,10 +721,10 @@ class AscendAutoMonadConverter {
} }


// AissgnAll support tuple to tuple assign. // AissgnAll support tuple to tuple assign.
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source) {
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) {
if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) { if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) {
// Assign single value. // Assign single value.
return Assign(target, source);
return Assign(target, source, is_link);
} }
// Assign tuple. // Assign tuple.
std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem}); std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem});
@@ -736,7 +736,7 @@ class AscendAutoMonadConverter {
tuple_inputs.reserve(targets.size() + 1); tuple_inputs.reserve(targets.size() + 1);
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t i = 0; i < targets.size(); ++i) { for (size_t i = 0; i < targets.size(); ++i) {
tuple_inputs.emplace_back(Assign(targets[i], sources[i]));
tuple_inputs.emplace_back(Assign(targets[i], sources[i], is_link));
} }
return kernel_graph_->NewCNode(tuple_inputs); return kernel_graph_->NewCNode(tuple_inputs);
} }


Loading…
Cancel
Save