| @@ -47,7 +47,6 @@ using mindspore::profiler::TensorMemory; | |||
| namespace mindspore { | |||
| namespace somas { | |||
| constexpr auto kGapSize = 512; | |||
| constexpr auto kParallelComputeSizeThreshold = 2000; | |||
| constexpr auto kGraphId = "graph_id"; | |||
| constexpr auto kHashId = "hash_id"; | |||
| @@ -1193,21 +1192,20 @@ bool Somas::Assign(const session::KernelGraph *graph) { | |||
| for (auto tensor : tensors_list_) { | |||
| if (tensor->GetSolverTensorDesc() != nullptr) { | |||
| SomasSolverTensorDescPtr pSolverTensor = tensor->GetSolverTensorDesc(); | |||
| solver_tensor_desc_list_.insert( | |||
| std::pair<size_t, SomasSolverTensorDescPtr>(pSolverTensor->index_, pSolverTensor)); | |||
| solver_tensor_desc_map_.insert(std::pair<size_t, SomasSolverTensorDescPtr>(pSolverTensor->index_, pSolverTensor)); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "End Loop to create solver info"; | |||
| MS_LOG(INFO) << "Start Solving"; | |||
| if (solver_tensor_desc_list_.empty()) { | |||
| if (solver_tensor_desc_map_.empty()) { | |||
| MS_LOG(INFO) << "solver_tensor_desc_list is empty."; | |||
| return true; | |||
| } | |||
| somas_solver_ = std::make_shared<SomasSolverPre>(); | |||
| auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &reuse_matrix_, | |||
| contiguous_tensors_list_removed_ref, false); | |||
| auto status = | |||
| somas_solver_->Solving(graph, &solver_tensor_desc_map_, &reuse_matrix_, contiguous_tensors_list_removed_ref, false); | |||
| MS_LOG(INFO) << "End Solving"; | |||
| if (status != SUCCESS) { | |||
| GenGraphStatisticInfo(); | |||
| @@ -77,7 +77,7 @@ class Somas { | |||
| std::vector<vector<uint32_t>> streams_groups_; | |||
| // Solver | |||
| std::unordered_map<size_t, SomasSolverTensorDescPtr> solver_tensor_desc_list_; | |||
| TensorsDescMap solver_tensor_desc_map_; | |||
| SomasSolverPrePtr somas_solver_; | |||
| // Contiguous list | |||
| @@ -108,14 +108,6 @@ class BlockTensor { | |||
| m_size_ = bt.m_size_; | |||
| return *this; | |||
| } | |||
| void log() { | |||
| SomasSolverTensorDescPtr p = m_start_tensor_; | |||
| MS_LOG(DEBUG) << "Block of Tensors [" << m_start_tensor_->index_ << "]\nsize: " << m_size_ << "Tensors:"; | |||
| while (p) { | |||
| MS_LOG(DEBUG) << "[" << p->index_ << "," << p->size_ << "]"; | |||
| p = p->right_; | |||
| } | |||
| } | |||
| bool Alone() const { return ((NULL == m_start_tensor_->right_) && (NULL == m_start_tensor_->left_)); } | |||
| }; | |||
| @@ -22,7 +22,7 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <map> | |||
| #include "backend/optimizer/somas/somas_solver_alg.h" | |||
| #include "backend/optimizer/somas/somas_solver_core.h" | |||
| #include "backend/optimizer/somas/somas_solver_pre.h" | |||
| @@ -33,6 +33,7 @@ using std::vector; | |||
| namespace mindspore { | |||
| namespace somas { | |||
| Status SomasSolverCore::MemoryAllocationSolver() { | |||
| auto start = std::chrono::system_clock::now(); | |||
| Status retval = SUCCESS; | |||
| @@ -87,21 +88,24 @@ Status SomasSolverCore::MemoryAllocationSolver() { | |||
| MS_LOG(INFO) << "SOMAS SOLVER RESUME:"; | |||
| MS_LOG(INFO) << "Best Solution:[" << 1 + best_sol << "/" << sol_count_ << "] "; | |||
| MS_LOG(INFO) << "Best result:" << best << " Bytes " << (best) / (giga) << " GB (" | |||
| << (best - lifelongmemory_) / (giga) << " GB + " << lifelongmemory_ / (giga) | |||
| << (best - lifelong_memory_) / (giga) << " GB + " << lifelong_memory_ / (giga) | |||
| << " GB from lifelong tensors)"; | |||
| MS_LOG(INFO) << "Best timing:" << best_timing << " ms"; | |||
| MS_LOG(INFO) << "Best algorithm: " << algorithm_type_[best_algorithm].c_str(); | |||
| MS_LOG(INFO) << "Best sorting strategy: " << sorting_[best_sorting].c_str(); | |||
| MS_LOG(INFO) << "Best offset strategy: " << branching_[best_branching].c_str(); | |||
| MS_LOG(INFO) << "Best algorithm: " << algorithmTypeNames[best_algorithm]; | |||
| MS_LOG(INFO) << "Best sorting strategy: " << sortingNames[best_sorting]; | |||
| MS_LOG(INFO) << "Best offset strategy: " << branchingNames[best_branching]; | |||
| MS_LOG(INFO) << "Time elapsed: " << total_time << " ms"; | |||
| MS_LOG(INFO) << "Spread:" << static_cast<double>((worst - best) / static_cast<double>(best * cent)) << " %%"; | |||
| best_sol_ = best_sol; | |||
| SetBestSolution(); | |||
| } else { | |||
| MS_LOG(INFO) << "Algorithm strategy: " << algorithm_type_[algorithm_].c_str(); | |||
| MS_LOG(INFO) << "Sorting strategy: " << sorting_[sort_strategy_].c_str(); | |||
| MS_LOG(INFO) << "Offset strategy: " << branching_[branching_strategy_].c_str(); | |||
| // print only for single heuristic no multi thread | |||
| if (!is_multi_thread_valid_) { | |||
| MS_LOG(INFO) << "Algorithm strategy: " << algorithmTypeNames[algorithm_]; | |||
| MS_LOG(INFO) << "Sorting strategy: " << sortingNames[sort_strategy_]; | |||
| MS_LOG(INFO) << "Offset strategy: " << branchingNames[branching_strategy_]; | |||
| } | |||
| BuildBlocks(); | |||
| SortTensors(); | |||
| upperbound_ = FindSolutions(); | |||
| @@ -167,7 +171,7 @@ bool SomasSolverCore::Verify(const size_t &upperbound) { | |||
| } | |||
| if (upperbound != result) { | |||
| MS_LOG(WARNING) << "ERROR Invalid upperbound result --> Footprint Result: " << upperbound_ | |||
| << " Tensor Result: " << result + lifelongmemory_; | |||
| << " Tensor Result: " << result + lifelong_memory_; | |||
| retval = false; | |||
| } | |||
| MS_LOG(DEBUG) | |||
| @@ -179,13 +183,13 @@ bool SomasSolverCore::Verify(const size_t &upperbound) { | |||
| void SomasSolverCore::BuildBlocks() { | |||
| MS_LOG(DEBUG) << "Building block of tensors"; | |||
| lifelongmemory_ = 0; | |||
| lifelong_memory_ = 0; | |||
| uint64_t tensors_block_count = 0; | |||
| for (auto tensor : tensors_) { | |||
| SomasSolverTensorDescPtr pTensor = tensor.second; | |||
| if (pTensor->blocked_) continue; | |||
| if (pTensor->lifelong_) { | |||
| lifelongmemory_ += pTensor->size_; | |||
| lifelong_memory_ += pTensor->size_; | |||
| continue; | |||
| } | |||
| // move to the left | |||
| @@ -211,9 +215,6 @@ void SomasSolverCore::BuildBlocks() { | |||
| if (tensors_block_count != tensors_.size()) | |||
| MS_LOG(INFO) << static_cast<int>(tensors_.size() - tensors_block_count) << " lifelong tensors found"; | |||
| // for debug | |||
| for (auto &b : block_tensors_) b.log(); | |||
| } | |||
| void SomasSolverCore::Clean() { | |||
| @@ -264,7 +265,7 @@ static bool GreaterSizeGreaterConstraintsGreaterIndex(const BlockTensor &t1, con | |||
| #endif | |||
| void SomasSolverCore::SortTensors() { // need to sort the tensors for Fast Heuristic | |||
| MS_LOG(DEBUG) << "Sorting Blocks of tensor, strategy: " << sorting_[sort_strategy_].c_str(); | |||
| MS_LOG(DEBUG) << "Sorting Blocks of tensor, strategy: " << sortingNames[sort_strategy_]; | |||
| typedef bool (*SortingFunction)(const BlockTensor &, const BlockTensor &); | |||
| std::unordered_map<SortingType, SortingFunction> sort_map; | |||
| sort_map[kGreaterSizeSmallerIndex] = &GreaterSizeSmallerIndex; | |||
| @@ -278,8 +279,6 @@ void SomasSolverCore::SortTensors() { // need to sort the tensors for Fast Heur | |||
| if (sort_strategy_ < kNumSortingTypes) { | |||
| sort(block_tensors_.begin(), block_tensors_.end(), *(sort_map[sort_strategy_])); | |||
| } | |||
| // log for debug purposes | |||
| for (auto &block : block_tensors_) block.log(); | |||
| } | |||
| void SomasSolverCore::RestoreSolution(uint32_t sol_id) { | |||
| @@ -305,12 +304,13 @@ size_t SomasSolverCore::Search(const std::shared_ptr<FootPrint> &pFootprint) { | |||
| result = pFootprint->Result(); | |||
| auto end = std::chrono::system_clock::now(); | |||
| timing_ = std::chrono::duration_cast<std::chrono::milliseconds>((end - start)).count(); | |||
| if (all_) { | |||
| // print for serial all_ or multi thread solver | |||
| if (all_ || is_multi_thread_valid_) { | |||
| const double giga = 1073741824.; | |||
| MS_LOG(INFO) << timing_ << " ms\t" << sol_count_ + 1 << "/" | |||
| << kNumFittingTypes * kNumAlgorithmTypes * kNumSortingTypes << "\t" << result << " Bytes (" | |||
| << result / giga << " GB)\t" << algorithm_type_[algorithm_].c_str() << "\t" | |||
| << sorting_[sort_strategy_].c_str() << "\t" << branching_[branching_strategy_].c_str(); | |||
| << result / giga << " GB)\t" << algorithmTypeNames[algorithm_] << "\t" | |||
| << sortingNames[sort_strategy_] << "\t" << branchingNames[branching_strategy_]; | |||
| } | |||
| } else { | |||
| MS_LOG(INFO) << "FastSolver could not find solution"; | |||
| @@ -319,8 +319,6 @@ size_t SomasSolverCore::Search(const std::shared_ptr<FootPrint> &pFootprint) { | |||
| if (result < upperbound_) { | |||
| upperbound_ = result; | |||
| best_sol_ = pFootprint->m_solId_; | |||
| best_branching_ = branching_strategy_; | |||
| best_sort_ = sort_strategy_; | |||
| } | |||
| return upperbound_; | |||
| @@ -329,19 +327,23 @@ size_t SomasSolverCore::Search(const std::shared_ptr<FootPrint> &pFootprint) { | |||
| void SomasSolverCore::AppendLifelongTensors() { | |||
| MS_LOG(DEBUG) << "Appending lifelong tensors to solution"; | |||
| size_t offset = upperbound_; | |||
| std::map<size_t, SomasSolverTensorDescPtr> lifelongTensors; | |||
| for (auto t_ : tensors_) { | |||
| SomasSolverTensorDescPtr pTensor = t_.second; | |||
| if (pTensor->lifelong_) { | |||
| pTensor->offset_ = offset; | |||
| offset += pTensor->size_; | |||
| if (t_.second->lifelong_) { | |||
| lifelongTensors.insert(t_); | |||
| } | |||
| } | |||
| upperbound_ += lifelongmemory_; | |||
| MS_LOG(DEBUG) << lifelongmemory_ << " bytes from lifelong tensors added to solution"; | |||
| for (auto t_ : lifelongTensors) { | |||
| SomasSolverTensorDescPtr pTensor = t_.second; | |||
| pTensor->offset_ = offset; | |||
| offset += pTensor->size_; | |||
| } | |||
| upperbound_ += lifelong_memory_; | |||
| MS_LOG(DEBUG) << lifelong_memory_ << " bytes from lifelong tensors added to solution"; | |||
| } | |||
| size_t SomasSolverCore::FindSolutions() { | |||
| MS_LOG(DEBUG) << "Start allocating blocks,offset strategy: " << branching_[branching_strategy_].c_str(); | |||
| MS_LOG(DEBUG) << "Start allocating blocks,offset strategy: " << branchingNames[branching_strategy_]; | |||
| std::shared_ptr<FootPrint> pFootprint = std::make_shared<FootPrint>(); | |||
| pFootprint->setBranchingStrategy(branching_strategy_); | |||
| @@ -29,25 +29,23 @@ | |||
| namespace mindspore { | |||
| namespace somas { | |||
| class SomasSolverCore { | |||
| public: | |||
| /// Interface Function: receive parameters, creates the model to solve and then save the result | |||
| SomasSolverCore(const std::unordered_map<size_t, SomasSolverTensorDescPtr> &tensors, | |||
| const std::vector<DynamicBitSet> *constraints) | |||
| : tensors_(tensors), | |||
| SomasSolverCore(const TensorsDescMap &tensors, const std::vector<DynamicBitSet> *constraints, uint32_t sol, | |||
| bool isMultiThreadValid = true) | |||
| : best_sol_(0), | |||
| sort_strategy_(kGreaterSizeSmallerIndex), | |||
| branching_strategy_(kBest), | |||
| sol_count_(sol), | |||
| algorithm_(kManyObjects), | |||
| tensors_(tensors), | |||
| constraints_(*constraints), | |||
| upperbound_(SIZE_MAX), | |||
| timing_(0), | |||
| lifelongmemory_(0), | |||
| verify_(false), | |||
| all_(true), | |||
| best_sol_(0), | |||
| best_sort_(kGreaterSizeSmallerIndex), | |||
| best_branching_(kBest), | |||
| sort_strategy_(kGreaterSizeSmallerIndex), | |||
| branching_strategy_(kBest), | |||
| sol_count_(0), | |||
| algorithm_(kManyObjects) {} | |||
| is_multi_thread_valid_(isMultiThreadValid) {} | |||
| ~SomasSolverCore() = default; | |||
| Status MemoryAllocationSolver(); | |||
| @@ -64,37 +62,29 @@ class SomasSolverCore { | |||
| void SetAlgorithmStrategy(AlgorithmType algorithm_strategy) { algorithm_ = algorithm_strategy; } | |||
| void SetAllStrategies(bool all) { all_ = all; } | |||
| const size_t &GetUpperbound() const { return upperbound_; } | |||
| const size_t &Getlifelongmemory() const { return lifelong_memory_; } | |||
| private: | |||
| std::unordered_map<size_t, SomasSolverTensorDescPtr> tensors_; | |||
| vector<BlockTensor> block_tensors_; | |||
| std::vector<DynamicBitSet> constraints_; | |||
| size_t upperbound_{0}; | |||
| size_t timing_{0}; | |||
| size_t lifelongmemory_{0}; | |||
| bool verify_{false}; | |||
| bool all_{false}; | |||
| uint32_t best_sol_{0}; | |||
| SortingType best_sort_; | |||
| FittingType best_branching_; | |||
| SortingType sort_strategy_; | |||
| FittingType branching_strategy_; | |||
| uint32_t sol_count_{0}; | |||
| AlgorithmType algorithm_; | |||
| size_t timing_{0}; | |||
| private: | |||
| const TensorsDescMap &tensors_; | |||
| vector<BlockTensor> block_tensors_; | |||
| const std::vector<DynamicBitSet> &constraints_; | |||
| size_t upperbound_{0}; | |||
| size_t lifelong_memory_{0}; | |||
| bool verify_{false}; | |||
| bool all_{false}; | |||
| bool is_multi_thread_valid_{true}; | |||
| size_t FindSolutions(); | |||
| size_t Search(const std::shared_ptr<FootPrint> &pFootprint); | |||
| void AppendLifelongTensors(); | |||
| void Destroy(std::shared_ptr<FootPrint> &); | |||
| const std::string sorting_[6] = {"size(>), index(<)", | |||
| "size(>), index(>)", | |||
| "size(>), constraints(<), index(<)", | |||
| "size(>), constraints(<), index(>)", | |||
| "size(>), constraints(>), index(<)", | |||
| "size(>), constraints(>), index(>)"}; | |||
| const std::string branching_[4] = {"bestfit", "smallest", "largest", "worstfit"}; | |||
| const std::string algorithm_type_[2] = {"Shared Objects", "Single Object"}; | |||
| }; | |||
| } // namespace somas | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "common/thread_pool.h" | |||
| #include "backend/optimizer/somas/somas_solver_core.h" | |||
| #include "backend/optimizer/somas/somas_solver_pre.h" | |||
| @@ -26,57 +27,161 @@ | |||
| namespace mindspore { | |||
| namespace somas { | |||
| Status SomasSolverPre::Solving(const session::KernelGraph *graph, | |||
| std::unordered_map<size_t, SomasSolverTensorDescPtr> *ptensors, | |||
| constexpr auto kSolNumThresholdMultiThread = 8; | |||
| Status SomasSolverPre::checkTensors(TensorsDescMap *pTensors, uint32_t index1, uint32_t index2) { | |||
| auto &tensors = *pTensors; | |||
| if (nullptr == tensors[index1]) { | |||
| MS_LOG(WARNING) << "NULL tensor received in continuous constraint (tensor index " << index1 << ")"; | |||
| return FAILED; | |||
| } | |||
| if (nullptr == tensors[index2]) { | |||
| MS_LOG(WARNING) << "NULL tensor received in continuous constraint (tensor index " << index2 << ")"; | |||
| return FAILED; | |||
| } | |||
| if (tensors[index1]->right_) | |||
| MS_LOG(WARNING) << "Warning:tensor " << index1 | |||
| << " already has a right tensor (id: " << tensors[index1]->right_->index_; | |||
| if (tensors[index2]->left_) | |||
| MS_LOG(WARNING) << "Warning:tensor " << index2 | |||
| << " already has a left tensor (id: " << tensors[index2]->left_->index_; | |||
| return SUCCESS; | |||
| } | |||
| Status SomasSolverPre::addContiguousInfoInMap(const vector<vector<size_t>> &continuous_v, TensorsDescMap *pTensors) { | |||
| auto &tensors = *pTensors; | |||
| // creating S Lists | |||
| for (auto &aux : continuous_v) { | |||
| for (uint32_t i = 0; i < aux.size() - 1; i++) { | |||
| uint32_t index1 = aux[i]; | |||
| uint32_t index2 = aux[i + 1]; | |||
| if (checkTensors(pTensors, index1, index2) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| tensors[index1]->right_ = tensors[index2]; | |||
| tensors[index2]->left_ = tensors[index1]; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status SomasSolverPre::addContiguousInfoInMultiMaps(const vector<vector<size_t>> &continuous_v, | |||
| vector<TensorsDescMap> *vecTensorsMap, TensorsDescMap *pTensors) { | |||
| // creating S Lists | |||
| for (auto &aux : continuous_v) { | |||
| for (uint32_t i = 0; i < aux.size() - 1; i++) { | |||
| uint32_t index1 = aux[i]; | |||
| uint32_t index2 = aux[i + 1]; | |||
| if (checkTensors(pTensors, index1, index2) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| for (size_t sol = 0; sol < vecTensorsMap->size(); sol++) { | |||
| auto &tensors_sol = (*vecTensorsMap)[sol]; | |||
| tensors_sol[index1]->right_ = tensors_sol[index2]; | |||
| tensors_sol[index2]->left_ = tensors_sol[index1]; | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| vector<TensorsDescMap> SomasSolverPre::createTensorsMaps(const TensorsDescMap &tensors, size_t total_sol) { | |||
| vector<TensorsDescMap> vecTensorsMap(total_sol); | |||
| vecTensorsMap[0] = tensors; | |||
| for (auto &pairT : tensors) { | |||
| for (size_t sol = 1; sol < total_sol; sol++) { | |||
| SomasSolverTensorDesc newDesc = *(pairT.second.get()); | |||
| SomasSolverTensorDescPtr newDescPtr = std::make_shared<SomasSolverTensorDesc>(newDesc); | |||
| vecTensorsMap[sol].insert(std::make_pair(pairT.first, newDescPtr)); | |||
| } | |||
| } | |||
| return std::move(vecTensorsMap); | |||
| } | |||
| Status SomasSolverPre::Solving(const session::KernelGraph *graph, TensorsDescMap *ptensors, | |||
| const std::vector<DynamicBitSet> *pConstraints, | |||
| const vector<vector<size_t>> &continuous_v, bool bVerifySolution, bool ball, | |||
| SortingType sorting, FittingType fitting, AlgorithmType algorithm) { | |||
| Status retval = SUCCESS; | |||
| try { | |||
| std::unordered_map<size_t, SomasSolverTensorDescPtr> &tensors = *ptensors; | |||
| MS_LOG(INFO) << "Filling in constraints matrix.."; | |||
| uint32_t continuous_cnt = 0; | |||
| // creating S Lists | |||
| for (auto &aux : continuous_v) { | |||
| for (uint32_t i = 0; i < aux.size() - 1; i++) { | |||
| uint32_t index1 = aux[i]; | |||
| uint32_t index2 = aux[i + 1]; | |||
| if (NULL == tensors[index1]) { | |||
| MS_LOG(WARNING) << "NULL tensor received in continuous constraint (tensor index " << index1 << ")"; | |||
| return FAILED; | |||
| TensorsDescMap &tensors = *ptensors; | |||
| size_t total_sol = kNumSortingTypes * kNumFittingTypes * kNumAlgorithmTypes; | |||
| size_t process_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||
| bool isMultiThreadPermit = ball && process_num >= total_sol && total_sol > 1; | |||
| bool isMultiThreadValid = isMultiThreadPermit && (total_sol > kSolNumThresholdMultiThread || | |||
| kParallelComputeSizeThreshold <= tensors.size()); | |||
| const double giga = 1024. * 1024. * 1024.; | |||
| if (isMultiThreadValid) { | |||
| vector<std::shared_ptr<SomasSolverCore>> solvers; | |||
| std::vector<common::Task> tasks; | |||
| vector<TensorsDescMap> vecTensorsMap = createTensorsMaps(tensors, total_sol); | |||
| if (addContiguousInfoInMultiMaps(continuous_v, &vecTensorsMap, ptensors) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| auto start = std::chrono::system_clock::now(); | |||
| for (size_t algorithm = 0, sol = 0; algorithm < kNumAlgorithmTypes; algorithm++) { | |||
| for (size_t sort_strategy = 0; sort_strategy < kNumSortingTypes; sort_strategy++) { | |||
| for (size_t branching_strategy = 0; branching_strategy < kNumFittingTypes; branching_strategy++) { | |||
| std::shared_ptr<SomasSolverCore> pSolver = | |||
| std::make_shared<SomasSolverCore>(vecTensorsMap[sol], pConstraints, sol); | |||
| pSolver->SetAlgorithmStrategy(AlgorithmType(algorithm)); | |||
| pSolver->SetSortingStrategy(SortingType(sort_strategy)); | |||
| pSolver->SetFittingStrategy(FittingType(branching_strategy)); | |||
| pSolver->SetAllStrategies(false); | |||
| pSolver->VerifySolution(bVerifySolution); | |||
| auto task = [pSolver]() { | |||
| return pSolver->MemoryAllocationSolver() == SUCCESS ? common::SUCCESS : common::FAIL; | |||
| }; | |||
| tasks.emplace_back(task); | |||
| solvers.emplace_back(pSolver); | |||
| sol++; | |||
| } | |||
| } | |||
| if (NULL == tensors[index2]) { | |||
| MS_LOG(WARNING) << "NULL tensor received in continuous constraint (tensor index " << index2 << ")"; | |||
| return FAILED; | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| size_t best_sol = 0, worst = 0, best = SIZE_MAX, best_timing = SIZE_MAX; | |||
| for (size_t sol = 0; sol < total_sol; sol++) { | |||
| auto &solver = solvers[sol]; | |||
| auto &upperbound = solver->GetUpperbound(); | |||
| if (upperbound > worst) { | |||
| worst = upperbound; | |||
| } | |||
| if (upperbound <= best) { | |||
| best = upperbound; | |||
| best_sol = sol; | |||
| best_timing = solver->timing_; | |||
| } | |||
| if (tensors[index1]->right_) | |||
| MS_LOG(WARNING) << "Warning:tensor " << index1 | |||
| << " already has a right tensor (id: " << tensors[index1]->right_->index_; | |||
| if (tensors[index2]->left_) | |||
| MS_LOG(WARNING) << "Warning:tensor " << index2 | |||
| << " already has a left tensor (id: " << tensors[index2]->left_->index_; | |||
| tensors[index1]->right_ = tensors[index2]; | |||
| tensors[index2]->left_ = tensors[index1]; | |||
| continuous_cnt++; | |||
| } | |||
| } | |||
| continuous_cnt++; | |||
| std::shared_ptr<SomasSolverCore> pSolver = std::make_shared<SomasSolverCore>(tensors, pConstraints); | |||
| pSolver->SetAlgorithmStrategy(algorithm); | |||
| pSolver->SetSortingStrategy(sorting); | |||
| pSolver->SetFittingStrategy(fitting); | |||
| pSolver->SetAllStrategies(ball); | |||
| pSolver->VerifySolution(bVerifySolution); | |||
| if (SUCCESS == (pSolver->MemoryAllocationSolver())) { | |||
| max_offset_ = pSolver->GetUpperbound(); | |||
| const double giga = 1024. * 1024. * 1024.; | |||
| MS_LOG(INFO) << "SomasSolver::Solving SUCCESS"; | |||
| MS_LOG(INFO) << "SomasSolver::Solving RESULT: " << max_offset_ << " (" << max_offset_ / (giga) << " GB)"; | |||
| auto end = std::chrono::system_clock::now(); | |||
| size_t total_time = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count(); | |||
| auto &best_solver = solvers[best_sol]; | |||
| for (auto &tensor : tensors) { | |||
| *(tensor.second.get()) = *(vecTensorsMap[best_sol][tensor.first]); | |||
| } | |||
| max_offset_ = best_solver->GetUpperbound(); | |||
| MS_LOG(INFO) << "SOMAS SOLVER RESUME:"; | |||
| MS_LOG(INFO) << "Best Solution:[" << 1 + best_sol << "/" << total_sol << "] "; | |||
| MS_LOG(INFO) << "Best result:" << best << " Bytes " << (best) / (giga) << " GB (" | |||
| << (best - best_solver->Getlifelongmemory()) / (giga) << " GB + " | |||
| << best_solver->Getlifelongmemory() / (giga) << " GB from lifelong tensors)"; | |||
| MS_LOG(INFO) << "Best timing:" << best_timing << " ms"; | |||
| MS_LOG(INFO) << "Best algorithm: " << algorithmTypeNames[best_solver->algorithm_]; | |||
| MS_LOG(INFO) << "Best sorting strategy: " << sortingNames[best_solver->sort_strategy_]; | |||
| MS_LOG(INFO) << "Best offset strategy: " << branchingNames[best_solver->branching_strategy_]; | |||
| MS_LOG(INFO) << "Time elapsed: " << total_time << " ms"; | |||
| MS_LOG(INFO) << "Spread:" << static_cast<double>((worst - best) / static_cast<double>(best * 100.0)) << " %%"; | |||
| } else { | |||
| if (addContiguousInfoInMap(continuous_v, ptensors) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| std::shared_ptr<SomasSolverCore> pSolver = std::make_shared<SomasSolverCore>(tensors, pConstraints, 0, false); | |||
| pSolver->SetAlgorithmStrategy(algorithm); | |||
| pSolver->SetSortingStrategy(sorting); | |||
| pSolver->SetFittingStrategy(fitting); | |||
| pSolver->SetAllStrategies(ball); | |||
| pSolver->VerifySolution(bVerifySolution); | |||
| if (SUCCESS == (pSolver->MemoryAllocationSolver())) { | |||
| max_offset_ = pSolver->GetUpperbound(); | |||
| MS_LOG(INFO) << "SomasSolver::Solving SUCCESS"; | |||
| MS_LOG(INFO) << "SomasSolver::Solving RESULT: " << max_offset_ << " (" << max_offset_ / (giga) << " GB)"; | |||
| } | |||
| } | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -91,18 +196,16 @@ Status SomasSolverPre::Solving(const session::KernelGraph *graph, | |||
| return retval; | |||
| } | |||
| void SomasSolverPre::Log(const session::KernelGraph *graph, | |||
| const unordered_map<size_t, SomasSolverTensorDescPtr> &tensors, | |||
| void SomasSolverPre::Log(const session::KernelGraph *graph, const TensorsDescMap &tensors, | |||
| const std::vector<DynamicBitSet> *pConstraints, const vector<vector<size_t>> &continuous_v) { | |||
| SolverInputLog(graph, tensors, pConstraints, continuous_v); | |||
| SolverOutputLog(graph, tensors); | |||
| } | |||
| void SomasSolverPre::SolverInputLog(const session::KernelGraph *graph, | |||
| const unordered_map<size_t, SomasSolverTensorDescPtr> &tensors, | |||
| void SomasSolverPre::SolverInputLog(const session::KernelGraph *graph, const TensorsDescMap &tensors, | |||
| const std::vector<DynamicBitSet> *pConstraints, | |||
| const vector<vector<size_t>> &continuous_v) { | |||
| MS_LOG(INFO) << "SomasSolver::Log Writing somas-input.txt.."; | |||
| MS_LOG(INFO) << "SomasSolver::Log Writing somas_solver_input.."; | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| @@ -150,9 +253,8 @@ void SomasSolverPre::SolverInputLog(const session::KernelGraph *graph, | |||
| MS_LOG(INFO) << "SomasSolver input Log done"; | |||
| } | |||
| void SomasSolverPre::SolverOutputLog(const session::KernelGraph *graph, | |||
| const unordered_map<size_t, SomasSolverTensorDescPtr> &tensors) const { | |||
| MS_LOG(INFO) << "SomasSolver::Log Writing somas output..."; | |||
| void SomasSolverPre::SolverOutputLog(const session::KernelGraph *graph, const TensorsDescMap &tensors) const { | |||
| MS_LOG(INFO) << "SomasSolver::Log Writing somas_solver_output_.."; | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH); | |||
| @@ -179,11 +281,11 @@ void SomasSolverPre::SolverOutputLog(const session::KernelGraph *graph, | |||
| for (auto &t : tensors) { | |||
| SomasSolverTensorDescPtr tensor = t.second; | |||
| int continuous = 0; | |||
| if (tensor->left_ == NULL && tensor->right_ != NULL) | |||
| if (tensor->left_ == nullptr && tensor->right_ != nullptr) | |||
| continuous = 1; | |||
| else if (tensor->left_ != NULL && tensor->right_ != NULL) | |||
| else if (tensor->left_ != nullptr && tensor->right_ != nullptr) | |||
| continuous = 2; | |||
| else if (tensor->left_ != NULL && tensor->right_ == NULL) | |||
| else if (tensor->left_ != nullptr && tensor->right_ == nullptr) | |||
| continuous = 3; | |||
| const size_t alignment = 512; | |||
| bool size_aligned = tensor->size_ % alignment == 0; | |||
| @@ -34,6 +34,15 @@ using std::vector; | |||
| namespace mindspore { | |||
| namespace somas { | |||
| constexpr char const *sortingNames[6] = {"size(>), index(<)", | |||
| "size(>), index(>)", | |||
| "size(>), constraints(<), index(<)", | |||
| "size(>), constraints(<), index(>)", | |||
| "size(>), constraints(>), index(<)", | |||
| "size(>), constraints(>), index(>)"}; | |||
| constexpr char const *branchingNames[4] = {"bestfit", "smallest", "largest", "worstfit"}; | |||
| constexpr char const *algorithmTypeNames[2] = {"Shared Objects", "Single Object"}; | |||
| constexpr auto kParallelComputeSizeThreshold = 2000; | |||
| enum Status { FAILED, SUCCESS }; | |||
| enum AlgorithmType { kManyObjects = 0, kSingleObject, kNumAlgorithmTypes }; | |||
| enum SortingType { | |||
| @@ -164,7 +173,7 @@ struct SomasSolverTensorDesc { | |||
| } | |||
| }; | |||
| using SomasSolverTensorDescPtr = std::shared_ptr<SomasSolverTensorDesc>; | |||
| typedef std::unordered_map<size_t, SomasSolverTensorDescPtr> TensorsDescMap; | |||
| class SomasSolverPre { | |||
| public: | |||
| SomasSolverPre() = default; | |||
| @@ -175,22 +184,27 @@ class SomasSolverPre { | |||
| size_t GetMaxOffset() { return max_offset_; } | |||
| Status Solving(const session::KernelGraph *graph, std::unordered_map<size_t, SomasSolverTensorDescPtr> *tensors, | |||
| Status Solving(const session::KernelGraph *graph, TensorsDescMap *tensors, | |||
| const std::vector<DynamicBitSet> *pConstraints, const vector<vector<size_t>> &continuous_v, | |||
| bool bVerifySolution, // true -> Check continuous and non overlapping constraints solution | |||
| bool ball = true, // true -> run full set of heuristics, false -> run single heuristic specified | |||
| SortingType sorting = kGreaterSizeSmallerIndex, FittingType fitting = kBest, | |||
| AlgorithmType algorithm = kManyObjects); | |||
| void Log(const session::KernelGraph *graph, const unordered_map<size_t, SomasSolverTensorDescPtr> &tensors, | |||
| void Log(const session::KernelGraph *graph, const TensorsDescMap &tensors, | |||
| const std::vector<DynamicBitSet> *pConstraints_v, const vector<vector<size_t>> &continuous_v); | |||
| Status checkTensors(TensorsDescMap *tensors, uint32_t index1, uint32_t index2); | |||
| Status addContiguousInfoInMap(const vector<vector<size_t>> &continuous_v, TensorsDescMap *tensors); | |||
| Status addContiguousInfoInMultiMaps(const vector<vector<size_t>> &continuous_v, vector<TensorsDescMap> *vecTensorsMap, | |||
| TensorsDescMap *tensors); | |||
| private: | |||
| size_t max_offset_; | |||
| void SolverInputLog(const session::KernelGraph *graph, const unordered_map<size_t, SomasSolverTensorDescPtr> &tensors, | |||
| void SolverInputLog(const session::KernelGraph *graph, const TensorsDescMap &tensors, | |||
| const std::vector<DynamicBitSet> *pConstraints_v, const vector<vector<size_t>> &continuous_v); | |||
| void SolverOutputLog(const session::KernelGraph *graph, | |||
| const unordered_map<size_t, SomasSolverTensorDescPtr> &tensors) const; | |||
| void SolverOutputLog(const session::KernelGraph *graph, const TensorsDescMap &tensors) const; | |||
| vector<TensorsDescMap> createTensorsMaps(const TensorsDescMap &tensors, size_t total_sol); | |||
| }; | |||
| using SomasSolverPrePtr = std::shared_ptr<SomasSolverPre>; | |||
| } // namespace somas | |||