You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_load.h 9.2 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_
  17. #define MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_
  18. #include <memory>
  19. #include <vector>
  20. #include <map>
  21. #include <mutex>
  22. #include <tuple>
  23. #include <string>
  24. #include <utility>
  25. #include <deque>
  26. #include <algorithm>
  27. #ifdef OFFLINE_DBG_MODE
  28. #include "debugger/offline_debug/offline_logger.h"
  29. #endif
  30. #include "debug/tensor_data.h"
  31. #ifdef ONLINE_DBG_MODE
  32. #include "debug/data_dump/dump_json_parser.h"
  33. namespace mindspore {
  34. #endif
  35. class TensorLoader {
  36. public:
  37. TensorLoader() : iter_num_(-1), mem_total_(0), mem_usage_(0) {}
  38. ~TensorLoader() { EmptyTensor(); }
  39. void MoveTensorCurrentToPrev(std::string tensor_name) {
  40. auto handle = tensor_list_map_.extract(tensor_name);
  41. if (!handle.empty()) {
  42. MS_LOG(INFO) << "Moving " << tensor_name << " from current map to previous map";
  43. prev_tensor_list_map_.insert(std::move(handle));
  44. }
  45. }
  46. void SwapCurrentPrev() { tensor_list_map_.swap(prev_tensor_list_map_); }
  47. bool TensorExistsInCurrent(std::string tensor_name) const {
  48. return tensor_list_map_.find(tensor_name) != tensor_list_map_.end();
  49. }
  50. // only parameters will return true
  51. bool PrevTensorExistsInCurrent(std::string tensor_name) const { return TensorExistsInCurrent(tensor_name + ":prev"); }
  52. void MoveParametersCurrentToPrev() {
  53. MS_LOG(INFO) << "Moving parameters from current map to previous map";
  54. auto iter = tensor_list_map_.begin();
  55. while (iter != tensor_list_map_.end()) {
  56. auto key = iter->first;
  57. if (PrevTensorExistsInCurrent(key)) {
  58. // :prev tensor only exists for parameter. Move it to prev
  59. ++iter;
  60. MoveTensorCurrentToPrev(key);
  61. } else {
  62. ++iter;
  63. }
  64. }
  65. }
  66. bool IsPrevTensor(std::string tensor_name) const {
  67. const std::string suffix = ":prev";
  68. if (tensor_name.length() <= suffix.length()) return false;
  69. return std::equal(suffix.rbegin(), suffix.rend(), tensor_name.rbegin());
  70. }
  71. bool LoadNewTensor(std::shared_ptr<TensorData> tensor, bool keep_prev) {
  72. lock_.lock();
  73. auto tensor_name = tensor->GetName();
  74. if (keep_prev) {
  75. // add prev step tensor into current step map with ":prev" suffix
  76. auto handle = prev_tensor_list_map_.extract(tensor_name);
  77. if (!handle.empty()) {
  78. handle.key() = tensor_name + ":prev";
  79. tensor_list_map_.insert(std::move(handle));
  80. }
  81. }
  82. std::string key_name = tensor_name;
  83. #ifdef OFFLINE_DBG_MODE
  84. key_name += (":" + std::to_string(tensor->GetDeviceId()) + ":" + std::to_string(tensor->GetRootGraphId()) + ":" +
  85. std::to_string(tensor->GetIsOutput()) + ":" + std::to_string(tensor->GetSlot()));
  86. if (tensor_list_map_.find(key_name) != tensor_list_map_.end() &&
  87. tensor->GetIteration() == tensor_list_map_[key_name]->GetIteration() - 1) {
  88. key_name += ":prev";
  89. }
  90. auto iter = tensor_list_map_.find(key_name);
  91. if (iter != tensor_list_map_.end()) {
  92. iter->second->DeleteDataPtr();
  93. }
  94. #endif
  95. tensor_list_map_[key_name] = tensor; // use [] instead of insert to ensure latest value
  96. lock_.unlock();
  97. return true;
  98. }
  99. std::vector<std::shared_ptr<TensorData>> GetTensor() {
  100. std::vector<std::shared_ptr<TensorData>> tensor_list;
  101. for (auto &it : tensor_list_map_) {
  102. if (!IsPrevTensor(it.first)) tensor_list.push_back(it.second);
  103. }
  104. return tensor_list;
  105. }
  106. std::shared_ptr<TensorData> GetTensor(const std::string &tensor_name) const {
  107. auto iter = tensor_list_map_.find(tensor_name);
  108. if (iter != tensor_list_map_.end()) return iter->second;
  109. return nullptr;
  110. }
  111. uint32_t GetIterNum() const { return iter_num_; }
  112. std::map<std::string, std::shared_ptr<TensorData>> GetTensorMap() { return tensor_list_map_; }
  113. std::shared_ptr<TensorData> GetPrevTensor(const std::string &tensor_name) {
  114. if (tensor_list_map_.find(tensor_name + ":prev") != tensor_list_map_.end()) {
  115. return tensor_list_map_[tensor_name + ":prev"];
  116. }
  117. return nullptr;
  118. }
  119. void SearchTensors(const std::vector<std::string> &search_list,
  120. std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> *result_list) {
  121. for (auto i : search_list) {
  122. std::map<std::string, std::shared_ptr<TensorData>>::iterator iter;
  123. iter = tensor_list_map_.find(i);
  124. if (iter != tensor_list_map_.end()) {
  125. result_list->push_back(std::make_tuple(i, iter->second));
  126. } else {
  127. result_list->push_back(std::make_tuple(i, nullptr));
  128. }
  129. }
  130. }
  131. void EmptyTensor() {
  132. std::lock_guard<std::mutex> lg(lock_);
  133. prev_tensor_list_map_.clear();
  134. tensor_list_map_.swap(prev_tensor_list_map_);
  135. }
  136. void EmptyCurrentTensor() { tensor_list_map_.clear(); }
  137. void set_iter_num(uint32_t iter_num) { this->iter_num_ = iter_num; }
  138. bool EnableMemoryControl() { return mem_total_ > 0; }
  139. void AppendToCacheEvictQueue(const std::string &tensor_name) {
  140. std::lock_guard<std::mutex> lk(mem_lock_);
  141. if (std::find(cache_evict_queue_.begin(), cache_evict_queue_.end(), tensor_name) == cache_evict_queue_.end()) {
  142. cache_evict_queue_.push_back(tensor_name);
  143. evict_cond.notify_one();
  144. }
  145. }
  146. bool CheckMemoryAvailable(const std::string &backend_name, const uint64_t data_size) {
  147. // 1. Check if the tensor can fit in the entire limit. If not, don't attempt any read or evictions and generate
  148. // warning.
  149. if (data_size > mem_total_) {
  150. MS_LOG(ERROR) << "Failed to load data of tensor " << backend_name << " because the its data size (" << data_size
  151. << ") exceeds the maximum memory limit (" << mem_total_ << ").";
  152. return false;
  153. }
  154. // 2. Check if there's is enough cache space available for current tensor. If not, try evict cache.
  155. bool ret = CheckAndEvictTensorCache(data_size);
  156. return ret;
  157. }
  158. bool CheckAndEvictTensorCache(const uint64_t data_size) {
  159. std::string candidate_name;
  160. uint64_t candidates_size;
  161. std::unique_lock<std::mutex> lk(mem_lock_);
  162. while (data_size > mem_total_ - mem_usage_) {
  163. // wait until there is any not-in-use candidate to be evicted from cache
  164. evict_cond.wait(lk, [&] { return !cache_evict_queue_.empty(); });
  165. candidate_name = cache_evict_queue_.front();
  166. candidates_size = tensor_list_map_[candidate_name]->GetByteSize();
  167. // evict candidate tensor
  168. lock_.lock();
  169. tensor_list_map_[candidate_name]->DeleteDataPtr();
  170. tensor_list_map_.erase(candidate_name);
  171. lock_.unlock();
  172. cache_evict_queue_.pop_front();
  173. mem_usage_ = std::max(uint64_t(0), mem_usage_ - candidates_size);
  174. MS_LOG(INFO) << "Evict tensor: " << candidate_name;
  175. }
  176. // Reserve space for the current target tensor.
  177. mem_usage_ = std::min(mem_total_, mem_usage_ + data_size);
  178. return true;
  179. }
  180. void SetMemTotal(uint64_t total_mem_size) { this->mem_total_ = total_mem_size; }
  181. #ifdef ONLINE_DBG_MODE
  182. bool DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath,
  183. const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type,
  184. TypeId device_type, const std::string &addr_format, size_t slot) {
  185. if (filepath.empty()) {
  186. MS_LOG(ERROR) << "Dump file path is null!";
  187. return false;
  188. }
  189. std::string path = "";
  190. if (trans_flag) {
  191. path = filepath + '.' + host_fmt;
  192. } else {
  193. path = filepath + '.' + addr_format;
  194. }
  195. MS_LOG(INFO) << "Dump path is " << path;
  196. std::string tensor_loader_name = tensor_name + ":" + std::to_string(slot);
  197. auto iter = tensor_list_map_.find(tensor_loader_name);
  198. if (iter != tensor_list_map_.end()) {
  199. std::shared_ptr<TensorData> node = iter->second;
  200. size_t host_size = node->GetByteSize();
  201. return DumpJsonParser::DumpToFile(path, node->GetDataPtr(), host_size, host_shape, host_type);
  202. }
  203. MS_LOG(INFO) << "Tensor name:" << tensor_name << " not found in tensor_list_map_";
  204. return true;
  205. }
  206. #endif
  207. private:
  208. // the pair is (device_id, iteration)
  209. std::map<std::string, std::shared_ptr<TensorData>> tensor_list_map_;
  210. std::map<std::string, std::shared_ptr<TensorData>> prev_tensor_list_map_;
  211. uint32_t iter_num_;
  212. std::mutex lock_;
  213. std::mutex mem_lock_;
  214. uint64_t mem_total_;
  215. uint64_t mem_usage_;
  216. std::deque<std::string> cache_evict_queue_;
  217. std::condition_variable evict_cond;
  218. };
  219. #ifdef ONLINE_DBG_MODE
  220. } // namespace mindspore
  221. #endif
  222. #endif // MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_