|
|
|
@@ -34,10 +34,12 @@ |
|
|
|
#include "backend/optimizer/common/helper.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
#include "debug/common.h" |
|
|
|
#include "common/thread_pool.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace somas { |
|
|
|
constexpr auto kGapSize = 512; |
|
|
|
constexpr auto kParallelComputeSizeThreshold = 2000; |
|
|
|
std::map<TensorType, std::string> tensor_type_name_map = {{kCommon, "Common"}, |
|
|
|
{kOutputOnly, "OutputOnly"}, |
|
|
|
{kWorkspace, "Workspace"}, |
|
|
|
@@ -641,7 +643,7 @@ void Somas::ComputeConflictPairs() { |
|
|
|
MS_LOG(INFO) << "End Preprocessing Conflicts"; |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Start Conflict Computing (Bitset Model)"; |
|
|
|
|
|
|
|
auto start_conflict = std::chrono::system_clock::now(); |
|
|
|
std::sort(nodes_list_.begin(), nodes_list_.end(), NodeSort); |
|
|
|
|
|
|
|
// Loop to add edges within each stream (node order within stream) |
|
|
|
@@ -708,76 +710,107 @@ void Somas::ComputeConflictPairs() { |
|
|
|
MS_LOG(INFO) << "Start Tensor Relation Computing"; |
|
|
|
count = tensors_list_.back()->GetId() + 1; |
|
|
|
for (size_t i = 0; i < count; i++) { |
|
|
|
tensor_relation.emplace_back(count); |
|
|
|
reuse_matrix_.emplace_back(count); |
|
|
|
} |
|
|
|
|
|
|
|
if (tensors_list_.size() < kParallelComputeSizeThreshold) { |
|
|
|
ComputeMultiTensorConflicts(tensors_list_, tensors_list_, nodes_dependency, &reuse_matrix_); |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "Tensor Num " << tensors_list_.size() << " is larger than " << kParallelComputeSizeThreshold; |
|
|
|
MS_LOG(INFO) << "Enter Multi-Thread Mode..."; |
|
|
|
size_t process_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); |
|
|
|
MS_LOG(INFO) << "Threads Num is " << process_num; |
|
|
|
|
|
|
|
size_t start_index = 0; |
|
|
|
size_t total_size = tensors_list_.size(); |
|
|
|
size_t job_size = total_size / process_num; |
|
|
|
if (job_size == 0) { |
|
|
|
job_size = total_size; |
|
|
|
} |
|
|
|
std::vector<common::Task> tasks; |
|
|
|
while (start_index < total_size) { |
|
|
|
size_t end_index = (start_index + job_size) > total_size ? total_size : start_index + job_size; |
|
|
|
auto jobs = std::vector<SomasTensorPtr>(tensors_list_.begin() + start_index, tensors_list_.begin() + end_index); |
|
|
|
auto task = [this, jobs, &nodes_dependency]() { |
|
|
|
this->ComputeMultiTensorConflicts(jobs, tensors_list_, nodes_dependency, &reuse_matrix_); |
|
|
|
return common::SUCCESS; |
|
|
|
}; |
|
|
|
tasks.emplace_back(task); |
|
|
|
start_index += job_size; |
|
|
|
} |
|
|
|
|
|
|
|
common::ThreadPool::GetInstance().SyncRun(tasks); |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "End Tensor Relation Computing"; |
|
|
|
auto end_conflict = std::chrono::system_clock::now(); |
|
|
|
MS_LOG(INFO) << "End Conflict Computing (Bitset Model)(time taken " |
|
|
|
<< std::chrono::duration_cast<std::chrono::milliseconds>(end_conflict - start_conflict).count() << "ms)"; |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t i = 0; i < tensors_list_.size(); i++) { |
|
|
|
auto t0 = tensors_list_[i]; |
|
|
|
if (t0->IsLifelong() || t0->IsRefOverlap() || t0->GetAlignedSize() == 0) { |
|
|
|
void Somas::ComputeMultiTensorConflicts(const std::vector<SomasTensorPtr> &calc_tensors_list, |
|
|
|
const std::vector<SomasTensorPtr> &all_tensors_list, |
|
|
|
const vector<DynamicBitSet> &nodes_dependency, |
|
|
|
std::vector<DynamicBitSet> *tensor_relation) const { |
|
|
|
auto start = std::chrono::system_clock::now(); |
|
|
|
MS_LOG(INFO) << "Start Computing Conflicts Pairs, tensors list size is " << calc_tensors_list.size(); |
|
|
|
for (size_t i = 0; i < calc_tensors_list.size(); i++) { |
|
|
|
auto calc_tensor = calc_tensors_list[i]; |
|
|
|
if (calc_tensor->IsLifelong() || calc_tensor->IsRefOverlap() || calc_tensor->GetAlignedSize() == 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
size_t t0_src_node = t0->GetSourceNode()->GetId(); |
|
|
|
for (size_t j = i + 1; j < tensors_list_.size(); j++) { |
|
|
|
auto t1 = tensors_list_[j]; |
|
|
|
|
|
|
|
if (t0 == t1 || t1->IsLifelong() || t1->IsRefOverlap() || t1->GetAlignedSize() == 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
size_t t1_src_node = t1->GetSourceNode()->GetId(); |
|
|
|
if (t0_src_node == t1_src_node) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
ComputeOneTensorConflicts(calc_tensor, all_tensors_list, nodes_dependency, tensor_relation); |
|
|
|
} |
|
|
|
auto end = std::chrono::system_clock::now(); |
|
|
|
MS_LOG(INFO) << "End Computing Conflicts Pairs (time taken " |
|
|
|
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms)"; |
|
|
|
} |
|
|
|
|
|
|
|
bool reuse = true; |
|
|
|
bool all_dst_depend = false; |
|
|
|
// check t0's all consumers is t1's source node's dependency or not |
|
|
|
for (const auto &dst_node : t0->destinations_) { |
|
|
|
if (nodes_dependency[t1_src_node].IsBitTrue(dst_node->GetId()) == false) { |
|
|
|
// t0's consumer is not in t1's source node's dependency, not sure this consumer is done or not when t1 |
|
|
|
// produced |
|
|
|
reuse = false; |
|
|
|
all_dst_depend = false; |
|
|
|
break; |
|
|
|
} else if (t1_src_node == dst_node->GetId()) { |
|
|
|
// t0 is t1's source node's input, can't reuse |
|
|
|
reuse = false; |
|
|
|
all_dst_depend = true; |
|
|
|
break; |
|
|
|
} else { |
|
|
|
// t0's consumer is in t1's source node's dependency, this consumer is done when t1 produced |
|
|
|
reuse = true; |
|
|
|
all_dst_depend = true; |
|
|
|
} |
|
|
|
} |
|
|
|
void Somas::ComputeOneTensorConflicts(const std::shared_ptr<SomasTensor> &calc_tensor, |
|
|
|
const std::vector<SomasTensorPtr> &all_tensors_list, |
|
|
|
const vector<DynamicBitSet> &nodes_dependency, |
|
|
|
std::vector<DynamicBitSet> *tensor_relation) const { |
|
|
|
for (size_t j = 0; j < all_tensors_list.size(); j++) { |
|
|
|
auto target_tensor = all_tensors_list[j]; |
|
|
|
if (calc_tensor == target_tensor || target_tensor->IsLifelong() || target_tensor->IsRefOverlap() || |
|
|
|
target_tensor->GetAlignedSize() == 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
size_t calc_src_node = calc_tensor->GetSourceNode()->GetId(); |
|
|
|
size_t target_src_node = target_tensor->GetSourceNode()->GetId(); |
|
|
|
if (calc_src_node == target_src_node) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if ((*tensor_relation)[calc_tensor->GetId()].IsBitTrue(target_tensor->GetId()) || |
|
|
|
(*tensor_relation)[target_tensor->GetId()].IsBitTrue(calc_tensor->GetId())) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (all_dst_depend == false) { |
|
|
|
// check t1's all consumers is t0's source node's dependency or not |
|
|
|
bool reuse = true; |
|
|
|
// check calc_tensor's all consumers is target_tensor's source node's dependency or not |
|
|
|
for (const auto &dst_node : calc_tensor->destinations_) { |
|
|
|
if (nodes_dependency[target_src_node].IsBitTrue(dst_node->GetId()) == false) { |
|
|
|
// calc_tensor's consumer is not in target_tensor's source node's dependency, not sure this consumer is done or |
|
|
|
// not when target_tensor produced |
|
|
|
reuse = false; |
|
|
|
break; |
|
|
|
} else if (target_src_node == dst_node->GetId()) { |
|
|
|
// calc_tensor is target_tensor's source node's input, can't reuse |
|
|
|
reuse = false; |
|
|
|
break; |
|
|
|
} else { |
|
|
|
// calc_tensor's consumer is in target_tensor's source node's dependency, this consumer is done when |
|
|
|
// target_tensor produced |
|
|
|
reuse = true; |
|
|
|
for (const auto &dst_node : t1->destinations_) { |
|
|
|
if (nodes_dependency[t0_src_node].IsBitTrue(dst_node->GetId()) == false) { |
|
|
|
reuse = false; |
|
|
|
all_dst_depend = false; |
|
|
|
break; |
|
|
|
} else if (t0_src_node == dst_node->GetId()) { |
|
|
|
reuse = false; |
|
|
|
all_dst_depend = true; |
|
|
|
break; |
|
|
|
} else { |
|
|
|
reuse = true; |
|
|
|
all_dst_depend = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (all_dst_depend == true && reuse == true) { |
|
|
|
tensor_relation[t0->GetId()].SetBitTrue(t1->GetId()); |
|
|
|
tensor_relation[t1->GetId()].SetBitTrue(t0->GetId()); |
|
|
|
} |
|
|
|
if (reuse) { |
|
|
|
// calc_tensor and target_tensor have dependencies so they can reuse each other |
|
|
|
(*tensor_relation)[calc_tensor->GetId()].SetBitTrue(target_tensor->GetId()); |
|
|
|
(*tensor_relation)[target_tensor->GetId()].SetBitTrue(calc_tensor->GetId()); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "End Tensor Relation Computing"; |
|
|
|
MS_LOG(INFO) << "End Conflict Computing (Bitset Model)"; |
|
|
|
} |
|
|
|
|
|
|
|
bool Somas::NodeSort(SomasNodePtr node1, SomasNodePtr node2) { return node1->GetId() < node2->GetId(); } |
|
|
|
@@ -798,13 +831,13 @@ bool Somas::Assign(const session::KernelGraph *graph) { |
|
|
|
// Keep all constraints for first tensor in list |
|
|
|
size_t tid_0 = ref_node_list[0]; |
|
|
|
for (SomasTensorPtr tensor : tensors_list_) { |
|
|
|
if (tensor_relation[tid_0].IsBitTrue(tensor->GetId()) == false) { |
|
|
|
if (reuse_matrix_[tid_0].IsBitTrue(tensor->GetId()) == false) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (size_t tid : ref_node_list) { |
|
|
|
if (tensor_relation[tid].IsBitTrue(tensor->GetId()) == false) { |
|
|
|
tensor_relation[tid_0].SetBitFalse(tensor->GetId()); |
|
|
|
tensor_relation[tensor->GetId()].SetBitFalse(tid_0); |
|
|
|
if (reuse_matrix_[tid].IsBitTrue(tensor->GetId()) == false) { |
|
|
|
reuse_matrix_[tid_0].SetBitFalse(tensor->GetId()); |
|
|
|
reuse_matrix_[tensor->GetId()].SetBitFalse(tid_0); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -924,23 +957,21 @@ bool Somas::Assign(const session::KernelGraph *graph) { |
|
|
|
for (auto ref_overlap_list : ref_overlap_constraints_) { |
|
|
|
for (size_t tid_1 : ref_overlap_list) { |
|
|
|
for (size_t tid_2 : ref_overlap_list) { |
|
|
|
tensor_relation[tid_1].SetBitTrue(tid_2); |
|
|
|
tensor_relation[tid_2].SetBitTrue(tid_1); |
|
|
|
reuse_matrix_[tid_1].SetBitTrue(tid_2); |
|
|
|
reuse_matrix_[tid_2].SetBitTrue(tid_1); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "End Solving Preprocessing for Ref Overlap"; |
|
|
|
|
|
|
|
#ifdef SOMAS_DEBUG |
|
|
|
// Compute number of constraints for each tensor |
|
|
|
auto tensors_num = tensors_list_.size(); |
|
|
|
for (auto tensor1 : tensors_list_) { |
|
|
|
size_t count_constraints = 0; |
|
|
|
for (auto tensor2 : tensors_list_) { |
|
|
|
if (tensor_relation[tensor1->GetId()].IsBitTrue(tensor2->GetId()) == false) { |
|
|
|
count_constraints++; |
|
|
|
} |
|
|
|
} |
|
|
|
tensor1->num_constraints_ = count_constraints; |
|
|
|
auto ones_num = reuse_matrix_[tensor1->GetId()].CountOnesNum(); |
|
|
|
tensor1->num_constraints_ = tensors_num - ones_num; |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
// Prepare solver info |
|
|
|
MS_LOG(INFO) << "Start Loop to create solver info"; |
|
|
|
@@ -960,7 +991,7 @@ bool Somas::Assign(const session::KernelGraph *graph) { |
|
|
|
} |
|
|
|
|
|
|
|
somas_solver_ = std::make_shared<SomasSolverPre>(); |
|
|
|
auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &tensor_relation, |
|
|
|
auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &reuse_matrix_, |
|
|
|
contiguous_tensors_list_removed_ref, false); |
|
|
|
MS_LOG(INFO) << "End Solving"; |
|
|
|
if (status != SUCCESS) { |
|
|
|
|