Browse Source

!29034 Ignore Partial(DeadNode) in backend routine.

Merge pull request !29034 from 张清华/eliminate_tuple_unused_item2
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
a3dbc84375
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 53 additions and 15 deletions
  1. +0
    -1
      mindspore/ccsrc/pipeline/jit/parse/resolve.cc
  2. +24
    -12
      mindspore/ccsrc/runtime/framework/control_node_parser.cc
  3. +29
    -2
      mindspore/ccsrc/runtime/framework/control_node_scheduler.cc

+ 0
- 1
mindspore/ccsrc/pipeline/jit/parse/resolve.cc View File

@@ -373,7 +373,6 @@ py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symb
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph is nullptr.";
}
SymbolResolver symbol_resolver(name_space, symbol, node);
symbol_resolver.Resolve();
if (!symbol_resolver.Resolve()) {
MS_LOG(EXCEPTION) << "Fail to resolve node, NodeInfo.";
}


+ 24
- 12
mindspore/ccsrc/runtime/framework/control_node_parser.cc View File

@@ -1001,16 +1001,19 @@ void ControlNodeParser::ParseDeviceContextForPartialNode(const std::vector<AnfNo
if (inputs.size() <= kPartialFuncGraphPos) {
MS_LOG(EXCEPTION) << "Invalid input size for partial node:" << cnode->DebugString();
}

auto &func_node = inputs[kPartialFuncGraphPos];
// Ignore if the node is 'Partial(DeadNode,)'.
auto func_value = GetValueNode<StringImmPtr>(func_node);
if (func_value != nullptr && func_value->value() == kDeadNodeName) {
MS_LOG(DEBUG) << "Ignore partial dead node:" << cnode->DebugString();
continue;
}
// Fetch the funcgraph in partial node.
const auto &func_graph_node = inputs[kPartialFuncGraphPos];
MS_EXCEPTION_IF_NULL(func_graph_node);
if ((!func_graph_node->isa<ValueNode>()) || (!IsValueNode<FuncGraph>(func_graph_node))) {
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_graph_node->DebugString()
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_node->DebugString()
<< " for partial node:" << cnode->DebugString();
}
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_graph_node);
MS_EXCEPTION_IF_NULL(func_graph);

// Fetch the device contexts for the formal parameters in the funcgraph of partial node.
auto iter = func_graph_to_device_contexts_.find(func_graph);
@@ -1326,12 +1329,21 @@ void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr>
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
if (inputs.size() <= kPartialFuncGraphPos || (!inputs[kPartialFuncGraphPos]->isa<ValueNode>()) ||
(!IsValueNode<FuncGraph>(inputs[kPartialFuncGraphPos]))) {
MS_LOG(EXCEPTION) << "Invalid partial node:" << node->DebugString();
if (inputs.size() <= kPartialFuncGraphPos) {
MS_LOG(EXCEPTION) << "Invalid input size for partial node:" << node->DebugString();
}
auto &func_node = inputs[kPartialFuncGraphPos];
// Ignore if the node is 'Partial(DeadNode,)'.
auto func_value = GetValueNode<StringImmPtr>(func_node);
if (func_value != nullptr && func_value->value() == kDeadNodeName) {
MS_LOG(DEBUG) << "Ignore partial dead node:" << node->DebugString();
continue;
}
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_node->DebugString()
<< " for partial node:" << node->DebugString();
}
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[kPartialFuncGraphPos]);
MS_EXCEPTION_IF_NULL(func_graph);
const auto &parameters = func_graph->parameters();
if (inputs.size() - kPartialInputStartPos > parameters.size()) {
MS_LOG(EXCEPTION) << "Invalid partial input size:" << inputs.size()


+ 29
- 2
mindspore/ccsrc/runtime/framework/control_node_scheduler.cc View File

@@ -138,6 +138,30 @@ std::vector<SwitchActorPtr> ControlNodeScheduler::BuildSwitchActor(const GraphCo
return switch_actors;
}

namespace {
bool IsValidPartialCNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr) {
return false;
}
const auto &inputs = cnode->inputs();
if (inputs.size() <= kPartialFuncGraphPos) {
return false;
}
if (!IsPrimitive(inputs[kAnfPrimitiveIndex], prim::kPrimPartial)) {
return false;
}
// Ignore if the node is 'Partial(DeadNode,)'.
auto func_value = GetValueNode<StringImmPtr>(inputs[kPartialFuncGraphPos]);
if (func_value != nullptr && func_value->value() == kDeadNodeName) {
MS_LOG(DEBUG) << "Ignore partial dead node:" << cnode->DebugString();
return false;
}
return true;
}
} // namespace

std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
std::vector<GatherActorPtr> gather_actors;
const auto &control_nodes = graph_compiler_info.control_nodes_;
@@ -146,7 +170,7 @@ std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCo

for (const auto &control_node : control_nodes) {
// Partial node and call node will be converted to gather actor.
if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) || AnfAlgo::IsCallNode(control_node)) {
if (IsValidPartialCNode(control_node) || AnfAlgo::IsCallNode(control_node)) {
const auto &actor_name = GetActorName(control_node);
const auto &parameters = FetchInputNodeByCNode(control_node);
const auto &gather_actor =
@@ -634,7 +658,10 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act
// Link arrow from gather actor
const auto &actor_name = GetActorName(from_node);
const auto &actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
if (actor == nullptr) {
MS_LOG(DEBUG) << "No actor of " << actor_name;
return;
}
const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
MS_EXCEPTION_IF_NULL(gather_actor);
LinkPartialArrow(gather_actor, to_actor, from_node_with_index.second, to_node_with_index.second);


Loading…
Cancel
Save