Browse Source

max_destinations_ map to optimize conflict computing model

tags/v1.2.0-rc1
Ioannis Lamprou laiyongqiang 5 years ago
parent
commit
fdc5ed63fd
4 changed files with 13 additions and 65 deletions
  1. +7
    -60
      mindspore/ccsrc/backend/optimizer/somas/somas.cc
  2. +0
    -1
      mindspore/ccsrc/backend/optimizer/somas/somas.h
  3. +5
    -4
      mindspore/ccsrc/backend/optimizer/somas/somas_tensor.cc
  4. +1
    -0
      mindspore/ccsrc/backend/optimizer/somas/somas_tensor.h

+ 7
- 60
mindspore/ccsrc/backend/optimizer/somas/somas.cc View File

@@ -911,71 +911,12 @@ void Somas::GenContiguousList(const session::KernelGraph *graph) {
}
}

void Somas::PreprocessingConflicts() {
// Compute ancestor streams
for (auto stream : streams_list_) {
stream->ComputeAncestorStreams();
}

// Preset ancestor streams for node
for (auto node : nodes_list_) {
node->PresetAncestorStreams(streams_list_);
}

// Compute ancestor nodes : needs to be executed in topological order
for (auto node : nodes_list_) {
node->ComputeAncestorNodes();
}

// Compute MaxDestinationId for between-stream tensors
for (auto tensor : tensors_list_) {
if (tensor->IsBetweenStreams()) {
tensor->ComputeMaxDestinationId();
}
}

// Preprocessing for stream groups
for (auto group : streams_groups_) {
vector<SomasStreamPtr> previous_streams;
for (int64_t stream_id : group) {
auto it = std::find_if(streams_list_.begin(), streams_list_.end(),
[stream_id](const SomasStreamPtr &stream) { return stream->GetId() == stream_id; });
if (it != streams_list_.end()) {
for (auto stream : previous_streams) {
(*it)->ancestor_streams_group_.insert(stream);
}
previous_streams.push_back(*it);
}
}
}

// Atomic: fix any issues on saved lifetimes of tensors
for (auto tensor : tensors_list_) {
MS_EXCEPTION_IF_NULL(tensor);
for (auto node : tensor->destinations_) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(tensor->GetSourceNode());
if (tensor->GetSourceNode()->GetId() > node->GetId()) {
tensor->lifetime_.start_ = node->GetId();
}
}
MS_EXCEPTION_IF_NULL(tensor->GetSourceNode());
if (tensor->GetSourceNode()->GetId() > tensor->lifetime_.end_) {
tensor->lifetime_.end_ = tensor->GetSourceNode()->GetId();
}
}
}

void Somas::ComputeConflictPairs() {
if (tensors_list_.empty()) {
MS_LOG(INFO) << "No Tensor for Conflict computing";
return;
}

MS_LOG(INFO) << "Start Preprocessing Conflicts";
PreprocessingConflicts();
MS_LOG(INFO) << "End Preprocessing Conflicts";

MS_LOG(INFO) << "Start Conflict Computing (Bitset Model)";
auto start_conflict = std::chrono::system_clock::now();
std::sort(nodes_list_.begin(), nodes_list_.end(), NodeSort);
@@ -1082,6 +1023,11 @@ void Somas::UpdateTensorDestinations() {
tensor->destinations_.insert(tensor->GetSourceNode());
}
}

// Loop to compute max destinations in each stream
for (const auto &tensor : tensors_list_) {
tensor->ComputeMaxDestinationId();
}
}

void Somas::ComputeMultiTensorConflicts(const std::vector<SomasTensorPtr> &calc_tensors_list,
@@ -1125,7 +1071,8 @@ void Somas::ComputeOneTensorConflicts(const std::shared_ptr<SomasTensor> &calc_t

bool reuse = true;
// check calc_tensor's all consumers is target_tensor's source node's dependency or not
for (const auto &dst_node : calc_tensor->destinations_) {
for (const auto &dst_map : calc_tensor->max_destinations_) {
const auto &dst_node = dst_map.second;
if (nodes_dependency[target_src_node].IsBitTrue(dst_node->GetId()) == false) {
// calc_tensor's consumer is not in target_tensor's source node's dependency, not sure this consumer is done or
// not when target_tensor produced


+ 0
- 1
mindspore/ccsrc/backend/optimizer/somas/somas.h View File

@@ -121,7 +121,6 @@ class Somas {
SomasTensorPtr CreateGapTensor(size_t gap_tensor_id);
void GenContiguousList(const session::KernelGraph *graph);

void PreprocessingConflicts();
void ComputeConflictPairs();

bool Assign(const session::KernelGraph *graph);


+ 5
- 4
mindspore/ccsrc/backend/optimizer/somas/somas_tensor.cc View File

@@ -56,10 +56,11 @@ SomasSolverTensorDescPtr SomasTensor::GetSolverTensorDesc() {
}

void SomasTensor::ComputeMaxDestinationId() {
for (SomasStreamPtr stream : destinationStreams_) max_destination_id_[stream] = 0;

for (SomasNodePtr node : destinations_)
if (node->GetId() > max_destination_id_[node->GetStream()]) max_destination_id_[node->GetStream()] = node->GetId();
for (const auto &node : destinations_)
if (node->GetId() > max_destination_id_[node->GetStream()]) {
max_destination_id_[node->GetStream()] = node->GetId();
max_destinations_[node->GetStream()] = node;
}
}
} // namespace somas
} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/backend/optimizer/somas/somas_tensor.h View File

@@ -82,6 +82,7 @@ class SomasTensor {
std::set<SomasNodePtr> destinations_;
std::set<SomasStreamPtr> destinationStreams_;
unordered_map<SomasStreamPtr, size_t> max_destination_id_;
unordered_map<SomasStreamPtr, SomasNodePtr> max_destinations_;

// Constructors/Destructors
explicit SomasTensor(size_t id, SomasNodePtr source_node, SomasStreamPtr source_stream, size_t real_size,


Loading…
Cancel
Save