Merge pull request !2931 from zhoufeng/liantiao1tags/v0.6.0-beta
| @@ -40,6 +40,9 @@ using kernel::KernelBuildInfoPtr; | |||
| using kernel::KernelMod; | |||
| using kernel::KernelModPtr; | |||
| namespace { | |||
| constexpr size_t kNopNodeInputSize = 2; | |||
| constexpr size_t kNopNodeRealInputIndex = 1; | |||
| std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) { | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| std::vector<size_t> shape_size_t; | |||
| @@ -48,6 +51,26 @@ std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) { | |||
| } | |||
| } // namespace | |||
| AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) { | |||
| MS_EXCEPTION_IF_NULL(tuple_get_item); | |||
| if (tuple_get_item->size() != kTupleGetItemInputSize) { | |||
| MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; | |||
| } | |||
| return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem); | |||
| } | |||
| size_t AnfRuntimeAlgorithm::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) { | |||
| MS_EXCEPTION_IF_NULL(tuple_get_item); | |||
| if (tuple_get_item->size() != kTupleGetItemInputSize) { | |||
| MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; | |||
| } | |||
| auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(output_index_value_node); | |||
| auto value_node = output_index_value_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| return IntToSize(GetValue<int>(value_node->value())); | |||
| } | |||
| KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| if (anf_node->isa<ValueNode>()) { | |||
| @@ -83,49 +106,47 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz | |||
| } | |||
| } | |||
| KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index, | |||
| KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, int index, | |||
| bool visit_nop_node, | |||
| const std::vector<PrimitivePtr> &return_types) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| for (const auto &prim_type : return_types) { | |||
| if (CheckPrimitiveType(anf_node, prim_type)) { | |||
| return std::make_pair(anf_node, index); | |||
| } | |||
| if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool { | |||
| return CheckPrimitiveType(anf_node, prim_type); | |||
| })) { | |||
| return KernelWithIndex(anf_node, index); | |||
| } | |||
| if (anf_node->isa<ValueNode>()) { | |||
| return std::make_pair(anf_node, 0); | |||
| } else if (anf_node->isa<Parameter>()) { | |||
| return std::make_pair(anf_node, 0); | |||
| } else if (anf_node->isa<CNode>()) { | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input0 = cnode->input(0); | |||
| MS_EXCEPTION_IF_NULL(input0); | |||
| if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { | |||
| if (cnode->inputs().size() != kTupleGetItemInputSize) { | |||
| MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; | |||
| } | |||
| auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(input2); | |||
| auto value_node = input2->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| int item_idx = GetValue<int>(value_node->value()); | |||
| return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx), | |||
| visit_nop_node, return_types); | |||
| } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { | |||
| return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types); | |||
| } else if (opt::IsNopNode(cnode) && visit_nop_node) { | |||
| if (cnode->inputs().size() == 2) { | |||
| return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node"; | |||
| if (!anf_node->isa<CNode>()) { | |||
| return KernelWithIndex(anf_node, 0); | |||
| } | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) { | |||
| auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode), | |||
| GetTupleGetItemOutIndex(cnode), visit_nop_node, return_types); | |||
| if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) { | |||
| MS_EXCEPTION_IF_NULL(item_with_index_tmp.first); | |||
| auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| const std::vector<AnfNodePtr> &make_tuple_inputs = make_tuple->inputs(); | |||
| size_t make_tuple_input_index = item_with_index_tmp.second + 1; | |||
| if (make_tuple_input_index >= make_tuple_inputs.size()) { | |||
| MS_LOG(EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size() | |||
| << "]."; | |||
| } | |||
| } else { | |||
| return std::make_pair(anf_node, index); | |||
| return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, visit_nop_node, return_types); | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "The input is invalid"; | |||
| return item_with_index_tmp; | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimDepend) || CheckPrimitiveType(cnode, prim::kPrimControlDepend)) { | |||
| return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, visit_nop_node, return_types); | |||
| } | |||
| if (opt::IsNopNode(cnode) && visit_nop_node) { | |||
| if (cnode->size() != kNopNodeInputSize) { | |||
| MS_LOG(EXCEPTION) << "Invalid nop node " << cnode->DebugString(); | |||
| } | |||
| return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, visit_nop_node, return_types); | |||
| } | |||
| return KernelWithIndex(anf_node, index); | |||
| } | |||
| std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node, | |||
| @@ -591,7 +612,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, | |||
| if (opt::IsNopNode(node) && visit_nop_node) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().size() == 2) { | |||
| if (cnode->size() == kNopNodeInputSize) { | |||
| return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; | |||
| @@ -613,7 +634,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod | |||
| if (opt::IsNopNode(node) && visit_nop_node) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().size() == 2) { | |||
| if (cnode->inputs().size() == kNopNodeInputSize) { | |||
| return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; | |||
| @@ -806,7 +827,7 @@ bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { | |||
| IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || | |||
| IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || | |||
| IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || | |||
| IsPrimitive(input, prim::kPrimReturn); | |||
| IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); | |||
| return !is_virtual_node; | |||
| } | |||
| @@ -1117,5 +1138,14 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, s | |||
| } | |||
| return GetCNodeOutputPrecision(kernel_with_index.first); | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->inputs().empty()) { | |||
| MS_LOG(EXCEPTION) << "Illegal null input of cnode."; | |||
| } | |||
| auto input = node->input(kAnfPrimitiveIndex); | |||
| return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch); | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -42,9 +42,12 @@ using DeviceAddress = device::DeviceAddress; | |||
| using DeviceAddressPtr = device::DeviceAddressPtr; | |||
| class AnfRuntimeAlgorithm { | |||
| public: | |||
| // get real input node of tuple_get_item | |||
| static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item); | |||
| static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); | |||
| // get input_anf_node's real kernel by recurse | |||
| static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index); | |||
| static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index, | |||
| static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, int output_index, | |||
| bool visit_nop_node = false, | |||
| const std::vector<PrimitivePtr> &return_types = { | |||
| prim::kPrimMakeTuple}); | |||
| @@ -205,6 +208,7 @@ class AnfRuntimeAlgorithm { | |||
| static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); | |||
| // get fix output precision from prev node, input_idx is the input index of current node related to prev node. | |||
| static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); | |||
| static bool IsCondControlKernel(const CNodePtr &node); | |||
| }; | |||
| } // namespace session | |||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | |||
| @@ -17,6 +17,7 @@ | |||
| #include "backend/session/ascend_control_parser.h" | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/union_find_set.h" | |||
| #include "runtime/device/ascend/ascend_label_assign.h" | |||
| @@ -31,94 +32,11 @@ static constexpr size_t kCNodePartialLength = 2; | |||
| static constexpr size_t kCNodePartialFunc = 1; | |||
| static constexpr size_t kCNodeSwitchLayerBranch = 2; | |||
| static constexpr size_t kCNodeSwitchLayerLength = 3; | |||
| static constexpr size_t kCNodeAssignTarget = 1; | |||
| static constexpr size_t kCNodeAssignSource = 2; | |||
| namespace mindspore { | |||
| namespace session { | |||
| static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) { | |||
| auto &nodes = parent_graph->execution_order(); | |||
| CNodePtr last_jump_node = nullptr; | |||
| for (auto &node : nodes) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) { | |||
| if (child_graph->get_start_label() == node->input(kCNodeCallArg)) { | |||
| return node; | |||
| } | |||
| last_jump_node = node; | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) { | |||
| if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || | |||
| child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) { | |||
| return node; | |||
| } | |||
| last_jump_node = node; | |||
| } | |||
| } | |||
| if (last_jump_node == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); | |||
| } | |||
| return last_jump_node; | |||
| } | |||
| static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| if (memo->find(kg.get()) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(kg.get()); | |||
| const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs(); | |||
| for (auto &iter : real_inputs) { | |||
| auto ¶ = iter.first; | |||
| MS_EXCEPTION_IF_NULL(para); | |||
| if (para->isa<Parameter>()) { | |||
| union_find_set->Add(para); | |||
| } | |||
| for (auto &arg : iter.second) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| if (!arg->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| union_find_set->Add(arg); | |||
| } | |||
| } | |||
| for (auto &child : kg->child_graph_order()) { | |||
| InitUnionFindSet(NOT_NULL(child), union_find_set, memo); | |||
| } | |||
| } | |||
| static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| if (memo->find(kg.get()) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(kg.get()); | |||
| const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs(); | |||
| for (auto &iter : real_inputs) { | |||
| auto ¶ = iter.first; | |||
| for (auto &arg : iter.second) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| if (!arg->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) { | |||
| continue; | |||
| } | |||
| union_find_set->Union(arg, para); | |||
| } | |||
| } | |||
| for (auto &child : kg->child_graph_order()) { | |||
| UnionParentParameter(NOT_NULL(child), union_find_set, memo); | |||
| } | |||
| } | |||
| static UnionFindSet<AnfNodePtr> MakeUnionFindSet(NotNull<KernelGraphPtr> root_kg) { | |||
| UnionFindSet<AnfNodePtr> result; | |||
| std::set<KernelGraphPtr> memo; | |||
| InitUnionFindSet(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); | |||
| memo.clear(); | |||
| UnionParentParameter(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); | |||
| return result; | |||
| } | |||
| static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> main_parameter, | |||
| const std::set<AnfNodePtr> ¶meter_reuse_set, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| @@ -135,8 +53,9 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> | |||
| continue; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(para); | |||
| MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to " | |||
| << main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get()); | |||
| MS_LOG(INFO) << "In " << kg->ToString() << " replace " << para->DebugString() << " of graph " | |||
| << AnfAlgo::GetGraphId(para.get()) << " to " << main_parameter->DebugString() << " of graph " | |||
| << AnfAlgo::GetGraphId(main_parameter.get().get()); | |||
| kg->ReplaceNode(NOT_NULL(para), main_parameter); | |||
| } | |||
| @@ -145,7 +64,7 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> | |||
| } | |||
| } | |||
| static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNodePtr key, | |||
| static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNodePtr &key, | |||
| const std::set<AnfNodePtr> ¶meter_reuse_set) { | |||
| AnfNodePtr main_parameter = key; | |||
| std::set<AnfNodePtr> root_inputs_set; | |||
| @@ -160,8 +79,19 @@ static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNod | |||
| return main_parameter; | |||
| } | |||
| static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) { | |||
| auto parameter_reuse_sets = parameter_set->GetSets(); | |||
| static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, | |||
| const std::vector<std::pair<AnfNodePtr, AnfNodePtr>> &link_list) { | |||
| // make union find set | |||
| UnionFindSet<AnfNodePtr> union_find_set; | |||
| for (auto &[param, arg] : link_list) { | |||
| union_find_set.Add(param); | |||
| union_find_set.Add(arg); | |||
| } | |||
| for (auto &[param, arg] : link_list) { | |||
| union_find_set.Union(param, arg); | |||
| } | |||
| auto parameter_reuse_sets = union_find_set.GetSets(); | |||
| for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { | |||
| if (parameter_reuse_set.size() <= 1) { | |||
| continue; | |||
| @@ -172,7 +102,7 @@ static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet | |||
| } | |||
| } | |||
| CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) { | |||
| static CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) { | |||
| for (size_t i = start; i < list.size() - 1; ++i) { | |||
| if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { | |||
| return list[i]; | |||
| @@ -181,71 +111,287 @@ CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) { | |||
| return nullptr; | |||
| } | |||
| static void UpdateLabelIdToLabelSetMap(const std::vector<CNodePtr> &exec_order, | |||
| const NotNull<std::map<uint32_t, CNodePtr> *> label_id_to_label_set) { | |||
| for (auto &node : exec_order) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!IsPrimitiveCNode(node, prim::kPrimLabelSet)) { | |||
| continue; | |||
| } | |||
| if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { | |||
| MS_LOG(EXCEPTION) << node->DebugString() << " has no attr kAttrLabelIndex"; | |||
| } | |||
| uint32_t label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); | |||
| if (auto iter = label_id_to_label_set->find(label_id); iter != label_id_to_label_set->end()) { | |||
| MS_LOG(EXCEPTION) << "There are more than one node has same label id " << label_id | |||
| << ", node: " << iter->second->DebugString() << " and " << node->DebugString(); | |||
| } | |||
| (*label_id_to_label_set)[label_id] = node; | |||
| } | |||
| } | |||
| static std::vector<CNodePtr> GetTargetLabelSetNodes(NotNull<CNodePtr> jump_node, | |||
| const std::map<uint32_t, CNodePtr> &label_id_to_label_set) { | |||
| std::vector<uint32_t> target_label_list; | |||
| std::vector<CNodePtr> target_labelset_nodes; | |||
| if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelGoto)) { | |||
| if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, jump_node)) { | |||
| MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kAttrLabelIndex"; | |||
| } | |||
| uint32_t label_id = AnfAlgo::GetNodeAttr<uint32_t>(jump_node.get(), kAttrLabelIndex); | |||
| target_label_list.push_back(label_id); | |||
| } else if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelSwitch)) { | |||
| if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, jump_node)) { | |||
| MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kPrimLabelSwitch"; | |||
| } | |||
| target_label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(jump_node.get(), kAttrLabelSwitchList); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unknown type jump node " << jump_node->DebugString(); | |||
| } | |||
| for (auto label_id : target_label_list) { | |||
| auto iter = label_id_to_label_set.find(label_id); | |||
| if (iter == label_id_to_label_set.end()) { | |||
| MS_LOG(EXCEPTION) << "Connot find LabelSet node has label id " << label_id; | |||
| } | |||
| target_labelset_nodes.push_back(iter->second); | |||
| } | |||
| return target_labelset_nodes; | |||
| } | |||
| static void EraseNodeFromExecOrder(const AnfNodePtr &node, const NotNull<std::vector<CNodePtr> *> exec_order) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto exec_iter = std::find(exec_order->begin(), exec_order->end(), node); | |||
| if (exec_iter == exec_order->end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find " << node->DebugString() << " in exec order."; | |||
| } | |||
| exec_order->erase(exec_iter); | |||
| } | |||
| void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { | |||
| std::set<KernelGraphPtr> memo; | |||
| std::vector<std::pair<AnfNodePtr, AnfNodePtr>> link_list; | |||
| // Insert Assign | |||
| ChildGraphDataAssign(kg, NOT_NULL(&link_list), NOT_NULL(&memo)); | |||
| // Reuse Parameter | |||
| ReuseParameter(kg, link_list); | |||
| // replace call by label goto / label switch | |||
| memo.clear(); | |||
| (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); | |||
| // assign label resource | |||
| device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg); | |||
| std::map<uint32_t, KernelGraphPtr> graph_id_map; | |||
| for (auto &g : memo) { | |||
| MS_EXCEPTION_IF_NULL(g); | |||
| if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id() | |||
| << ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString(); | |||
| } | |||
| void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph, | |||
| const std::set<KernelGraphPtr> &graph_list) { | |||
| std::vector<CNodePtr> exec_order = root_graph->execution_order(); | |||
| std::set<CNodePtr> search_list(exec_order.begin(), exec_order.end()); | |||
| std::set<AnfNodePtr> root_inputs(root_graph->inputs().begin(), root_graph->inputs().end()); | |||
| auto ref_map = root_graph->GetRefMap(); | |||
| ReferenceCounter parameter_count([](int32_t read, int32_t write) -> bool { return write == 1; }); | |||
| std::multimap<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> ref_multimap; | |||
| std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()), | |||
| [](const std::pair<std::pair<AnfNodePtr, size_t>, std::pair<AnfNodePtr, size_t>> &p) | |||
| -> std::pair<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> { | |||
| return {p.first.first, {p.first.second, p.second.first, p.second.second}}; | |||
| }); | |||
| std::set<CNodePtr> all_nodes; | |||
| std::map<AnfNodePtr, CNodePtr> para_to_written_node; | |||
| for (auto &graph : graph_list) { | |||
| auto out = graph->get_return(); | |||
| MS_EXCEPTION_IF_NULL(out); | |||
| search_list.insert(out->cast<CNodePtr>()); | |||
| auto nodes = TopoSort(out); | |||
| for (auto &node : nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode != nullptr) { | |||
| all_nodes.insert(cnode); | |||
| } | |||
| } | |||
| } | |||
| // prepare referance count | |||
| for (auto &node : search_list) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| // if assign node | |||
| std::set<AnfNodePtr> refed_parameters; | |||
| for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) { | |||
| refed_parameters.insert(std::get<1>(iter->second)); | |||
| } | |||
| for (auto &in : node->inputs()) { | |||
| auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first; | |||
| if (!visit_node->isa<Parameter>() || root_inputs.find(visit_node) != root_inputs.end()) { | |||
| continue; | |||
| } | |||
| if (refed_parameters.find(visit_node) != refed_parameters.end()) { | |||
| parameter_count.AddWriteCount(visit_node, 1); | |||
| para_to_written_node[visit_node] = node; | |||
| } else { | |||
| parameter_count.AddReadCount(visit_node, 1); | |||
| } | |||
| } | |||
| graph_id_map[g->graph_id()] = g; | |||
| } | |||
| // Insert Assign | |||
| ChildGraphDataAssign(graph_id_map); | |||
| // Make UnionFindSet | |||
| UnionFindSet<AnfNodePtr> parameter_set = MakeUnionFindSet(kg); | |||
| // Reuse Parameter | |||
| ReuseParameter(kg, NOT_NULL(¶meter_set)); | |||
| while (parameter_count.HasValidElem()) { | |||
| auto [para, read, written] = parameter_count.GetOneValidElem(); | |||
| MS_LOG(INFO) << para->DebugString() << " was read " << read << " times, written " << written << " times."; | |||
| auto assign_iter = para_to_written_node.find(para); | |||
| if (assign_iter == para_to_written_node.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find assign node that write " << para->DebugString(); | |||
| } | |||
| auto &assign_node = assign_iter->second; | |||
| MS_EXCEPTION_IF_NULL(assign_node); | |||
| if (!IsPrimitiveCNode(assign_node, prim::kPrimAssign)) { | |||
| parameter_count.EraseElem(para); | |||
| continue; | |||
| } | |||
| MS_LOG(INFO) << "Erase " << assign_node->DebugString(5); | |||
| EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order)); | |||
| auto source = AnfAlgo::VisitKernelWithReturnType(assign_node->input(kCNodeAssignSource), 0).first; | |||
| parameter_count.AddReadCount(source, -1); | |||
| parameter_count.AddWriteCount(para, -1); | |||
| for (auto &node : all_nodes) { | |||
| for (size_t i = 0; i < node->size(); ++i) { | |||
| if (node->input(i) == para) { | |||
| MS_LOG_INFO << "Replace " << node->DebugString() << " input " << i << " by " << source->DebugString(); | |||
| node->set_input(i, source); | |||
| } | |||
| } | |||
| } | |||
| parameter_count.AddReadCount(source, 1); | |||
| parameter_count.AddReadCount(para, -1); | |||
| } | |||
| root_graph->set_execution_order(exec_order); | |||
| } | |||
| void AscendControlParser::EraseLabel(NotNull<KernelGraphPtr> root_graph) { | |||
| std::vector<CNodePtr> exec_order = root_graph->execution_order(); | |||
| ReferenceCounter label_count([](int32_t read, int32_t write) -> bool { return read <= 1; }); | |||
| std::map<AnfNodePtr, CNodePtr> label_to_written_node; | |||
| std::map<uint32_t, CNodePtr> label_id_to_label_set; | |||
| UpdateLabelIdToLabelSetMap(exec_order, NOT_NULL(&label_id_to_label_set)); | |||
| CNodePtr last_node = nullptr; | |||
| for (auto &cur_node : exec_order) { | |||
| MS_EXCEPTION_IF_NULL(cur_node); | |||
| if (AnfAlgo::IsCondControlKernel(cur_node)) { | |||
| std::vector<CNodePtr> target_labelset_nodes = GetTargetLabelSetNodes(NOT_NULL(cur_node), label_id_to_label_set); | |||
| for (auto &label_set : target_labelset_nodes) { | |||
| label_count.AddReadCount(label_set, 1); | |||
| label_to_written_node[label_set] = cur_node; | |||
| } | |||
| } else if (IsPrimitiveCNode(cur_node, prim::kPrimLabelSet)) { | |||
| label_count.AddWriteCount(cur_node, 1); | |||
| if (last_node != nullptr && !AnfAlgo::IsCondControlKernel(last_node)) { | |||
| label_count.AddReadCount(cur_node, 1); | |||
| label_to_written_node[cur_node] = last_node; | |||
| } | |||
| } | |||
| last_node = cur_node; | |||
| } | |||
| while (label_count.HasValidElem()) { | |||
| auto [label_set, read, written] = label_count.GetOneValidElem(); | |||
| MS_LOG(INFO) << label_set->DebugString() << " was read " << read << " times, written " << written << " times."; | |||
| auto iter = label_to_written_node.find(label_set); | |||
| if (read > 0 && iter == label_to_written_node.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find node jump to " << label_set->DebugString(); | |||
| } | |||
| CNodePtr jump_node = read > 0 ? iter->second : nullptr; | |||
| if (jump_node == nullptr || IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) { | |||
| MS_LOG(INFO) << "Erase node " << label_set->DebugString(); | |||
| EraseNodeFromExecOrder(label_set, NOT_NULL(&exec_order)); | |||
| } | |||
| if (jump_node != nullptr && IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) { | |||
| MS_LOG(INFO) << "Erase node " << jump_node->DebugString(); | |||
| EraseNodeFromExecOrder(jump_node, NOT_NULL(&exec_order)); | |||
| } | |||
| label_count.EraseElem(label_set); | |||
| } | |||
| root_graph->set_execution_order(exec_order); | |||
| } | |||
| void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { | |||
| std::set<KernelGraphPtr> memo; | |||
| (void)RecurseGraph(root_graph, NOT_NULL(&memo)); | |||
| EraseParameter(root_graph, memo); | |||
| EraseLabel(root_graph); | |||
| } | |||
| void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) { | |||
| for (auto &iter : graph_id_map) { | |||
| auto &kg = iter.second; | |||
| MS_LOG(INFO) << "Data assign graph:" << kg->graph_id(); | |||
| MS_EXCEPTION_IF_NULL(kg); | |||
| std::set<std::pair<AnfNodePtr, AnfNodePtr>> memo; | |||
| const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs(); | |||
| for (auto &it : real_inputs) { | |||
| auto ¶meter = it.first; | |||
| auto &args = it.second; | |||
| for (auto &arg : args) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| if (memo.find({parameter, arg}) != memo.end()) { | |||
| continue; | |||
| } else { | |||
| memo.emplace(parameter, arg); | |||
| } | |||
| auto unreuse_args_map = kg->unreuse_args(); | |||
| auto unreuse_arg_iter = unreuse_args_map.find(arg); | |||
| if (unreuse_arg_iter == unreuse_args_map.end()) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| if (!arg->isa<Parameter>()) { | |||
| MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << "."; | |||
| } | |||
| MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() | |||
| << ", arg:" << arg->DebugString(); | |||
| std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallNode( | |||
| NotNull<CNodePtr> call_node) { | |||
| std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ret; | |||
| if (!IsPrimitiveCNode(call_node.get(), prim::kPrimCall)) { | |||
| MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " is not a call node."; | |||
| } | |||
| if (call_node->size() <= kCNodeCallArg) { | |||
| MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " has invalid inputs size " << call_node->size(); | |||
| } | |||
| const std::vector<AnfNodePtr> &call_node_inputs = call_node->inputs(); | |||
| auto call_arg = call_node_inputs[kCNodeCallArg]; | |||
| MS_EXCEPTION_IF_NULL(call_arg); | |||
| if (IsValueNode<KernelGraph>(call_arg)) { | |||
| ret.emplace_back(GetValueNode<KernelGraphPtr>(call_arg), | |||
| std::vector<AnfNodePtr>(call_node_inputs.begin() + kCNodeCallArg + 1, call_node_inputs.end())); | |||
| } else if (IsPrimitiveCNode(call_arg, prim::kPrimSwitch)) { | |||
| auto switch_cnode = call_arg->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(switch_cnode); | |||
| const std::vector<AnfNodePtr> &switch_inputs = switch_cnode->inputs(); | |||
| if (switch_inputs.size() <= kCNodeSwitchCond) { | |||
| MS_LOG(EXCEPTION) << "Node " << switch_cnode->DebugString() << " has invalid inputs size " | |||
| << switch_inputs.size(); | |||
| } | |||
| for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) { | |||
| const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter)); | |||
| ret.emplace_back(target_graph, args); | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unsupport call node: " << call_node->DebugString(5); | |||
| } | |||
| return ret; | |||
| } | |||
| void AscendControlParser::ChildGraphDataAssign( | |||
| NotNull<KernelGraphPtr> kg, const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| if (memo->find(kg) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(kg.get()); | |||
| MS_LOG(INFO) << "Start link data for " << kg->ToString(); | |||
| const std::vector<CNodePtr> &nodes = kg->execution_order(); | |||
| for (auto &node : nodes) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimCall)) { | |||
| continue; | |||
| } | |||
| auto child_graph_list = ParseCallNode(NOT_NULL(node)); | |||
| for (auto &[child_graph, args] : child_graph_list) { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| const std::vector<AnfNodePtr> ¶ms = child_graph->inputs(); | |||
| if (args.size() != params.size()) { | |||
| MS_LOG(EXCEPTION) << child_graph->ToString() << " needs " << params.size() << " inputs but call node " | |||
| << node->DebugString(5) << " gives " << args.size(); | |||
| } | |||
| for (size_t i = 0; i < args.size(); ++i) { | |||
| if (args[i]->isa<Parameter>() && memo->find(child_graph) == memo->end()) { | |||
| MS_LOG(INFO) << args[i]->DebugString() << " to " << params[i]->DebugString() | |||
| << " should be reused, continue."; | |||
| link_list->emplace_back(args[i], params[i]); | |||
| continue; | |||
| } | |||
| auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get())); | |||
| if (target_graph_iter == graph_id_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; | |||
| } | |||
| InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg), | |||
| NOT_NULL(parameter)); | |||
| InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); | |||
| } | |||
| } | |||
| kg->SetExecOrderByDefault(); | |||
| } | |||
| kg->SetExecOrderByDefault(); | |||
| for (auto &child_graph : kg->child_graph_order()) { | |||
| ChildGraphDataAssign(NOT_NULL(child_graph), link_list, memo); | |||
| } | |||
| } | |||
| @@ -325,7 +471,7 @@ void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNul | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), | |||
| return_node->input(kFirstDataInputIndex), attch_node.get()}; | |||
| auto depend_node = kg->NewCNode(inputs); | |||
| return_node->set_input(1, depend_node); | |||
| return_node->set_input(kFirstDataInputIndex, depend_node); | |||
| } | |||
| void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node, | |||
| @@ -381,6 +527,7 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP | |||
| new_inputs.push_back(sub_label); | |||
| cur_node->set_inputs(new_inputs); | |||
| cur_node->set_abstract(nullptr); | |||
| AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>({call_kg}), cur_node.get()); | |||
| MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); | |||
| } | |||
| @@ -409,9 +556,12 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod | |||
| std::vector<AnfNodePtr> new_switch_inputs = { | |||
| std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), | |||
| origin_switch_inputs[kCNodeSwitchCond]}; | |||
| std::vector<KernelGraphPtr> child_graphs; | |||
| for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { | |||
| // 3.1 branch kernel graph and args | |||
| KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| KernelGraphPtr branch_fg; | |||
| std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| child_graphs.push_back(branch_fg); | |||
| // 3.2 recurse sub graph | |||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | |||
| new_switch_inputs.push_back(branch_label); | |||
| @@ -420,6 +570,7 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod | |||
| cur_node->set_inputs(new_switch_inputs); | |||
| cur_node->set_abstract(nullptr); | |||
| AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), cur_node.get()); | |||
| MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString(); | |||
| } | |||
| @@ -453,9 +604,12 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull | |||
| std::vector<AnfNodePtr> new_switch_inputs = { | |||
| std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), | |||
| origin_switch_inputs[kCNodeSwitchCond]}; | |||
| std::vector<KernelGraphPtr> child_graphs; | |||
| for (size_t i = 0; i < branch_partial.size(); ++i) { | |||
| // 3.1 branch kernel graph and args | |||
| KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| KernelGraphPtr branch_fg; | |||
| std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| child_graphs.push_back(branch_fg); | |||
| // 3.2 recurse sub graph | |||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | |||
| new_switch_inputs.push_back(branch_label); | |||
| @@ -463,13 +617,14 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull | |||
| new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); | |||
| cur_node->set_inputs(new_switch_inputs); | |||
| cur_node->set_abstract(nullptr); | |||
| AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), cur_node.get()); | |||
| MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); | |||
| } | |||
| KernelGraphPtr AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) { | |||
| std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) { | |||
| if (!node.get()->isa<CNode>()) { | |||
| if (IsValueNode<KernelGraph>(node)) { | |||
| return GetValueNode<KernelGraphPtr>(node); | |||
| return {GetValueNode<KernelGraphPtr>(node), {}}; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); | |||
| } | |||
| @@ -485,12 +640,11 @@ KernelGraphPtr AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) { | |||
| MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << "."; | |||
| } | |||
| auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]); | |||
| return branch_kg; | |||
| return {branch_kg, std::vector<AnfNodePtr>(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end())}; | |||
| } | |||
| void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, | |||
| NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from, | |||
| NotNull<AnfNodePtr> to) { | |||
| void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, const AnfNodePtr &jump_node, | |||
| NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to) { | |||
| std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); | |||
| std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); | |||
| MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]"; | |||
| @@ -500,22 +654,35 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr | |||
| } | |||
| for (size_t i = 0; i < from_outputs.size(); i++) { | |||
| auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); | |||
| if (assign_node != nullptr) { | |||
| auto jump_node = GetJumpNode(from_graph, to_graph); | |||
| const auto &from_graph_exe_order = from_graph->execution_order(); | |||
| auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); | |||
| if (jump_node_iter == from_graph_exe_order.end()) { | |||
| MS_EXCEPTION_IF_NULL(jump_node); | |||
| MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id(); | |||
| } | |||
| // insert assign between jump_node -1 and jump_node | |||
| if (jump_node_iter != from_graph_exe_order.begin()) { | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); | |||
| } | |||
| if (jump_node != nullptr) { | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); | |||
| const auto &from_graph_exe_order = from_graph->execution_order(); | |||
| std::vector<CNodePtr> real_exe_order(from_graph_exe_order.size()); | |||
| size_t real_exe_order_size = 0; | |||
| std::copy_if(from_graph_exe_order.begin(), from_graph_exe_order.end(), real_exe_order.begin(), | |||
| [&real_exe_order_size](const CNodePtr &node) -> bool { | |||
| return (IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimPartial)) | |||
| ? false | |||
| : (++real_exe_order_size, true); | |||
| }); | |||
| real_exe_order.resize(real_exe_order_size); | |||
| if (jump_node == nullptr) { | |||
| if (!real_exe_order.empty()) { | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(real_exe_order.rbegin())), NOT_NULL(assign_node)); | |||
| } else { | |||
| InsertDependToGraph(from_graph, NOT_NULL(assign_node)); | |||
| } | |||
| continue; | |||
| } | |||
| auto jump_node_iter = std::find(real_exe_order.begin(), real_exe_order.end(), jump_node); | |||
| if (jump_node_iter == real_exe_order.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph " | |||
| << from_graph->ToString(); | |||
| } | |||
| // insert assign between jump_node -1 and jump_node | |||
| if (jump_node_iter != real_exe_order.begin()) { | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); | |||
| } | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); | |||
| } | |||
| } | |||
| @@ -618,26 +785,45 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i | |||
| } | |||
| } | |||
| void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) { | |||
| MS_LOG(INFO) << "Graph id:" << kg->graph_id(); | |||
| kg->SetExecOrderByDefault(); | |||
| auto call_nodes = kg->FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name())); | |||
| std::vector<KernelGraphPtr> child_graph_order; | |||
| for (auto &call_node : call_nodes) { | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>()); | |||
| for (const auto &child_graph : call_child_graphs) { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| if (child_graph != kg->parent_graph()) { | |||
| child_graph->set_parent_graph(kg.get()); | |||
| } | |||
| child_graph_order.push_back(child_graph); | |||
| } | |||
| void AscendControlParser::ReferenceCounter::AddReadCount(const AnfNodePtr &key, int32_t num) { | |||
| auto iter = count_.find(key); | |||
| if (iter != count_.end()) { | |||
| iter->second.first += num; | |||
| } else { | |||
| count_[key] = {num, 0}; | |||
| } | |||
| for (size_t i = 0; i < child_graph_order.size(); i++) { | |||
| MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; | |||
| } | |||
| void AscendControlParser::ReferenceCounter::AddWriteCount(const AnfNodePtr &key, int32_t num) { | |||
| auto iter = count_.find(key); | |||
| if (iter != count_.end()) { | |||
| iter->second.second += num; | |||
| } else { | |||
| count_[key] = {0, num}; | |||
| } | |||
| } | |||
| void AscendControlParser::ReferenceCounter::EraseElem(const AnfNodePtr &key) { count_.erase(key); } | |||
| bool AscendControlParser::ReferenceCounter::HasValidElem() const { | |||
| auto it = std::find_if(count_.begin(), count_.end(), | |||
| [this](const std::pair<AnfNodePtr, std::pair<uint32_t, uint32_t>> &p) -> bool { | |||
| auto &[read, written] = p.second; | |||
| return predicate_(read, written); | |||
| }); | |||
| return it != count_.end(); | |||
| } | |||
| std::tuple<AnfNodePtr, int32_t, int32_t> AscendControlParser::ReferenceCounter::GetOneValidElem() const { | |||
| auto it = std::find_if(count_.begin(), count_.end(), | |||
| [this](const std::pair<AnfNodePtr, std::pair<uint32_t, uint32_t>> &p) -> bool { | |||
| auto &[read, written] = p.second; | |||
| return predicate_(read, written); | |||
| }); | |||
| if (it == count_.end()) { | |||
| MS_LOG(EXCEPTION) << "No valid parameter."; | |||
| } | |||
| kg->set_child_graph_order(child_graph_order); | |||
| return {it->first, it->second.first, it->second.second}; | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -20,6 +20,8 @@ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <tuple> | |||
| #include <utility> | |||
| #include <functional> | |||
| #include "backend/session/kernel_graph.h" | |||
| #include "utils/base_ref.h" | |||
| #include "utils/contract.h" | |||
| @@ -29,16 +31,23 @@ namespace mindspore { | |||
| namespace session { | |||
| class AscendControlParser { | |||
| public: | |||
| static void ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map); | |||
| static void LinkGraph(NotNull<KernelGraphPtr> kg); | |||
| static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node); | |||
| static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node, | |||
| NotNull<AnfNodePtr> second_node); | |||
| static void ExecutorValidate(NotNull<KernelGraphPtr> root_graph); | |||
| static void UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg); | |||
| static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, const AnfNodePtr &jump_node, | |||
| NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||
| private: | |||
| class ReferenceCounter; | |||
| static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list); | |||
| static void EraseLabel(NotNull<KernelGraphPtr> root_graph); | |||
| static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg, | |||
| const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static NotNull<CNodePtr> GetStartLabel(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | |||
| const CNodePtr &last_label); | |||
| static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | |||
| @@ -53,11 +62,10 @@ class AscendControlParser { | |||
| static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | |||
| const CNodePtr &last_label); | |||
| static KernelGraphPtr ParsePartial(NotNull<AnfNodePtr> node); | |||
| static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph, | |||
| NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||
| static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||
| static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallNode(NotNull<CNodePtr> call_node); | |||
| static std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> ParsePartial(NotNull<AnfNodePtr> node); | |||
| // root graph order | |||
| static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, | |||
| @@ -65,6 +73,19 @@ class AscendControlParser { | |||
| static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| }; | |||
| class AscendControlParser::ReferenceCounter { | |||
| public: | |||
| explicit ReferenceCounter(std::function<bool(int32_t, int32_t)> func) : predicate_(func), count_() {} | |||
| void AddReadCount(const AnfNodePtr &key, int32_t num); | |||
| void AddWriteCount(const AnfNodePtr &key, int32_t num); | |||
| void EraseElem(const AnfNodePtr &key); | |||
| bool HasValidElem() const; | |||
| std::tuple<AnfNodePtr, int32_t, int32_t> GetOneValidElem() const; | |||
| private: | |||
| std::function<bool(int32_t, int32_t)> predicate_; | |||
| std::map<AnfNodePtr, std::pair<int32_t, int32_t>> count_; | |||
| }; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -289,6 +289,17 @@ static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph, | |||
| // this action should from bottom to top | |||
| graph->UpdateCallRealInput(); | |||
| } | |||
| void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) { | |||
| auto return_node = root_graph->get_return(); | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| if (return_node->size() <= kReturnDataIndex) { | |||
| return; | |||
| } | |||
| auto make_tuple = root_graph->NewCNode( | |||
| {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()}); | |||
| root_graph->set_output(make_tuple); | |||
| } | |||
| } // namespace | |||
| GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| @@ -305,22 +316,39 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| std::vector<KernelGraphPtr> all_graphs; | |||
| auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); | |||
| BackendOptimization(all_graphs); | |||
| // split switch | |||
| SplitGraphs(NOT_NULL(root_graph)); | |||
| // empty graph dont entry to backend | |||
| if (root_graph->execution_order().empty()) { | |||
| MS_LOG(INFO) << root_graph->ToString() << " is empty graph."; | |||
| InsertMakeTupleForOutput(NOT_NULL(root_graph)); | |||
| root_graph->set_executable(false); | |||
| InitRuntimeResource(); | |||
| return root_graph->graph_id(); | |||
| } | |||
| // create parameter for multiple branch | |||
| std::set<KernelGraphPtr> memo; | |||
| CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo)); | |||
| memo.clear(); | |||
| // insert goto labels and label_sets | |||
| LinkChildGraphs(NOT_NULL(root_graph)); | |||
| // resource initialize | |||
| InitRuntimeResource(); | |||
| // recurse compile child root_graph | |||
| std::set<KernelGraphPtr> memo; | |||
| RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo)); | |||
| IrFusionPass(NOT_NULL(root_graph), NOT_NULL(&memo)); | |||
| memo.clear(); | |||
| SelectKernel(NOT_NULL(root_graph)); | |||
| memo.clear(); | |||
| HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo)); | |||
| memo.clear(); | |||
| AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo)); | |||
| memo.clear(); | |||
| UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo)); | |||
| memo.clear(); | |||
| // add make_tuple to the output graph | |||
| InsertMakeTupleForOutput(NOT_NULL(root_graph)); | |||
| // root root_graph valiate,include genearte execute order and so on | |||
| RootGraphExecutorValidate(NOT_NULL(root_graph)); | |||
| // adjust kernel | |||
| @@ -1682,7 +1710,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri | |||
| bool split_flag = false; | |||
| auto apply_list = GetCNodes(TopoSort(graph->get_return())); | |||
| // update the root graph child graph order | |||
| AscendControlParser::UpdateChildGraphOrder(graph); | |||
| graph->UpdateChildGraphOrder(); | |||
| // get child list from current graph | |||
| std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(apply_list, cut_prims); | |||
| if (child_graph_lists.size() > 1) { | |||
| @@ -1714,7 +1742,7 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri | |||
| } | |||
| split_flag = true; | |||
| } | |||
| AscendControlParser::UpdateChildGraphOrder(graph); | |||
| graph->UpdateChildGraphOrder(); | |||
| UpdateRealInput(graph, split_flag, memo); | |||
| MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; | |||
| } | |||
| @@ -1753,5 +1781,216 @@ void AscendSession::RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const Not | |||
| } | |||
| } | |||
| } | |||
| void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| if (memo->find(graph.get()) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(graph.get()); | |||
| graph->UpdateChildGraphOrder(); | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| CreateMultiBranchOutput(NOT_NULL(child_graph), memo); | |||
| } | |||
| std::map<AnfNodePtr, AnfNodePtr> need_replace_list; | |||
| auto node_list = GetCNodes(TopoSort(graph->get_return())); | |||
| for (auto &node : node_list) { | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { | |||
| // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output | |||
| // auto multi_output_param = graph->NewParameter(); | |||
| auto origin_inputs = graph->inputs(); | |||
| auto output_param = CreateNewParameterFromCNode(node, true, graph.get().get()); | |||
| MS_EXCEPTION_IF_NULL(graph->MutableInputs()); | |||
| graph->MutableInputs()->operator=(origin_inputs); | |||
| graph->AddChildGraphResult(output_param); | |||
| std::vector<AnfNodePtr> depend_inputs = { | |||
| graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name()))), output_param, node}; | |||
| auto depend = graph->NewCNode(depend_inputs); | |||
| need_replace_list.emplace(node, depend); | |||
| MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString() | |||
| << ", depend node is " << depend->DebugString(); | |||
| // insert assign in order to transfer child graph output to parameter | |||
| auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node); | |||
| for (auto &child_graph : child_graphs) { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| if (child_graph->get_output_null()) { | |||
| continue; | |||
| } | |||
| auto graph_output = child_graph->output(); | |||
| AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, NOT_NULL(graph_output), | |||
| NOT_NULL(output_param)); | |||
| } | |||
| } | |||
| } | |||
| // searching for nodes' input to replace call by depend(parameter, call) | |||
| for (auto &node : node_list) { | |||
| for (size_t i = 0; i < node->size(); ++i) { | |||
| auto input = node->input(i); | |||
| auto iter = need_replace_list.find(input); | |||
| if (iter != need_replace_list.end()) { | |||
| node->set_input(i, iter->second); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| if (memo->find(graph) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(graph.get()); | |||
| opt::AscendBackendIRFusionOptimization(graph); | |||
| opt::AscendBackendFuseBasicOpt(graph, true); | |||
| opt::AscendBackendGraphKernelOpt(graph, true); | |||
| graph->SetExecOrderByDefault(); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| if (save_graphs) { | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| std::string file_path = | |||
| save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(graph->graph_id()) + ".ir"; | |||
| DumpIR(file_path, graph.get()); | |||
| } | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| IrFusionPass(NOT_NULL(child_graph), memo); | |||
| } | |||
| } | |||
| void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) { | |||
| MS_LOG(INFO) << "Start select kernel."; | |||
| size_t raise_precision_count = 0; | |||
| size_t reduce_precision_count = 0; | |||
| std::set<KernelGraphPtr> memo; | |||
| (void)RecurseSelectKernelInfo(root_graph, NOT_NULL(&memo), &raise_precision_count, &reduce_precision_count); | |||
| memo.clear(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() == kGraphMode) { | |||
| if (raise_precision_count > 0) { | |||
| MS_LOG(WARNING) << "There has " << raise_precision_count | |||
| << " node/nodes used raise precision to selected the kernel!"; | |||
| } | |||
| if (reduce_precision_count > 0) { | |||
| MS_LOG(WARNING) << "There has " << raise_precision_count | |||
| << " node/nodes used reduce precision to selected the kernel!"; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Finish!"; | |||
| } | |||
| void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph, | |||
| NotNull<std::set<KernelGraphPtr> *> const memo, | |||
| size_t *const raise_precision_count, | |||
| size_t *const reduce_precision_count) const { | |||
| if (memo->find(graph) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(graph.get()); | |||
| MS_LOG(INFO) << "Start to select kernel info in graph: " << graph->graph_id(); | |||
| for (const auto &cnode : graph->execution_order()) { | |||
| if (AnfAlgo::IsCondControlKernel(cnode)) { | |||
| std::vector<KernelGraphPtr> child_graphs; | |||
| if (AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)) { | |||
| child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cnode, kAttrChildGraph); | |||
| } | |||
| for (auto &child_graph : child_graphs) { | |||
| RecurseSelectKernelInfo(NOT_NULL(child_graph), memo, raise_precision_count, reduce_precision_count); | |||
| } | |||
| } | |||
| auto status = device::ascend::SelectKernelInfo(cnode); | |||
| if (status == device::ascend::kStatusRaisePrecision) { | |||
| (*raise_precision_count)++; | |||
| } else if (status == device::ascend::kStatusReducePrecision) { | |||
| (*reduce_precision_count)++; | |||
| } | |||
| MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); | |||
| } | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->save_graphs_flag(); | |||
| auto save_graphs_path = context_ptr->save_graphs_path(); | |||
| if (save_graphs) { | |||
| if (save_graphs_path.empty()) { | |||
| save_graphs_path = "."; | |||
| } | |||
| std::string file_path = | |||
| save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(graph->graph_id()) + ".ir"; | |||
| DumpIR(file_path, graph.get()); | |||
| } | |||
| MS_LOG(INFO) << "Finish selecting kernel info in graph: " << graph->graph_id(); | |||
| } | |||
| void AscendSession::HardwareOptimize(NotNull<KernelGraphPtr> graph, | |||
| NotNull<std::set<KernelGraphPtr> *> const memo) const { | |||
| if (memo->find(graph) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(graph.get()); | |||
| MS_LOG(INFO) << "Start to do HardwareOptimize in graph: " << graph->graph_id(); | |||
| // convert kernel Graph to model | |||
| predictmodel::StepConvertGraph(graph.get()); | |||
| HardwareOptimize(graph.get()); | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| HardwareOptimize(NOT_NULL(child_graph), memo); | |||
| } | |||
| MS_LOG(INFO) << "Finish doing HardwareOptimize in graph: " << graph->graph_id(); | |||
| } | |||
| void AscendSession::AssignStaticMemory(NotNull<KernelGraphPtr> graph, | |||
| NotNull<std::set<KernelGraphPtr> *> const memo) const { | |||
| if (memo->find(graph) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(graph.get()); | |||
| MS_LOG(INFO) << "Start to assign static memory for parameter in graph: " << graph->graph_id(); | |||
| // assign static memory for parameters | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| runtime_instance->AssignStaticMemoryInput(graph.get().get()); | |||
| runtime_instance->AssignStaticMemoryValueNode(graph.get().get()); | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| AssignStaticMemory(NOT_NULL(child_graph), memo); | |||
| } | |||
| MS_LOG(INFO) << "Finish assigning static memory for parameter in graph: " << graph->graph_id(); | |||
| } | |||
| void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph, | |||
| NotNull<std::set<KernelGraphPtr> *> const memo) const { | |||
| if (memo->find(graph) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(graph.get()); | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| UpdateRefOutputMap(NOT_NULL(child_graph), memo); | |||
| // copy ref map to final graph | |||
| auto child_ref_map = child_graph->GetRefMap(); | |||
| for (auto &item : child_ref_map) { | |||
| if (graph->IsInRefOutputMap(item.first)) { | |||
| MS_LOG(WARNING) << "The ref pair <" << item.first.first->DebugString() << ", " << item.first.second | |||
| << "> is already in " << graph->ToString(); | |||
| continue; | |||
| } | |||
| graph->AddRefCorrespondPairs(item.first, item.second); | |||
| } | |||
| } | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -151,6 +151,15 @@ class AscendSession : public SessionBasic { | |||
| // sync intial tensors' data to device | |||
| void SyncInitialTenosrToDevice(); | |||
| void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph); | |||
| // create parameter to receive data from multiple branch output | |||
| void CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo); | |||
| void SelectKernel(NotNull<KernelGraphPtr> root_graph); | |||
| void RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> const memo, | |||
| size_t *const raise_precision_count, size_t *const reduce_precision_count) const; | |||
| void IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo); | |||
| void HardwareOptimize(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const; | |||
| void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const; | |||
| void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const; | |||
| // member variables | |||
| // key is final_graph_id,value is child graph execute order of final graph | |||
| @@ -616,8 +616,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de | |||
| if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { | |||
| depend_mode = AnfAlgo::GetNodeAttr<int>(cnode, kControlDependMode); | |||
| } | |||
| MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() | |||
| << "], depend_mode :" << depend_mode << "."; | |||
| MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() | |||
| << "], depend_mode :" << depend_mode << "."; | |||
| if (prior_node->isa<Parameter>() && depend_mode == 1) { | |||
| prior_nodes = GetOutputNodes(prior_node); | |||
| } | |||
| @@ -647,7 +647,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de | |||
| } | |||
| MS_EXCEPTION_IF_NULL(first_node); | |||
| MS_EXCEPTION_IF_NULL(second_node); | |||
| MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); | |||
| MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() | |||
| << ",second node:" << second_node->DebugString(); | |||
| AddDependEdge(second_node, first_node, 1); | |||
| } | |||
| } | |||
| @@ -991,6 +992,30 @@ bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const { | |||
| return false; | |||
| } | |||
| void KernelGraph::UpdateChildGraphOrder() { | |||
| MS_LOG(INFO) << "Update " << ToString() << " child graph order."; | |||
| SetExecOrderByDefault(); | |||
| auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name())); | |||
| std::vector<KernelGraphPtr> child_graph_order; | |||
| for (auto &call_node : call_nodes) { | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>()); | |||
| for (const auto &child_graph : call_child_graphs) { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| if (child_graph != parent_graph_) { | |||
| auto shared_this = std::dynamic_pointer_cast<KernelGraph>(shared_from_this()); | |||
| MS_EXCEPTION_IF_NULL(shared_this); | |||
| child_graph->set_parent_graph(shared_this); | |||
| } | |||
| child_graph_order.push_back(child_graph); | |||
| } | |||
| } | |||
| for (size_t i = 0; i < child_graph_order.size(); ++i) { | |||
| MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; | |||
| } | |||
| child_graph_order_ = child_graph_order; | |||
| } | |||
| std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | |||
| KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } | |||
| @@ -156,6 +156,12 @@ class KernelGraph : public FuncGraph { | |||
| bool IsFinalOutputKernel(const AnfNodePtr &node) const; | |||
| uint32_t current_epoch() const { return current_epoch_; } | |||
| void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; } | |||
| void UpdateChildGraphOrder(); | |||
| const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; } | |||
| void AddChildGraphResult(const AnfNodePtr ¶meter) { child_graph_result_.push_back(parameter); } | |||
| void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) { | |||
| child_graph_result_ = child_graph_result; | |||
| } | |||
| private: | |||
| // remove value node form graph | |||
| @@ -173,6 +179,7 @@ class KernelGraph : public FuncGraph { | |||
| void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends); | |||
| std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | |||
| std::vector<AnfNodePtr> child_graph_result_; | |||
| std::vector<CNodePtr> execution_order_; | |||
| uint32_t graph_id_; | |||
| uint32_t stream_distinction_label_; | |||
| @@ -74,7 +74,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||
| return input_tensors[input_idx]; | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr"; | |||
| MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr"; | |||
| } | |||
| } | |||
| // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) | |||
| @@ -107,8 +107,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||
| return tensor; | |||
| } | |||
| BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) { | |||
| BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) { | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]"; | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); | |||
| @@ -120,7 +120,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| VectorRef ret; | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors); | |||
| auto out = CreateTensorForOutput(cnode->input(i), graph, input_tensors); | |||
| ret.push_back(out); | |||
| } | |||
| return ret; | |||
| @@ -133,25 +133,6 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, | |||
| return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors); | |||
| } | |||
| BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) { | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| if (!AnfAlgo::IsRealKernel(anf)) { | |||
| MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel"; | |||
| } | |||
| if (anf->isa<ValueNode>()) { | |||
| return CreateOneTensor(anf, 0, graph, input_tensors); | |||
| } | |||
| VectorRef ret; | |||
| if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) { | |||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) { | |||
| auto out = CreateOneTensor(anf, i, graph, input_tensors); | |||
| ret.emplace_back(out); | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| @@ -880,20 +861,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| const std::vector<tensor::TensorPtr> &input_tensors) const { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| if (!kernel_graph->child_graph_order().empty()) { | |||
| // use the last child graph output as the root graph output | |||
| UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors); | |||
| return; | |||
| } | |||
| auto anf_outputs = kernel_graph->outputs(); | |||
| for (auto &item : anf_outputs) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| MS_LOG(INFO) << "Update output[" << item->DebugString() << "]"; | |||
| if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) { | |||
| outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors)); | |||
| continue; | |||
| } | |||
| outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors)); | |||
| outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors)); | |||
| } | |||
| } | |||
| @@ -294,6 +294,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(mem_manager_); | |||
| auto graph_inputs = graph->inputs(); | |||
| auto graph_valid_input = graph->valid_inputs(); | |||
| graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end()); | |||
| std::vector<AnfNodePtr> need_alloc_nodes; | |||
| for (size_t i = 0; i < graph_inputs.size(); ++i) { | |||
| auto item = graph_inputs[i]; | |||
| @@ -240,6 +240,7 @@ constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; | |||
| constexpr auto kAttrOffset = "offset"; | |||
| constexpr auto kAttrPsKey = "ps_key"; | |||
| constexpr auto kAttrOptimizerType = "optim_type"; | |||
| constexpr auto kAttrChildGraph = "child_graph"; | |||
| // attr value | |||
| constexpr auto kValueTargetSwitch = "target_switch"; | |||