Browse Source

!28504 fix sub_graph_sink hete bug

Merge pull request !28504 from baihuawei/fix_hetert0104
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
4e054d2f9e
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 23 additions and 6 deletions
  1. +22
    -5
      mindspore/ccsrc/pipeline/jit/action.cc
  2. +1
    -1
      mindspore/ccsrc/runtime/framework/actor/output_actor.cc

+ 22
- 5
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -92,6 +92,19 @@ bool IsDynamicShapeGraph(FuncGraphPtr func_graph) {
[](const AnfNodePtr &node) { return AnfAlgo::IsNodeDynamicShape(node); });
}

bool ExistControlNode(FuncGraphPtr func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto node_list = TopoSort(func_graph->get_return());
std::vector<PrimitivePtr> control_ops = {prim::kPrimSwitch, prim::kPrimCall, prim::kPrimSwitchLayer};
for (auto &node : node_list) {
if (std::any_of(control_ops.begin(), control_ops.end(),
[&](PrimitivePtr prim) { return AnfAlgo::CheckPrimitiveType(node, prim); })) {
return true;
}
}
return false;
}

bool EnableMindRTForAscendSubGraph(const FuncGraphManagerPtr manager, FuncGraphPtr func_graph) {
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(func_graph);
@@ -100,17 +113,21 @@ bool EnableMindRTForAscendSubGraph(const FuncGraphManagerPtr manager, FuncGraphP
std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
std::string backend = context_ptr->backend_policy();
auto graphs = manager->func_graphs();
bool exist_while =
std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
bool exist_ctrl = exist_while || ExistControlNode(func_graph);
if (!func_graph->ContainMultiTarget() && task_sink &&
context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
auto graphs = manager->func_graphs();
bool exist_while =
std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
if (device_target == kAscendDevice && backend != kMsVm && !exist_while) {
return true;
}
}
if (device_target == kAscendDevice && func_graph->ContainMultiTarget() && !IsDynamicShapeGraph(func_graph)) {
return true;
if (func_graph->ContainMultiTarget() && task_sink &&
context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
if (device_target == kAscendDevice && !IsDynamicShapeGraph(func_graph) && backend != kMsVm && !exist_ctrl) {
return true;
}
}
return false;
}


+ 1
- 1
mindspore/ccsrc/runtime/framework/actor/output_actor.cc View File

@@ -121,7 +121,7 @@ TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t
MS_LOG(ERROR) << "The output position is of range: " << output_position;
return nullptr;
}
auto device_context = device_contexts_[output_position];
auto &device_context = device_contexts_[output_position];
MS_EXCEPTION_IF_NULL(device_context);
if (device_context->GetDeviceAddressType() != device_tensor->DeviceType()) {
auto old_device_context = device_context;


Loading…
Cancel
Save