Browse Source

!10236 Support if_by_if case by labelgoto with labelswitch

From: @liangzelang
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
a98ae4129a
9 changed files with 178 additions and 7 deletions
  1. +10
    -0
      mindspore/ccsrc/backend/optimizer/common/node_pass.cc
  2. +8
    -2
      mindspore/ccsrc/backend/session/ascend_control_parser.cc
  3. +1
    -0
      mindspore/ccsrc/backend/session/ascend_control_parser.h
  4. +111
    -1
      mindspore/ccsrc/backend/session/ascend_session.cc
  5. +4
    -0
      mindspore/ccsrc/backend/session/ascend_session.h
  6. +27
    -0
      mindspore/ccsrc/backend/session/kernel_graph.cc
  7. +7
    -1
      mindspore/ccsrc/backend/session/kernel_graph.h
  8. +9
    -3
      mindspore/ccsrc/backend/session/session_basic.cc
  9. +1
    -0
      mindspore/ccsrc/backend/session/session_basic.h

+ 10
- 0
mindspore/ccsrc/backend/optimizer/common/node_pass.cc View File

@@ -45,6 +45,16 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
bool change = (new_node != nullptr);
if (new_node != nullptr && new_node != node) {
(void)manager->Replace(node, new_node);
// if replaced node is end_goto, refresh relative params in kernel graph
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
if (kernel_graph != nullptr && node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto end_label = kernel_graph->get_end_goto();
if (cnode == end_label && AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
kernel_graph->set_end_goto(new_node->cast<CNodePtr>());
}
}
(void)seen_node.erase(node);
} else if (new_node == nullptr) {
new_node = node;


+ 8
- 2
mindspore/ccsrc/backend/session/ascend_control_parser.cc View File

@@ -739,8 +739,14 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr
<< from_graph->ToString();
}
// insert assign between jump_node -1 and jump_node
if (jump_node_iter != from_graph_exe_order.begin()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
while (jump_node_iter != from_graph_exe_order.begin()) {
CNodePtr node = *(jump_node_iter - 1);
if (AnfAlgo::GetGraphId(node.get()) == from_graph->graph_id()) {
InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
break;
} else {
jump_node_iter--;
}
}
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
}


+ 1
- 0
mindspore/ccsrc/backend/session/ascend_control_parser.h View File

@@ -23,6 +23,7 @@
#include <utility>
#include <functional>
#include <memory>
#include <string>
#include "backend/session/kernel_graph.h"
#include "base/base_ref.h"
#include "utils/contract.h"


+ 111
- 1
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -64,7 +64,7 @@
#include "ps/util.h"
#include "ps/ps_cache/ps_cache_manager.h"
#endif
static constexpr uint32_t kLabelSwitchLabelId = 2;
namespace mindspore {
namespace session {
const size_t kInvalidIndex = SIZE_MAX;
@@ -485,6 +485,8 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
memo.clear();
// insert goto labels and label_sets
LinkChildGraphs(NOT_NULL(root_graph));
// replace labelgoto with labelswitch in subgraph called multiple times
MultiCallGraphOptimize(NOT_NULL(root_graph));
// resource initialize
InitRuntimeResource();

@@ -667,6 +669,10 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tens
MS_LOG(INFO) << "No child graph has anf output";
return;
}
// load data to extra params
std::set<KernelGraphPtr> memo;
SyncDataToExtraParams(NOT_NULL(kernel_graph), NOT_NULL(&memo));
memo.clear();
// load input data from user input
LoadInputData(kernel_graph, inputs);
if (debugger_) {
@@ -1190,6 +1196,110 @@ void AscendSession::BackendOptimization(const std::vector<KernelGraphPtr> &all_g

void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); }

bool AscendSession::IsMultiCallGraph(NotNull<KernelGraphPtr> graph, std::vector<GraphId> parent_graphs) {
std::stack<GraphId> post_graph;
std::set<GraphId> memo;
post_graph.push(graph->graph_id());
while (!post_graph.empty()) {
auto graph_id = post_graph.top();
post_graph.pop();
memo.insert(graph_id);
for (auto child_graph : graphs_[graph_id]->child_graph_order()) {
std::shared_ptr<KernelGraph> child_graph_ptr = child_graph.lock();
MS_EXCEPTION_IF_NULL(child_graph_ptr);
if (std::find(parent_graphs.begin(), parent_graphs.end(), child_graph_ptr->graph_id()) != parent_graphs.end()) {
MS_LOG(DEBUG) << "graph:" << graph->graph_id() << " will call its parent graph:" << child_graph_ptr->graph_id();
return false;
} else if (memo.find(child_graph_ptr->graph_id()) == memo.end()) {
MS_LOG(DEBUG) << "child graph:" << child_graph_ptr->graph_id() << " into deque, wait for check.";
post_graph.push(child_graph_ptr->graph_id());
}
}
}
return true;
}

void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) {
for (auto current : parent_graphs_) {
if (current.second.size() < 2) {
continue;
}
auto graph = graphs_[current.first];
auto parent_kernel_graphs = current.second;
if (!IsMultiCallGraph(NOT_NULL(graph), parent_kernel_graphs)) {
MS_LOG(DEBUG) << "graph:" << graph->graph_id() << " with it's parent graphs make up a cycle";
continue;
}
MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs";
int32_t index = 0;
std::vector<KernelGraphPtr> child_graphs;
auto start_label = graph->get_start_label();
auto end_node = graph->get_end_goto();
ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0);
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
post_label_param};
for (auto graph_id : parent_kernel_graphs) {
auto kg = graphs_[graph_id];
auto nodes = kg->execution_order();
for (uint32_t i = 0; i < nodes.size(); i++) {
if (AnfAlgo::GetCNodeName(nodes[i]) == kLabelGotoOpName &&
(AnfAlgo::GetNodeAttr<uint32_t>(nodes[i], kAttrLabelIndex) ==
AnfAlgo::GetNodeAttr<uint32_t>(start_label, kAttrLabelIndex))) {
if (i < (nodes.size() - 1)) {
new_inputs.push_back(nodes[i + 1]);
} else {
MS_LOG(EXCEPTION) << "No labelset after labelgoto";
}
ParameterPtr pre_label_param = kg->AddExtraParamAndTensor("label_param", index++);
AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(kg), nodes[i], NOT_NULL(pre_label_param),
NOT_NULL(post_label_param));
}
}
kg->SetExecOrderByDefault();
child_graphs.push_back(kg);
}
end_node->set_inputs(new_inputs);
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), end_node);
std::vector<uint32_t> label_list;
for (size_t i = kLabelSwitchLabelId; i < end_node->size(); ++i) {
auto input = end_node->input(i);
MS_EXCEPTION_IF_NULL(input);
if (!input->isa<CNode>() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) {
break;
}
uint32_t goto_label_id = AnfAlgo::GetNodeAttr<uint32_t>(input, kAttrLabelIndex);
label_list.push_back(goto_label_id);
MS_LOG(INFO) << "Switch " << end_node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id "
<< goto_label_id;
}
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), end_node);
end_node->set_inputs({end_node->input(kAnfPrimitiveIndex), end_node->input(kFirstDataInputIndex)});
graph->SetExecOrderByDefault();
}
}

void AscendSession::SyncDataToExtraParams(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(graph.get()) != memo->end()) {
return;
}
memo->insert(graph.get());
auto extra_param_tensor = graph->GetExtraParamAndTensor();
for (uint32_t i = 0; i < extra_param_tensor.size(); i++) {
auto param = extra_param_tensor[i].first;
auto tensor = extra_param_tensor[i].second;
auto device_address = AnfAlgo::GetMutableOutputAddr(param, 0);
MS_EXCEPTION_IF_NULL(device_address);
tensor->set_device_address(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(param, 0), LongToSize(tensor->data().nbytes()),
tensor->data_type(), tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
for (auto &child_graph : graph->child_graph_order()) {
SyncDataToExtraParams(NOT_NULL(child_graph.lock()), memo);
}
}

void AscendSession::RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph) {
AscendControlParser::ExecutorValidate(graph);
}


+ 4
- 0
mindspore/ccsrc/backend/session/ascend_session.h View File

@@ -93,6 +93,10 @@ class AscendSession : public SessionBasic {

static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
static void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
// replace labelgoto with labelswitch in subgraph called multiple times
void MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph);
bool IsMultiCallGraph(NotNull<KernelGraphPtr> graph, std::vector<GraphId> parent_graphs);
void SyncDataToExtraParams(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
// merge execution order list of child graphs
void MergeGraphExecOrder();


+ 27
- 0
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -1213,6 +1213,33 @@ void KernelGraph::RemoveNodeFromGraph(const AnfNodePtr &node) {
}
}

ParameterPtr KernelGraph::AddExtraParamAndTensor(std::string param_name, int32_t value) {
ParameterPtr param;
ShapeVector shp = {1};
tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
MS_EXCEPTION_IF_NULL(tensor_ptr);
mindspore::abstract::AbstractBasePtr paremeter_abstract_ptr = tensor_ptr->ToAbstract();
ParameterPtr new_param = std::make_shared<Parameter>(shared_from_this()->cast<KernelGraphPtr>());
MS_EXCEPTION_IF_NULL(new_param);
new_param->set_name(param_name);
new_param->set_abstract(paremeter_abstract_ptr);
param = NewParameter(new_param);
// ensure alloc mem for this param
std::vector<AnfNodePtr> *mute_inputs = MutableInputs();
MS_EXCEPTION_IF_NULL(mute_inputs);
mute_inputs->push_back(param);

tensor::TensorPtr data_tensor_ptr = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
MS_EXCEPTION_IF_NULL(data_tensor_ptr);
int32_t *val = nullptr;
val = static_cast<int32_t *>(data_tensor_ptr->data_c());
*val = value;

extra_param_tensor_.push_back(std::make_pair(param, data_tensor_ptr));
MS_LOG(INFO) << "Create new param: " << param->DebugString();
return param;
}

void KernelGraph::UpdateGraphDynamicAttr() {
for (const auto &cnode : execution_order_) {
if (AnfAlgo::IsDynamicShape(cnode)) {


+ 7
- 1
mindspore/ccsrc/backend/session/kernel_graph.h View File

@@ -44,6 +44,7 @@ class KernelGraph : public FuncGraph {
executable_ = true;
summary_node_exist_ = false;
stream_distinction_label_ = kInvalidDistincLabel;
extra_param_tensor_ = {};
}

KernelGraph(const KernelGraph &graph) : FuncGraph(graph) {
@@ -87,6 +88,7 @@ class KernelGraph : public FuncGraph {
first_step_ = graph.first_step_;
has_optimizer_ = graph.has_optimizer_;
is_dynamic_shape_ = graph.is_dynamic_shape_;
extra_param_tensor_ = graph.extra_param_tensor_;
}

~KernelGraph() override;
@@ -220,7 +222,9 @@ class KernelGraph : public FuncGraph {
}
}
void RemoveNodeFromGraph(const AnfNodePtr &node);

// Add Param which pass callback point
ParameterPtr AddExtraParamAndTensor(std::string param_name, int32_t value);
const std::vector<std::pair<ParameterPtr, tensor::TensorPtr>> GetExtraParamAndTensor() { return extra_param_tensor_; }
void UpdateGraphDynamicAttr();
bool is_dynamic_shape() const { return is_dynamic_shape_; }
void SetOptimizerFlag();
@@ -302,6 +306,8 @@ class KernelGraph : public FuncGraph {
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
std::vector<AnfNodePtr> child_graph_result_;
std::vector<CNodePtr> execution_order_;
// extra params and tensors for control flow
std::vector<std::pair<ParameterPtr, tensor::TensorPtr>> extra_param_tensor_;
uint32_t graph_id_;
uint32_t stream_distinction_label_;



+ 9
- 3
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -1012,6 +1012,12 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
(void)ConstructKernelGraph(child_graph, all_out_graph);
}
(void)CreateValueNodeKernelGraph(node, graph.get());
auto &parent_graph = parent_graphs_[front_backend_graph_map_[child_graph]->graph_id()];
auto parent_graph_it =
std::find(parent_graph.begin(), parent_graph.end(), front_backend_graph_map_[func_graph]->graph_id());
if (parent_graph_it == parent_graph.end()) {
parent_graph.push_back(front_backend_graph_map_[func_graph]->graph_id());
}
continue;
}
// Create cnode
@@ -1096,10 +1102,10 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
}
auto &input_nodes = kernel_graph->input_nodes();
if ((inputs.size() + input_ctrl_size) - 3 != input_nodes.size()) {
auto extra_param_size = kernel_graph->GetExtraParamAndTensor().size();
if ((inputs.size() + input_ctrl_size) - 3 != input_nodes.size() - extra_param_size) {
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
<< ", input_ctrl_size:" << input_ctrl_size;
<< ", input_ctrl_size:" << input_ctrl_size << ", extra_param_size:" << extra_param_size;
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);


+ 1
- 0
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -202,6 +202,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
std::unordered_map<FuncGraphPtr, KernelGraphPtr> front_backend_graph_map_;
std::unordered_map<GraphId, std::vector<GraphId>> parent_graphs_;
std::shared_ptr<Context> context_;
CallBackFunc summary_callback_;
static GraphId graph_sum_;


Loading…
Cancel
Save