|
|
|
@@ -92,6 +92,19 @@ bool IsDynamicShapeGraph(FuncGraphPtr func_graph) { |
|
|
|
[](const AnfNodePtr &node) { return AnfAlgo::IsNodeDynamicShape(node); }); |
|
|
|
} |
|
|
|
|
|
|
|
bool ExistControlNode(FuncGraphPtr func_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
auto node_list = TopoSort(func_graph->get_return()); |
|
|
|
std::vector<PrimitivePtr> control_ops = {prim::kPrimSwitch, prim::kPrimCall, prim::kPrimSwitchLayer}; |
|
|
|
for (auto &node : node_list) { |
|
|
|
if (std::any_of(control_ops.begin(), control_ops.end(), |
|
|
|
[&](PrimitivePtr prim) { return AnfAlgo::CheckPrimitiveType(node, prim); })) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
bool EnableMindRTForAscendSubGraph(const FuncGraphManagerPtr manager, FuncGraphPtr func_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
@@ -100,17 +113,21 @@ bool EnableMindRTForAscendSubGraph(const FuncGraphManagerPtr manager, FuncGraphP |
|
|
|
std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); |
|
|
|
auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); |
|
|
|
std::string backend = context_ptr->backend_policy(); |
|
|
|
auto graphs = manager->func_graphs(); |
|
|
|
bool exist_while = |
|
|
|
std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); }); |
|
|
|
bool exist_ctrl = exist_while || ExistControlNode(func_graph); |
|
|
|
if (!func_graph->ContainMultiTarget() && task_sink && |
|
|
|
context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { |
|
|
|
auto graphs = manager->func_graphs(); |
|
|
|
bool exist_while = |
|
|
|
std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); }); |
|
|
|
if (device_target == kAscendDevice && backend != kMsVm && !exist_while) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
if (device_target == kAscendDevice && func_graph->ContainMultiTarget() && !IsDynamicShapeGraph(func_graph)) { |
|
|
|
return true; |
|
|
|
if (func_graph->ContainMultiTarget() && task_sink && |
|
|
|
context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { |
|
|
|
if (device_target == kAscendDevice && !IsDynamicShapeGraph(func_graph) && backend != kMsVm && !exist_ctrl) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|