|
|
|
@@ -911,71 +911,12 @@ void Somas::GenContiguousList(const session::KernelGraph *graph) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Somas::PreprocessingConflicts() { |
|
|
|
// Compute ancestor streams |
|
|
|
for (auto stream : streams_list_) { |
|
|
|
stream->ComputeAncestorStreams(); |
|
|
|
} |
|
|
|
|
|
|
|
// Preset ancestor streams for node |
|
|
|
for (auto node : nodes_list_) { |
|
|
|
node->PresetAncestorStreams(streams_list_); |
|
|
|
} |
|
|
|
|
|
|
|
// Compute ancestor nodes : needs to be executed in topological order |
|
|
|
for (auto node : nodes_list_) { |
|
|
|
node->ComputeAncestorNodes(); |
|
|
|
} |
|
|
|
|
|
|
|
// Compute MaxDestinationId for between-stream tensors |
|
|
|
for (auto tensor : tensors_list_) { |
|
|
|
if (tensor->IsBetweenStreams()) { |
|
|
|
tensor->ComputeMaxDestinationId(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Preprocessing for stream groups |
|
|
|
for (auto group : streams_groups_) { |
|
|
|
vector<SomasStreamPtr> previous_streams; |
|
|
|
for (int64_t stream_id : group) { |
|
|
|
auto it = std::find_if(streams_list_.begin(), streams_list_.end(), |
|
|
|
[stream_id](const SomasStreamPtr &stream) { return stream->GetId() == stream_id; }); |
|
|
|
if (it != streams_list_.end()) { |
|
|
|
for (auto stream : previous_streams) { |
|
|
|
(*it)->ancestor_streams_group_.insert(stream); |
|
|
|
} |
|
|
|
previous_streams.push_back(*it); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Atomic: fix any issues on saved lifetimes of tensors |
|
|
|
for (auto tensor : tensors_list_) { |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
for (auto node : tensor->destinations_) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor->GetSourceNode()); |
|
|
|
if (tensor->GetSourceNode()->GetId() > node->GetId()) { |
|
|
|
tensor->lifetime_.start_ = node->GetId(); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(tensor->GetSourceNode()); |
|
|
|
if (tensor->GetSourceNode()->GetId() > tensor->lifetime_.end_) { |
|
|
|
tensor->lifetime_.end_ = tensor->GetSourceNode()->GetId(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Somas::ComputeConflictPairs() { |
|
|
|
if (tensors_list_.empty()) { |
|
|
|
MS_LOG(INFO) << "No Tensor for Conflict computing"; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Start Preprocessing Conflicts"; |
|
|
|
PreprocessingConflicts(); |
|
|
|
MS_LOG(INFO) << "End Preprocessing Conflicts"; |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Start Conflict Computing (Bitset Model)"; |
|
|
|
auto start_conflict = std::chrono::system_clock::now(); |
|
|
|
std::sort(nodes_list_.begin(), nodes_list_.end(), NodeSort); |
|
|
|
@@ -1082,6 +1023,11 @@ void Somas::UpdateTensorDestinations() { |
|
|
|
tensor->destinations_.insert(tensor->GetSourceNode()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Loop to compute max destinations in each stream |
|
|
|
for (const auto &tensor : tensors_list_) { |
|
|
|
tensor->ComputeMaxDestinationId(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Somas::ComputeMultiTensorConflicts(const std::vector<SomasTensorPtr> &calc_tensors_list, |
|
|
|
@@ -1125,7 +1071,8 @@ void Somas::ComputeOneTensorConflicts(const std::shared_ptr<SomasTensor> &calc_t |
|
|
|
|
|
|
|
bool reuse = true; |
|
|
|
// check calc_tensor's all consumers is target_tensor's source node's dependency or not |
|
|
|
for (const auto &dst_node : calc_tensor->destinations_) { |
|
|
|
for (const auto &dst_map : calc_tensor->max_destinations_) { |
|
|
|
const auto &dst_node = dst_map.second; |
|
|
|
if (nodes_dependency[target_src_node].IsBitTrue(dst_node->GetId()) == false) { |
|
|
|
// calc_tensor's consumer is not in target_tensor's source node's dependency, not sure this consumer is done or |
|
|
|
// not when target_tensor produced |
|
|
|
|