| @@ -23,10 +23,10 @@ | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "base/core_ops.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| // namespace to support intermediate representation definition | |||
| @@ -191,6 +191,41 @@ std::string get_id(const AnfNodePtr &node) { | |||
| void reset_id() { node_ids.clear(); } | |||
| } // namespace id_generator | |||
| namespace { | |||
| std::string GetMaketupleNodeTarget(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto func_graph = cnode->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto users = manager->node_users()[cnode]; | |||
| std::string first_user_target = GetCNodeTarget(users.back().first); | |||
| bool is_used_by_different_target = | |||
| std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) -> bool { | |||
| return GetCNodeTarget(u.first) != first_user_target; | |||
| }); | |||
| if (!is_used_by_different_target) { | |||
| return first_user_target; | |||
| } | |||
| auto inputs = cnode->inputs(); | |||
| std::vector<AnfNodePtr> real_inputs; | |||
| std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs)); | |||
| std::string first_input_target = GetCNodeTarget(real_inputs[0]); | |||
| bool is_from_different_target = | |||
| std::any_of(std::begin(real_inputs), std::end(real_inputs), | |||
| [&first_input_target](const AnfNodePtr &n) -> bool { return GetCNodeTarget(n) != first_input_target; }); | |||
| if (!is_from_different_target) { | |||
| return first_input_target; | |||
| } | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| std::string default_target = context_ptr->device_target(); | |||
| return default_target; | |||
| } | |||
| } // namespace | |||
| std::string GetCNodeTarget(const AnfNodePtr &node) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -220,10 +255,9 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { | |||
| if (att_target != nullptr) { | |||
| if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || | |||
| IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || | |||
| IsPrimitive(attr_input, prim::kPrimMakeTuple) || IsPrimitive(attr_input, prim::kPrimStateSetItem) || | |||
| IsPrimitive(attr_input, prim::kPrimDepend) || IsPrimitive(attr_input, prim::kPrimTupleGetItem) || | |||
| IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || | |||
| IsPrimitive(attr_input, prim::kPrimPartial)) { | |||
| IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || | |||
| IsPrimitive(attr_input, prim::kPrimTupleGetItem) || IsPrimitive(attr_input, prim::kPrimControlDepend) || | |||
| IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) { | |||
| primitive->EraseAttr("primitive_target"); | |||
| return default_target; | |||
| } | |||
| @@ -236,6 +270,9 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { | |||
| } | |||
| return target; | |||
| } | |||
| if (IsPrimitive(node, prim::kPrimMakeTuple)) { | |||
| return GetMaketupleNodeTarget(cnode); | |||
| } | |||
| return default_target; | |||
| } | |||
| } // namespace mindspore | |||