diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 40b630a7b9..34ec1ff346 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -92,6 +92,19 @@ bool IsDynamicShapeGraph(FuncGraphPtr func_graph) { [](const AnfNodePtr &node) { return AnfAlgo::IsNodeDynamicShape(node); }); } +bool ExistControlNode(FuncGraphPtr func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto node_list = TopoSort(func_graph->get_return()); + std::vector control_ops = {prim::kPrimSwitch, prim::kPrimCall, prim::kPrimSwitchLayer}; + for (auto &node : node_list) { + if (std::any_of(control_ops.begin(), control_ops.end(), + [&](PrimitivePtr prim) { return AnfAlgo::CheckPrimitiveType(node, prim); })) { + return true; + } + } + return false; +} + bool EnableMindRTForAscendSubGraph(const FuncGraphManagerPtr manager, FuncGraphPtr func_graph) { MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(func_graph); @@ -100,17 +113,21 @@ bool EnableMindRTForAscendSubGraph(const FuncGraphManagerPtr manager, FuncGraphP std::string device_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); auto task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); std::string backend = context_ptr->backend_policy(); + auto graphs = manager->func_graphs(); + bool exist_while = + std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); }); + bool exist_ctrl = exist_while || ExistControlNode(func_graph); if (!func_graph->ContainMultiTarget() && task_sink && context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { - auto graphs = manager->func_graphs(); - bool exist_while = - std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); }); if (device_target == kAscendDevice && backend != kMsVm && !exist_while) { return true; } } - if (device_target == kAscendDevice && func_graph->ContainMultiTarget() && !IsDynamicShapeGraph(func_graph)) { - return true; + if (func_graph->ContainMultiTarget() && task_sink && + context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { + if (device_target == kAscendDevice && !IsDynamicShapeGraph(func_graph) && backend != kMsVm && !exist_ctrl) { + return true; + } } return false; } diff --git a/mindspore/ccsrc/runtime/framework/actor/output_actor.cc b/mindspore/ccsrc/runtime/framework/actor/output_actor.cc index 9b456c18d9..348f81abc6 100644 --- a/mindspore/ccsrc/runtime/framework/actor/output_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/output_actor.cc @@ -121,7 +121,7 @@ TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t MS_LOG(ERROR) << "The output position is of range: " << output_position; return nullptr; } - auto device_context = device_contexts_[output_position]; + auto &device_context = device_contexts_[output_position]; MS_EXCEPTION_IF_NULL(device_context); if (device_context->GetDeviceAddressType() != device_tensor->DeviceType()) { auto old_device_context = device_context;