|
|
|
@@ -19,6 +19,17 @@ |
|
|
|
#include "session/ascend_control_parser.h" |
|
|
|
#include "session/anf_runtime_algorithm.h" |
|
|
|
|
|
|
|
static constexpr size_t kCNodePrim = 0; |
|
|
|
static constexpr size_t kCNodeCallArg = 1; |
|
|
|
static constexpr size_t kCNodeSwitchCond = 1; |
|
|
|
static constexpr size_t kCNodeSwitchTrue = 2; |
|
|
|
static constexpr size_t kCNodeSwitchFalse = 3; |
|
|
|
static constexpr size_t kCNodeSwitchLength = 4; |
|
|
|
static constexpr size_t kCNodePartialLength = 2; |
|
|
|
static constexpr size_t kCNodePartialFunc = 1; |
|
|
|
static constexpr size_t kCNodeSwitchLayerBranch = 2; |
|
|
|
static constexpr size_t kCNodeSwitchLayerLength = 3; |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
|
|
|
|
@@ -61,7 +72,7 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { |
|
|
|
ChildGraphDataAssign(graph_id_map); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr AscendControlParser::GetNextRealKernel(std::vector<CNodePtr> list, size_t start) { |
|
|
|
CNodePtr AscendControlParser::GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) { |
|
|
|
for (size_t i = start; i < list.size() - 1; ++i) { |
|
|
|
if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { |
|
|
|
return list[i]; |
|
|
|
@@ -83,11 +94,11 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr |
|
|
|
memo->insert(kg.get()); |
|
|
|
|
|
|
|
// 2. args replace placeholder |
|
|
|
LinkParentGraph(kg, last_node, last_label, memo); |
|
|
|
LinkParentGraph(kg, last_node, last_label); |
|
|
|
|
|
|
|
// 3. topological sort |
|
|
|
kg->SetExecOrderByDefault(); |
|
|
|
std::vector<CNodePtr> nodes = kg->execution_order(); |
|
|
|
const std::vector<CNodePtr> &nodes = kg->execution_order(); |
|
|
|
if (nodes.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!"; |
|
|
|
} |
|
|
|
@@ -149,9 +160,9 @@ void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, |
|
|
|
const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo) { |
|
|
|
const CNodePtr &last_label) { |
|
|
|
auto origin_return = kg->get_return(); |
|
|
|
std::vector<AnfNodePtr> origin_return_inputs = origin_return->inputs(); |
|
|
|
const std::vector<AnfNodePtr> &origin_return_inputs = origin_return->inputs(); |
|
|
|
// if entry graph, replace return with make_tuple |
|
|
|
if (from_graph_call_node == nullptr || last_label == nullptr) { |
|
|
|
MS_LOG(INFO) << kg->ToString() << " is entry graph."; |
|
|
|
@@ -173,7 +184,7 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP |
|
|
|
MS_LOG(INFO) << "process call func " << cur_node->DebugString(); |
|
|
|
|
|
|
|
// 1 get kernel graph |
|
|
|
auto origin_inputs = cur_node->inputs(); |
|
|
|
const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs(); |
|
|
|
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))}; |
|
|
|
if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) { |
|
|
|
MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; |
|
|
|
@@ -217,15 +228,14 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod |
|
|
|
InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); |
|
|
|
} |
|
|
|
// 3 recurse sub graph |
|
|
|
auto origin_switch_inputs = cur_node->inputs(); |
|
|
|
const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs(); |
|
|
|
std::vector<AnfNodePtr> new_switch_inputs = { |
|
|
|
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), |
|
|
|
origin_switch_inputs[kCNodeSwitchCond]}; |
|
|
|
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { |
|
|
|
// 3.1 branch kernel graph and args |
|
|
|
CNodePtr partial; |
|
|
|
KernelGraphPtr branch_fg; |
|
|
|
std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); |
|
|
|
std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); |
|
|
|
// 3.2 recurse sub graph |
|
|
|
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); |
|
|
|
new_switch_inputs.push_back(branch_label); |
|
|
|
@@ -249,9 +259,9 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull |
|
|
|
auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); |
|
|
|
MS_EXCEPTION_IF_NULL(branch_tuple); |
|
|
|
if (!branch_tuple->isa<CNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; |
|
|
|
MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode"; |
|
|
|
} |
|
|
|
auto branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs(); |
|
|
|
const std::vector<AnfNodePtr> &branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs(); |
|
|
|
// 1 return label |
|
|
|
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); |
|
|
|
// 2 add depend relationship |
|
|
|
@@ -260,15 +270,14 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull |
|
|
|
InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); |
|
|
|
} |
|
|
|
// 3 recurse sub graph |
|
|
|
auto origin_switch_inputs = cur_node->inputs(); |
|
|
|
const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs(); |
|
|
|
std::vector<AnfNodePtr> new_switch_inputs = { |
|
|
|
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), |
|
|
|
origin_switch_inputs[kCNodeSwitchCond]}; |
|
|
|
for (size_t i = 0; i < branch_partial.size(); ++i) { |
|
|
|
// 3.1 branch kernel graph and args |
|
|
|
CNodePtr partial; |
|
|
|
KernelGraphPtr branch_fg; |
|
|
|
std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); |
|
|
|
std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); |
|
|
|
// 3.2 recurse sub graph |
|
|
|
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); |
|
|
|
new_switch_inputs.push_back(branch_label); |
|
|
|
@@ -315,18 +324,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul |
|
|
|
InsertDependToGraph(kg, NOT_NULL(assign_node)); |
|
|
|
} |
|
|
|
|
|
|
|
NotNull<AnfNodePtr> AscendControlParser::GetRealInput(NotNull<KernelGraphPtr> from_graph, |
|
|
|
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> param) { |
|
|
|
std::set<AnfNodePtr> args_list = to_graph->GetRealInput(param); |
|
|
|
for (auto arg : args_list) { |
|
|
|
if (arg->func_graph() == from_graph.get()) { |
|
|
|
return NOT_NULL(arg); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << to_graph->ToString() << " input " << param->DebugString() << " not from " |
|
|
|
<< from_graph->ToString(); |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph, |
|
|
|
NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param) { |
|
|
|
if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple) && IsPrimitiveCNode(param, prim::kPrimMakeTuple)) { |
|
|
|
@@ -369,10 +366,10 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe |
|
|
|
return {}; |
|
|
|
} |
|
|
|
memo->insert(graph.get()); |
|
|
|
|
|
|
|
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order(); |
|
|
|
graph->SetExecOrderByDefault(); |
|
|
|
|
|
|
|
std::vector<CNodePtr> cnodes = graph->execution_order(); |
|
|
|
const std::vector<CNodePtr> &cnodes = graph->execution_order(); |
|
|
|
std::map<uint32_t, CNodePtr> label_map; |
|
|
|
std::map<CNodePtr, std::vector<uint32_t>> label_switch_map; |
|
|
|
std::tie(label_map, label_switch_map) = GetLabelNode(cnodes); |
|
|
|
@@ -388,10 +385,10 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe |
|
|
|
std::find_if(label_map.begin(), label_map.end(), |
|
|
|
[node](const std::map<uint32_t, CNodePtr>::value_type iter) { return iter.second == node; }); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { |
|
|
|
if (!CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) { |
|
|
|
if (label_iter == label_map.end() || !CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check label index fail"; |
|
|
|
} |
|
|
|
auto child_graph = graph->child_graph_order()[label_iter->first]; |
|
|
|
auto child_graph = child_graph_order[label_iter->first]; |
|
|
|
if (child_graph == graph->parent_graph()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -407,7 +404,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe |
|
|
|
if (!CheckLabelIndex(label_iter->first + i, label_list[i], label_iter->second, graph)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check label index fail"; |
|
|
|
} |
|
|
|
auto child_graph = graph->child_graph_order()[label_iter->first + i]; |
|
|
|
auto child_graph = child_graph_order[label_iter->first + i]; |
|
|
|
if (child_graph == graph->parent_graph()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -426,10 +423,11 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe |
|
|
|
|
|
|
|
bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, |
|
|
|
NotNull<KernelGraphPtr> graph) { |
|
|
|
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order(); |
|
|
|
// check index and child order size |
|
|
|
if (graph->child_graph_order().size() <= static_cast<size_t>(order_index)) { |
|
|
|
if (child_graph_order.size() <= IntToSize(order_index)) { |
|
|
|
MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " |
|
|
|
<< graph->child_graph_order().size() << " goto index " << order_index; |
|
|
|
<< child_graph_order.size() << " goto index " << order_index; |
|
|
|
} |
|
|
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) { |
|
|
|
@@ -443,7 +441,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i |
|
|
|
label_index = label_goto_index; |
|
|
|
} |
|
|
|
// get start_label_set_index of child graph |
|
|
|
auto child_graph = graph->child_graph_order()[order_index]; |
|
|
|
auto child_graph = child_graph_order[order_index]; |
|
|
|
MS_EXCEPTION_IF_NULL(child_graph); |
|
|
|
auto start_label_set = child_graph->get_start_label(); |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) { |
|
|
|
@@ -468,8 +466,7 @@ std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t |
|
|
|
uint32_t index = 0; |
|
|
|
for (auto &node : nodes) { |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { |
|
|
|
label_map[index] = node; |
|
|
|
++index; |
|
|
|
label_map[index++] = node; |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) { |
|
|
|
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; |
|
|
|
@@ -479,8 +476,7 @@ std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t |
|
|
|
std::vector<uint32_t> label_list = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList)); |
|
|
|
label_switch_map.insert({node, label_list}); |
|
|
|
for (size_t i = 0; i < label_list.size(); ++i) { |
|
|
|
label_map[index] = node; |
|
|
|
++index; |
|
|
|
label_map[index++] = node; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|