From c02a707831941dd566dd40bb5b97beaa0bd72e21 Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Thu, 28 Jan 2021 09:34:19 +0800 Subject: [PATCH] extract functions in SOMAS --- .../ccsrc/backend/optimizer/somas/somas.cc | 475 ++++++++++-------- .../ccsrc/backend/optimizer/somas/somas.h | 10 + 2 files changed, 267 insertions(+), 218 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/somas/somas.cc b/mindspore/ccsrc/backend/optimizer/somas/somas.cc index b21711cdc6..b93c300481 100644 --- a/mindspore/ccsrc/backend/optimizer/somas/somas.cc +++ b/mindspore/ccsrc/backend/optimizer/somas/somas.cc @@ -655,49 +655,7 @@ void Somas::ComputeConflictPairs() { MS_LOG(INFO) << "Start Conflict Computing (Bitset Model)"; auto start_conflict = std::chrono::system_clock::now(); std::sort(nodes_list_.begin(), nodes_list_.end(), NodeSort); - - // Loop to add edges within each stream (node order within stream) - for (const auto &stream : streams_list_) { - auto &nodes = stream->nodes_; - std::sort(nodes.begin(), nodes.end(), NodeSort); - for (size_t i = 1; i < nodes.size(); i++) { - const auto &previous_node = nodes[i - 1]; - const auto ¤t_node = nodes[i]; - current_node->ancestor_nodes_.insert(previous_node); - } - } - - // Loop to add edges from end to beginning of next group - for (const auto &group : streams_groups_) { - for (size_t i = 1; i < group.size(); i++) { - int64_t previous_stream = group[i - 1]; - int64_t current_stream = group[i]; - - auto it = - std::find_if(streams_list_.begin(), streams_list_.end(), - [previous_stream](const SomasStreamPtr &stream) { return stream->GetId() == previous_stream; }); - if (it == streams_list_.end()) { - continue; - } - auto &last_node_in_prev_stream = (*it)->nodes_.back(); - - it = std::find_if(streams_list_.begin(), streams_list_.end(), - [current_stream](const SomasStreamPtr &stream) { return stream->GetId() == current_stream; }); - if (it == streams_list_.end()) { - continue; - } - auto &first_node_in_cur_stream = (*it)->nodes_.front(); - - first_node_in_cur_stream->ancestor_nodes_.insert(last_node_in_prev_stream); - } - } - - // Loop to avoid tensors with empty destinations (add itself) - for (const auto &tensor : tensors_list_) { - if (tensor->destinations_.size() == 0) { - tensor->destinations_.insert(tensor->GetSourceNode()); - } - } + UpdateTensorDestinations(); MS_LOG(INFO) << "Start Bitset"; std::vector nodes_dependency; @@ -757,6 +715,51 @@ void Somas::ComputeConflictPairs() { << std::chrono::duration_cast(end_conflict - start_conflict).count() << "ms)"; } +void Somas::UpdateTensorDestinations() { + // Loop to add edges within each stream (node order within stream) + for (const auto &stream : streams_list_) { + auto &nodes = stream->nodes_; + std::sort(nodes.begin(), nodes.end(), NodeSort); + for (size_t i = 1; i < nodes.size(); i++) { + const auto &previous_node = nodes[i - 1]; + const auto ¤t_node = nodes[i]; + current_node->ancestor_nodes_.insert(previous_node); + } + } + + // Loop to add edges from end to beginning of next group + for (const auto &group : streams_groups_) { + for (size_t i = 1; i < group.size(); i++) { + int64_t previous_stream = group[i - 1]; + int64_t current_stream = group[i]; + + auto it = + std::find_if(streams_list_.begin(), streams_list_.end(), + [previous_stream](const SomasStreamPtr &stream) { return stream->GetId() == previous_stream; }); + if (it == streams_list_.end()) { + continue; + } + auto &last_node_in_prev_stream = (*it)->nodes_.back(); + + it = std::find_if(streams_list_.begin(), streams_list_.end(), + [current_stream](const SomasStreamPtr &stream) { return stream->GetId() == current_stream; }); + if (it == streams_list_.end()) { + continue; + } + auto &first_node_in_cur_stream = (*it)->nodes_.front(); + + first_node_in_cur_stream->ancestor_nodes_.insert(last_node_in_prev_stream); + } + } + + // Loop to avoid tensors with empty destinations (add itself) + for (const auto &tensor : tensors_list_) { + if (tensor->destinations_.size() == 0) { + tensor->destinations_.insert(tensor->GetSourceNode()); + } + } +} + void Somas::ComputeMultiTensorConflicts(const std::vector &calc_tensors_list, const std::vector &all_tensors_list, const vector &nodes_dependency, @@ -832,49 +835,81 @@ bool Somas::Assign(const session::KernelGraph *graph) { } // Ref Node Preprocessing - MS_LOG(INFO) << "Start Solving Preprocessing for Ref Node"; - std::map contiguous_ref_map; - for (auto ref_node_list : ref_node_constraints_) { - // Count contiguous tensors in ref list - size_t contiguous_in_ref_list = std::count_if(ref_node_list.begin(), ref_node_list.end(), - [this](size_t tid) { return tensors_map_[tid]->contiguous_; }); - // Keep all constraints for first tensor in list - size_t tid_0 = ref_node_list[0]; - for (SomasTensorPtr tensor : tensors_list_) { - if (reuse_matrix_[tid_0].IsBitTrue(tensor->GetId()) == false) { - continue; - } - for (size_t tid : ref_node_list) { - if (reuse_matrix_[tid].IsBitTrue(tensor->GetId()) == false) { - reuse_matrix_[tid_0].SetBitFalse(tensor->GetId()); - reuse_matrix_[tensor->GetId()].SetBitFalse(tid_0); - break; - } - } - } - // Set rest to size 0, so that solver ignores them (if not contiguous) - for (size_t i = 1; i < ref_node_list.size(); ++i) { - if (!tensors_map_[ref_node_list[i]]->contiguous_) { - tensors_map_[ref_node_list[i]]->aligned_size_ = 0; - } - } - // Keep info about contiguous and check for errors - if (ref_node_list.size() > 2 && contiguous_in_ref_list > 0) { - MS_LOG(WARNING) << "Ref node of size greater than two with at least one contiguous tensor in"; - } - if (ref_node_list.size() == 2 && contiguous_in_ref_list == 1) { - MS_LOG(WARNING) << "Ref node of size two with only one contiguous tensor" << ref_node_list[0] << ":" - << tensors_map_[ref_node_list[0]]->contiguous_ << ", " << ref_node_list[1] << ":" - << tensors_map_[ref_node_list[1]]->contiguous_; + UpdateRefTensorsConflict(); + std::map contiguous_list_with_ref_index_map = GetContiguousListContainRefTensor(); + vector> contiguous_tensors_list_removed_ref = contiguous_tensors_list_; + std::set> contiguous_tensors_list_to_remove; + for (auto ref_list_pair : contiguous_list_with_ref_index_map) { + contiguous_tensors_list_to_remove.insert(contiguous_tensors_list_[ref_list_pair.second]); + } + + for (auto contiguous_list : contiguous_tensors_list_to_remove) { + auto iterator = std::find(contiguous_tensors_list_removed_ref.begin(), contiguous_tensors_list_removed_ref.end(), + contiguous_list); + if (iterator != contiguous_tensors_list_removed_ref.end()) { + contiguous_tensors_list_removed_ref.erase(iterator); + } else { + MS_LOG(WARNING) << "Could not find contiguous list to remove for ref"; } - if (ref_node_list.size() == 2 && contiguous_in_ref_list == 2) { - contiguous_ref_map[ref_node_list[0]] = ref_node_list[1]; + } + MS_LOG(INFO) << "End Solving Preprocessing for Ref Node"; + UpdateRefOverlapTensorsConflicts(); + +#ifdef SOMAS_DEBUG + // Compute number of constraints for each tensor + auto tensors_num = tensors_list_.size(); + for (auto tensor1 : tensors_list_) { + auto ones_num = reuse_matrix_[tensor1->GetId()].CountOnesNum(); + tensor1->num_constraints_ = tensors_num - ones_num; + } +#endif + + // Prepare solver info + MS_LOG(INFO) << "Start Loop to create solver info"; + for (auto tensor : tensors_list_) { + if (tensor->GetSolverTensorDesc() != nullptr) { + SomasSolverTensorDescPtr pSolverTensor = tensor->GetSolverTensorDesc(); + solver_tensor_desc_list_.insert( + std::pair(pSolverTensor->index_, pSolverTensor)); } } - // Handle contiguous ref node (remove ref from contiguous_tensors_list_) - std::map contiguous_ref_list_map; + MS_LOG(INFO) << "End Loop to create solver info"; + + MS_LOG(INFO) << "Start Solving"; + if (solver_tensor_desc_list_.empty()) { + MS_LOG(INFO) << "solver_tensor_desc_list is empty."; + return true; + } + + somas_solver_ = std::make_shared(); + auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &reuse_matrix_, + contiguous_tensors_list_removed_ref, false); + MS_LOG(INFO) << "End Solving"; + if (status != SUCCESS) { + GenGraphStatisticInfo(); + MS_LOG(EXCEPTION) << "SOMAS Solving Failed."; + } + + // Update solver_tensor_desc offset to tensors list + for (const auto &tensor : tensors_list_) { + tensor->SetOffset(); + } + + UpdateRefTensorsOffset(); + UpdateContiguousTensorsOffset(contiguous_list_with_ref_index_map); + + // Set mem_offset_ value by solver result + mem_offset_ = static_cast(somas_solver_->GetMaxOffset()); + + return true; +} + +std::map Somas::GetContiguousListContainRefTensor() { + // key: contiguous list index with ref node input; value: contiguous list index with ref node output + std::map contiguous_list_with_ref_index_map; + std::map ref_tensors_in_contiguous_map = GetRefTensorsInContiguousList(); std::map>> contiguous_ref_list_error_check_map; - for (auto ref_pair : contiguous_ref_map) { + for (auto ref_pair : ref_tensors_in_contiguous_map) { size_t ref_first = ref_pair.first; size_t ref_second = ref_pair.second; bool found_first = false; @@ -903,15 +938,16 @@ bool Somas::Assign(const session::KernelGraph *graph) { } } } + if (!found_first) { MS_LOG(WARNING) << "Contiguous ref tensor " << ref_first << " not found in any contiguous list"; } if (!found_second) { MS_LOG(WARNING) << "Contiguous ref tensor " << ref_second << " not found in any contiguous list"; } - if (contiguous_ref_list_map.find(index_first) == contiguous_ref_list_map.end() || - contiguous_ref_list_map[index_first] == index_second) { - contiguous_ref_list_map[index_first] = index_second; + if (contiguous_list_with_ref_index_map.find(index_first) == contiguous_list_with_ref_index_map.end() || + contiguous_list_with_ref_index_map[index_first] == index_second) { + contiguous_list_with_ref_index_map[index_first] = index_second; // Checking for error cases if (index_in_list_first != index_in_list_second) { MS_LOG(WARNING) << "Inconsistency in contiguous ref: tensor " << ref_first << " in position " @@ -919,10 +955,9 @@ bool Somas::Assign(const session::KernelGraph *graph) { << " in position " << index_in_list_second << " of contiguous list " << index_second; } contiguous_ref_list_error_check_map[index_first][index_second].insert(index_in_list_first); - } else { // contiguous_ref_list_map.find(index_first) != contiguous_ref_list_map.end() && - // contiguous_ref_list_map[index_first] != index_second + } else { MS_LOG(WARNING) << "Contiguous list " << index_first << " associated (ref node) with two other contiguous lists: " - << contiguous_ref_list_map[index_first] << " and " << index_second; + << contiguous_list_with_ref_index_map[index_first] << " and " << index_second; } } @@ -943,23 +978,61 @@ bool Somas::Assign(const session::KernelGraph *graph) { } } } + return contiguous_list_with_ref_index_map; +} - std::set> contiguous_tensors_list_to_remove; +std::map Somas::GetRefTensorsInContiguousList() { + // key: refnode input value: refnode output + std::map ref_tensors_in_contiguous_map; + for (auto ref_node_list : ref_node_constraints_) { + // Count contiguous tensors in ref list + size_t contiguous_in_ref_list = std::count_if(ref_node_list.begin(), ref_node_list.end(), + [this](size_t tid) { return tensors_map_[tid]->contiguous_; }); + // Keep info about contiguous and check for errors + if (ref_node_list.size() > 2 && contiguous_in_ref_list > 0) { + MS_LOG(WARNING) << "Ref node of size greater than two with at least one contiguous tensor in"; + } + if (ref_node_list.size() == 2 && contiguous_in_ref_list == 1) { + MS_LOG(WARNING) << "Ref node of size two with only one contiguous tensor" << ref_node_list[0] << ":" + << tensors_map_[ref_node_list[0]]->contiguous_ << ", " << ref_node_list[1] << ":" + << tensors_map_[ref_node_list[1]]->contiguous_; + } + if (ref_node_list.size() == 2 && contiguous_in_ref_list == 2) { + ref_tensors_in_contiguous_map[ref_node_list[0]] = ref_node_list[1]; + } + } + return ref_tensors_in_contiguous_map; +} + +void Somas::UpdateContiguousTensorsOffset(const std::map &contiguous_ref_list_map) { + // Handle contiguous ref node for (auto ref_list_pair : contiguous_ref_list_map) { - contiguous_tensors_list_to_remove.insert(contiguous_tensors_list_[ref_list_pair.second]); + size_t index_first = ref_list_pair.first; + size_t index_second = ref_list_pair.second; + for (size_t x = 0; x < contiguous_tensors_list_[index_second].size(); x++) { + tensors_map_[contiguous_tensors_list_[index_second][x]]->offset_ = + tensors_map_[contiguous_tensors_list_[index_first][x]]->offset_; + } } - vector> contiguous_tensors_list_removed_ref = contiguous_tensors_list_; - for (auto contiguous_list : contiguous_tensors_list_to_remove) { - auto iterator = std::find(contiguous_tensors_list_removed_ref.begin(), contiguous_tensors_list_removed_ref.end(), - contiguous_list); - if (iterator != contiguous_tensors_list_removed_ref.end()) { - contiguous_tensors_list_removed_ref.erase(iterator); - } else { - MS_LOG(WARNING) << "Could not find contiguous list to remove for ref"; + + // Contiguous gaps postprocessing + for (auto list : contiguous_tensors_list_) { + tensors_map_[list[0]]->offset_ += kGapSize; + } +} + +void Somas::UpdateRefTensorsOffset() { + // Ref Node Postprocessing + MS_LOG(INFO) << "\nStart Solving Postprocessing for Ref Node"; + // Set offset for rest of ref node list (ignored by solver due to ref node preprocessing) + for (auto ref_node_list : ref_node_constraints_) { + for (size_t i = 1; i < ref_node_list.size(); ++i) { + tensors_map_[ref_node_list[i]]->offset_ = tensors_map_[ref_node_list[0]]->offset_; } } - MS_LOG(INFO) << "End Solving Preprocessing for Ref Node"; +} +void Somas::UpdateRefOverlapTensorsConflicts() { // Ref Overlap Preprocessing MS_LOG(INFO) << "Start Solving Preprocessing for Ref Overlap"; // In ConflictComputing(), by use of ref_overlap_ flag, each tensor in a ref_overlap_list has all entries 1 in @@ -973,75 +1046,31 @@ bool Somas::Assign(const session::KernelGraph *graph) { } } MS_LOG(INFO) << "End Solving Preprocessing for Ref Overlap"; +} -#ifdef SOMAS_DEBUG - // Compute number of constraints for each tensor - auto tensors_num = tensors_list_.size(); - for (auto tensor1 : tensors_list_) { - auto ones_num = reuse_matrix_[tensor1->GetId()].CountOnesNum(); - tensor1->num_constraints_ = tensors_num - ones_num; - } -#endif - - // Prepare solver info - MS_LOG(INFO) << "Start Loop to create solver info"; - for (auto tensor : tensors_list_) { - if (tensor->GetSolverTensorDesc() != nullptr) { - SomasSolverTensorDescPtr pSolverTensor = tensor->GetSolverTensorDesc(); - solver_tensor_desc_list_.insert( - std::pair(pSolverTensor->index_, pSolverTensor)); - } - } - MS_LOG(INFO) << "End Loop to create solver info"; - - MS_LOG(INFO) << "Start Solving"; - if (solver_tensor_desc_list_.empty()) { - MS_LOG(INFO) << "solver_tensor_desc_list is empty."; - return true; - } - - somas_solver_ = std::make_shared(); - auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &reuse_matrix_, - contiguous_tensors_list_removed_ref, false); - MS_LOG(INFO) << "End Solving"; - if (status != SUCCESS) { - GenGraphStatisticInfo(); - MS_LOG(EXCEPTION) << "SOMAS Solving Failed."; - } - - // Update solver_tensor_desc offset to tensors list - for (const auto &tensor : tensors_list_) { - tensor->SetOffset(); - } - - // Ref Node Postprocessing - MS_LOG(INFO) << "\nStart Solving Postprocessing for Ref Node"; - // Set offset for rest of ref node list (ignored by solver due to ref node preprocessing) +void Somas::UpdateRefTensorsConflict() { + // Keep all constraints for first tensor in list for (auto ref_node_list : ref_node_constraints_) { - for (size_t i = 1; i < ref_node_list.size(); ++i) { - tensors_map_[ref_node_list[i]]->offset_ = tensors_map_[ref_node_list[0]]->offset_; + size_t tid_0 = ref_node_list[0]; + for (SomasTensorPtr tensor : tensors_list_) { + if (reuse_matrix_[tid_0].IsBitTrue(tensor->GetId()) == false) { + continue; + } + for (size_t tid : ref_node_list) { + if (reuse_matrix_[tid].IsBitTrue(tensor->GetId()) == false) { + reuse_matrix_[tid_0].SetBitFalse(tensor->GetId()); + reuse_matrix_[tensor->GetId()].SetBitFalse(tid_0); + break; + } + } } - } - // Handle contiguous ref node - for (auto ref_list_pair : contiguous_ref_list_map) { - size_t index_first = ref_list_pair.first; - size_t index_second = ref_list_pair.second; - for (size_t x = 0; x < contiguous_tensors_list_[index_second].size(); x++) { - tensors_map_[contiguous_tensors_list_[index_second][x]]->offset_ = - tensors_map_[contiguous_tensors_list_[index_first][x]]->offset_; + // Set rest to size 0, so that solver ignores them (if not contiguous) + for (size_t i = 1; i < ref_node_list.size(); ++i) { + if (!tensors_map_[ref_node_list[i]]->contiguous_) { + tensors_map_[ref_node_list[i]]->aligned_size_ = 0; + } } } - MS_LOG(INFO) << "\nEnd Solving Postprocessing for Ref Node"; - - // Contiguous gaps postprocessing - for (auto list : contiguous_tensors_list_) { - tensors_map_[list[0]]->offset_ += kGapSize; - } - - // Set mem_offset_ value by solver result - mem_offset_ = static_cast(somas_solver_->GetMaxOffset()); - - return true; } std::string Somas::GetSplitName(const std::string &scope_name) const { @@ -1076,56 +1105,37 @@ void Somas::DumpSomasInfoIR(const string filename) { return; } - ofs << "All Parameters:\n\n"; - ofs << "index:" - << "\tsize:" - << "\tstart_addr:" - << "\tsource node name:" - << "\tnode out index:\n"; + DumpParameters(ofs); + DumpTensors(ofs); + DumpNodes(ofs); - for (const auto ¶m : parameters_list_) { - ofs << "%" << param->id_ << "P" - << "\t" - << "#" << param->size_ << "S" - << "\t" - << "&" << param->addr_ << "\t" << param->source_node_->fullname_with_scope() << "\t" << param->output_index_ - << "\n"; + ofs << "\n\nAll Stream Groups:\n\n"; + for (const auto &stream_group : streams_groups_) { + for (const auto &stream : stream_group) { + ofs << "stm" << stream << " "; + } + ofs << "\n"; } - ofs << "\n\nAll Tensors:\n\n"; - ofs << "index:" - << "\tsize:" - << "\treal_size:" - << "\toffset:" - << "\taddr:" - << "\ttype:" - << "\tlifelong:" - << "\tlife_start:" - << "\tlife_end:" - << "\tsource node name:\n"; - - for (const auto &tensor : tensors_list_) { - auto scope_name = tensor->GetSourceNode()->scope_full_name_; - std::string split_name = GetSplitName(scope_name); - ofs << "%" << tensor->GetId() << "T" - << "\t" - << "#" << tensor->GetAlignedSize() << "S" - << "\t" - << "#" << tensor->GetOriginalSize() << "S" - << "\t" - << "&" << tensor->GetOffset() << "" - << "\t" - << "&" << static_cast(tensor->GetOffset() + mem_base_addr_) << "\t" - << tensor_type_name_map[tensor->type_] << "\t" << tensor->IsLifelong() << "\t" << tensor->lifetime_.start_ - << "\t" << tensor->lifetime_.end_ << "\t" << split_name << "\n"; + if (!ref_node_constraints_.empty()) { + ofs << "\n\nAll Ref Node Info:\n\n"; + for (const auto &ref_in_out : ref_node_constraints_) { + ofs << "refnode input-output:"; + for (const auto &item : ref_in_out) { + ofs << "%" << item << "T "; + } + ofs << "\n"; + } } +} +void Somas::DumpNodes(std::ofstream &ofs) const { ofs << "\n\nAll Nodes:\n\n"; for (const auto &node : nodes_list_) { auto scope_name = node->scope_full_name_; std::string split_name = GetSplitName(scope_name); ofs << "$" << node->GetId() << "\t" << split_name << "\t" << static_cast(node->GetType()) << "\t"; - vector> input_list; + std::vector> input_list; std::transform( node->input_tensors_.begin(), node->input_tensors_.end(), std::back_inserter(input_list), [](SomasTensorPtr in) -> std::pair { return std::make_pair("Tensor", in->GetId()); }); @@ -1161,24 +1171,53 @@ void Somas::DumpSomasInfoIR(const string filename) { ofs << "\tstreamID[" << "@" << node->GetStream()->GetId() << "]\n"; } +} - ofs << "\n\nAll Stream Groups:\n\n"; - for (const auto &stream_group : streams_groups_) { - for (const auto &stream : stream_group) { - ofs << "stm" << stream << " "; - } - ofs << "\n"; +void Somas::DumpTensors(std::ofstream &ofs) const { + ofs << "\n\nAll Tensors:\n\n"; + ofs << "index:" + << "\tsize:" + << "\treal_size:" + << "\toffset:" + << "\taddr:" + << "\ttype:" + << "\tlifelong:" + << "\tlife_start:" + << "\tlife_end:" + << "\tsource node name:\n"; + + for (const auto &tensor : tensors_list_) { + auto scope_name = tensor->GetSourceNode()->scope_full_name_; + std::string split_name = GetSplitName(scope_name); + ofs << "%" << tensor->GetId() << "T" + << "\t" + << "#" << tensor->GetAlignedSize() << "S" + << "\t" + << "#" << tensor->GetOriginalSize() << "S" + << "\t" + << "&" << tensor->GetOffset() << "" + << "\t" + << "&" << static_cast(tensor->GetOffset() + mem_base_addr_) << "\t" + << tensor_type_name_map[tensor->type_] << "\t" << tensor->IsLifelong() << "\t" << tensor->lifetime_.start_ + << "\t" << tensor->lifetime_.end_ << "\t" << split_name << "\n"; } +} - if (!ref_node_constraints_.empty()) { - ofs << "\n\nAll Ref Node Info:\n\n"; - for (const auto &ref_in_out : ref_node_constraints_) { - ofs << "refnode input-output:"; - for (const auto &item : ref_in_out) { - ofs << "%" << item << "T "; - } - ofs << "\n"; - } +void Somas::DumpParameters(std::ofstream &ofs) const { + ofs << "All Parameters:\n\n"; + ofs << "index:" + << "\tsize:" + << "\tstart_addr:" + << "\tsource node name:" + << "\tnode out index:\n"; + + for (const auto ¶m : parameters_list_) { + ofs << "%" << param->id_ << "P" + << "\t" + << "#" << param->size_ << "S" + << "\t" + << "&" << param->addr_ << "\t" << param->source_node_->fullname_with_scope() << "\t" << param->output_index_ + << "\n"; } } diff --git a/mindspore/ccsrc/backend/optimizer/somas/somas.h b/mindspore/ccsrc/backend/optimizer/somas/somas.h index 4cf54c41c8..7e7269e298 100644 --- a/mindspore/ccsrc/backend/optimizer/somas/somas.h +++ b/mindspore/ccsrc/backend/optimizer/somas/somas.h @@ -138,6 +138,16 @@ class Somas { const std::vector &all_tensors_list, const vector &nodes_dependency, std::vector *tensor_relation) const; + void UpdateTensorDestinations(); + void UpdateRefTensorsConflict(); + void UpdateRefOverlapTensorsConflicts(); + void UpdateRefTensorsOffset(); + void UpdateContiguousTensorsOffset(const std::map &contiguous_ref_list_map); + void DumpParameters(std::ofstream &ofs) const; + void DumpTensors(std::ofstream &ofs) const; + void DumpNodes(std::ofstream &ofs) const; + std::map GetContiguousListContainRefTensor(); + std::map GetRefTensorsInContiguousList(); }; using SomasPtr = std::shared_ptr;