|
|
|
@@ -1001,16 +1001,19 @@ void ControlNodeParser::ParseDeviceContextForPartialNode(const std::vector<AnfNo |
|
|
|
if (inputs.size() <= kPartialFuncGraphPos) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid input size for partial node:" << cnode->DebugString(); |
|
|
|
} |
|
|
|
|
|
|
|
auto &func_node = inputs[kPartialFuncGraphPos]; |
|
|
|
// Ignore if the node is 'Partial(DeadNode,)'. |
|
|
|
auto func_value = GetValueNode<StringImmPtr>(func_node); |
|
|
|
if (func_value != nullptr && func_value->value() == kDeadNodeName) { |
|
|
|
MS_LOG(DEBUG) << "Ignore partial dead node:" << cnode->DebugString(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Fetch the funcgraph in partial node. |
|
|
|
const auto &func_graph_node = inputs[kPartialFuncGraphPos]; |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph_node); |
|
|
|
if ((!func_graph_node->isa<ValueNode>()) || (!IsValueNode<FuncGraph>(func_graph_node))) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_graph_node->DebugString() |
|
|
|
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node); |
|
|
|
if (func_graph == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_node->DebugString() |
|
|
|
<< " for partial node:" << cnode->DebugString(); |
|
|
|
} |
|
|
|
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_graph_node); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
|
|
|
|
// Fetch the device contexts for the formal parameters in the funcgraph of partial node. |
|
|
|
auto iter = func_graph_to_device_contexts_.find(func_graph); |
|
|
|
@@ -1326,12 +1329,21 @@ void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr> |
|
|
|
const auto &cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
const auto &inputs = cnode->inputs(); |
|
|
|
if (inputs.size() <= kPartialFuncGraphPos || (!inputs[kPartialFuncGraphPos]->isa<ValueNode>()) || |
|
|
|
(!IsValueNode<FuncGraph>(inputs[kPartialFuncGraphPos]))) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid partial node:" << node->DebugString(); |
|
|
|
if (inputs.size() <= kPartialFuncGraphPos) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid input size for partial node:" << node->DebugString(); |
|
|
|
} |
|
|
|
auto &func_node = inputs[kPartialFuncGraphPos]; |
|
|
|
// Ignore if the node is 'Partial(DeadNode,)'. |
|
|
|
auto func_value = GetValueNode<StringImmPtr>(func_node); |
|
|
|
if (func_value != nullptr && func_value->value() == kDeadNodeName) { |
|
|
|
MS_LOG(DEBUG) << "Ignore partial dead node:" << node->DebugString(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
const auto &func_graph = GetValueNode<FuncGraphPtr>(func_node); |
|
|
|
if (func_graph == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid funcgraph node:" << func_node->DebugString() |
|
|
|
<< " for partial node:" << node->DebugString(); |
|
|
|
} |
|
|
|
const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[kPartialFuncGraphPos]); |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
const auto ¶meters = func_graph->parameters(); |
|
|
|
if (inputs.size() - kPartialInputStartPos > parameters.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid partial input size:" << inputs.size() |
|
|
|
|