| @@ -23,10 +23,10 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include "base/core_ops.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "utils/context/ms_context.h" | #include "utils/context/ms_context.h" | ||||
| #include "base/core_ops.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support intermediate representation definition | // namespace to support intermediate representation definition | ||||
| @@ -191,6 +191,41 @@ std::string get_id(const AnfNodePtr &node) { | |||||
| void reset_id() { node_ids.clear(); } | void reset_id() { node_ids.clear(); } | ||||
| } // namespace id_generator | } // namespace id_generator | ||||
| namespace { | |||||
| std::string GetMaketupleNodeTarget(const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto func_graph = cnode->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto users = manager->node_users()[cnode]; | |||||
| std::string first_user_target = GetCNodeTarget(users.back().first); | |||||
| bool is_used_by_different_target = | |||||
| std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) -> bool { | |||||
| return GetCNodeTarget(u.first) != first_user_target; | |||||
| }); | |||||
| if (!is_used_by_different_target) { | |||||
| return first_user_target; | |||||
| } | |||||
| auto inputs = cnode->inputs(); | |||||
| std::vector<AnfNodePtr> real_inputs; | |||||
| std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs)); | |||||
| std::string first_input_target = GetCNodeTarget(real_inputs[0]); | |||||
| bool is_from_different_target = | |||||
| std::any_of(std::begin(real_inputs), std::end(real_inputs), | |||||
| [&first_input_target](const AnfNodePtr &n) -> bool { return GetCNodeTarget(n) != first_input_target; }); | |||||
| if (!is_from_different_target) { | |||||
| return first_input_target; | |||||
| } | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string default_target = context_ptr->device_target(); | |||||
| return default_target; | |||||
| } | |||||
| } // namespace | |||||
| std::string GetCNodeTarget(const AnfNodePtr &node) { | std::string GetCNodeTarget(const AnfNodePtr &node) { | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -220,10 +255,9 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { | |||||
| if (att_target != nullptr) { | if (att_target != nullptr) { | ||||
| if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || | if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || | ||||
| IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || | IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || | ||||
| IsPrimitive(attr_input, prim::kPrimMakeTuple) || IsPrimitive(attr_input, prim::kPrimStateSetItem) || | |||||
| IsPrimitive(attr_input, prim::kPrimDepend) || IsPrimitive(attr_input, prim::kPrimTupleGetItem) || | |||||
| IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || | |||||
| IsPrimitive(attr_input, prim::kPrimPartial)) { | |||||
| IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || | |||||
| IsPrimitive(attr_input, prim::kPrimTupleGetItem) || IsPrimitive(attr_input, prim::kPrimControlDepend) || | |||||
| IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) { | |||||
| primitive->EraseAttr("primitive_target"); | primitive->EraseAttr("primitive_target"); | ||||
| return default_target; | return default_target; | ||||
| } | } | ||||
| @@ -236,6 +270,9 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { | |||||
| } | } | ||||
| return target; | return target; | ||||
| } | } | ||||
| if (IsPrimitive(node, prim::kPrimMakeTuple)) { | |||||
| return GetMaketupleNodeTarget(cnode); | |||||
| } | |||||
| return default_target; | return default_target; | ||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||