Browse Source

Fetch total front node in kernel graph.

tags/v1.6.0
gaoyong10 4 years ago
parent
commit
ca677c0f16
3 changed files with 15 additions and 14 deletions
  1. +1
    -0
      mindspore/ccsrc/backend/session/kernel_graph.h
  2. +9
    -3
      mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc
  3. +5
    -11
      mindspore/ccsrc/runtime/framework/control_node_parser.cc

+ 1
- 0
mindspore/ccsrc/backend/session/kernel_graph.h View File

@@ -390,6 +390,7 @@ class KernelGraph : public FuncGraph {
void set_is_executing_sink(bool is_executing_sink) { is_executing_sink_ = is_executing_sink; }
bool is_loop_count_sink() const { return is_loop_count_sink_; }
void set_is_loop_count_sink(bool is_loop_count_sink) { is_loop_count_sink_ = is_loop_count_sink; }
const mindspore::HashMap<AnfNodePtr, AnfNodePtr> &front_backend_anf_map() { return front_backend_anf_map_; }

AnfWithOutIndex GetElementInTupleBackendFrontIndexMap(const AnfNodePtr &back_node) {
auto iter = tuple_backend_front_anf_index_map_.find(back_node);


+ 9
- 3
mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc View File

@@ -78,7 +78,15 @@ void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
// There are two kinds of run conditions for entrance actor:
// 1.Data comes from the data source actor, it is in the form of data arrow.
const auto &data_iter = input_op_datas_.find(sequential_num);
if (data_iter != input_op_datas_.end()) {
const auto &control_iter = input_op_controls_.find(sequential_num);
if (data_iter != input_op_datas_.end() || control_iter != input_op_controls_.end()) {
// If the data comes from the data source actor, use the default branch id.
output_branch_id_ = 0;

if (data_iter == input_op_datas_.end()) {
return;
}

for (auto &input_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(input_data);
if (IntToSize(input_data->index_) >= input_device_tensors_.size()) {
@@ -88,8 +96,6 @@ void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(input_data->data_);
input_device_tensors_[input_data->index_] = input_data->data_;
}
// If the data comes from the data source actor, use the default branch id.
output_branch_id_ = 0;
} else {
// 2.Data comes from the gather actor, it is in the form of data with branch id.
output_branch_id_ = real_parameters_with_branch_id_[sequential_num].front().branch_id_;


+ 5
- 11
mindspore/ccsrc/runtime/framework/control_node_parser.cc View File

@@ -844,7 +844,7 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def
const auto &cnode = return_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
const auto output_nodes = FetchAllOutputWithIndex(inputs[kReturnInputPos]);
const auto output_nodes = FetchInputNodeByNode(inputs[kReturnInputPos]);
std::vector<const DeviceContext *> return_device_contexts;

for (const auto &output_node : output_nodes) {
@@ -909,19 +909,13 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def

void ControlNodeParser::FetchFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs) {
for (const auto &graph : graphs) {
MS_EXCEPTION_IF_NULL(graph);
if (graph->execution_order().empty()) {
continue;
}

for (auto &kernel : graph->execution_order()) {
auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
if (front_node != nullptr) {
front_node_to_kernel_graph_[front_node] = graph;
}
}
const auto &graph_outputs = graph->graph_output_map();
for (const auto &backend_to_front : graph_outputs) {
front_node_to_kernel_graph_[backend_to_front.second.first] = graph;
const auto &front_to_backend_nodes = graph->front_backend_anf_map();
for (const auto &front_to_backend_node : front_to_backend_nodes) {
front_node_to_kernel_graph_[front_to_backend_node.first] = graph;
}
}
}


Loading…
Cancel
Save