Browse Source

Fix output actor for control flow.

tags/v1.6.0
gaoyong10 4 years ago
parent
commit
b14df4644f
4 changed files with 16 additions and 1 deletions
  1. +5
    -0
      mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc
  2. +7
    -1
      mindspore/ccsrc/runtime/framework/actor/output_actor.cc
  3. +2
    -0
      mindspore/ccsrc/runtime/framework/actor/output_actor.h
  4. +2
    -0
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc

+ 5
- 0
mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc View File

@@ -131,6 +131,11 @@ void EntranceActor::EraseInput(const OpContext<DeviceTensor> *const context) {
return;
}

const auto &control_iter = input_op_controls_.find(sequential_num);
if (control_iter != input_op_controls_.end()) {
input_op_controls_.erase(control_iter);
}

const auto &iter = input_op_data_with_branch_id_.find(sequential_num);
if (iter == input_op_data_with_branch_id_.end() || iter->second.empty()) {
MS_LOG(ERROR) << "Cannot find input in batch op result for actor:" << GetAID();


+ 7
- 1
mindspore/ccsrc/runtime/framework/actor/output_actor.cc View File

@@ -89,6 +89,7 @@ void OutputActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<De

// Save the output nodes to clear the device tensor in the running end.
output_nodes_[output_position] = node_with_index;
output_device_tensors_[output_position] = input_data->data_;
}

TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position) {
@@ -152,7 +153,10 @@ void OutputActor::UpdateOutputDeviceAddress() {
MS_EXCEPTION_IF_NULL(tensor);
auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
MS_EXCEPTION_IF_NULL(tensor_device_address);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
if (i >= output_device_tensors_.size()) {
MS_LOG(EXCEPTION) << "Invalid index:" << i << " current:" << output_device_tensors_.size();
}
auto device_tensor = output_device_tensors_[i];
MS_EXCEPTION_IF_NULL(device_tensor);

// Update tensor device address by device tensor of output node.
@@ -192,6 +196,8 @@ void OutputActor::UpdateOutputDeviceAddress() {
output_node_to_tensor_device_address_.clear();
output_nodes_.clear();
output_nodes_.resize(outputs_num_);
output_device_tensors_.clear();
output_device_tensors_.resize(outputs_num_);
}
} // namespace runtime
} // namespace mindspore

+ 2
- 0
mindspore/ccsrc/runtime/framework/actor/output_actor.h View File

@@ -50,6 +50,7 @@ class OutputActor : public AbstractActor {
current_outputs_num_(0) {
outputs_.resize(outputs_num);
output_nodes_.resize(outputs_num);
output_device_tensors_.resize(outputs_num);
device_contexts_.resize(outputs_num);
}
~OutputActor() override = default;
@@ -86,6 +87,7 @@ class OutputActor : public AbstractActor {
// The outputs.
std::vector<TensorPtr> outputs_;
std::vector<KernelWithIndex> output_nodes_;
std::vector<DeviceTensor *> output_device_tensors_;
// Record the output nodes whose output address must be persisted and can't be changed.
// For example the output address of output node in the sink graph is persisted.
std::set<AnfNodePtr> output_address_persisted_nodes_;


+ 2
- 0
mindspore/ccsrc/runtime/framework/graph_scheduler.cc View File

@@ -575,6 +575,8 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
iter - host_queue_ds_actor->data_nodes_.begin());
} else {
(void)host_queue_ds_actor->data_node_position_map_.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
(void)host_queue_ds_actor->data_node_position_map_.emplace(backend_iter->second.begin()->first,
host_queue_ds_actor->data_nodes_.size());
(void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.begin()->first);
(void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.begin()->second);
}


Loading…
Cancel
Save