| @@ -29,6 +29,7 @@ | |||
| #include "hccl/hcom.h" | |||
| #include "common/trans.h" | |||
| #include "runtime/context.h" | |||
| #include "device/ascend/ascend_label_assign.h" | |||
| #include "device/ascend/ascend_stream_assign.h" | |||
| #include "device/ascend/ascend_memory_pool.h" | |||
| #include "framework/ge_runtime/model_runner.h" | |||
| @@ -281,21 +282,24 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { | |||
| return true; | |||
| } | |||
| AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); | |||
| AscendStreamAssign &stream_assign_instance = AscendStreamAssign::GetInstance(); | |||
| AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); | |||
| // the streams' flag not HEAD_STREAM | |||
| std::vector<uint32_t> wait_active_stream_list; | |||
| assign_instance.GetWaitStreams(&wait_active_stream_list); | |||
| auto force_copy_stream_list = assign_instance.hcom_streams(); | |||
| stream_assign_instance.GetWaitStreams(&wait_active_stream_list); | |||
| auto force_copy_stream_list = stream_assign_instance.hcom_streams(); | |||
| MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum() | |||
| << ", total event num:" << assign_instance.total_event_num() | |||
| MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_assign_instance.GetTotalStreamNum() | |||
| << ", total event num:" << stream_assign_instance.total_event_num() | |||
| << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) | |||
| << ", wait_active_stream_list size:" << wait_active_stream_list.size() | |||
| << ", force_copy_stream_list size:" << force_copy_stream_list.size(); | |||
| std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list; | |||
| std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>( | |||
| task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, | |||
| 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.total_event_num(), 0); | |||
| 0, 0, 0, 0, 0, stream_assign_instance.GetTotalStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), | |||
| stream_assign_instance.total_event_num(), 0); | |||
| auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); | |||
| if (!ret.second) { | |||
| @@ -15,6 +15,8 @@ | |||
| */ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <set> | |||
| #include "device/ascend/ascend_label_assign.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| @@ -36,6 +38,7 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) { | |||
| uint32_t goto_label_id = GetValue<uint32_t>(value); | |||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get()); | |||
| MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; | |||
| node->set_inputs({node->input(0)}); | |||
| } | |||
| static void UpdateLabelSwitch(NotNull<CNodePtr> node) { | |||
| @@ -58,29 +61,93 @@ static void UpdateLabelSwitch(NotNull<CNodePtr> node) { | |||
| MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get()); | |||
| node->set_inputs({node->input(0), node->input(1)}); | |||
| } | |||
| void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph) { | |||
| auto cnode_list = graph->execution_order(); | |||
| // 1 assign label id to label_set | |||
| uint32_t cur_label_id = 0; | |||
| for (auto &node : cnode_list) { | |||
| if (AnfAlgo::GetCNodeName(node) == kLabelSetOpName) { | |||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(cur_label_id), node); | |||
| MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << cur_label_id; | |||
| ++cur_label_id; | |||
| static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>> graph, NotNull<uint32_t *> label_id, | |||
| NotNull<std::set<std::shared_ptr<session::KernelGraph>> *> memo) { | |||
| if (memo->find(graph.get()) != memo->end()) { | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Assign label for " << graph->ToString(); | |||
| auto nodes = TopoSort(graph->get_return()); | |||
| for (auto &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::string node_name = AnfAlgo::GetCNodeName(node); | |||
| if (node_name == kLabelSetOpName && !AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { | |||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(*label_id), node); | |||
| MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << *label_id; | |||
| ++(*label_id); | |||
| } | |||
| } | |||
| // 2 update label_switch / label_goto | |||
| for (auto &node : cnode_list) { | |||
| if (AnfAlgo::GetCNodeName(node) == kLabelGotoOpName) { | |||
| UpdateLabelGoto(NOT_NULL(node)); | |||
| for (auto &cg : graph->child_graph_order()) { | |||
| AssignLabelForLabelSet(NOT_NULL(cg), label_id, memo); | |||
| } | |||
| } | |||
| static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGraph>> graph, | |||
| NotNull<std::set<std::shared_ptr<session::KernelGraph>> *> memo) { | |||
| if (memo->find(graph.get()) != memo->end()) { | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); | |||
| auto nodes = TopoSort(graph->get_return()); | |||
| for (auto &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetCNodeName(node) == kLabelSwitchOpName) { | |||
| UpdateLabelSwitch(NOT_NULL(node)); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::string node_name = AnfAlgo::GetCNodeName(node); | |||
| if (node_name == kLabelGotoOpName) { | |||
| UpdateLabelGoto(NOT_NULL(cnode)); | |||
| cnode->set_abstract(nullptr); | |||
| } | |||
| if (node_name == kLabelSwitchOpName) { | |||
| UpdateLabelSwitch(NOT_NULL(cnode)); | |||
| } | |||
| } | |||
| for (auto &cg : graph->child_graph_order()) { | |||
| AssignLabelForGotoSwitch(NOT_NULL(cg), memo); | |||
| } | |||
| } | |||
| void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) { | |||
| MS_LOG(INFO) << "Assign label start."; | |||
| std::set<std::shared_ptr<session::KernelGraph>> memo; | |||
| uint32_t label_id = 0; | |||
| AssignLabelForLabelSet(graph, NOT_NULL(&label_id), NOT_NULL(&memo)); | |||
| memo.clear(); | |||
| { | |||
| std::lock_guard<std::mutex> lock(label_num_mutex_); | |||
| label_num_[graph.get().get()] = label_id; | |||
| } | |||
| AssignLabelForGotoSwitch(graph, NOT_NULL(&memo)); | |||
| MS_LOG(INFO) << "Assign label end."; | |||
| } | |||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { | |||
| std::lock_guard<std::mutex> lock(label_num_mutex_); | |||
| auto iter = label_num_.find(graph.get()); | |||
| if (iter == label_num_.end()) { | |||
| MS_LOG(WARNING) << "Graph " << graph->ToString() << " has not assigned label."; | |||
| return 1; | |||
| } | |||
| return iter->second; | |||
| } | |||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { | |||
| return GetLabelNum(NOT_NULL(graph.get().get())); | |||
| } | |||
| } // namespace ascend | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ | |||
| #include <memory> | |||
| #include <map> | |||
| #include "session/kernel_graph.h" | |||
| #include "utils/contract.h" | |||
| @@ -35,11 +36,16 @@ class AscendLabelAssign { | |||
| AscendLabelAssign(const AscendLabelAssign &) = delete; | |||
| AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; | |||
| void AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph); | |||
| void AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph); | |||
| uint32_t GetLabelNum(NotNull<const session::KernelGraph *> graph); | |||
| uint32_t GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph); | |||
| private: | |||
| AscendLabelAssign() = default; | |||
| ~AscendLabelAssign() = default; | |||
| std::map<const session::KernelGraph *, uint32_t> label_num_; | |||
| std::mutex label_num_mutex_; | |||
| }; | |||
| } // namespace ascend | |||
| } // namespace device | |||
| @@ -17,6 +17,7 @@ | |||
| #include "kernel/rts/label_switch.h" | |||
| #include <asm-generic/param.h> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "runtime/stream.h" | |||
| #include "framework/ge_runtime/task_info.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| @@ -66,13 +67,33 @@ std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr | |||
| MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id; | |||
| std::vector<TaskInfoPtr> task_info_list; | |||
| cond_ = inputs[0]->addr; | |||
| // std::shared_ptr<LabelSwitchTaskInfo> task_info_ptr = | |||
| // std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, &label_list_, cond_); | |||
| // need updata ge task info define | |||
| std::shared_ptr<LabelSwitchTaskInfo> task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_); | |||
| // todo: need update ge task info define | |||
| auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, 0); | |||
| // auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, label_list_, cond_); | |||
| MS_EXCEPTION_IF_NULL(task_info_ptr); | |||
| task_info_list.emplace_back(task_info_ptr); | |||
| return task_info_list; | |||
| } | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernelInfo() { | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> label_switch_build_info{}; | |||
| vector<string> input_format{kOpFormat_DEFAULT, kOpFormat_DEFAULT}; | |||
| vector<TypeId> input_type{kNumberTypeUInt32, kNumberTypeBool}; | |||
| if (input_format.size() != input_type.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size " | |||
| << input_type.size(); | |||
| } | |||
| for (size_t i = 0; i < input_format.size(); ++i) { | |||
| auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); | |||
| builder.SetInputsFormat({input_format[i]}); | |||
| builder.SetInputsDeviceType({input_type[i]}); | |||
| builder.SetProcessor(AICORE); | |||
| builder.SetKernelType(RT_KERNEL); | |||
| builder.SetFusionType(OPAQUE); | |||
| label_switch_build_info.emplace_back(builder.Build()); | |||
| } | |||
| return label_switch_build_info; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -42,6 +42,14 @@ class LabelSwitchKernel : public RtKernel { | |||
| void *cond_; | |||
| }; | |||
| class LabelSwitchDesc : public RtKerDesc { | |||
| public: | |||
| LabelSwitchDesc() = default; | |||
| ~LabelSwitchDesc() override = default; | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetKernelInfo() override; | |||
| }; | |||
| MS_REG_RTKERNEL_DESC(labelswitch, LabelSwitchDesc); | |||
| MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -44,6 +44,12 @@ RtKerDescFactory &RtKerDescFactory::Get() { | |||
| return _this; | |||
| } | |||
| static bool IsDefaultKernelInfo(const std::string &name) { | |||
| static const std::set<std::string> white_list = {kStreamSwitchOpName, kStreamActiveOpName, kLabelSetOpName, | |||
| kLabelGotoOpName}; | |||
| return white_list.find(name) != white_list.end(); | |||
| } | |||
| void GetRtKelInfo(const CNodePtr &kernel_node, | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| @@ -58,7 +64,7 @@ void GetRtKelInfo(const CNodePtr &kernel_node, | |||
| } | |||
| // if can't find kernel info in kernel info database, use the default kernel info | |||
| auto node_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (node_name == "StreamSwitch" || node_name == "StreamActive") { | |||
| if (IsDefaultKernelInfo(node_name)) { | |||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| // set input infos | |||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| @@ -331,12 +331,14 @@ bool ExecuteAction(const ResourcePtr &res) { | |||
| } | |||
| auto graph_id = res->results()[kOutput].cast<GraphId>(); | |||
| auto bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::MsBackend>>(); | |||
| std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>(); | |||
| std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr); | |||
| MS_EXCEPTION_IF_NULL(msbc_ptr); | |||
| compile::VmEvalFuncPtr run = | |||
| std::make_shared<compile::VmEvalFunc>([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef { | |||
| MS_LOG(INFO) << "Execute args size" << args.size(); | |||
| auto outs = bc_ptr->RunGraph(graph_id, args); | |||
| MS_LOG(DEBUG) << "out size" << outs.size(); | |||
| std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef { | |||
| MS_LOG(INFO) << "Execute args size " << args.size(); | |||
| auto outs = msbc_ptr->RunGraph(graph_id, args); | |||
| MS_LOG(DEBUG) << "out size " << outs.size(); | |||
| return outs[0]; | |||
| }); | |||
| res->results()[kOutput] = run; | |||
| @@ -6,22 +6,23 @@ file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ) | |||
| if (ENABLE_GPU) | |||
| file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "gpu_session.cc" | |||
| ) | |||
| list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) | |||
| endif () | |||
| if (ENABLE_CPU) | |||
| file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "cpu_session.cc" | |||
| ) | |||
| list(APPEND _SESSION_SRC_LIST ${_CPU_SRC_LIST}) | |||
| endif () | |||
| if (ENABLE_D) | |||
| file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "ascend_session.cc" | |||
| "ascend_control_parser.cc" | |||
| ) | |||
| list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) | |||
| endif () | |||
| @@ -0,0 +1,319 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "session/ascend_control_parser.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| static VectorRef GetCallArgs(std::vector<AnfNodePtr>::iterator iter_begin, std::vector<AnfNodePtr>::iterator iter_end) { | |||
| VectorRef call_args; | |||
| for (auto iter = iter_begin; iter != iter_end; ++iter) { | |||
| if (utils::isa<ValueNode>(*iter)) { | |||
| call_args.push_back(GetValueNode(*iter)); | |||
| } else { | |||
| call_args.push_back(*iter); | |||
| } | |||
| } | |||
| return call_args; | |||
| } | |||
| void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { | |||
| std::set<KernelGraphPtr> memo; | |||
| ProcessKernelGraph(kg, nullptr, nullptr, {}, NOT_NULL(&memo)); | |||
| } | |||
| NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | |||
| const CNodePtr &last_label, const VectorRef &args, | |||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); | |||
| // 0. recursive condition | |||
| if (memo->find(kg) != memo->end()) { | |||
| MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString(); | |||
| return NOT_NULL(kg->get_start_label()); | |||
| } | |||
| // 2. args replace placeholder | |||
| LinkParentGraph(kg, last_node, last_label, args); | |||
| // 3. topological sort | |||
| std::vector<CNodePtr> nodes = GetCNodes(TopoSort(kg->get_return())); | |||
| if (nodes.empty()) { | |||
| MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!"; | |||
| } | |||
| // 4. insert first_label | |||
| auto start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); | |||
| for (auto node : nodes) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { | |||
| InsertControlDependToGraph(kg, NOT_NULL(start_label), NOT_NULL(node)); | |||
| break; | |||
| } | |||
| } | |||
| kg->set_start_label(start_label); | |||
| // 5. traverse | |||
| for (size_t i = 0; i < nodes.size(); ++i) { | |||
| auto &cnode = nodes[i]; | |||
| if (cnode->size() < kCNodePrim + 1) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||
| } | |||
| AnfNodePtr fn = cnode->input(kCNodePrim); | |||
| if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { | |||
| MS_LOG(DEBUG) << "continue node " << cnode->DebugString(); | |||
| continue; | |||
| } | |||
| AnfNodePtr arg = cnode->input(kCNodeCallArg); | |||
| if (IsValueNode<KernelGraph>(arg)) { | |||
| RecurseCall(kg, NOT_NULL(cnode), (i + 1 < nodes.size() ? nodes[i + 1] : nullptr), memo); | |||
| } else if (!arg->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); | |||
| } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) { | |||
| auto arg_cnode = arg->cast<CNodePtr>(); | |||
| cnode->set_inputs(cnode->inputs()); | |||
| RecurseSwitch(kg, NOT_NULL(cnode), memo); | |||
| } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) { | |||
| auto arg_cnode = arg->cast<CNodePtr>(); | |||
| cnode->set_inputs(cnode->inputs()); | |||
| RecurseSwitchLayer(kg, NOT_NULL(cnode), memo); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); | |||
| return NOT_NULL(start_label); | |||
| } | |||
| std::vector<CNodePtr> AscendControlParser::GetCNodes(const std::vector<AnfNodePtr> &in) { | |||
| std::vector<CNodePtr> out; | |||
| for (auto &node : in) { | |||
| if (node->isa<CNode>()) { | |||
| out.push_back(node->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| return out; | |||
| } | |||
| void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node) { | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))}; | |||
| auto return_node = kg->get_return(); | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| inputs.push_back(return_node->input(1)); | |||
| inputs.push_back(attch_node.get()); | |||
| auto depend_node = kg->NewCNode(inputs); | |||
| return_node->set_input(1, depend_node); | |||
| } | |||
| void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node, | |||
| NotNull<AnfNodePtr> second_node) { | |||
| MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() | |||
| << ", the second node is " << second_node->DebugString(); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())), | |||
| first_node, second_node}; | |||
| auto control_depend = kg->NewCNode(inputs); | |||
| InsertDependToGraph(kg, NOT_NULL(control_depend)); | |||
| } | |||
| void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | |||
| const CNodePtr &last_label, const VectorRef &args) { | |||
| if (from_graph_call_node != nullptr) { | |||
| SetSubGraphInput(kg, NOT_NULL(from_graph_call_node), args); | |||
| } | |||
| auto origin_return = kg->get_return(); | |||
| std::vector<AnfNodePtr> origin_return_inputs = origin_return->inputs(); | |||
| // if entry graph, replace return with make_tuple | |||
| if (from_graph_call_node == nullptr || last_label == nullptr) { | |||
| MS_LOG(INFO) << kg->ToString() << " is entry graph."; | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {std::make_shared<ValueNode>(prim::kPrimMakeTuple)}; | |||
| make_tuple_inputs.insert(make_tuple_inputs.end(), origin_return_inputs.begin() + 1, origin_return_inputs.end()); | |||
| auto make_tuple = kg->NewCNode(make_tuple_inputs); | |||
| origin_return->set_inputs({origin_return->input(kCNodePrim), make_tuple}); | |||
| } else { | |||
| // else replace return with label_goto | |||
| auto label_goto = | |||
| kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName)), last_label}); | |||
| InsertDependToGraph(kg, NOT_NULL(label_goto)); | |||
| } | |||
| } | |||
| void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "process call func " << cur_node->DebugString(); | |||
| // 1 get kernel graph | |||
| auto origin_inputs = cur_node->inputs(); | |||
| std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))}; | |||
| auto call_args = GetCallArgs(origin_inputs.begin() + 1, origin_inputs.end()); | |||
| if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) { | |||
| MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; | |||
| return; | |||
| } | |||
| // 2 return label | |||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); | |||
| // 3 add depend relationship | |||
| InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); | |||
| if (next_node != nullptr && next_node != kg->get_return()) { | |||
| InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); | |||
| } | |||
| auto call_kg = GetValueNode<KernelGraphPtr>(origin_inputs[kCNodeCallArg]); | |||
| // 4 modify call op to goto op | |||
| cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]); | |||
| // 5 recurse sub graph | |||
| CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, call_args, memo); | |||
| new_inputs.push_back(sub_label); | |||
| new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end()); | |||
| cur_node->set_inputs(new_inputs); | |||
| cur_node->set_abstract(nullptr); | |||
| MS_LOG(INFO) << "success process call func " << cur_node->DebugString(); | |||
| } | |||
| void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | |||
| if (cur_node->size() < kCNodeSwitchLength) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; | |||
| } | |||
| // 1 return label | |||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(prim::kPrimLabelSet)}); | |||
| // 2 recurse sub graph | |||
| auto origin_switch_inputs = cur_node->inputs(); | |||
| std::vector<AnfNodePtr> new_switch_inputs = { | |||
| std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), | |||
| origin_switch_inputs[kCNodeSwitchCond]}; | |||
| for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { | |||
| // 2.1 branch kernel graph and args | |||
| CNodePtr partial; | |||
| KernelGraphPtr branch_fg; | |||
| VectorRef call_args; | |||
| std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| // 2.2 add depend relationship | |||
| InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); | |||
| // 2.3 recurse sub graph | |||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); | |||
| new_switch_inputs.push_back(branch_label); | |||
| } | |||
| std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); | |||
| new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end()); | |||
| cur_node->set_inputs(new_switch_inputs); | |||
| cur_node->set_abstract(nullptr); | |||
| MS_LOG(INFO) << "success process switch func " << cur_node->DebugString(); | |||
| } | |||
| void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | |||
| if (cur_node->size() < kCNodeSwitchLayerLength) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; | |||
| } | |||
| auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); | |||
| MS_EXCEPTION_IF_NULL(branch_tuple); | |||
| if (!branch_tuple->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; | |||
| } | |||
| auto branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs(); | |||
| // 1 return label | |||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName))}); | |||
| // 2 recurse sub graph | |||
| auto origin_switch_inputs = cur_node->inputs(); | |||
| std::vector<AnfNodePtr> new_switch_inputs = {std::make_shared<ValueNode>(prim::kPrimLabelSwitch), | |||
| origin_switch_inputs[kCNodeSwitchCond]}; | |||
| for (size_t i = 0; i < branch_partial.size(); ++i) { | |||
| // 2.1 branch kernel graph and args | |||
| CNodePtr partial; | |||
| KernelGraphPtr branch_fg; | |||
| VectorRef call_args; | |||
| std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| // 2.2 add depend relationship | |||
| InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); | |||
| // 2.3 recurse sub graph | |||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); | |||
| new_switch_inputs.push_back(branch_label); | |||
| } | |||
| new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); | |||
| cur_node->set_inputs(new_switch_inputs); | |||
| cur_node->set_abstract(nullptr); | |||
| MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString(); | |||
| } | |||
| std::tuple<CNodePtr, KernelGraphPtr, VectorRef> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) { | |||
| if (!node.get()->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); | |||
| } | |||
| // 2.1 branch kernel graph and args | |||
| auto partial_cnode = utils::cast<CNodePtr>(node.get()); | |||
| if (partial_cnode->size() < kCNodePartialLength) { | |||
| MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; | |||
| } | |||
| auto partial_inputs = partial_cnode->inputs(); | |||
| auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]); | |||
| auto call_args = GetCallArgs(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end()); | |||
| return {partial_cnode, branch_kg, call_args}; | |||
| } | |||
| void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, | |||
| NotNull<AnfNodePtr> to) { | |||
| if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && | |||
| AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { | |||
| return; | |||
| } | |||
| if (from.get() == to.get()) { | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " | |||
| << to->DebugString(); | |||
| // config inputs of assign node | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from}; | |||
| // generate a new cnode | |||
| auto assign_node = kg->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(assign_node); | |||
| assign_node->set_abstract(to->abstract()); | |||
| // append the assign at the end of from graph | |||
| InsertDependToGraph(kg, NOT_NULL(assign_node)); | |||
| } | |||
| size_t AscendControlParser::SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node, | |||
| size_t input_index) { | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||
| return input_index + output_num; | |||
| } | |||
| auto &graph_inputs = kg->inputs(); | |||
| if (input_index >= graph_inputs.size()) { | |||
| MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); | |||
| } | |||
| auto backend_parameter = graph_inputs[input_index]; | |||
| if (node.get()->isa<Parameter>()) { | |||
| MS_EXCEPTION_IF_NULL(backend_parameter); | |||
| MS_LOG(INFO) << "Reuse node [" << node->DebugString() << "], old node[" << backend_parameter->DebugString() | |||
| << "] will be replaced."; | |||
| kg->ReplaceNode(backend_parameter, node); | |||
| return input_index; | |||
| } | |||
| InsertAssignToGraph(kg, node, NOT_NULL(backend_parameter)); | |||
| return input_index + 1; | |||
| } | |||
| void AscendControlParser::SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node, | |||
| const VectorRef &args) {} | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H | |||
| #define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H | |||
| #include <set> | |||
| #include <vector> | |||
| #include <tuple> | |||
| #include "session/kernel_graph.h" | |||
| #include "utils/base_ref.h" | |||
| #include "utils/contract.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| class AscendControlParser { | |||
| public: | |||
| static void LinkGraph(NotNull<KernelGraphPtr> kg); | |||
| static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node); | |||
| static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node, | |||
| NotNull<AnfNodePtr> second_node); | |||
| private: | |||
| static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | |||
| const CNodePtr &last_label, const VectorRef &args, | |||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &in); | |||
| static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | |||
| const CNodePtr &last_label, const VectorRef &args); | |||
| static void SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node, | |||
| const VectorRef &args); | |||
| static std::tuple<CNodePtr, KernelGraphPtr, VectorRef> ParsePartial(NotNull<AnfNodePtr> node); | |||
| static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||
| static size_t SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node, size_t input_index); | |||
| static constexpr size_t kCNodePrim = 0; | |||
| static constexpr size_t kCNodeCallArg = 1; | |||
| static constexpr size_t kCNodeSwitchCond = 1; | |||
| static constexpr size_t kCNodeSwitchTrue = 2; | |||
| static constexpr size_t kCNodeSwitchFalse = 3; | |||
| static constexpr size_t kCNodeSwitchLength = 4; | |||
| static constexpr size_t kCNodePartialLength = 2; | |||
| static constexpr size_t kCNodePartialFunc = 1; | |||
| static constexpr size_t kCNodeSwitchLayerCond = 1; | |||
| static constexpr size_t kCNodeSwitchLayerBranch = 2; | |||
| static constexpr size_t kCNodeSwitchLayerLength = 3; | |||
| }; | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H | |||
| @@ -160,14 +160,14 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) { | |||
| std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) { | |||
| std::vector<CNodePtr> cnodes = {}; | |||
| size_t i = 0; | |||
| for (const auto anf : anf_nodes) { | |||
| for (auto anf : anf_nodes) { | |||
| MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString(); | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| if (anf->isa<CNode>()) { | |||
| cnodes.push_back(anf->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| return std::move(cnodes); | |||
| return cnodes; | |||
| } | |||
| std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) { | |||
| @@ -189,7 +189,7 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co | |||
| ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end())); | |||
| } | |||
| } | |||
| return std::move(ret); | |||
| return ret; | |||
| } | |||
| void UpdateRealInput(KernelGraph *graph) { | |||
| @@ -232,7 +232,7 @@ void UpdateRealInput(KernelGraph *graph) { | |||
| auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); | |||
| partial_cnode->set_inputs( | |||
| std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); | |||
| return std::move(ret); | |||
| return ret; | |||
| }; | |||
| bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); | |||
| bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); | |||
| @@ -256,27 +256,28 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| // split switch | |||
| SplitGraph(graph); | |||
| // insert goto labels and label_sets | |||
| LinkChildGraphs(graph.get()); | |||
| LinkChildGraphs(NOT_NULL(graph)); | |||
| // resource initialize | |||
| InitRuntimeResource(); | |||
| // ir fusion | |||
| IRFusion(graph); | |||
| // kernel select | |||
| SelectKernelGraphKernel(*graph); | |||
| // convert model of predict module | |||
| ConvertPredictModel(graph); | |||
| // hardware optimize | |||
| HardwareOptimizeGraphs(graph); | |||
| // assign label | |||
| AssignLabel(NOT_NULL(graph)); | |||
| if (!graph->executable()) { | |||
| return graph->graph_id(); | |||
| } | |||
| for (auto iter : graphs_) { | |||
| if (iter.second == graph) { | |||
| MS_LOG(INFO) << "Entry graph " << graph->ToString() << " graph id " << graph->graph_id(); | |||
| final_graph_id_ = graph->graph_id(); | |||
| } | |||
| MS_LOG(INFO) << "CompileChildGraph " << iter.second->ToString(); | |||
| CompileChildGraph(iter.second); | |||
| } | |||
| // adjust kernel | |||
| AdjustKernel(graph); | |||
| // root graph valiate,include genearte execute order and so on | |||
| RootGraphExecutorValidate(graph.get()); | |||
| // assign stream | |||
| AssignStream(graph); | |||
| // assign label | |||
| AssignLabel(NOT_NULL(graph)); | |||
| // build kernel if node is cnode | |||
| BuildKernel(graph); | |||
| // alloc mem | |||
| MemoryAlloc(graph.get()); | |||
| // task generate | |||
| @@ -556,7 +557,7 @@ void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| MS_LOG(INFO) << "Finish!"; | |||
| } | |||
| void AscendSession::AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const { | |||
| void AscendSession::AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const { | |||
| MS_LOG(INFO) << "Start!"; | |||
| device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph); | |||
| MS_LOG(INFO) << "Finish!"; | |||
| @@ -1305,29 +1306,13 @@ void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived | |||
| } | |||
| void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { | |||
| MS_LOG(INFO) << "Insert depend at the end of graph, the attach node is " << attch_node->DebugString(); | |||
| auto graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))}; | |||
| auto return_node = graph->get_return(); | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| inputs.push_back(return_node->input(1)); | |||
| inputs.push_back(attch_node); | |||
| auto depend_node = graph->NewCNode(inputs); | |||
| return_node->set_input(1, depend_node); | |||
| AscendControlParser::InsertDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(attch_node)); | |||
| } | |||
| void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, | |||
| const AnfNodePtr &second_node) { | |||
| MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() | |||
| << ", the second node is " << second_node->DebugString(); | |||
| auto graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("ControlDepend"))}; | |||
| inputs.push_back(first_node); | |||
| inputs.push_back(second_node); | |||
| auto control_depend = graph->NewCNode(inputs); | |||
| InsertDependToGraph(graph_id, control_depend); | |||
| AscendControlParser::InsertControlDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(first_node), | |||
| NOT_NULL(second_node)); | |||
| } | |||
| size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { | |||
| @@ -1482,5 +1467,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||
| SplitGraph(child_graph); | |||
| } | |||
| } | |||
| void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -28,6 +28,7 @@ | |||
| #include "session/kernel_graph.h" | |||
| #include "kernel/kernel.h" | |||
| #include "session/session_factory.h" | |||
| #include "session/ascend_control_parser.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| @@ -74,7 +75,7 @@ class AscendSession : public SessionBasic { | |||
| void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const; | |||
| void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const; | |||
| void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void MemoryAlloc(KernelGraph *kernel_graph) const; | |||
| void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; | |||
| @@ -96,7 +97,8 @@ class AscendSession : public SessionBasic { | |||
| void SetFinalGraphOutput(const VectorRef &vec_output); | |||
| void SplitGraph(const KernelGraphPtr &graph); | |||
| void LinkChildGraphs(KernelGraph *graph) {} | |||
| void LinkChildGraphs(NotNull<KernelGraphPtr> graph); | |||
| void IRFusion(const KernelGraphPtr &graph) {} | |||
| void SelectKernelGraphKernel(const KernelGraph &graph) {} | |||
| void ConvertPredictModel(const KernelGraphPtr graph) {} | |||
| @@ -28,6 +28,7 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "ir/anf.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "utils/contract.h" | |||
| #include "device/kernel_info.h" | |||
| namespace mindspore { | |||
| @@ -108,6 +109,7 @@ class KernelGraph : public FuncGraph { | |||
| std::vector<std::shared_ptr<KernelGraph>> child_graph_order() const { return child_graph_order_; } | |||
| // checkout whether current graph is leaf graph | |||
| bool IsLeafGraph() const; | |||
| // set input_tensors pointer of control parameter | |||
| void set_input_ctrl_tensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &input_tensors_ptr) { | |||
| input_ctrl_tensors_ = input_tensors_ptr; | |||
| @@ -126,6 +128,9 @@ class KernelGraph : public FuncGraph { | |||
| // used to dump ir | |||
| std::string ToString() const override; | |||
| void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } | |||
| CNodePtr get_start_label() { return start_label_; } | |||
| private: | |||
| // remove value node form graph | |||
| bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); | |||
| @@ -168,12 +173,16 @@ class KernelGraph : public FuncGraph { | |||
| std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_; | |||
| // child graph execute order in root graph | |||
| std::vector<std::shared_ptr<KernelGraph>> child_graph_order_; | |||
| // input_tensors of control parameter | |||
| std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_; | |||
| // parameter graph | |||
| std::shared_ptr<KernelGraph> parent_graph_; | |||
| // record real parameters,inputs_ is the formal parameters | |||
| std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_; | |||
| CNodePtr start_label_; | |||
| }; | |||
| } // namespace session | |||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | |||
| @@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "../../../mindspore/ccsrc/transform/*.cc" | |||
| "../../../mindspore/ccsrc/session/anf_runtime_algorithm.cc" | |||
| "../../../mindspore/ccsrc/session/ascend_session.cc" | |||
| "../../../mindspore/ccsrc/session/ascend_control_parser.cc" | |||
| "../../../mindspore/ccsrc/session/kernel_graph.cc" | |||
| "../../../mindspore/ccsrc/session/session_basic.cc" | |||
| "../../../mindspore/ccsrc/session/session_factory.cc" | |||
| @@ -22,7 +22,9 @@ namespace mindspore { | |||
| namespace device { | |||
| namespace ascend { | |||
| void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &>) {} | |||
| void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) {} | |||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { return 1; } | |||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { return 1; } | |||
| void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; } | |||
| @@ -39,9 +41,7 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve | |||
| } // namespace ascend | |||
| void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; } | |||
| void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; } | |||
| bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | |||
| return true; | |||
| } | |||
| bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return true; } | |||
| bool KernelAdjust::NeedInsertSwitch() { return true; } | |||
| void KernelAdjust::Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr) { return; } | |||
| } // namespace device | |||