| @@ -29,6 +29,7 @@ | |||||
| #include "hccl/hcom.h" | #include "hccl/hcom.h" | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| #include "runtime/context.h" | #include "runtime/context.h" | ||||
| #include "device/ascend/ascend_label_assign.h" | |||||
| #include "device/ascend/ascend_stream_assign.h" | #include "device/ascend/ascend_stream_assign.h" | ||||
| #include "device/ascend/ascend_memory_pool.h" | #include "device/ascend/ascend_memory_pool.h" | ||||
| #include "framework/ge_runtime/model_runner.h" | #include "framework/ge_runtime/model_runner.h" | ||||
| @@ -281,21 +282,24 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); | |||||
| AscendStreamAssign &stream_assign_instance = AscendStreamAssign::GetInstance(); | |||||
| AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); | |||||
| // the streams' flag not HEAD_STREAM | // the streams' flag not HEAD_STREAM | ||||
| std::vector<uint32_t> wait_active_stream_list; | std::vector<uint32_t> wait_active_stream_list; | ||||
| assign_instance.GetWaitStreams(&wait_active_stream_list); | |||||
| auto force_copy_stream_list = assign_instance.hcom_streams(); | |||||
| stream_assign_instance.GetWaitStreams(&wait_active_stream_list); | |||||
| auto force_copy_stream_list = stream_assign_instance.hcom_streams(); | |||||
| MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum() | |||||
| << ", total event num:" << assign_instance.total_event_num() | |||||
| MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_assign_instance.GetTotalStreamNum() | |||||
| << ", total event num:" << stream_assign_instance.total_event_num() | |||||
| << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) | |||||
| << ", wait_active_stream_list size:" << wait_active_stream_list.size() | << ", wait_active_stream_list size:" << wait_active_stream_list.size() | ||||
| << ", force_copy_stream_list size:" << force_copy_stream_list.size(); | << ", force_copy_stream_list size:" << force_copy_stream_list.size(); | ||||
| std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list; | std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list; | ||||
| std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>( | std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>( | ||||
| task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, | task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, | ||||
| 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.total_event_num(), 0); | |||||
| 0, 0, 0, 0, 0, stream_assign_instance.GetTotalStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), | |||||
| stream_assign_instance.total_event_num(), 0); | |||||
| auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); | auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); | ||||
| if (!ret.second) { | if (!ret.second) { | ||||
| @@ -15,6 +15,8 @@ | |||||
| */ | */ | ||||
| #include <vector> | #include <vector> | ||||
| #include <string> | |||||
| #include <set> | |||||
| #include "device/ascend/ascend_label_assign.h" | #include "device/ascend/ascend_label_assign.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| @@ -36,6 +38,7 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) { | |||||
| uint32_t goto_label_id = GetValue<uint32_t>(value); | uint32_t goto_label_id = GetValue<uint32_t>(value); | ||||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get()); | AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(goto_label_id), node.get()); | ||||
| MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; | MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; | ||||
| node->set_inputs({node->input(0)}); | |||||
| } | } | ||||
| static void UpdateLabelSwitch(NotNull<CNodePtr> node) { | static void UpdateLabelSwitch(NotNull<CNodePtr> node) { | ||||
| @@ -58,29 +61,93 @@ static void UpdateLabelSwitch(NotNull<CNodePtr> node) { | |||||
| MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; | MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; | ||||
| } | } | ||||
| AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get()); | AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get()); | ||||
| node->set_inputs({node->input(0), node->input(1)}); | |||||
| } | } | ||||
| void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph) { | |||||
| auto cnode_list = graph->execution_order(); | |||||
| // 1 assign label id to label_set | |||||
| uint32_t cur_label_id = 0; | |||||
| for (auto &node : cnode_list) { | |||||
| if (AnfAlgo::GetCNodeName(node) == kLabelSetOpName) { | |||||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(cur_label_id), node); | |||||
| MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << cur_label_id; | |||||
| ++cur_label_id; | |||||
| static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>> graph, NotNull<uint32_t *> label_id, | |||||
| NotNull<std::set<std::shared_ptr<session::KernelGraph>> *> memo) { | |||||
| if (memo->find(graph.get()) != memo->end()) { | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Assign label for " << graph->ToString(); | |||||
| auto nodes = TopoSort(graph->get_return()); | |||||
| for (auto &node : nodes) { | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| std::string node_name = AnfAlgo::GetCNodeName(node); | |||||
| if (node_name == kLabelSetOpName && !AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { | |||||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue<uint32_t>(*label_id), node); | |||||
| MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << *label_id; | |||||
| ++(*label_id); | |||||
| } | } | ||||
| } | } | ||||
| // 2 update label_switch / label_goto | |||||
| for (auto &node : cnode_list) { | |||||
| if (AnfAlgo::GetCNodeName(node) == kLabelGotoOpName) { | |||||
| UpdateLabelGoto(NOT_NULL(node)); | |||||
| for (auto &cg : graph->child_graph_order()) { | |||||
| AssignLabelForLabelSet(NOT_NULL(cg), label_id, memo); | |||||
| } | |||||
| } | |||||
| static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGraph>> graph, | |||||
| NotNull<std::set<std::shared_ptr<session::KernelGraph>> *> memo) { | |||||
| if (memo->find(graph.get()) != memo->end()) { | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); | |||||
| auto nodes = TopoSort(graph->get_return()); | |||||
| for (auto &node : nodes) { | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | } | ||||
| if (AnfAlgo::GetCNodeName(node) == kLabelSwitchOpName) { | |||||
| UpdateLabelSwitch(NOT_NULL(node)); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| std::string node_name = AnfAlgo::GetCNodeName(node); | |||||
| if (node_name == kLabelGotoOpName) { | |||||
| UpdateLabelGoto(NOT_NULL(cnode)); | |||||
| cnode->set_abstract(nullptr); | |||||
| } | } | ||||
| if (node_name == kLabelSwitchOpName) { | |||||
| UpdateLabelSwitch(NOT_NULL(cnode)); | |||||
| } | |||||
| } | |||||
| for (auto &cg : graph->child_graph_order()) { | |||||
| AssignLabelForGotoSwitch(NOT_NULL(cg), memo); | |||||
| } | |||||
| } | |||||
| void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) { | |||||
| MS_LOG(INFO) << "Assign label start."; | |||||
| std::set<std::shared_ptr<session::KernelGraph>> memo; | |||||
| uint32_t label_id = 0; | |||||
| AssignLabelForLabelSet(graph, NOT_NULL(&label_id), NOT_NULL(&memo)); | |||||
| memo.clear(); | |||||
| { | |||||
| std::lock_guard<std::mutex> lock(label_num_mutex_); | |||||
| label_num_[graph.get().get()] = label_id; | |||||
| } | } | ||||
| AssignLabelForGotoSwitch(graph, NOT_NULL(&memo)); | |||||
| MS_LOG(INFO) << "Assign label end."; | |||||
| } | |||||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { | |||||
| std::lock_guard<std::mutex> lock(label_num_mutex_); | |||||
| auto iter = label_num_.find(graph.get()); | |||||
| if (iter == label_num_.end()) { | |||||
| MS_LOG(WARNING) << "Graph " << graph->ToString() << " has not assigned label."; | |||||
| return 1; | |||||
| } | |||||
| return iter->second; | |||||
| } | |||||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { | |||||
| return GetLabelNum(NOT_NULL(graph.get().get())); | |||||
| } | } | ||||
| } // namespace ascend | } // namespace ascend | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ | #define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "utils/contract.h" | #include "utils/contract.h" | ||||
| @@ -35,11 +36,16 @@ class AscendLabelAssign { | |||||
| AscendLabelAssign(const AscendLabelAssign &) = delete; | AscendLabelAssign(const AscendLabelAssign &) = delete; | ||||
| AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; | AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; | ||||
| void AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &> graph); | |||||
| void AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph); | |||||
| uint32_t GetLabelNum(NotNull<const session::KernelGraph *> graph); | |||||
| uint32_t GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph); | |||||
| private: | private: | ||||
| AscendLabelAssign() = default; | AscendLabelAssign() = default; | ||||
| ~AscendLabelAssign() = default; | ~AscendLabelAssign() = default; | ||||
| std::map<const session::KernelGraph *, uint32_t> label_num_; | |||||
| std::mutex label_num_mutex_; | |||||
| }; | }; | ||||
| } // namespace ascend | } // namespace ascend | ||||
| } // namespace device | } // namespace device | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "kernel/rts/label_switch.h" | #include "kernel/rts/label_switch.h" | ||||
| #include <asm-generic/param.h> | #include <asm-generic/param.h> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| #include "runtime/stream.h" | #include "runtime/stream.h" | ||||
| #include "framework/ge_runtime/task_info.h" | #include "framework/ge_runtime/task_info.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| @@ -66,13 +67,33 @@ std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr | |||||
| MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id; | MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id; | ||||
| std::vector<TaskInfoPtr> task_info_list; | std::vector<TaskInfoPtr> task_info_list; | ||||
| cond_ = inputs[0]->addr; | cond_ = inputs[0]->addr; | ||||
| // std::shared_ptr<LabelSwitchTaskInfo> task_info_ptr = | |||||
| // std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, &label_list_, cond_); | |||||
| // need updata ge task info define | |||||
| std::shared_ptr<LabelSwitchTaskInfo> task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_); | |||||
| // todo: need update ge task info define | |||||
| auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, 0); | |||||
| // auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, label_list_, cond_); | |||||
| MS_EXCEPTION_IF_NULL(task_info_ptr); | MS_EXCEPTION_IF_NULL(task_info_ptr); | ||||
| task_info_list.emplace_back(task_info_ptr); | task_info_list.emplace_back(task_info_ptr); | ||||
| return task_info_list; | return task_info_list; | ||||
| } | } | ||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernelInfo() { | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> label_switch_build_info{}; | |||||
| vector<string> input_format{kOpFormat_DEFAULT, kOpFormat_DEFAULT}; | |||||
| vector<TypeId> input_type{kNumberTypeUInt32, kNumberTypeBool}; | |||||
| if (input_format.size() != input_type.size()) { | |||||
| MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size " | |||||
| << input_type.size(); | |||||
| } | |||||
| for (size_t i = 0; i < input_format.size(); ++i) { | |||||
| auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); | |||||
| builder.SetInputsFormat({input_format[i]}); | |||||
| builder.SetInputsDeviceType({input_type[i]}); | |||||
| builder.SetProcessor(AICORE); | |||||
| builder.SetKernelType(RT_KERNEL); | |||||
| builder.SetFusionType(OPAQUE); | |||||
| label_switch_build_info.emplace_back(builder.Build()); | |||||
| } | |||||
| return label_switch_build_info; | |||||
| } | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,6 +42,14 @@ class LabelSwitchKernel : public RtKernel { | |||||
| void *cond_; | void *cond_; | ||||
| }; | }; | ||||
| class LabelSwitchDesc : public RtKerDesc { | |||||
| public: | |||||
| LabelSwitchDesc() = default; | |||||
| ~LabelSwitchDesc() override = default; | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetKernelInfo() override; | |||||
| }; | |||||
| MS_REG_RTKERNEL_DESC(labelswitch, LabelSwitchDesc); | |||||
| MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel); | MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,6 +44,12 @@ RtKerDescFactory &RtKerDescFactory::Get() { | |||||
| return _this; | return _this; | ||||
| } | } | ||||
| static bool IsDefaultKernelInfo(const std::string &name) { | |||||
| static const std::set<std::string> white_list = {kStreamSwitchOpName, kStreamActiveOpName, kLabelSetOpName, | |||||
| kLabelGotoOpName}; | |||||
| return white_list.find(name) != white_list.end(); | |||||
| } | |||||
| void GetRtKelInfo(const CNodePtr &kernel_node, | void GetRtKelInfo(const CNodePtr &kernel_node, | ||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| @@ -58,7 +64,7 @@ void GetRtKelInfo(const CNodePtr &kernel_node, | |||||
| } | } | ||||
| // if can't find kernel info in kernel info database, use the default kernel info | // if can't find kernel info in kernel info database, use the default kernel info | ||||
| auto node_name = AnfAlgo::GetCNodeName(kernel_node); | auto node_name = AnfAlgo::GetCNodeName(kernel_node); | ||||
| if (node_name == "StreamSwitch" || node_name == "StreamActive") { | |||||
| if (IsDefaultKernelInfo(node_name)) { | |||||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | ||||
| // set input infos | // set input infos | ||||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); | auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| @@ -331,12 +331,14 @@ bool ExecuteAction(const ResourcePtr &res) { | |||||
| } | } | ||||
| auto graph_id = res->results()[kOutput].cast<GraphId>(); | auto graph_id = res->results()[kOutput].cast<GraphId>(); | ||||
| auto bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::MsBackend>>(); | |||||
| std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>(); | |||||
| std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr); | |||||
| MS_EXCEPTION_IF_NULL(msbc_ptr); | |||||
| compile::VmEvalFuncPtr run = | compile::VmEvalFuncPtr run = | ||||
| std::make_shared<compile::VmEvalFunc>([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef { | |||||
| MS_LOG(INFO) << "Execute args size" << args.size(); | |||||
| auto outs = bc_ptr->RunGraph(graph_id, args); | |||||
| MS_LOG(DEBUG) << "out size" << outs.size(); | |||||
| std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef { | |||||
| MS_LOG(INFO) << "Execute args size " << args.size(); | |||||
| auto outs = msbc_ptr->RunGraph(graph_id, args); | |||||
| MS_LOG(DEBUG) << "out size " << outs.size(); | |||||
| return outs[0]; | return outs[0]; | ||||
| }); | }); | ||||
| res->results()[kOutput] = run; | res->results()[kOutput] = run; | ||||
| @@ -6,22 +6,23 @@ file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ) | ) | ||||
| if (ENABLE_GPU) | if (ENABLE_GPU) | ||||
| file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| "gpu_session.cc" | "gpu_session.cc" | ||||
| ) | ) | ||||
| list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) | list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) | ||||
| endif () | endif () | ||||
| if (ENABLE_CPU) | if (ENABLE_CPU) | ||||
| file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| "cpu_session.cc" | "cpu_session.cc" | ||||
| ) | ) | ||||
| list(APPEND _SESSION_SRC_LIST ${_CPU_SRC_LIST}) | list(APPEND _SESSION_SRC_LIST ${_CPU_SRC_LIST}) | ||||
| endif () | endif () | ||||
| if (ENABLE_D) | if (ENABLE_D) | ||||
| file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| "ascend_session.cc" | "ascend_session.cc" | ||||
| "ascend_control_parser.cc" | |||||
| ) | ) | ||||
| list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) | list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) | ||||
| endif () | endif () | ||||
| @@ -0,0 +1,319 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include "session/ascend_control_parser.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| namespace mindspore { | |||||
| namespace session { | |||||
| static VectorRef GetCallArgs(std::vector<AnfNodePtr>::iterator iter_begin, std::vector<AnfNodePtr>::iterator iter_end) { | |||||
| VectorRef call_args; | |||||
| for (auto iter = iter_begin; iter != iter_end; ++iter) { | |||||
| if (utils::isa<ValueNode>(*iter)) { | |||||
| call_args.push_back(GetValueNode(*iter)); | |||||
| } else { | |||||
| call_args.push_back(*iter); | |||||
| } | |||||
| } | |||||
| return call_args; | |||||
| } | |||||
| void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { | |||||
| std::set<KernelGraphPtr> memo; | |||||
| ProcessKernelGraph(kg, nullptr, nullptr, {}, NOT_NULL(&memo)); | |||||
| } | |||||
| NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | |||||
| const CNodePtr &last_label, const VectorRef &args, | |||||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); | |||||
| // 0. recursive condition | |||||
| if (memo->find(kg) != memo->end()) { | |||||
| MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString(); | |||||
| return NOT_NULL(kg->get_start_label()); | |||||
| } | |||||
| // 2. args replace placeholder | |||||
| LinkParentGraph(kg, last_node, last_label, args); | |||||
| // 3. topological sort | |||||
| std::vector<CNodePtr> nodes = GetCNodes(TopoSort(kg->get_return())); | |||||
| if (nodes.empty()) { | |||||
| MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!"; | |||||
| } | |||||
| // 4. insert first_label | |||||
| auto start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); | |||||
| for (auto node : nodes) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { | |||||
| InsertControlDependToGraph(kg, NOT_NULL(start_label), NOT_NULL(node)); | |||||
| break; | |||||
| } | |||||
| } | |||||
| kg->set_start_label(start_label); | |||||
| // 5. traverse | |||||
| for (size_t i = 0; i < nodes.size(); ++i) { | |||||
| auto &cnode = nodes[i]; | |||||
| if (cnode->size() < kCNodePrim + 1) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||||
| } | |||||
| AnfNodePtr fn = cnode->input(kCNodePrim); | |||||
| if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { | |||||
| MS_LOG(DEBUG) << "continue node " << cnode->DebugString(); | |||||
| continue; | |||||
| } | |||||
| AnfNodePtr arg = cnode->input(kCNodeCallArg); | |||||
| if (IsValueNode<KernelGraph>(arg)) { | |||||
| RecurseCall(kg, NOT_NULL(cnode), (i + 1 < nodes.size() ? nodes[i + 1] : nullptr), memo); | |||||
| } else if (!arg->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); | |||||
| } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) { | |||||
| auto arg_cnode = arg->cast<CNodePtr>(); | |||||
| cnode->set_inputs(cnode->inputs()); | |||||
| RecurseSwitch(kg, NOT_NULL(cnode), memo); | |||||
| } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) { | |||||
| auto arg_cnode = arg->cast<CNodePtr>(); | |||||
| cnode->set_inputs(cnode->inputs()); | |||||
| RecurseSwitchLayer(kg, NOT_NULL(cnode), memo); | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); | |||||
| return NOT_NULL(start_label); | |||||
| } | |||||
| std::vector<CNodePtr> AscendControlParser::GetCNodes(const std::vector<AnfNodePtr> &in) { | |||||
| std::vector<CNodePtr> out; | |||||
| for (auto &node : in) { | |||||
| if (node->isa<CNode>()) { | |||||
| out.push_back(node->cast<CNodePtr>()); | |||||
| } | |||||
| } | |||||
| return out; | |||||
| } | |||||
| void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node) { | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))}; | |||||
| auto return_node = kg->get_return(); | |||||
| MS_EXCEPTION_IF_NULL(return_node); | |||||
| inputs.push_back(return_node->input(1)); | |||||
| inputs.push_back(attch_node.get()); | |||||
| auto depend_node = kg->NewCNode(inputs); | |||||
| return_node->set_input(1, depend_node); | |||||
| } | |||||
| void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node, | |||||
| NotNull<AnfNodePtr> second_node) { | |||||
| MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() | |||||
| << ", the second node is " << second_node->DebugString(); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())), | |||||
| first_node, second_node}; | |||||
| auto control_depend = kg->NewCNode(inputs); | |||||
| InsertDependToGraph(kg, NOT_NULL(control_depend)); | |||||
| } | |||||
| void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | |||||
| const CNodePtr &last_label, const VectorRef &args) { | |||||
| if (from_graph_call_node != nullptr) { | |||||
| SetSubGraphInput(kg, NOT_NULL(from_graph_call_node), args); | |||||
| } | |||||
| auto origin_return = kg->get_return(); | |||||
| std::vector<AnfNodePtr> origin_return_inputs = origin_return->inputs(); | |||||
| // if entry graph, replace return with make_tuple | |||||
| if (from_graph_call_node == nullptr || last_label == nullptr) { | |||||
| MS_LOG(INFO) << kg->ToString() << " is entry graph."; | |||||
| std::vector<AnfNodePtr> make_tuple_inputs = {std::make_shared<ValueNode>(prim::kPrimMakeTuple)}; | |||||
| make_tuple_inputs.insert(make_tuple_inputs.end(), origin_return_inputs.begin() + 1, origin_return_inputs.end()); | |||||
| auto make_tuple = kg->NewCNode(make_tuple_inputs); | |||||
| origin_return->set_inputs({origin_return->input(kCNodePrim), make_tuple}); | |||||
| } else { | |||||
| // else replace return with label_goto | |||||
| auto label_goto = | |||||
| kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName)), last_label}); | |||||
| InsertDependToGraph(kg, NOT_NULL(label_goto)); | |||||
| } | |||||
| } | |||||
| void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_LOG(INFO) << "process call func " << cur_node->DebugString(); | |||||
| // 1 get kernel graph | |||||
| auto origin_inputs = cur_node->inputs(); | |||||
| std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))}; | |||||
| auto call_args = GetCallArgs(origin_inputs.begin() + 1, origin_inputs.end()); | |||||
| if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) { | |||||
| MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; | |||||
| return; | |||||
| } | |||||
| // 2 return label | |||||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); | |||||
| // 3 add depend relationship | |||||
| InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); | |||||
| if (next_node != nullptr && next_node != kg->get_return()) { | |||||
| InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); | |||||
| } | |||||
| auto call_kg = GetValueNode<KernelGraphPtr>(origin_inputs[kCNodeCallArg]); | |||||
| // 4 modify call op to goto op | |||||
| cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]); | |||||
| // 5 recurse sub graph | |||||
| CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, call_args, memo); | |||||
| new_inputs.push_back(sub_label); | |||||
| new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end()); | |||||
| cur_node->set_inputs(new_inputs); | |||||
| cur_node->set_abstract(nullptr); | |||||
| MS_LOG(INFO) << "success process call func " << cur_node->DebugString(); | |||||
| } | |||||
| void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | |||||
| if (cur_node->size() < kCNodeSwitchLength) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; | |||||
| } | |||||
| // 1 return label | |||||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(prim::kPrimLabelSet)}); | |||||
| // 2 recurse sub graph | |||||
| auto origin_switch_inputs = cur_node->inputs(); | |||||
| std::vector<AnfNodePtr> new_switch_inputs = { | |||||
| std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), | |||||
| origin_switch_inputs[kCNodeSwitchCond]}; | |||||
| for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { | |||||
| // 2.1 branch kernel graph and args | |||||
| CNodePtr partial; | |||||
| KernelGraphPtr branch_fg; | |||||
| VectorRef call_args; | |||||
| std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||||
| // 2.2 add depend relationship | |||||
| InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); | |||||
| // 2.3 recurse sub graph | |||||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); | |||||
| new_switch_inputs.push_back(branch_label); | |||||
| } | |||||
| std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); | |||||
| new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end()); | |||||
| cur_node->set_inputs(new_switch_inputs); | |||||
| cur_node->set_abstract(nullptr); | |||||
| MS_LOG(INFO) << "success process switch func " << cur_node->DebugString(); | |||||
| } | |||||
| void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | |||||
| if (cur_node->size() < kCNodeSwitchLayerLength) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; | |||||
| } | |||||
| auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); | |||||
| MS_EXCEPTION_IF_NULL(branch_tuple); | |||||
| if (!branch_tuple->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; | |||||
| } | |||||
| auto branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs(); | |||||
| // 1 return label | |||||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName))}); | |||||
| // 2 recurse sub graph | |||||
| auto origin_switch_inputs = cur_node->inputs(); | |||||
| std::vector<AnfNodePtr> new_switch_inputs = {std::make_shared<ValueNode>(prim::kPrimLabelSwitch), | |||||
| origin_switch_inputs[kCNodeSwitchCond]}; | |||||
| for (size_t i = 0; i < branch_partial.size(); ++i) { | |||||
| // 2.1 branch kernel graph and args | |||||
| CNodePtr partial; | |||||
| KernelGraphPtr branch_fg; | |||||
| VectorRef call_args; | |||||
| std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||||
| // 2.2 add depend relationship | |||||
| InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); | |||||
| // 2.3 recurse sub graph | |||||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); | |||||
| new_switch_inputs.push_back(branch_label); | |||||
| } | |||||
| new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); | |||||
| cur_node->set_inputs(new_switch_inputs); | |||||
| cur_node->set_abstract(nullptr); | |||||
| MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString(); | |||||
| } | |||||
| std::tuple<CNodePtr, KernelGraphPtr, VectorRef> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) { | |||||
| if (!node.get()->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); | |||||
| } | |||||
| // 2.1 branch kernel graph and args | |||||
| auto partial_cnode = utils::cast<CNodePtr>(node.get()); | |||||
| if (partial_cnode->size() < kCNodePartialLength) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; | |||||
| } | |||||
| auto partial_inputs = partial_cnode->inputs(); | |||||
| auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]); | |||||
| auto call_args = GetCallArgs(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end()); | |||||
| return {partial_cnode, branch_kg, call_args}; | |||||
| } | |||||
| void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, | |||||
| NotNull<AnfNodePtr> to) { | |||||
| if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && | |||||
| AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { | |||||
| return; | |||||
| } | |||||
| if (from.get() == to.get()) { | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " | |||||
| << to->DebugString(); | |||||
| // config inputs of assign node | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from}; | |||||
| // generate a new cnode | |||||
| auto assign_node = kg->NewCNode(inputs); | |||||
| MS_EXCEPTION_IF_NULL(assign_node); | |||||
| assign_node->set_abstract(to->abstract()); | |||||
| // append the assign at the end of from graph | |||||
| InsertDependToGraph(kg, NOT_NULL(assign_node)); | |||||
| } | |||||
| size_t AscendControlParser::SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node, | |||||
| size_t input_index) { | |||||
| auto output_num = AnfAlgo::GetOutputTensorNum(node); | |||||
| if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||||
| return input_index + output_num; | |||||
| } | |||||
| auto &graph_inputs = kg->inputs(); | |||||
| if (input_index >= graph_inputs.size()) { | |||||
| MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); | |||||
| } | |||||
| auto backend_parameter = graph_inputs[input_index]; | |||||
| if (node.get()->isa<Parameter>()) { | |||||
| MS_EXCEPTION_IF_NULL(backend_parameter); | |||||
| MS_LOG(INFO) << "Reuse node [" << node->DebugString() << "], old node[" << backend_parameter->DebugString() | |||||
| << "] will be replaced."; | |||||
| kg->ReplaceNode(backend_parameter, node); | |||||
| return input_index; | |||||
| } | |||||
| InsertAssignToGraph(kg, node, NOT_NULL(backend_parameter)); | |||||
| return input_index + 1; | |||||
| } | |||||
| void AscendControlParser::SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node, | |||||
| const VectorRef &args) {} | |||||
| } // namespace session | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H | |||||
| #define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include <tuple> | |||||
| #include "session/kernel_graph.h" | |||||
| #include "utils/base_ref.h" | |||||
| #include "utils/contract.h" | |||||
| namespace mindspore { | |||||
| namespace session { | |||||
| class AscendControlParser { | |||||
| public: | |||||
| static void LinkGraph(NotNull<KernelGraphPtr> kg); | |||||
| static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node); | |||||
| static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node, | |||||
| NotNull<AnfNodePtr> second_node); | |||||
| private: | |||||
| static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | |||||
| const CNodePtr &last_label, const VectorRef &args, | |||||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| static std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &in); | |||||
| static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | |||||
| const CNodePtr &last_label, const VectorRef &args); | |||||
| static void SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node, | |||||
| const VectorRef &args); | |||||
| static std::tuple<CNodePtr, KernelGraphPtr, VectorRef> ParsePartial(NotNull<AnfNodePtr> node); | |||||
| static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||||
| static size_t SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node, size_t input_index); | |||||
| static constexpr size_t kCNodePrim = 0; | |||||
| static constexpr size_t kCNodeCallArg = 1; | |||||
| static constexpr size_t kCNodeSwitchCond = 1; | |||||
| static constexpr size_t kCNodeSwitchTrue = 2; | |||||
| static constexpr size_t kCNodeSwitchFalse = 3; | |||||
| static constexpr size_t kCNodeSwitchLength = 4; | |||||
| static constexpr size_t kCNodePartialLength = 2; | |||||
| static constexpr size_t kCNodePartialFunc = 1; | |||||
| static constexpr size_t kCNodeSwitchLayerCond = 1; | |||||
| static constexpr size_t kCNodeSwitchLayerBranch = 2; | |||||
| static constexpr size_t kCNodeSwitchLayerLength = 3; | |||||
| }; | |||||
| } // namespace session | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H | |||||
| @@ -160,14 +160,14 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) { | |||||
| std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) { | std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) { | ||||
| std::vector<CNodePtr> cnodes = {}; | std::vector<CNodePtr> cnodes = {}; | ||||
| size_t i = 0; | size_t i = 0; | ||||
| for (const auto anf : anf_nodes) { | |||||
| for (auto anf : anf_nodes) { | |||||
| MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString(); | MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString(); | ||||
| MS_EXCEPTION_IF_NULL(anf); | MS_EXCEPTION_IF_NULL(anf); | ||||
| if (anf->isa<CNode>()) { | if (anf->isa<CNode>()) { | ||||
| cnodes.push_back(anf->cast<CNodePtr>()); | cnodes.push_back(anf->cast<CNodePtr>()); | ||||
| } | } | ||||
| } | } | ||||
| return std::move(cnodes); | |||||
| return cnodes; | |||||
| } | } | ||||
| std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) { | std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, const std::vector<CNodePtr> &cnodes) { | ||||
| @@ -189,7 +189,7 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co | |||||
| ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end())); | ret.push_back(std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.end())); | ||||
| } | } | ||||
| } | } | ||||
| return std::move(ret); | |||||
| return ret; | |||||
| } | } | ||||
| void UpdateRealInput(KernelGraph *graph) { | void UpdateRealInput(KernelGraph *graph) { | ||||
| @@ -232,7 +232,7 @@ void UpdateRealInput(KernelGraph *graph) { | |||||
| auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); | auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); | ||||
| partial_cnode->set_inputs( | partial_cnode->set_inputs( | ||||
| std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); | std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); | ||||
| return std::move(ret); | |||||
| return ret; | |||||
| }; | }; | ||||
| bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); | bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); | ||||
| bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); | bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); | ||||
| @@ -256,27 +256,28 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| // split switch | // split switch | ||||
| SplitGraph(graph); | SplitGraph(graph); | ||||
| // insert goto labels and label_sets | // insert goto labels and label_sets | ||||
| LinkChildGraphs(graph.get()); | |||||
| LinkChildGraphs(NOT_NULL(graph)); | |||||
| // resource initialize | // resource initialize | ||||
| InitRuntimeResource(); | InitRuntimeResource(); | ||||
| // ir fusion | |||||
| IRFusion(graph); | |||||
| // kernel select | |||||
| SelectKernelGraphKernel(*graph); | |||||
| // convert model of predict module | |||||
| ConvertPredictModel(graph); | |||||
| // hardware optimize | |||||
| HardwareOptimizeGraphs(graph); | |||||
| // assign label | |||||
| AssignLabel(NOT_NULL(graph)); | |||||
| if (!graph->executable()) { | |||||
| return graph->graph_id(); | |||||
| } | |||||
| for (auto iter : graphs_) { | |||||
| if (iter.second == graph) { | |||||
| MS_LOG(INFO) << "Entry graph " << graph->ToString() << " graph id " << graph->graph_id(); | |||||
| final_graph_id_ = graph->graph_id(); | |||||
| } | |||||
| MS_LOG(INFO) << "CompileChildGraph " << iter.second->ToString(); | |||||
| CompileChildGraph(iter.second); | |||||
| } | |||||
| // adjust kernel | // adjust kernel | ||||
| AdjustKernel(graph); | AdjustKernel(graph); | ||||
| // root graph valiate,include genearte execute order and so on | // root graph valiate,include genearte execute order and so on | ||||
| RootGraphExecutorValidate(graph.get()); | RootGraphExecutorValidate(graph.get()); | ||||
| // assign stream | // assign stream | ||||
| AssignStream(graph); | AssignStream(graph); | ||||
| // assign label | |||||
| AssignLabel(NOT_NULL(graph)); | |||||
| // build kernel if node is cnode | |||||
| BuildKernel(graph); | |||||
| // alloc mem | // alloc mem | ||||
| MemoryAlloc(graph.get()); | MemoryAlloc(graph.get()); | ||||
| // task generate | // task generate | ||||
| @@ -556,7 +557,7 @@ void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_grap | |||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| } | } | ||||
| void AscendSession::AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const { | |||||
| void AscendSession::AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const { | |||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph); | device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph); | ||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| @@ -1305,29 +1306,13 @@ void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived | |||||
| } | } | ||||
| void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { | void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { | ||||
| MS_LOG(INFO) << "Insert depend at the end of graph, the attach node is " << attch_node->DebugString(); | |||||
| auto graph = GetGraph(graph_id); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))}; | |||||
| auto return_node = graph->get_return(); | |||||
| MS_EXCEPTION_IF_NULL(return_node); | |||||
| inputs.push_back(return_node->input(1)); | |||||
| inputs.push_back(attch_node); | |||||
| auto depend_node = graph->NewCNode(inputs); | |||||
| return_node->set_input(1, depend_node); | |||||
| AscendControlParser::InsertDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(attch_node)); | |||||
| } | } | ||||
| void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, | void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, | ||||
| const AnfNodePtr &second_node) { | const AnfNodePtr &second_node) { | ||||
| MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() | |||||
| << ", the second node is " << second_node->DebugString(); | |||||
| auto graph = GetGraph(graph_id); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("ControlDepend"))}; | |||||
| inputs.push_back(first_node); | |||||
| inputs.push_back(second_node); | |||||
| auto control_depend = graph->NewCNode(inputs); | |||||
| InsertDependToGraph(graph_id, control_depend); | |||||
| AscendControlParser::InsertControlDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(first_node), | |||||
| NOT_NULL(second_node)); | |||||
| } | } | ||||
| size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { | size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { | ||||
| @@ -1482,5 +1467,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||||
| SplitGraph(child_graph); | SplitGraph(child_graph); | ||||
| } | } | ||||
| } | } | ||||
| void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); } | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "kernel/kernel.h" | #include "kernel/kernel.h" | ||||
| #include "session/session_factory.h" | #include "session/session_factory.h" | ||||
| #include "session/ascend_control_parser.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -74,7 +75,7 @@ class AscendSession : public SessionBasic { | |||||
| void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void AssignLabel(NotNull<const KernelGraphPtr &> kernel_graph) const; | |||||
| void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const; | |||||
| void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void MemoryAlloc(KernelGraph *kernel_graph) const; | void MemoryAlloc(KernelGraph *kernel_graph) const; | ||||
| void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; | void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; | ||||
| @@ -96,7 +97,8 @@ class AscendSession : public SessionBasic { | |||||
| void SetFinalGraphOutput(const VectorRef &vec_output); | void SetFinalGraphOutput(const VectorRef &vec_output); | ||||
| void SplitGraph(const KernelGraphPtr &graph); | void SplitGraph(const KernelGraphPtr &graph); | ||||
| void LinkChildGraphs(KernelGraph *graph) {} | |||||
| void LinkChildGraphs(NotNull<KernelGraphPtr> graph); | |||||
| void IRFusion(const KernelGraphPtr &graph) {} | void IRFusion(const KernelGraphPtr &graph) {} | ||||
| void SelectKernelGraphKernel(const KernelGraph &graph) {} | void SelectKernelGraphKernel(const KernelGraph &graph) {} | ||||
| void ConvertPredictModel(const KernelGraphPtr graph) {} | void ConvertPredictModel(const KernelGraphPtr graph) {} | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
| #include "utils/contract.h" | |||||
| #include "device/kernel_info.h" | #include "device/kernel_info.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -108,6 +109,7 @@ class KernelGraph : public FuncGraph { | |||||
| std::vector<std::shared_ptr<KernelGraph>> child_graph_order() const { return child_graph_order_; } | std::vector<std::shared_ptr<KernelGraph>> child_graph_order() const { return child_graph_order_; } | ||||
| // checkout whether current graph is leaf graph | // checkout whether current graph is leaf graph | ||||
| bool IsLeafGraph() const; | bool IsLeafGraph() const; | ||||
| // set input_tensors pointer of control parameter | // set input_tensors pointer of control parameter | ||||
| void set_input_ctrl_tensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &input_tensors_ptr) { | void set_input_ctrl_tensors(const std::shared_ptr<std::vector<tensor::TensorPtr>> &input_tensors_ptr) { | ||||
| input_ctrl_tensors_ = input_tensors_ptr; | input_ctrl_tensors_ = input_tensors_ptr; | ||||
| @@ -126,6 +128,9 @@ class KernelGraph : public FuncGraph { | |||||
| // used to dump ir | // used to dump ir | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } | |||||
| CNodePtr get_start_label() { return start_label_; } | |||||
| private: | private: | ||||
| // remove value node form graph | // remove value node form graph | ||||
| bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); | bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); | ||||
| @@ -168,12 +173,16 @@ class KernelGraph : public FuncGraph { | |||||
| std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_; | std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_; | ||||
| // child graph execute order in root graph | // child graph execute order in root graph | ||||
| std::vector<std::shared_ptr<KernelGraph>> child_graph_order_; | std::vector<std::shared_ptr<KernelGraph>> child_graph_order_; | ||||
| // input_tensors of control parameter | // input_tensors of control parameter | ||||
| std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_; | std::shared_ptr<std::vector<tensor::TensorPtr>> input_ctrl_tensors_; | ||||
| // parameter graph | // parameter graph | ||||
| std::shared_ptr<KernelGraph> parent_graph_; | std::shared_ptr<KernelGraph> parent_graph_; | ||||
| // record real parameters,inputs_ is the formal parameters | // record real parameters,inputs_ is the formal parameters | ||||
| std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_; | std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_; | ||||
| CNodePtr start_label_; | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | ||||
| @@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| "../../../mindspore/ccsrc/transform/*.cc" | "../../../mindspore/ccsrc/transform/*.cc" | ||||
| "../../../mindspore/ccsrc/session/anf_runtime_algorithm.cc" | "../../../mindspore/ccsrc/session/anf_runtime_algorithm.cc" | ||||
| "../../../mindspore/ccsrc/session/ascend_session.cc" | "../../../mindspore/ccsrc/session/ascend_session.cc" | ||||
| "../../../mindspore/ccsrc/session/ascend_control_parser.cc" | |||||
| "../../../mindspore/ccsrc/session/kernel_graph.cc" | "../../../mindspore/ccsrc/session/kernel_graph.cc" | ||||
| "../../../mindspore/ccsrc/session/session_basic.cc" | "../../../mindspore/ccsrc/session/session_basic.cc" | ||||
| "../../../mindspore/ccsrc/session/session_factory.cc" | "../../../mindspore/ccsrc/session/session_factory.cc" | ||||
| @@ -22,7 +22,9 @@ namespace mindspore { | |||||
| namespace device { | namespace device { | ||||
| namespace ascend { | namespace ascend { | ||||
| void AscendLabelAssign::AssignLabel(NotNull<const std::shared_ptr<session::KernelGraph> &>) {} | |||||
| void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) {} | |||||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { return 1; } | |||||
| uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { return 1; } | |||||
| void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; } | void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; } | ||||
| @@ -39,9 +41,7 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve | |||||
| } // namespace ascend | } // namespace ascend | ||||
| void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; } | void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; } | ||||
| void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; } | void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return; } | ||||
| bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { | |||||
| return true; | |||||
| } | |||||
| bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { return true; } | |||||
| bool KernelAdjust::NeedInsertSwitch() { return true; } | bool KernelAdjust::NeedInsertSwitch() { return true; } | ||||
| void KernelAdjust::Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr) { return; } | void KernelAdjust::Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr) { return; } | ||||
| } // namespace device | } // namespace device | ||||