|
|
|
@@ -122,6 +122,7 @@ void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
node->scope_full_name_ = kernel->fullname_with_scope(); |
|
|
|
nodes_list_.push_back(node); |
|
|
|
stream->nodes_.push_back(node); |
|
|
|
auto key = kernel.get(); |
|
|
|
nodes_map_[key] = node; |
|
|
|
node_index++; |
|
|
|
@@ -565,220 +566,155 @@ void Somas::PreprocessingConflicts() { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static bool LifetimeOverlap(lifetime_t lifetime1, lifetime_t lifetime2) { |
|
|
|
size_t start1 = std::min(lifetime1.start_, lifetime1.end_); |
|
|
|
size_t end1 = std::max(lifetime1.start_, lifetime1.end_); |
|
|
|
size_t start2 = std::min(lifetime2.start_, lifetime2.end_); |
|
|
|
size_t end2 = std::max(lifetime2.start_, lifetime2.end_); |
|
|
|
return (std::max(end1, end2) - std::min(start1, start2) <= end2 - start2 + end1 - start1); |
|
|
|
} |
|
|
|
|
|
|
|
static bool Subset(std::set<SomasStreamPtr> streamSet1, std::set<SomasStreamPtr> streamSet2) { |
|
|
|
for (auto stream : streamSet1) { |
|
|
|
if (streamSet2.count(stream) == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
void Somas::ComputeConflictPairs() { |
|
|
|
if (tensors_list_.empty()) { |
|
|
|
MS_LOG(INFO) << "No Tensor for Conflict computing"; |
|
|
|
return; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
static void EraseSet(std::set<SomasStreamPtr> *streamSet, std::set<SomasStreamPtr> removeStreamsSet) { |
|
|
|
for (auto stream : removeStreamsSet) { |
|
|
|
streamSet->erase(stream); |
|
|
|
MS_LOG(INFO) << "Start Preprocessing Conflicts"; |
|
|
|
PreprocessingConflicts(); |
|
|
|
MS_LOG(INFO) << "End Preprocessing Conflicts"; |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Start Conflict Computing (Bitset Model)"; |
|
|
|
std::sort(nodes_list_.begin(), nodes_list_.end(), NodeSort); |
|
|
|
|
|
|
|
// Loop to add edges within each stream (node order within stream) |
|
|
|
for (const auto &stream : streams_list_) { |
|
|
|
auto &nodes = stream->nodes_; |
|
|
|
std::sort(nodes.begin(), nodes.end(), NodeSort); |
|
|
|
for (size_t i = 1; i < nodes.size(); i++) { |
|
|
|
const auto &previous_node = nodes[i - 1]; |
|
|
|
const auto ¤t_node = nodes[i]; |
|
|
|
current_node->ancestor_nodes_.insert(previous_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static bool ValidSubset(std::set<SomasStreamPtr> destStreams, std::set<SomasStreamPtr> ancestorsAndSelf, |
|
|
|
SomasTensorPtr ancestorTensor, SomasTensorPtr tensor) { |
|
|
|
MS_EXCEPTION_IF_NULL(ancestorTensor); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
for (auto stream : destStreams) { |
|
|
|
if (ancestorsAndSelf.count(stream) == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (stream != tensor->GetSourceStream()) { |
|
|
|
MS_EXCEPTION_IF_NULL(tensor->GetSourceStream()); |
|
|
|
if (tensor->GetSourceStream()->ancestor_streams_group_.count(stream) == 0 && |
|
|
|
ancestorTensor->max_destination_id_[stream] > |
|
|
|
tensor->GetSourceNode()->anc_stream_max_order_[stream->GetId()]) { |
|
|
|
return false; |
|
|
|
// Loop to add edges from end to beginning of next group |
|
|
|
for (const auto &group : streams_groups_) { |
|
|
|
for (size_t i = 1; i < group.size(); i++) { |
|
|
|
int64_t previous_stream = group[i - 1]; |
|
|
|
int64_t current_stream = group[i]; |
|
|
|
|
|
|
|
auto it = |
|
|
|
std::find_if(streams_list_.begin(), streams_list_.end(), |
|
|
|
[previous_stream](const SomasStreamPtr &stream) { return stream->GetId() == previous_stream; }); |
|
|
|
if (it == streams_list_.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (ancestorTensor->max_destination_id_[stream] >= tensor->lifetime_.start_) { |
|
|
|
return false; |
|
|
|
auto &last_node_in_prev_stream = (*it)->nodes_.back(); |
|
|
|
|
|
|
|
it = std::find_if(streams_list_.begin(), streams_list_.end(), |
|
|
|
[current_stream](const SomasStreamPtr &stream) { return stream->GetId() == current_stream; }); |
|
|
|
if (it == streams_list_.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto &first_node_in_cur_stream = (*it)->nodes_.front(); |
|
|
|
|
|
|
|
first_node_in_cur_stream->ancestor_nodes_.insert(last_node_in_prev_stream); |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void Somas::ComputeConflictPairs() { |
|
|
|
if (tensors_list_.empty()) { |
|
|
|
MS_LOG(INFO) << "No Tensor for Conflict computing"; |
|
|
|
return; |
|
|
|
// Loop to avoid tensors with empty destinations (add itself) |
|
|
|
for (const auto &tensor : tensors_list_) { |
|
|
|
if (tensor->destinations_.size() == 0) { |
|
|
|
tensor->destinations_.insert(tensor->GetSourceNode()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Start Preprocessing Conflicts"; |
|
|
|
PreprocessingConflicts(); |
|
|
|
MS_LOG(INFO) << "End Preprocessing Conflicts"; |
|
|
|
MS_LOG(INFO) << "Start Bitset"; |
|
|
|
std::vector<DynamicBitSet> nodes_dependency; |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Start Array Initialization"; |
|
|
|
cannot_reuse_ = |
|
|
|
std::make_shared<Array>(tensors_list_.back()->GetId() + 1, |
|
|
|
tensors_list_.back()->GetId() + 1); // array size is (max_id + 1) x (max_id + 1) |
|
|
|
MS_LOG(INFO) << "End Array Initialization"; |
|
|
|
size_t count = nodes_list_.back()->GetId() + 1; |
|
|
|
for (size_t i = 0; i < count; i++) { |
|
|
|
nodes_dependency.emplace_back(count); |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Start Path Computing"; |
|
|
|
// Loop to compute ancestor paths via bitset for time dependence |
|
|
|
for (const auto &node : nodes_list_) { |
|
|
|
for (const auto &ancestor : node->ancestor_nodes_) { |
|
|
|
nodes_dependency[node->GetId()].SetBitTrue(ancestor->GetId()); |
|
|
|
Union(&nodes_dependency[node->GetId()], &nodes_dependency[ancestor->GetId()]); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "End Path Computing"; |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Start Conflict Computing"; |
|
|
|
MS_LOG(INFO) << "Start Tensor Relation Computing"; |
|
|
|
count = tensors_list_.back()->GetId() + 1; |
|
|
|
for (size_t i = 0; i < count; i++) { |
|
|
|
tensor_relation.emplace_back(count); |
|
|
|
} |
|
|
|
|
|
|
|
size_t count_reuse = 0; |
|
|
|
for (size_t i = 0; i < tensors_list_.size(); i++) { |
|
|
|
for (size_t j = i + 1; j < tensors_list_.size(); j++) { |
|
|
|
auto t0 = tensors_list_[i]; |
|
|
|
auto t1 = tensors_list_[j]; |
|
|
|
|
|
|
|
// Loop for ancestor stream groups reuse |
|
|
|
for (auto stream : streams_list_) { |
|
|
|
std::set<SomasStreamPtr> ancestors = stream->ancestor_streams_group_; |
|
|
|
|
|
|
|
std::set<SomasStreamPtr> ancestors_and_self = ancestors; |
|
|
|
ancestors_and_self.insert(stream); |
|
|
|
|
|
|
|
for (auto ancestor_stream : ancestors) { |
|
|
|
for (auto ancestor_tensor : ancestor_stream->tensors_) { |
|
|
|
if (ancestor_tensor->GetAlignedSize() == 0) continue; |
|
|
|
if (ancestor_tensor->IsLifelong()) continue; |
|
|
|
if (ancestor_tensor->IsSemiLifelongEnd()) continue; |
|
|
|
if (ancestor_tensor->IsRefOverlap()) continue; |
|
|
|
|
|
|
|
if (!ancestor_tensor->IsBetweenStreams() || Subset(ancestor_tensor->destinationStreams_, ancestors)) { |
|
|
|
for (auto tensor : stream->tensors_) { |
|
|
|
if (tensor->IsGap()) continue; |
|
|
|
if (tensor->GetAlignedSize() == 0) continue; |
|
|
|
if (tensor->IsLifelong()) continue; |
|
|
|
if (tensor->IsSemiLifelongStart()) continue; |
|
|
|
if (tensor->IsRefOverlap()) continue; |
|
|
|
|
|
|
|
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0; |
|
|
|
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0; |
|
|
|
count_reuse++; |
|
|
|
} |
|
|
|
if (t0 == t1 || t0->IsGap() || t1->IsGap() || t0->IsLifelong() || t1->IsLifelong() || t0->IsRefOverlap() || |
|
|
|
t1->IsRefOverlap() || t0->GetAlignedSize() == 0 || t1->GetAlignedSize() == 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
size_t t0_src_node = t0->GetSourceNode()->GetId(); |
|
|
|
size_t t1_src_node = t1->GetSourceNode()->GetId(); |
|
|
|
if (t0_src_node == t1_src_node) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
bool reuse = true; |
|
|
|
bool all_dst_depend = false; |
|
|
|
// check t0's all consumers is t1's source node's dependency or not |
|
|
|
for (const auto &dst_node : t0->destinations_) { |
|
|
|
if (nodes_dependency[t1_src_node].IsBitTrue(dst_node->GetId()) == false) { |
|
|
|
// t0's consumer is not in t1's source node's dependency, not sure this consumer is done or not when t1 |
|
|
|
// produced |
|
|
|
reuse = false; |
|
|
|
all_dst_depend = false; |
|
|
|
break; |
|
|
|
} else if (t1_src_node == dst_node->GetId()) { |
|
|
|
// t0 is t1's source node's input, can't reuse |
|
|
|
reuse = false; |
|
|
|
all_dst_depend = true; |
|
|
|
break; |
|
|
|
} else { |
|
|
|
for (auto tensor : stream->tensors_) { |
|
|
|
if (Subset(ancestor_tensor->destinationStreams_, ancestors_and_self) && |
|
|
|
ancestor_tensor->max_destination_id_[tensor->GetSourceStream()] < tensor->lifetime_.start_) { |
|
|
|
if (tensor->IsGap()) continue; |
|
|
|
if (tensor->GetAlignedSize() == 0) continue; |
|
|
|
if (tensor->IsLifelong()) continue; |
|
|
|
if (tensor->IsSemiLifelongStart()) continue; |
|
|
|
if (tensor->IsRefOverlap()) continue; |
|
|
|
|
|
|
|
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0; |
|
|
|
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0; |
|
|
|
count_reuse++; |
|
|
|
} |
|
|
|
} |
|
|
|
// t0's consumer is in t1's source node's dependency, this consumer is done when t1 produced |
|
|
|
reuse = true; |
|
|
|
all_dst_depend = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Loop for ancestor streams (no groups) |
|
|
|
for (auto stream : streams_list_) { |
|
|
|
auto ancestors_no_groups = stream->ancestor_streams_; |
|
|
|
EraseSet(&ancestors_no_groups, stream->ancestor_streams_group_); |
|
|
|
|
|
|
|
for (auto ancestor_stream : ancestors_no_groups) { |
|
|
|
for (auto ancestor_tensor : ancestor_stream->tensors_) { |
|
|
|
if (ancestor_tensor->GetAlignedSize() == 0) continue; |
|
|
|
if (ancestor_tensor->IsLifelong()) continue; |
|
|
|
if (ancestor_tensor->IsSemiLifelongEnd()) continue; |
|
|
|
if (ancestor_tensor->IsRefOverlap()) continue; |
|
|
|
|
|
|
|
if (!ancestor_tensor->IsBetweenStreams()) { |
|
|
|
for (auto tensor : stream->tensors_) { |
|
|
|
if (tensor->IsGap()) continue; |
|
|
|
if (tensor->GetAlignedSize() == 0) continue; |
|
|
|
if (tensor->IsLifelong()) continue; |
|
|
|
if (tensor->IsSemiLifelongStart()) continue; |
|
|
|
if (tensor->IsRefOverlap()) continue; |
|
|
|
|
|
|
|
size_t max_ancestor_order = tensor->GetSourceNode()->anc_stream_max_order_[ancestor_stream->GetId()]; |
|
|
|
|
|
|
|
if (ancestor_tensor->lifetime_.end_ <= max_ancestor_order) { |
|
|
|
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0; |
|
|
|
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0; |
|
|
|
count_reuse++; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { // ancestor tensor goes to another stream (might go to same stream also) |
|
|
|
std::set<SomasStreamPtr> dest_streams = ancestor_tensor->destinationStreams_; |
|
|
|
std::set<SomasStreamPtr> ancestors = stream->ancestor_streams_; |
|
|
|
|
|
|
|
std::set<SomasStreamPtr> ancestors_and_self = ancestors; |
|
|
|
ancestors_and_self.insert(stream); |
|
|
|
|
|
|
|
for (auto tensor : stream->tensors_) { |
|
|
|
if (tensor->IsGap()) continue; |
|
|
|
if (tensor->GetAlignedSize() == 0) continue; |
|
|
|
if (tensor->IsLifelong()) continue; |
|
|
|
if (tensor->IsSemiLifelongStart()) continue; |
|
|
|
if (tensor->IsRefOverlap()) continue; |
|
|
|
|
|
|
|
if (ValidSubset(dest_streams, ancestors_and_self, ancestor_tensor, tensor)) { |
|
|
|
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0; |
|
|
|
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0; |
|
|
|
count_reuse++; |
|
|
|
} |
|
|
|
if (all_dst_depend == false) { |
|
|
|
// check t1's all consumers is t0's source node's dependency or not |
|
|
|
reuse = true; |
|
|
|
for (const auto &dst_node : t1->destinations_) { |
|
|
|
if (nodes_dependency[t0_src_node].IsBitTrue(dst_node->GetId()) == false) { |
|
|
|
reuse = false; |
|
|
|
all_dst_depend = false; |
|
|
|
break; |
|
|
|
} else if (t0_src_node == dst_node->GetId()) { |
|
|
|
reuse = false; |
|
|
|
all_dst_depend = true; |
|
|
|
break; |
|
|
|
} else { |
|
|
|
reuse = true; |
|
|
|
all_dst_depend = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Loop for same stream |
|
|
|
for (auto stream : streams_list_) { |
|
|
|
MS_EXCEPTION_IF_NULL(stream); |
|
|
|
for (auto tensor1 : stream->tensors_) { |
|
|
|
if (tensor1->GetAlignedSize() == 0) continue; |
|
|
|
if (tensor1->IsLifelong()) continue; |
|
|
|
if (tensor1->IsRefOverlap()) continue; |
|
|
|
|
|
|
|
for (auto tensor2 : stream->tensors_) { |
|
|
|
if (tensor2->GetId() >= tensor1->GetId()) |
|
|
|
break; // keep only when tensors kept sorted in tensors-vector of each stream, otherwise remove |
|
|
|
|
|
|
|
if (tensor2->GetAlignedSize() == 0) continue; |
|
|
|
if (tensor2->IsLifelong()) continue; |
|
|
|
if (tensor2->IsRefOverlap()) continue; |
|
|
|
|
|
|
|
// Between streams extra safety |
|
|
|
if (tensor1->IsBetweenStreams() && tensor2->IsBetweenStreams()) continue; |
|
|
|
|
|
|
|
// Check lifetime overlap |
|
|
|
lifetime_t lifetime1 = tensor1->lifetime_; |
|
|
|
lifetime_t lifetime2 = tensor2->lifetime_; |
|
|
|
|
|
|
|
if (!LifetimeOverlap(lifetime1, lifetime2)) { |
|
|
|
// Between-streams extra safety |
|
|
|
if (tensor1->IsBetweenStreams() && lifetime1.end_ < lifetime2.start_) continue; |
|
|
|
if (tensor2->IsBetweenStreams() && lifetime2.end_ < lifetime1.start_) continue; |
|
|
|
|
|
|
|
// Semi-lifelong extra safety |
|
|
|
if (lifetime1.end_ < lifetime2.start_ && (tensor2->IsSemiLifelongStart() || tensor1->IsSemiLifelongEnd())) |
|
|
|
continue; |
|
|
|
if (lifetime2.end_ < lifetime1.start_ && (tensor1->IsSemiLifelongStart() || tensor2->IsSemiLifelongEnd())) |
|
|
|
continue; |
|
|
|
|
|
|
|
// If arrived here, allow reuse |
|
|
|
(*cannot_reuse_)(tensor2->GetId(), tensor1->GetId()) = 0; |
|
|
|
(*cannot_reuse_)(tensor1->GetId(), tensor2->GetId()) = 0; |
|
|
|
count_reuse++; |
|
|
|
} |
|
|
|
if (all_dst_depend == true && reuse == true) { |
|
|
|
tensor_relation[t0->GetId()].SetBitTrue(t1->GetId()); |
|
|
|
tensor_relation[t1->GetId()].SetBitTrue(t0->GetId()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "End Conflict Computing"; |
|
|
|
MS_LOG(INFO) << "Found " << count_reuse << " tensor pairs of allowed reusability"; |
|
|
|
MS_LOG(INFO) << "End Tensor Relation Computing"; |
|
|
|
MS_LOG(INFO) << "End Conflict Computing (Bitset Model)"; |
|
|
|
} |
|
|
|
|
|
|
|
bool Somas::NodeSort(SomasNodePtr node1, SomasNodePtr node2) { return node1->GetId() < node2->GetId(); } |
|
|
|
|
|
|
|
bool Somas::Assign(const session::KernelGraph *graph) { |
|
|
|
if (tensors_list_.empty()) { |
|
|
|
MS_LOG(INFO) << "No Tensor for Assigner"; |
|
|
|
@@ -795,13 +731,13 @@ bool Somas::Assign(const session::KernelGraph *graph) { |
|
|
|
// Keep all constraints for first tensor in list |
|
|
|
size_t tid_0 = ref_node_list[0]; |
|
|
|
for (SomasTensorPtr tensor : tensors_list_) { |
|
|
|
if ((*cannot_reuse_)(tid_0, tensor->GetId()) == 1) { |
|
|
|
if (tensor_relation[tid_0].IsBitTrue(tensor->GetId()) == false) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (size_t tid : ref_node_list) { |
|
|
|
if ((*cannot_reuse_)(tid, tensor->GetId()) == 1) { |
|
|
|
(*cannot_reuse_)(tid_0, tensor->GetId()) = 1; |
|
|
|
(*cannot_reuse_)(tensor->GetId(), tid_0) = 1; |
|
|
|
if (tensor_relation[tid].IsBitTrue(tensor->GetId()) == false) { |
|
|
|
tensor_relation[tid_0].SetBitFalse(tensor->GetId()); |
|
|
|
tensor_relation[tensor->GetId()].SetBitFalse(tid_0); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -921,8 +857,8 @@ bool Somas::Assign(const session::KernelGraph *graph) { |
|
|
|
for (auto ref_overlap_list : ref_overlap_constraints_) { |
|
|
|
for (size_t tid_1 : ref_overlap_list) { |
|
|
|
for (size_t tid_2 : ref_overlap_list) { |
|
|
|
(*cannot_reuse_)(tid_1, tid_2) = 0; |
|
|
|
(*cannot_reuse_)(tid_2, tid_1) = 0; |
|
|
|
tensor_relation[tid_1].SetBitTrue(tid_2); |
|
|
|
tensor_relation[tid_2].SetBitTrue(tid_1); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -932,7 +868,7 @@ bool Somas::Assign(const session::KernelGraph *graph) { |
|
|
|
for (auto tensor1 : tensors_list_) { |
|
|
|
size_t count_constraints = 0; |
|
|
|
for (auto tensor2 : tensors_list_) { |
|
|
|
if ((*cannot_reuse_)(tensor1->GetId(), tensor2->GetId()) == 1) { |
|
|
|
if (tensor_relation[tensor1->GetId()].IsBitTrue(tensor2->GetId()) == false) { |
|
|
|
count_constraints++; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -943,7 +879,7 @@ bool Somas::Assign(const session::KernelGraph *graph) { |
|
|
|
MS_LOG(INFO) << "Start Contiguous Gaps Preprocessing"; |
|
|
|
for (auto contiguous_list : contiguous_tensors_list_) { |
|
|
|
if (contiguous_list.size() < 3) { |
|
|
|
MS_LOG(ERROR) << "contiguous_list should has at least one input and two gap, now it is " |
|
|
|
MS_LOG(ERROR) << "contiguous_list should have at least one input and two gap, now it is " |
|
|
|
<< contiguous_list.size(); |
|
|
|
} |
|
|
|
size_t front_gap_id = contiguous_list[0]; |
|
|
|
@@ -959,10 +895,20 @@ bool Somas::Assign(const session::KernelGraph *graph) { |
|
|
|
size_t back_neighbour_id = contiguous_list[contiguous_list.size() - 2]; |
|
|
|
for (SomasTensorPtr tensor : tensors_list_) { |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
(*cannot_reuse_)(tensor->GetId(), front_gap_id) = (*cannot_reuse_)(tensor->GetId(), front_neighbour_id); |
|
|
|
(*cannot_reuse_)(front_gap_id, tensor->GetId()) = (*cannot_reuse_)(front_neighbour_id, tensor->GetId()); |
|
|
|
(*cannot_reuse_)(tensor->GetId(), back_gap_id) = (*cannot_reuse_)(tensor->GetId(), back_neighbour_id); |
|
|
|
(*cannot_reuse_)(back_gap_id, tensor->GetId()) = (*cannot_reuse_)(back_neighbour_id, tensor->GetId()); |
|
|
|
if (tensor_relation[tensor->GetId()].IsBitTrue(front_neighbour_id) == false) { |
|
|
|
tensor_relation[tensor->GetId()].SetBitFalse(front_gap_id); |
|
|
|
tensor_relation[front_gap_id].SetBitFalse(tensor->GetId()); |
|
|
|
} else { |
|
|
|
tensor_relation[tensor->GetId()].SetBitTrue(front_gap_id); |
|
|
|
tensor_relation[front_gap_id].SetBitTrue(tensor->GetId()); |
|
|
|
} |
|
|
|
if (tensor_relation[tensor->GetId()].IsBitTrue(back_neighbour_id) == false) { |
|
|
|
tensor_relation[tensor->GetId()].SetBitFalse(back_gap_id); |
|
|
|
tensor_relation[back_gap_id].SetBitFalse(tensor->GetId()); |
|
|
|
} else { |
|
|
|
tensor_relation[tensor->GetId()].SetBitTrue(back_gap_id); |
|
|
|
tensor_relation[back_gap_id].SetBitTrue(tensor->GetId()); |
|
|
|
} |
|
|
|
} |
|
|
|
SomasTensorPtr front_neighbour = tensors_map_[front_neighbour_id]; |
|
|
|
SomasTensorPtr back_neighbour = tensors_map_[back_neighbour_id]; |
|
|
|
@@ -995,8 +941,8 @@ bool Somas::Assign(const session::KernelGraph *graph) { |
|
|
|
} |
|
|
|
|
|
|
|
somas_solver_ = std::make_shared<SomasSolverPre>(); |
|
|
|
auto status = |
|
|
|
somas_solver_->Solving(graph, &solver_tensor_desc_list_, cannot_reuse_, contiguous_tensors_list_removed_ref, false); |
|
|
|
auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &tensor_relation, |
|
|
|
contiguous_tensors_list_removed_ref, false); |
|
|
|
MS_LOG(INFO) << "End Solving"; |
|
|
|
if (status != SUCCESS) { |
|
|
|
GenStatisticInfo(); |
|
|
|
|