| @@ -65,6 +65,7 @@ const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign"); | |||
| const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd"); | |||
| const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub"); | |||
| const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select"); | |||
| const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call"); | |||
| const PrimitivePtr kPrimDistribute = std::make_shared<Primitive>("distribute"); | |||
| const PrimitivePtr kPrimDot = std::make_shared<Primitive>("dot"); | |||
| @@ -71,6 +71,7 @@ extern const PrimitivePtr kPrimAssign; | |||
| extern const PrimitivePtr kPrimAssignAdd; | |||
| extern const PrimitivePtr kPrimAssignSub; | |||
| extern const PrimitivePtr kPrimSelect; | |||
| extern const PrimitivePtr kPrimCall; | |||
| extern const PrimitivePtr kPrimDistribute; | |||
| extern const PrimitivePtr kPrimDot; | |||
| @@ -271,7 +271,9 @@ size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { | |||
| size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| TypePtr type = node->Type(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| if (type == nullptr) { | |||
| return 0; | |||
| } | |||
| if (type->isa<Tuple>()) { | |||
| auto tuple_type = type->cast<TuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_type); | |||
| @@ -913,11 +915,66 @@ bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) { | |||
| FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| if (value_node == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto value = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (value == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto func_graph = value->cast<FuncGraphPtr>(); | |||
| return func_graph; | |||
| } | |||
| std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) { | |||
| if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) { | |||
| MS_LOG(EXCEPTION) << "anf node: " << call_node->DebugString() << "is not a call node."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| auto input1 = call_node->input(1); | |||
| MS_EXCEPTION_IF_NULL(input1); | |||
| if (input1->isa<ValueNode>()) { | |||
| auto value_node = input1->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto kernel_graph = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| return {kernel_graph->cast<KernelGraphPtr>()}; | |||
| } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { | |||
| auto switch_node = input1->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(switch_node); | |||
| MS_LOG(INFO) << "switch : " << switch_node->DebugString(); | |||
| auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr { | |||
| auto partial = switch_node->input(input_index); | |||
| MS_EXCEPTION_IF_NULL(partial); | |||
| auto partial_cnode = partial->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(partial_cnode); | |||
| auto graph_node = partial_cnode->input(1); | |||
| MS_EXCEPTION_IF_NULL(graph_node); | |||
| MS_LOG(INFO) << graph_node->DebugString(); | |||
| auto graph_value_node = graph_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(graph_value_node); | |||
| auto graph_value = graph_value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(graph_value); | |||
| auto child_graph = graph_value->cast<KernelGraphPtr>(); | |||
| return child_graph; | |||
| }; | |||
| return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)}; | |||
| } | |||
| return {}; | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| if (!CheckPrimitiveType(call_node, prim::kPrimCall)) { | |||
| MS_LOG(EXCEPTION) << "call node should be a 'call', but is a " << call_node->DebugString(); | |||
| } | |||
| auto input1 = call_node->input(1); | |||
| if (input1->isa<ValueNode>()) { | |||
| return false; | |||
| } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { | |||
| return true; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -32,6 +32,7 @@ | |||
| #include "kernel/kernel_build_info.h" | |||
| #include "operator/ops.h" | |||
| #include "utils/contract.h" | |||
| #include "session/kernel_graph.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| @@ -182,6 +183,8 @@ class AnfRuntimeAlgorithm { | |||
| static bool IsCommunicationOp(const AnfNodePtr &node); | |||
| static bool IsGetNext(const NotNull<AnfNodePtr> &node); | |||
| static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | |||
| static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node); | |||
| static bool IsSwitchCall(const CNodePtr &call_node); | |||
| }; | |||
| } // namespace session | |||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | |||
| @@ -156,6 +156,89 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) { | |||
| } | |||
| } | |||
| } | |||
| std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) { | |||
| std::vector<CNodePtr> cnodes = {}; | |||
| size_t i = 0; | |||
| for (const auto anf : anf_nodes) { | |||
| MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString(); | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| if (anf->isa<CNode>()) { | |||
| cnodes.push_back(anf->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| return std::move(cnodes); | |||
| } | |||
| std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) { | |||
| size_t after_call_index = 0; | |||
| std::vector<std::vector<CNodePtr>> ret; | |||
| for (size_t i = 0; i < cnodes.size(); i++) { | |||
| if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) { | |||
| auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]); | |||
| // if graph is the true branch of while,no need split graph | |||
| if (call_kernel_graph.size() == 1 && call_kernel_graph[0] == cur_graph.parent_graph()) { | |||
| continue; | |||
| } | |||
| auto prev_call_list = std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.begin() + i); | |||
| auto call_list = std::vector<CNodePtr>(1, cnodes[i]); | |||
| after_call_index = i + 1; | |||
| ret.push_back(prev_call_list); | |||
| ret.push_back(call_list); | |||
| } else if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { | |||
| ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end())); | |||
| } | |||
| } | |||
| return std::move(ret); | |||
| } | |||
| void UpdateRealInput(KernelGraph *graph) { | |||
| auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); | |||
| auto bind_call_partial_with_parameter = [&](const std::vector<AnfNodePtr> ¶meters, | |||
| const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id(); | |||
| if (args.empty()) { | |||
| return; | |||
| } | |||
| if (parameters.size() != args.size()) { | |||
| MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() | |||
| << " and args size:" << args.size() << " not equal!"; | |||
| } | |||
| for (size_t i = 0; i < parameters.size(); i++) { | |||
| MS_LOG(INFO) << "bind paramreter:" << parameters[i]->DebugString() << " ,arg:" << args[i]->DebugString(); | |||
| child_graph->SetRealInput(parameters[i], args[i]); | |||
| } | |||
| }; | |||
| for (auto &call_node : call_nodes) { | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node); | |||
| if (child_graphs.size() == 1) { | |||
| MS_EXCEPTION_IF_NULL(child_graphs[0]); | |||
| bind_call_partial_with_parameter( | |||
| child_graphs[0]->inputs(), std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()), | |||
| child_graphs[0].get()); | |||
| call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2)); | |||
| } else if (child_graphs.size() == 2) { | |||
| auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> { | |||
| auto switch_node = call_node->input(1); | |||
| MS_EXCEPTION_IF_NULL(switch_node); | |||
| auto switch_cnode = switch_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(switch_cnode); | |||
| auto partial = switch_cnode->input(input_index); | |||
| MS_EXCEPTION_IF_NULL(partial); | |||
| auto partial_cnode = partial->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(partial_cnode); | |||
| auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); | |||
| partial_cnode->set_inputs( | |||
| std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); | |||
| return std::move(ret); | |||
| }; | |||
| bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); | |||
| bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| @@ -171,7 +254,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| MS_LOG(INFO) << "start"; | |||
| auto graph = ConstructKernelGraph(func_graph); | |||
| // split switch | |||
| SplitSwitch(graph.get()); | |||
| SplitGraph(graph); | |||
| // insert goto labels and label_sets | |||
| LinkChildGraphs(graph.get()); | |||
| // resource initialize | |||
| @@ -1297,5 +1380,107 @@ void AscendSession::SyncInitialTenosrToDevice() { | |||
| } | |||
| } | |||
| } | |||
| KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_graph, | |||
| const std::vector<CNodePtr> &list) { | |||
| MS_EXCEPTION_IF_NULL(new_kernel_graph); | |||
| MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id(); | |||
| // count the output of every anf node | |||
| std::set<AnfNodePtr> has_output_nodes; | |||
| for (auto &anf_node : list) { | |||
| for (auto &input : anf_node->inputs()) { | |||
| (void)has_output_nodes.insert(input); | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { | |||
| new_kernel_graph->set_return(anf_node->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); | |||
| // create new parameter from cnode | |||
| for (auto &anf_node : list) { | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | |||
| auto input = cnode->inputs()[input_idx]; | |||
| if (!input->isa<CNode>()) { | |||
| cnode->set_input(input_idx, input); | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { | |||
| auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); | |||
| cnode->set_input(input_idx, new_parameter); | |||
| new_kernel_graph->SetRealInput(new_parameter, input); | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); | |||
| auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())); | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve}; | |||
| int output_idx = 0; | |||
| for (auto &anf_node : list) { | |||
| if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { | |||
| new_kernel_graph->set_return(anf_node); | |||
| } | |||
| if (has_output_nodes.find(anf_node) == has_output_nodes.end()) { | |||
| MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString(); | |||
| make_tuple_inputs.push_back(anf_node); | |||
| } | |||
| } | |||
| if (new_kernel_graph->get_return() == nullptr) { | |||
| new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); | |||
| } | |||
| MS_LOG(INFO) << "end"; | |||
| return new_kernel_graph; | |||
| } | |||
| void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||
| MS_LOG(INFO) << "start,graph_id:" << graph->graph_id(); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto apply_list = GetCNodes(TopoSort(graph->get_return())); | |||
| // update the root graph child graph order | |||
| graph->UpdateChildGraphOrder(); | |||
| // get child list from current graph | |||
| std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list); | |||
| auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr { | |||
| if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { | |||
| return child_graph_list[0]; | |||
| } | |||
| // create new child graph | |||
| auto child_graph = NewKernelGraph(); | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| // create new value node to bind child graph | |||
| auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph)); | |||
| std::vector<AnfNodePtr> new_call_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())), | |||
| graph_value_node}; | |||
| // set the graph id of all node of child graph | |||
| for (auto &child_graph_node : child_graph_list) { | |||
| AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); | |||
| } | |||
| SplitKernelGraph(child_graph, child_graph_list); | |||
| auto new_call = graph->NewCNode(new_call_input); | |||
| AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call); | |||
| return new_call; | |||
| }; | |||
| if (child_graph_lists.size() > 1) { | |||
| for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { | |||
| auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]); | |||
| if (call_index == 0) { | |||
| auto new_return_primitive = | |||
| graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))); | |||
| graph->set_return(graph->NewCNode({new_return_primitive, call_node})); | |||
| continue; | |||
| } | |||
| InsertDependToGraph(graph->graph_id(), call_node); | |||
| } | |||
| } | |||
| graph->UpdateChildGraphOrder(); | |||
| UpdateRealInput(graph.get()); | |||
| auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id())); | |||
| DumpIR(graph_name, graph); | |||
| MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end"; | |||
| // recurse to split child graph | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| SplitGraph(child_graph); | |||
| } | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -95,13 +95,16 @@ class AscendSession : public SessionBasic { | |||
| void SetFinalGraphOutput(const ValuePtr &value); | |||
| void SetFinalGraphOutput(const VectorRef &vec_output); | |||
| void SplitSwitch(KernelGraph *graph) {} | |||
| void SplitGraph(const KernelGraphPtr &graph); | |||
| void LinkChildGraphs(KernelGraph *graph) {} | |||
| void IRFusion(const KernelGraphPtr &graph) {} | |||
| void SelectKernelGraphKernel(const KernelGraph &graph) {} | |||
| void ConvertPredictModel(const KernelGraphPtr graph) {} | |||
| void HardwareOptimizeGraphs(const KernelGraphPtr graph) {} | |||
| void RootGraphExecutorValidate(KernelGraph *graph) {} | |||
| void RecurseUpdateAllChildGraohOrder(KernelGraph *root_graph); | |||
| KernelGraphPtr SplitKernelGraph(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list); | |||
| void ChildGraphCommunicationDecrease(std::vector<std::vector<AnfNodePtr>> *anf_node_lists); | |||
| // merge execution order list of child graphs | |||
| void MergeGraphExecOrder(); | |||
| @@ -16,9 +16,8 @@ | |||
| #include "session/kernel_graph.h" | |||
| #include <algorithm> | |||
| #include <queue> | |||
| #include <stack> | |||
| #include <unordered_set> | |||
| #include "common/utils.h" | |||
| #include <set> | |||
| #include "operator/ops.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| @@ -311,9 +310,10 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { | |||
| // create kernel_build_info for new value node | |||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| // set the format of value_node to DEFAULT_FORMAT | |||
| kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT}); | |||
| auto output_tensor_num = AnfAlgo::GetOutputTensorNum(value_node); | |||
| kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT)); | |||
| // set value node initial device data type = infer data type | |||
| std::vector<TypeId> types = std::vector<TypeId>(AnfAlgo::GetOutputTensorNum(value_node), kTypeUnknown); | |||
| std::vector<TypeId> types = std::vector<TypeId>(output_tensor_num, kTypeUnknown); | |||
| kernel_build_info_builder->SetOutputsDeviceType(types); | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); | |||
| AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); | |||
| @@ -584,7 +584,25 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() { | |||
| } | |||
| } | |||
| void KernelGraph::UpdateChildGraphOrder() {} | |||
| void KernelGraph::UpdateChildGraphOrder() { | |||
| MS_LOG(INFO) << "graph id:" << graph_id_; | |||
| auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name())); | |||
| child_graph_order_.clear(); | |||
| for (auto &call_node : call_nodes) { | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| auto call_child_graphs = AnfAlgo ::GetCallNodeKernelGraph(call_node->cast<CNodePtr>()); | |||
| for (const auto &child_graph : call_child_graphs) { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| if (child_graph != parent_graph()) { | |||
| child_graph->set_parent_graph(shared_from_this()->cast<std::shared_ptr<KernelGraph>>()); | |||
| child_graph_order_.push_back(child_graph); | |||
| } | |||
| } | |||
| } | |||
| for (size_t i = 0; i < child_graph_order_.size(); i++) { | |||
| MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order_[i]->graph_id() << "]"; | |||
| } | |||
| } | |||
| std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() { | |||
| std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order; | |||
| @@ -601,5 +619,36 @@ std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() { | |||
| } | |||
| bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); } | |||
| std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const { | |||
| auto anf_list = TopoSort(get_return()); | |||
| std::vector<CNodePtr> result; | |||
| for (const auto &anf : anf_list) { | |||
| if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { | |||
| result.push_back(anf->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| if (real_inputs_.find(parameter) == real_inputs_.end()) { | |||
| return {}; | |||
| } | |||
| return real_inputs_[parameter]; | |||
| } | |||
| void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| if (real_inputs_.find(parameter) == real_inputs_.end()) { | |||
| real_inputs_[parameter] = std::set<AnfNodePtr>(); | |||
| } | |||
| auto &args = real_inputs_[parameter]; | |||
| (void)args.insert(arg); | |||
| } | |||
| std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -23,6 +23,7 @@ | |||
| #include <string> | |||
| #include <queue> | |||
| #include <map> | |||
| #include <set> | |||
| #include <unordered_set> | |||
| #include "ir/func_graph.h" | |||
| #include "ir/anf.h" | |||
| @@ -113,6 +114,17 @@ class KernelGraph : public FuncGraph { | |||
| } | |||
| // get input_tensors pointer of control parameter | |||
| std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors() const { return input_ctrl_tensors_; } | |||
| // get parent kernel graph | |||
| std::shared_ptr<KernelGraph> parent_graph() const { return parent_graph_; } | |||
| // set parent kernel graph | |||
| void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; } | |||
| // find anf node in graph | |||
| std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const; | |||
| // get real inputs | |||
| std::set<AnfNodePtr> GetRealInput(const AnfNodePtr ¶meter); | |||
| void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); | |||
| // used to dump ir | |||
| std::string ToString() const override; | |||
| private: | |||
| // remove value node form graph | |||
| @@ -158,6 +170,10 @@ class KernelGraph : public FuncGraph { | |||
| std::vector<std::shared_ptr<KernelGraph>> child_graph_order_; | |||
| // input_tensors of control parameter | |||
| std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_; | |||
| // parameter graph | |||
| std::shared_ptr<KernelGraph> parent_graph_; | |||
| // record real parameters,inputs_ is the formal parameters | |||
| std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_; | |||
| }; | |||
| } // namespace session | |||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | |||
| @@ -247,27 +247,6 @@ std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool va | |||
| return parameters; | |||
| } | |||
| AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| if (!anf->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode"; | |||
| } | |||
| MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; | |||
| auto parameters = CreateParameterFromTuple(anf, valid_input, graph); | |||
| if (parameters.empty()) { | |||
| MS_LOG(EXCEPTION) << "No parameter exist!!"; | |||
| } | |||
| if (parameters.size() == 1) { | |||
| return parameters[0]; | |||
| } | |||
| std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); | |||
| auto make_tuple = graph->NewCNode(make_tuple_input); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; | |||
| return make_tuple; | |||
| } | |||
| size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) { | |||
| MS_LOG(INFO) << "Load kInputCtrlTensors"; | |||
| auto inputs_params = graph->input_ctrl_tensors(); | |||
| @@ -390,6 +369,24 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf | |||
| return new_parameter; | |||
| } | |||
| AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; | |||
| auto parameters = CreateParameterFromTuple(anf, valid_input, graph); | |||
| if (parameters.empty()) { | |||
| MS_LOG(EXCEPTION) << "No parameter exist!!"; | |||
| } | |||
| if (parameters.size() == 1) { | |||
| return parameters[0]; | |||
| } | |||
| std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); | |||
| auto make_tuple = graph->NewCNode(make_tuple_input); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; | |||
| return make_tuple; | |||
| } | |||
| CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, | |||
| bool *from_other_graph, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) { | |||
| @@ -454,7 +451,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||
| MS_EXCEPTION_IF_NULL(attr_input); | |||
| if (IsValueNode<FuncGraph>(attr_input)) { | |||
| // create primitive of cnode:call | |||
| cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))}; | |||
| cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; | |||
| // create a ValueNode<KernelGraph> as input of cnode:call | |||
| if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { | |||
| cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input)); | |||
| @@ -466,12 +463,10 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||
| } | |||
| } else if (attr_input->isa<CNode>()) { | |||
| // create primitive of cnode:call(switch) | |||
| cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))}; | |||
| cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; | |||
| if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { | |||
| auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); | |||
| auto prim = GetCNodePrimitive(cnode_input); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->name() != kSwitchOpName) { | |||
| if (!AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { | |||
| MS_LOG(EXCEPTION) << "CNode input[0] must be switch."; | |||
| } | |||
| cnode_inputs.emplace_back(cnode_input); | |||
| @@ -484,7 +479,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||
| auto prim = AnfAlgo::GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| // push attr to inputs[0] of new cnode | |||
| cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))}; | |||
| cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))}; | |||
| } | |||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | |||
| @@ -545,7 +540,6 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker | |||
| AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get()); | |||
| graph->FrontBackendlMapAdd(anf, new_value_node); | |||
| graph->AddValueNodeToGraph(new_value_node); | |||
| return new_value_node; | |||
| } | |||
| @@ -555,11 +549,11 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph | |||
| if (!anf->isa<Parameter>()) { | |||
| MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; | |||
| } | |||
| auto graph_inputs = graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); | |||
| auto new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); | |||
| TraceManager::EndTrace(); | |||
| graph_inputs->push_back(new_parameter); | |||
| graph->FrontBackendlMapAdd(anf, new_parameter); | |||
| @@ -114,6 +114,7 @@ class SessionBasic { | |||
| ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | |||
| ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); | |||
| ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); | |||
| AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); | |||
| std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | |||
| std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | |||