|
|
|
@@ -23,10 +23,10 @@ |
|
|
|
#include <vector> |
|
|
|
#include <unordered_map> |
|
|
|
|
|
|
|
#include "base/core_ops.h" |
|
|
|
#include "ir/func_graph.h" |
|
|
|
#include "ir/primitive.h" |
|
|
|
#include "utils/context/ms_context.h" |
|
|
|
#include "base/core_ops.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
// namespace to support intermediate representation definition |
|
|
|
@@ -191,6 +191,41 @@ std::string get_id(const AnfNodePtr &node) { |
|
|
|
void reset_id() { node_ids.clear(); } |
|
|
|
} // namespace id_generator |
|
|
|
|
|
|
|
namespace { |
|
|
|
std::string GetMaketupleNodeTarget(const CNodePtr &cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto func_graph = cnode->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
auto users = manager->node_users()[cnode]; |
|
|
|
std::string first_user_target = GetCNodeTarget(users.back().first); |
|
|
|
bool is_used_by_different_target = |
|
|
|
std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) -> bool { |
|
|
|
return GetCNodeTarget(u.first) != first_user_target; |
|
|
|
}); |
|
|
|
if (!is_used_by_different_target) { |
|
|
|
return first_user_target; |
|
|
|
} |
|
|
|
|
|
|
|
auto inputs = cnode->inputs(); |
|
|
|
std::vector<AnfNodePtr> real_inputs; |
|
|
|
std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs)); |
|
|
|
std::string first_input_target = GetCNodeTarget(real_inputs[0]); |
|
|
|
bool is_from_different_target = |
|
|
|
std::any_of(std::begin(real_inputs), std::end(real_inputs), |
|
|
|
[&first_input_target](const AnfNodePtr &n) -> bool { return GetCNodeTarget(n) != first_input_target; }); |
|
|
|
if (!is_from_different_target) { |
|
|
|
return first_input_target; |
|
|
|
} |
|
|
|
|
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
std::string default_target = context_ptr->device_target(); |
|
|
|
return default_target; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
std::string GetCNodeTarget(const AnfNodePtr &node) { |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
@@ -220,10 +255,9 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { |
|
|
|
if (att_target != nullptr) { |
|
|
|
if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimMakeTuple) || IsPrimitive(attr_input, prim::kPrimStateSetItem) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimDepend) || IsPrimitive(attr_input, prim::kPrimTupleGetItem) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimPartial)) { |
|
|
|
IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimTupleGetItem) || IsPrimitive(attr_input, prim::kPrimControlDepend) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) { |
|
|
|
primitive->EraseAttr("primitive_target"); |
|
|
|
return default_target; |
|
|
|
} |
|
|
|
@@ -236,6 +270,9 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
return target; |
|
|
|
} |
|
|
|
if (IsPrimitive(node, prim::kPrimMakeTuple)) { |
|
|
|
return GetMaketupleNodeTarget(cnode); |
|
|
|
} |
|
|
|
return default_target; |
|
|
|
} |
|
|
|
} // namespace mindspore |