|
|
|
@@ -136,6 +136,39 @@ std::vector<AnfNodePtr> ReorderVirtualNode(const std::vector<AnfNodePtr> &nodes, |
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> GetNextNodes(const AnfNodePtr &node, std::map<AnfNodePtr, size_t> *nodes_ref,
|
|
|
|
std::vector<AnfNodePtr> *result) {
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
MS_EXCEPTION_IF_NULL(nodes_ref);
|
|
|
|
MS_EXCEPTION_IF_NULL(result);
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
auto node_inputs = cnode->inputs();
|
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimSwitch)) {
|
|
|
|
std::reverse(node_inputs.begin(), node_inputs.end());
|
|
|
|
return node_inputs;
|
|
|
|
}
|
|
|
|
std::vector<AnfNodePtr> extend_inputs;
|
|
|
|
for (auto &input : node_inputs) {
|
|
|
|
MS_EXCEPTION_IF_NULL(input);
|
|
|
|
if (IsPrimitiveCNode(input, prim::kPrimPartial)) {
|
|
|
|
auto iter = nodes_ref->find(input);
|
|
|
|
if (iter != nodes_ref->end() && iter->second == 1) {
|
|
|
|
iter->second--;
|
|
|
|
result->emplace_back(input);
|
|
|
|
auto partial_cnode = input->cast<CNodePtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_cnode);
|
|
|
|
auto partial_inputs = partial_cnode->inputs();
|
|
|
|
std::reverse(partial_inputs.begin(), partial_inputs.end());
|
|
|
|
(void)extend_inputs.insert(extend_inputs.end(), partial_inputs.begin(), partial_inputs.end());
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
extend_inputs.emplace_back(input);
|
|
|
|
}
|
|
|
|
return extend_inputs;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
std::vector<AnfNodePtr> result;
|
|
|
|
@@ -158,13 +191,8 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string & |
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
auto node_inputs = cnode->inputs();
|
|
|
|
if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
|
|
|
|
std::reverse(node_inputs.begin(), node_inputs.end());
|
|
|
|
}
|
|
|
|
for (auto &input : node_inputs) {
|
|
|
|
auto next_nodes = GetNextNodes(node, &nodes_ref, &result);
|
|
|
|
for (auto &input : next_nodes) {
|
|
|
|
MS_EXCEPTION_IF_NULL(input);
|
|
|
|
auto iter = nodes_ref.find(input);
|
|
|
|
if (iter != nodes_ref.end()) {
|
|
|
|
@@ -621,8 +649,9 @@ std::vector<GraphSegmentPtr> GraphPartition::Partition(const FuncGraphPtr &graph |
|
|
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
auto enable_loop_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK);
|
|
|
|
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
|
|
|
if (contain_multi_target) {
|
|
|
|
if (contain_multi_target || !enable_loop_sink) {
|
|
|
|
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT)) {
|
|
|
|
auto other_target = GetOtherTarget(nodes);
|
|
|
|
nodes = ParallelSort(graph, default_target, other_target);
|
|
|
|
|