From: @liangzelang Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -45,6 +45,16 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) { | |||
| bool change = (new_node != nullptr); | |||
| if (new_node != nullptr && new_node != node) { | |||
| (void)manager->Replace(node, new_node); | |||
| // if replaced node is end_goto, refresh relative params in kernel graph | |||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||
| if (kernel_graph != nullptr && node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto end_label = kernel_graph->get_end_goto(); | |||
| if (cnode == end_label && AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) { | |||
| kernel_graph->set_end_goto(new_node->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| (void)seen_node.erase(node); | |||
| } else if (new_node == nullptr) { | |||
| new_node = node; | |||
| @@ -739,8 +739,14 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr | |||
| << from_graph->ToString(); | |||
| } | |||
| // insert assign between jump_node -1 and jump_node | |||
| if (jump_node_iter != from_graph_exe_order.begin()) { | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); | |||
| while (jump_node_iter != from_graph_exe_order.begin()) { | |||
| CNodePtr node = *(jump_node_iter - 1); | |||
| if (AnfAlgo::GetGraphId(node.get()) == from_graph->graph_id()) { | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); | |||
| break; | |||
| } else { | |||
| jump_node_iter--; | |||
| } | |||
| } | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); | |||
| } | |||
| @@ -23,6 +23,7 @@ | |||
| #include <utility> | |||
| #include <functional> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "backend/session/kernel_graph.h" | |||
| #include "base/base_ref.h" | |||
| #include "utils/contract.h" | |||
| @@ -64,7 +64,7 @@ | |||
| #include "ps/util.h" | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| static constexpr uint32_t kLabelSwitchLabelId = 2; | |||
| namespace mindspore { | |||
| namespace session { | |||
| const size_t kInvalidIndex = SIZE_MAX; | |||
| @@ -485,6 +485,8 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||
| memo.clear(); | |||
| // insert goto labels and label_sets | |||
| LinkChildGraphs(NOT_NULL(root_graph)); | |||
| // replace labelgoto with labelswitch in subgraph called multiple times | |||
| MultiCallGraphOptimize(NOT_NULL(root_graph)); | |||
| // resource initialize | |||
| InitRuntimeResource(); | |||
| @@ -667,6 +669,10 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tens | |||
| MS_LOG(INFO) << "No child graph has anf output"; | |||
| return; | |||
| } | |||
| // load data to extra params | |||
| std::set<KernelGraphPtr> memo; | |||
| SyncDataToExtraParams(NOT_NULL(kernel_graph), NOT_NULL(&memo)); | |||
| memo.clear(); | |||
| // load input data from user input | |||
| LoadInputData(kernel_graph, inputs); | |||
| if (debugger_) { | |||
| @@ -1190,6 +1196,110 @@ void AscendSession::BackendOptimization(const std::vector<KernelGraphPtr> &all_g | |||
| void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); } | |||
| bool AscendSession::IsMultiCallGraph(NotNull<KernelGraphPtr> graph, std::vector<GraphId> parent_graphs) { | |||
| std::stack<GraphId> post_graph; | |||
| std::set<GraphId> memo; | |||
| post_graph.push(graph->graph_id()); | |||
| while (!post_graph.empty()) { | |||
| auto graph_id = post_graph.top(); | |||
| post_graph.pop(); | |||
| memo.insert(graph_id); | |||
| for (auto child_graph : graphs_[graph_id]->child_graph_order()) { | |||
| std::shared_ptr<KernelGraph> child_graph_ptr = child_graph.lock(); | |||
| MS_EXCEPTION_IF_NULL(child_graph_ptr); | |||
| if (std::find(parent_graphs.begin(), parent_graphs.end(), child_graph_ptr->graph_id()) != parent_graphs.end()) { | |||
| MS_LOG(DEBUG) << "graph:" << graph->graph_id() << " will call its parent graph:" << child_graph_ptr->graph_id(); | |||
| return false; | |||
| } else if (memo.find(child_graph_ptr->graph_id()) == memo.end()) { | |||
| MS_LOG(DEBUG) << "child graph:" << child_graph_ptr->graph_id() << " into deque, wait for check."; | |||
| post_graph.push(child_graph_ptr->graph_id()); | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) { | |||
| for (auto current : parent_graphs_) { | |||
| if (current.second.size() < 2) { | |||
| continue; | |||
| } | |||
| auto graph = graphs_[current.first]; | |||
| auto parent_kernel_graphs = current.second; | |||
| if (!IsMultiCallGraph(NOT_NULL(graph), parent_kernel_graphs)) { | |||
| MS_LOG(DEBUG) << "graph:" << graph->graph_id() << " with it's parent graphs make up a cycle"; | |||
| continue; | |||
| } | |||
| MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs"; | |||
| int32_t index = 0; | |||
| std::vector<KernelGraphPtr> child_graphs; | |||
| auto start_label = graph->get_start_label(); | |||
| auto end_node = graph->get_end_goto(); | |||
| ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0); | |||
| std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), | |||
| post_label_param}; | |||
| for (auto graph_id : parent_kernel_graphs) { | |||
| auto kg = graphs_[graph_id]; | |||
| auto nodes = kg->execution_order(); | |||
| for (uint32_t i = 0; i < nodes.size(); i++) { | |||
| if (AnfAlgo::GetCNodeName(nodes[i]) == kLabelGotoOpName && | |||
| (AnfAlgo::GetNodeAttr<uint32_t>(nodes[i], kAttrLabelIndex) == | |||
| AnfAlgo::GetNodeAttr<uint32_t>(start_label, kAttrLabelIndex))) { | |||
| if (i < (nodes.size() - 1)) { | |||
| new_inputs.push_back(nodes[i + 1]); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "No labelset after labelgoto"; | |||
| } | |||
| ParameterPtr pre_label_param = kg->AddExtraParamAndTensor("label_param", index++); | |||
| AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(kg), nodes[i], NOT_NULL(pre_label_param), | |||
| NOT_NULL(post_label_param)); | |||
| } | |||
| } | |||
| kg->SetExecOrderByDefault(); | |||
| child_graphs.push_back(kg); | |||
| } | |||
| end_node->set_inputs(new_inputs); | |||
| AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), end_node); | |||
| std::vector<uint32_t> label_list; | |||
| for (size_t i = kLabelSwitchLabelId; i < end_node->size(); ++i) { | |||
| auto input = end_node->input(i); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if (!input->isa<CNode>() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) { | |||
| break; | |||
| } | |||
| uint32_t goto_label_id = AnfAlgo::GetNodeAttr<uint32_t>(input, kAttrLabelIndex); | |||
| label_list.push_back(goto_label_id); | |||
| MS_LOG(INFO) << "Switch " << end_node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " | |||
| << goto_label_id; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), end_node); | |||
| end_node->set_inputs({end_node->input(kAnfPrimitiveIndex), end_node->input(kFirstDataInputIndex)}); | |||
| graph->SetExecOrderByDefault(); | |||
| } | |||
| } | |||
| void AscendSession::SyncDataToExtraParams(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| if (memo->find(graph.get()) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(graph.get()); | |||
| auto extra_param_tensor = graph->GetExtraParamAndTensor(); | |||
| for (uint32_t i = 0; i < extra_param_tensor.size(); i++) { | |||
| auto param = extra_param_tensor[i].first; | |||
| auto tensor = extra_param_tensor[i].second; | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(param, 0); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| tensor->set_device_address(device_address); | |||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(param, 0), LongToSize(tensor->data().nbytes()), | |||
| tensor->data_type(), tensor->data_c())) { | |||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||
| } | |||
| } | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| SyncDataToExtraParams(NOT_NULL(child_graph.lock()), memo); | |||
| } | |||
| } | |||
| void AscendSession::RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph) { | |||
| AscendControlParser::ExecutorValidate(graph); | |||
| } | |||
| @@ -93,6 +93,10 @@ class AscendSession : public SessionBasic { | |||
| static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs); | |||
| static void LinkChildGraphs(NotNull<KernelGraphPtr> graph); | |||
| // replace labelgoto with labelswitch in subgraph called multiple times | |||
| void MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph); | |||
| bool IsMultiCallGraph(NotNull<KernelGraphPtr> graph, std::vector<GraphId> parent_graphs); | |||
| void SyncDataToExtraParams(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo); | |||
| void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph); | |||
| // merge execution order list of child graphs | |||
| void MergeGraphExecOrder(); | |||
| @@ -1213,6 +1213,33 @@ void KernelGraph::RemoveNodeFromGraph(const AnfNodePtr &node) { | |||
| } | |||
| } | |||
| ParameterPtr KernelGraph::AddExtraParamAndTensor(std::string param_name, int32_t value) { | |||
| ParameterPtr param; | |||
| ShapeVector shp = {1}; | |||
| tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||
| mindspore::abstract::AbstractBasePtr paremeter_abstract_ptr = tensor_ptr->ToAbstract(); | |||
| ParameterPtr new_param = std::make_shared<Parameter>(shared_from_this()->cast<KernelGraphPtr>()); | |||
| MS_EXCEPTION_IF_NULL(new_param); | |||
| new_param->set_name(param_name); | |||
| new_param->set_abstract(paremeter_abstract_ptr); | |||
| param = NewParameter(new_param); | |||
| // ensure alloc mem for this param | |||
| std::vector<AnfNodePtr> *mute_inputs = MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(mute_inputs); | |||
| mute_inputs->push_back(param); | |||
| tensor::TensorPtr data_tensor_ptr = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp); | |||
| MS_EXCEPTION_IF_NULL(data_tensor_ptr); | |||
| int32_t *val = nullptr; | |||
| val = static_cast<int32_t *>(data_tensor_ptr->data_c()); | |||
| *val = value; | |||
| extra_param_tensor_.push_back(std::make_pair(param, data_tensor_ptr)); | |||
| MS_LOG(INFO) << "Create new param: " << param->DebugString(); | |||
| return param; | |||
| } | |||
| void KernelGraph::UpdateGraphDynamicAttr() { | |||
| for (const auto &cnode : execution_order_) { | |||
| if (AnfAlgo::IsDynamicShape(cnode)) { | |||
| @@ -44,6 +44,7 @@ class KernelGraph : public FuncGraph { | |||
| executable_ = true; | |||
| summary_node_exist_ = false; | |||
| stream_distinction_label_ = kInvalidDistincLabel; | |||
| extra_param_tensor_ = {}; | |||
| } | |||
| KernelGraph(const KernelGraph &graph) : FuncGraph(graph) { | |||
| @@ -87,6 +88,7 @@ class KernelGraph : public FuncGraph { | |||
| first_step_ = graph.first_step_; | |||
| has_optimizer_ = graph.has_optimizer_; | |||
| is_dynamic_shape_ = graph.is_dynamic_shape_; | |||
| extra_param_tensor_ = graph.extra_param_tensor_; | |||
| } | |||
| ~KernelGraph() override; | |||
| @@ -220,7 +222,9 @@ class KernelGraph : public FuncGraph { | |||
| } | |||
| } | |||
| void RemoveNodeFromGraph(const AnfNodePtr &node); | |||
| // Add Param which pass callback point | |||
| ParameterPtr AddExtraParamAndTensor(std::string param_name, int32_t value); | |||
| const std::vector<std::pair<ParameterPtr, tensor::TensorPtr>> GetExtraParamAndTensor() { return extra_param_tensor_; } | |||
| void UpdateGraphDynamicAttr(); | |||
| bool is_dynamic_shape() const { return is_dynamic_shape_; } | |||
| void SetOptimizerFlag(); | |||
| @@ -302,6 +306,8 @@ class KernelGraph : public FuncGraph { | |||
| std::shared_ptr<std::vector<AnfNodePtr>> inputs_; | |||
| std::vector<AnfNodePtr> child_graph_result_; | |||
| std::vector<CNodePtr> execution_order_; | |||
| // extra params and tensors for control flow | |||
| std::vector<std::pair<ParameterPtr, tensor::TensorPtr>> extra_param_tensor_; | |||
| uint32_t graph_id_; | |||
| uint32_t stream_distinction_label_; | |||
| @@ -1012,6 +1012,12 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||
| (void)ConstructKernelGraph(child_graph, all_out_graph); | |||
| } | |||
| (void)CreateValueNodeKernelGraph(node, graph.get()); | |||
| auto &parent_graph = parent_graphs_[front_backend_graph_map_[child_graph]->graph_id()]; | |||
| auto parent_graph_it = | |||
| std::find(parent_graph.begin(), parent_graph.end(), front_backend_graph_map_[func_graph]->graph_id()); | |||
| if (parent_graph_it == parent_graph.end()) { | |||
| parent_graph.push_back(front_backend_graph_map_[func_graph]->graph_id()); | |||
| } | |||
| continue; | |||
| } | |||
| // Create cnode | |||
| @@ -1096,10 +1102,10 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); | |||
| } | |||
| auto &input_nodes = kernel_graph->input_nodes(); | |||
| if ((inputs.size() + input_ctrl_size) - 3 != input_nodes.size()) { | |||
| auto extra_param_size = kernel_graph->GetExtraParamAndTensor().size(); | |||
| if ((inputs.size() + input_ctrl_size) - 3 != input_nodes.size() - extra_param_size) { | |||
| MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() | |||
| << ", input_ctrl_size:" << input_ctrl_size; | |||
| << ", input_ctrl_size:" << input_ctrl_size << ", extra_param_size:" << extra_param_size; | |||
| } | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| @@ -202,6 +202,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | |||
| std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | |||
| std::unordered_map<FuncGraphPtr, KernelGraphPtr> front_backend_graph_map_; | |||
| std::unordered_map<GraphId, std::vector<GraphId>> parent_graphs_; | |||
| std::shared_ptr<Context> context_; | |||
| CallBackFunc summary_callback_; | |||
| static GraphId graph_sum_; | |||