Browse Source

!15356 Support mem reuse in control flow and multi-call subgraphs

From: @liangzelang
Reviewed-by: @zhoufeng54,@kisnwang
Signed-off-by: @kisnwang
pull/15356/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
78469f6083
6 changed files with 235 additions and 33 deletions
  1. +54
    -32
      mindspore/ccsrc/backend/optimizer/somas/somas.cc
  2. +1
    -1
      mindspore/ccsrc/backend/optimizer/somas/somas.h
  3. +28
    -0
      mindspore/ccsrc/backend/session/ascend_auto_monad.cc
  4. +130
    -0
      mindspore/ccsrc/backend/session/ascend_session.cc
  5. +18
    -0
      mindspore/ccsrc/backend/session/kernel_graph.h
  6. +4
    -0
      mindspore/ccsrc/utils/utils.h

+ 54
- 32
mindspore/ccsrc/backend/optimizer/somas/somas.cc View File

@@ -400,11 +400,17 @@ bool Somas::InitSomasTensors(const session::KernelGraph *graph) {

void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<CNodePtr> kernel_cnodes;
streams_list_ = {};
nodes_list_ = {};
size_t node_index = 0;
auto kernel_cnodes = graph->execution_order();
for (const auto &kernel : kernel_cnodes) {
if (graph->subgraph_multi_call()) {
kernel_cnodes = graph->mem_reuse_exec_order();
} else {
kernel_cnodes = graph->execution_order();
}
for (size_t i = 0; i < kernel_cnodes.size(); i++) {
auto kernel = kernel_cnodes[i];
SomasStreamPtr stream;
auto stream_id = AnfAlgo::GetStreamId(kernel);
auto it = find_if(streams_list_.begin(), streams_list_.end(),
@@ -427,7 +433,8 @@ void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) {
nodes_list_.push_back(node);
stream->nodes_.push_back(node);
auto key = kernel.get();
nodes_map_[key] = node;
auto &nodes = nodes_map_[key];
nodes.push_back(node);
node_index++;
}
}
@@ -438,7 +445,8 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph
size_t tensor_index = 0;
auto kernel_cnodes = graph->execution_order();
for (const auto &kernel : kernel_cnodes) {
auto node = nodes_map_[kernel.get()];
auto nodes = nodes_map_[kernel.get()];
auto node = nodes[0];
MS_EXCEPTION_IF_NULL(node);
auto stream = node->GetStream();
MS_EXCEPTION_IF_NULL(stream);
@@ -454,7 +462,7 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph
// Set all output tensor lifelong to true.
auto tensor = std::make_shared<SomasTensor>(output_tensor_index, node, stream, size, kLifeLongNone);
tensor->lifetime_.start_ = node->GetId();
tensor->lifetime_.end_ = node->GetId();
tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId();
tensor->type_ = kOutputOnly;
if (AnfAlgo::OutputAddrExist(kernel, index)) {
tensor->aligned_size_ = 0;
@@ -463,8 +471,10 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph
tensors_list_.push_back(tensor);
tensors_map_[output_tensor_index] = tensor;
stream->tensors_.push_back(tensor);
node->tensors_.insert(tensor);
node->output_tensors_.push_back(tensor);
std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) {
node->tensors_.insert(tensor);
node->output_tensors_.push_back(tensor);
});
index++;
}

@@ -477,15 +487,17 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph *graph
SomasTensorPtr tensor = std::make_shared<SomasTensor>(workspace_tensor_index, node, stream, size, kLifeLongNone);
tensor->type_ = kWorkspace;
tensor->lifetime_.start_ = node->GetId();
tensor->lifetime_.end_ = node->GetId();
tensor->lifetime_.end_ = (nodes.size() > 1) ? nodes.back()->GetId() : node->GetId();
if (AnfAlgo::WorkspaceAddrExist(kernel, index)) {
tensor->aligned_size_ = 0;
}
tensors_list_.push_back(tensor);
tensors_map_[workspace_tensor_index] = tensor;
stream->tensors_.push_back(tensor);
node->tensors_.insert(tensor);
node->workspace_tensors_.push_back(tensor);
std::for_each(nodes.begin(), nodes.end(), [tensor](auto &node) {
node->tensors_.insert(tensor);
node->workspace_tensors_.push_back(tensor);
});
index++;
}
}
@@ -505,7 +517,8 @@ void Somas::InitSomasInputTensors(const session::KernelGraph *graph) {
}
}
void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) {
auto node = nodes_map_[kernel.get()];
auto nodes = nodes_map_[kernel.get()];
auto node = nodes[0];
MS_EXCEPTION_IF_NULL(node);
auto stream = node->GetStream();
MS_EXCEPTION_IF_NULL(stream);
@@ -543,7 +556,7 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) {
MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input " << i << " ["
<< prenode_index.first->fullname_with_scope() << "] is not init.";
}
auto pre_somas_node = iter->second;
auto pre_somas_node = iter->second.at(0);
if (prenode_index.second > pre_somas_node->output_tensors_.size()) {
MS_LOG(EXCEPTION) << "Output index " << prenode_index.second << " exceed input node ["
<< prenode_index.first->fullname_with_scope() << "]'s outputs size "
@@ -551,15 +564,18 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) {
}
auto input_somas_tensor = pre_somas_node->output_tensors_[prenode_index.second];
MS_EXCEPTION_IF_NULL(input_somas_tensor);
node->input_tensors_.push_back(input_somas_tensor);
std::for_each(nodes.begin(), nodes.end(),
[input_somas_tensor](auto &node) { node->input_tensors_.push_back(input_somas_tensor); });
real_input_index++;
if (input_somas_tensor->type_ == kOutputOnly) {
input_somas_tensor->type_ = kCommon;
}
input_somas_tensor->destinations_.insert(node);
input_somas_tensor->destinationStreams_.insert(stream);
if (input_somas_tensor->lifetime_.end_ < node->GetId()) {
input_somas_tensor->lifetime_.end_ = node->GetId();
for (auto &repeat_node : nodes) {
input_somas_tensor->destinations_.insert(repeat_node);
if (input_somas_tensor->lifetime_.end_ < repeat_node->GetId()) {
input_somas_tensor->lifetime_.end_ = repeat_node->GetId();
}
}

if (node != pre_somas_node) {
@@ -574,7 +590,7 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) {
}

void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kernel) {
auto node = nodes_map_[kernel.get()];
auto node = nodes_map_[kernel.get()].at(0);
MS_EXCEPTION_IF_NULL(node);
auto stream = node->GetStream();
MS_EXCEPTION_IF_NULL(stream);
@@ -588,7 +604,7 @@ void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kern
MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input ["
<< pre_node->fullname_with_scope() << "] is not init.";
}
auto pre_somas_node = iter->second;
auto pre_somas_node = iter->second.at(0);
// set clean output tensors
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
@@ -698,7 +714,8 @@ void Somas::GetNextOutputProcess(const session::KernelGraph *graph) {
}
auto iter = nodes_map_.find(kernel.get());
if (iter != nodes_map_.end()) {
auto getnext_output_tensors = iter->second->output_tensors_;
auto &node = iter->second.at(0);
auto getnext_output_tensors = node->output_tensors_;
for (auto &tensor : getnext_output_tensors) {
total_size += tensor->GetAlignedSize();
tensor->lifelong_value_ = kLifeLongGraphAll;
@@ -720,7 +737,8 @@ void Somas::IndependentNodeOutputProcess(const session::KernelGraph *graph) {
}
auto iter = nodes_map_.find(kernel.get());
if (iter != nodes_map_.end()) {
auto semi_reuse_output_tensors = iter->second->output_tensors_;
auto &node = iter->second.at(0);
auto semi_reuse_output_tensors = node->output_tensors_;
for (auto &tensor : semi_reuse_output_tensors) {
total_size += tensor->GetAlignedSize();
tensor->lifelong_value_ = kLifeLongGraphAll;
@@ -749,9 +767,9 @@ void Somas::SummaryInputProcess(const session::KernelGraph *graph) {
size_t index = IntToSize(node_item.second.second);
auto iter = nodes_map_.find(node.get());
if (iter != nodes_map_.end()) {
auto input_node = iter->second;
auto input_node = iter->second.at(0);
if (index < input_node->output_tensors_.size()) {
auto tensor = iter->second->output_tensors_[index];
auto tensor = input_node->output_tensors_[index];
tensor->lifelong_value_ = kLifeLongGraphAll;
tensor->type_ = kSummaryInput;
total_summary_size += tensor->GetAlignedSize();
@@ -789,7 +807,8 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) {
if (graph->IsInRefOutputMap(out_pair)) {
auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
MS_EXCEPTION_IF_NULL(origin_pair.first);
auto output_tensor = nodes_map_[kernel.get()]->output_tensors_[out_index];
auto &node = nodes_map_[kernel.get()].at(0);
auto output_tensor = node->output_tensors_[out_index];
MS_EXCEPTION_IF_NULL(output_tensor);
output_tensor->type_ = kRefNodeOutput;
total_output_size += size;
@@ -797,7 +816,8 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) {
if (AnfAlgo::IsRealCNodeKernel(origin_pair.first)) {
auto ori_node = origin_pair.first->cast<CNodePtr>();
auto ori_index = origin_pair.second;
auto input_tensor = nodes_map_[ori_node.get()]->output_tensors_[ori_index];
auto &repeat_node = nodes_map_[ori_node.get()].at(0);
auto input_tensor = repeat_node->output_tensors_[ori_index];
MS_EXCEPTION_IF_NULL(input_tensor);
input_tensor->type_ = kRefNodeInput;
total_input_size += input_tensor->aligned_size_;
@@ -821,7 +841,7 @@ void Somas::NonTaskSplitProcess(const session::KernelGraph *graph) {
auto op_name = AnfAlgo::GetCNodeName(kernel);
if ((op_name == kSplitOpName || op_name == kSplitVOpName) && AnfAlgo::HasNodeAttr(kAttrNonTask, kernel)) {
std::vector<size_t> refnode_input_output;
auto node = nodes_map_[kernel.get()];
auto node = nodes_map_[kernel.get()].at(0);
if (node->input_tensors_.size() == 0) {
MS_LOG(EXCEPTION) << op_name << " has no input tensor, can not do split non_task process.";
}
@@ -852,7 +872,7 @@ void Somas::UnReuseNodeProcess(const session::KernelGraph *graph) {
if (iter != full_name_list.end()) {
MS_LOG(INFO) << "Set UnReuse Node in somas, Node:" << full_name;
auto key = kernel.get();
auto somas_node = nodes_map_[key];
auto somas_node = nodes_map_[key].at(0);
// input
auto inputs = somas_node->input_tensors_;
for (auto &input : inputs) {
@@ -1749,11 +1769,12 @@ uint8_t *Somas::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const {
auto iter = nodes_map_.find(key);
uint8_t *ptr = nullptr;
if (iter != nodes_map_.end()) {
if (index >= iter->second->output_tensors_.size()) {
auto &node = iter->second.at(0);
if (index >= node->output_tensors_.size()) {
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:["
<< iter->second->output_tensors_.size() << "]";
<< node->output_tensors_.size() << "]";
}
auto output_tensor = iter->second->output_tensors_[index];
auto output_tensor = node->output_tensors_[index];
ptr = mem_base_addr_ + output_tensor->offset_;
} else {
MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in nodes_map";
@@ -1766,11 +1787,12 @@ uint8_t *Somas::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const
auto iter = nodes_map_.find(key);
uint8_t *ptr = nullptr;
if (iter != nodes_map_.end()) {
if (index >= iter->second->workspace_tensors_.size()) {
auto &node = iter->second.at(0);
if (index >= node->workspace_tensors_.size()) {
MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:["
<< iter->second->workspace_tensors_.size() << "]";
<< node->workspace_tensors_.size() << "]";
}
auto workspace_tensor = iter->second->workspace_tensors_[index];
auto workspace_tensor = node->workspace_tensors_[index];
ptr = mem_base_addr_ + workspace_tensor->offset_;
}
return ptr;


+ 1
- 1
mindspore/ccsrc/backend/optimizer/somas/somas.h View File

@@ -63,7 +63,7 @@ class Somas {
std::string hash_id_;
// Maps
std::unordered_map<size_t, SomasTensorPtr> tensors_map_;
std::map<void *, SomasNodePtr> nodes_map_;
std::map<void *, std::vector<SomasNodePtr>> nodes_map_;
std::map<void *, vector<SomasParameterPtr>> parameters_map_;

// Vectors


+ 28
- 0
mindspore/ccsrc/backend/session/ascend_auto_monad.cc View File

@@ -296,6 +296,15 @@ class AscendAutoMonadContext : public BaseContext {
// Set flag to indicate whether has already created an stack or not.
void SetInitedStack(bool flag) { inited_stack_ = flag; }

// The graphs has recursion.
bool HasRecursiveCall() const { return has_recursive_call_; }
// The graphs has subgraph multi-call.
bool HasSubgraphMultiCall() const { return has_subgraph_multicall_; }
// set flag to indicate whether has recursion.
void SetRecursiveCall(bool flag) { has_recursive_call_ = flag; }
// set flag to indicate whether has multi-call.
void SetSubGraphMultiCall(bool flag) { has_subgraph_multicall_ = flag; }

// Map kernel_graph to its call info.
OrderedMap<KernelGraphPtr, CallInfo> call_info_map;

@@ -311,6 +320,10 @@ class AscendAutoMonadContext : public BaseContext {

// Create an stack for multi-call and non-tail recursion.
bool inited_stack_ = false;
// The graphs has recursion or not.
bool has_recursive_call_ = false;
// The graphs has subgraph multi-call or not.
bool has_subgraph_multicall_ = false;
};

//
@@ -643,6 +656,11 @@ class AscendAutoMonadConverter {
}
// Handle recursive call.
kernel_graph_->SetExecOrderByDefault();
if (call_info_.recursive) {
const auto &nodes = kernel_graph_->execution_order();
AnfAlgo::SetNodeAttr(kAttrRecursiveStart, prim::kValueOne, *nodes.begin());
AnfAlgo::SetNodeAttr(kAttrRecursiveEnd, prim::kValueOne, *nodes.rbegin());
}
for (auto &call_site : call_info_.call_sites) {
if (need_stackops_ && call_site.recursive) {
MS_LOG(INFO) << "graph:" << kernel_graph_->ToString() << ", loop call_site:" << call_site.cnode->DebugString();
@@ -661,6 +679,7 @@ class AscendAutoMonadConverter {
auto stack_destroy = StackDestroy(top_graph);
AnfAlgo::KeepOrder(top_graph, *exec_order.rbegin(), stack_destroy);
top_graph->SetExecOrderByDefault();
context_.SetRecursiveCall(true);
context_.SetInitedStack(true);
}
}
@@ -812,6 +831,9 @@ class AscendAutoMonadConverter {
// Create LabelGoto or LabelSwitch node.
auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels);
call_site->conversion_cnode = label_goto_switch;
if (call_site->recursive) {
AnfAlgo::SetNodeAttr(kAttrRecursive, prim::kValueOne, label_goto_switch);
}

// Setup return label and output if required.
if (call_site->return_label != kNoLabel) {
@@ -931,7 +953,11 @@ class AscendAutoMonadConverter {
MS_EXCEPTION_IF_NULL(label_param);
auto return_switch = LabelSwitch(label_param, return_labels);
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch);
if (!call_info_.recursive) {
AnfAlgo::SetNodeAttr(kAttrMultiCallEnd, prim::kValueOne, return_switch);
}
kernel_graph_->set_end_goto(return_switch);
context_.SetSubGraphMultiCall(true);
}

// Assign graph output to the output parameter.
@@ -1650,6 +1676,8 @@ void AscendAutoMonad::Run() {
CallInfoFinder::Run(&context);
AscendAutoMonadConverter::Run(&context);
kernel_graph_->set_label_num(context.CurrentLabel() + 1);
kernel_graph_->set_recursive_call(context.HasRecursiveCall());
kernel_graph_->set_subgraph_multi_call(context.HasSubgraphMultiCall());
MS_LOG(DEBUG) << "Ascend auto-monad finish.";
DumpGraphForDebug(kernel_graph_);
}


+ 130
- 0
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -1034,9 +1034,139 @@ void AscendSession::BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kerne
MS_LOG(INFO) << "Finish!";
}

static CNodePtr GetNextLabelSet(const std::vector<CNodePtr> &kernel_nodes, uint32_t index) {
uint32_t node_sizes = kernel_nodes.size();
if (index >= node_sizes - 1) {
MS_LOG(EXCEPTION) << "there is no node after this node:" << kernel_nodes[index]->DebugString();
}
auto kernel = kernel_nodes[index + 1];
if (AnfAlgo::GetCNodeName(kernel) != kLabelSetOpName) {
MS_LOG(EXCEPTION) << "the node is not labelset follow labelgoto/labelswitch, node: "
<< kernel_nodes[index]->DebugString();
}
return kernel;
}

static std::vector<CNodePtr> HandleRecursiveCall(const std::vector<CNodePtr> &kernel_cnodes, const uint32_t &back_label,
uint32_t *index, std::vector<CNodePtr> *back) {
MS_EXCEPTION_IF_NULL(index);
MS_EXCEPTION_IF_NULL(back);
std::vector<CNodePtr> front;
std::vector<CNodePtr> back_temp;
bool back_flag = false;
for (uint32_t i = *index; i < kernel_cnodes.size(); i++) {
if (!back_flag) {
front.emplace_back(kernel_cnodes[i]);
} else {
back->emplace_back(kernel_cnodes[i]);
}
if (AnfAlgo::HasNodeAttr(kAttrRecursiveEnd, kernel_cnodes[i])) {
*index = i;
back->insert(back->end(), back_temp.begin(), back_temp.end());
return front;
}
if (AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
back_flag = true;
if (AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], back_label)) {
continue;
} else {
auto temp = HandleRecursiveCall(kernel_cnodes, back_label, &(++i), &back_temp);
front.insert(front.end(), temp.begin(), temp.end());
continue;
}
}
}
return front;
}

static void UnfoldRecursiveExecOrder(KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
if (!kernel_graph->recursive_call()) {
return;
}
auto kernel_cnodes = kernel_graph->mem_reuse_exec_order();
std::vector<CNodePtr> mem_reuse_order;
mem_reuse_order.reserve(kernel_cnodes.size());
for (uint32_t i = 0; i < kernel_cnodes.size(); i++) {
if (!AnfAlgo::HasNodeAttr(kAttrRecursiveStart, kernel_cnodes[i])) {
mem_reuse_order.emplace_back(kernel_cnodes[i]);
continue;
}
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
std::vector<CNodePtr> back;
auto front = HandleRecursiveCall(kernel_cnodes, label_id, &i, &back);
mem_reuse_order.insert(mem_reuse_order.end(), front.begin(), front.end());
mem_reuse_order.insert(mem_reuse_order.end(), back.begin(), back.end());
}
kernel_graph->set_mem_reuse_exec_order(mem_reuse_order);
}

static void GetSubGraphExecOrder(const KernelGraph *kernel_graph, uint32_t index, const CNodePtr &back_node,
std::vector<CNodePtr> *mem_reuse_order) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(mem_reuse_order);
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(back_node, kAttrLabelIndex);
auto kernel_cnodes = kernel_graph->execution_order();
for (auto i = index; i < kernel_cnodes.size(); i++) {
mem_reuse_order->emplace_back(kernel_cnodes[i]);
if (AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], label_id)) {
return;
}
}
}

void InitMemReuseExecOrder(KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
if (!kernel_graph->subgraph_multi_call()) {
return;
}
std::unordered_map<uint32_t, uint32_t> label_id_index_map;
auto kernel_cnodes = kernel_graph->execution_order();
std::vector<CNodePtr> mem_reuse_order;
for (size_t i = 0; i < kernel_cnodes.size(); i++) {
mem_reuse_order.emplace_back(kernel_cnodes[i]);
if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSwitch) &&
!AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
!AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(kernel_cnodes[i], kAttrLabelSwitchList);
for (auto label_id : label_list) {
if (label_id_index_map.find(label_id) == label_id_index_map.end()) {
continue;
}
auto back_node = GetNextLabelSet(kernel_cnodes, i);
GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order);
}
continue;
}
if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelGoto) &&
!AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
!AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
if (label_id_index_map.find(label_id) == label_id_index_map.end()) {
continue;
}
auto back_node = GetNextLabelSet(kernel_cnodes, i);
GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order);
continue;
}
if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSet) &&
!AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
if (label_id_index_map.find(label_id) != label_id_index_map.end()) {
MS_LOG(EXCEPTION) << "Two labelsets with same label id.";
}
label_id_index_map[label_id] = i;
continue;
}
}
kernel_graph->set_mem_reuse_exec_order(mem_reuse_order);
UnfoldRecursiveExecOrder(kernel_graph);
}

void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
MS_LOG(INFO) << "Start!";
MS_EXCEPTION_IF_NULL(kernel_graph);
InitMemReuseExecOrder(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->AssignMemory(kernel_graph);


+ 18
- 0
mindspore/ccsrc/backend/session/kernel_graph.h View File

@@ -41,6 +41,7 @@ class KernelGraph : public FuncGraph {
KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), current_epoch_(0), is_dynamic_shape_(false) {
inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
execution_order_ = {};
mem_reuse_exec_order_ = {};
executable_ = true;
summary_node_exist_ = false;
stream_distinction_label_ = kInvalidDistincLabel;
@@ -51,6 +52,7 @@ class KernelGraph : public FuncGraph {
inputs_ = graph.inputs_;
child_graph_result_ = graph.child_graph_result_;
execution_order_ = graph.execution_order_;
mem_reuse_exec_order_ = graph.mem_reuse_exec_order_;
graph_id_ = graph.graph_id_;
stream_distinction_label_ = graph.stream_distinction_label_;
front_backend_anf_map_ = graph.front_backend_anf_map_;
@@ -112,6 +114,9 @@ class KernelGraph : public FuncGraph {
void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; }
void set_execution_order(std::vector<CNodePtr> &&order) { execution_order_ = std::move(order); }
const std::vector<CNodePtr> &execution_order() const { return execution_order_; }
// Set new exec_order for mem_reuse
void set_mem_reuse_exec_order(const std::vector<CNodePtr> &order) { mem_reuse_exec_order_ = order; }
const std::vector<CNodePtr> &mem_reuse_exec_order() const { return mem_reuse_exec_order_; }
void SetExecOrderByDefault();
uint32_t graph_id() const { return graph_id_; }
void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; }
@@ -278,6 +283,14 @@ class KernelGraph : public FuncGraph {

uint32_t label_num() const { return label_num_; }
void set_label_num(uint32_t num) { label_num_ = num; }
// The graphs has recursion.
bool recursive_call() const { return has_recursive_call_; }
// The graphs has subgraph multi-call.
bool subgraph_multi_call() const { return has_subgraph_multicall_; }
// set flag to indicate whether has recursion.
void set_recursive_call(bool flag) { has_recursive_call_ = flag; }
// set flag to indicate whether has multi-call.
void set_subgraph_multi_call(bool flag) { has_subgraph_multicall_ = flag; }

private:
// remove value node form graph
@@ -307,6 +320,7 @@ class KernelGraph : public FuncGraph {
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
std::vector<AnfNodePtr> child_graph_result_;
std::vector<CNodePtr> execution_order_;
std::vector<CNodePtr> mem_reuse_exec_order_;
// extra params and tensors for control flow
std::vector<std::pair<ParameterPtr, tensor::TensorPtr>> extra_param_tensor_;
uint32_t graph_id_;
@@ -360,6 +374,10 @@ class KernelGraph : public FuncGraph {
bool has_optimizer_{false};
bool is_dynamic_shape_{false};

// Indicate the graphs has recursion or multi-call or not as the root graph.
bool has_recursive_call_{false};
bool has_subgraph_multicall_{false};

// Number of labels. This is also the 'batch_num' for DavinciModel,
// It should be 1 if no labels used for control flow.
uint32_t label_num_ = 1;


+ 4
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -421,6 +421,10 @@ constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first";
constexpr auto kAttrIgnoreSideEffect = "ignore_side_effect";
constexpr auto kAttrSwitchLayer = "switch_layer";
constexpr auto kAttrReturn = "return";
constexpr auto kAttrRecursiveStart = "recursive_start";
constexpr auto kAttrRecursiveEnd = "recursive_end";
constexpr auto kAttrRecursive = "recursive";
constexpr auto kAttrMultiCallEnd = "multicall_end";

// attr value
constexpr auto kValueTargetSwitch = "target_switch";


Loading…
Cancel
Save