You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_load.h 8.8 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_
  17. #define MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_
  18. #include <memory>
  19. #include <vector>
  20. #include <map>
  21. #include <mutex>
  22. #include <tuple>
  23. #include <string>
  24. #include <utility>
  25. #include <deque>
  26. #include <algorithm>
  27. #include "debug/tensor_data.h"
  28. #ifdef ONLINE_DBG_MODE
  29. #include "debug/data_dump/dump_json_parser.h"
  30. namespace mindspore {
  31. #endif
  32. class TensorLoader {
  33. public:
  34. #ifndef __APPLE__
  35. TensorLoader() : iter_num_(-1), mem_total_(0), mem_usage_(0) {}
  36. #else
  37. TensorLoader() : mem_total_(0), mem_usage_(0) {}
  38. #endif
  39. ~TensorLoader() { EmptyTensor(); }
  40. void MoveTensorCurrentToPrev(std::string tensor_name) {
  41. auto handle = tensor_list_map_.extract(tensor_name);
  42. if (!handle.empty()) {
  43. MS_LOG(INFO) << "Moving " << tensor_name << " from current map to previous map";
  44. prev_tensor_list_map_.insert(std::move(handle));
  45. }
  46. }
  47. void SwapCurrentPrev() { tensor_list_map_.swap(prev_tensor_list_map_); }
  48. bool TensorExistsInCurrent(std::string tensor_name) const {
  49. return tensor_list_map_.find(tensor_name) != tensor_list_map_.end();
  50. }
  51. // only parameters will return true
  52. bool PrevTensorExistsInCurrent(std::string tensor_name) const { return TensorExistsInCurrent(tensor_name + ":prev"); }
  53. void MoveParametersCurrentToPrev() {
  54. MS_LOG(INFO) << "Moving parameters from current map to previous map";
  55. auto iter = tensor_list_map_.begin();
  56. while (iter != tensor_list_map_.end()) {
  57. auto key = iter->first;
  58. if (PrevTensorExistsInCurrent(key)) {
  59. // :prev tensor only exists for parameter. Move it to prev
  60. ++iter;
  61. MoveTensorCurrentToPrev(key);
  62. } else {
  63. ++iter;
  64. }
  65. }
  66. }
  67. bool IsPrevTensor(std::string tensor_name) const {
  68. const std::string suffix = ":prev";
  69. if (tensor_name.length() <= suffix.length()) return false;
  70. return std::equal(suffix.rbegin(), suffix.rend(), tensor_name.rbegin());
  71. }
  72. bool LoadNewTensor(std::shared_ptr<TensorData> tensor, bool keep_prev) {
  73. lock_.lock();
  74. auto tensor_name = tensor->GetName();
  75. if (keep_prev) {
  76. // add prev step tensor into current step map with ":prev" suffix
  77. auto handle = prev_tensor_list_map_.extract(tensor_name);
  78. if (!handle.empty()) {
  79. handle.key() = tensor_name + ":prev";
  80. tensor_list_map_.insert(std::move(handle));
  81. }
  82. }
  83. std::string key_name = tensor_name;
  84. #ifdef OFFLINE_DBG_MODE
  85. key_name += (":" + std::to_string(tensor->GetDeviceId()) + ":" + std::to_string(tensor->GetRootGraphId()) + ":" +
  86. std::to_string(tensor->GetIsOutput()) + ":" + std::to_string(tensor->GetSlot()));
  87. if (tensor_list_map_.find(key_name) != tensor_list_map_.end() &&
  88. tensor->GetIteration() == tensor_list_map_[key_name]->GetPrevIteration()) {
  89. key_name += ":prev";
  90. }
  91. #endif
  92. tensor_list_map_[key_name] = tensor; // use [] instead of insert to ensure latest value
  93. lock_.unlock();
  94. return true;
  95. }
  96. std::vector<std::shared_ptr<TensorData>> GetTensor() {
  97. std::vector<std::shared_ptr<TensorData>> tensor_list;
  98. for (auto &it : tensor_list_map_) {
  99. if (!IsPrevTensor(it.first)) tensor_list.push_back(it.second);
  100. }
  101. return tensor_list;
  102. }
  103. std::shared_ptr<TensorData> GetTensor(const std::string &tensor_name) const {
  104. auto iter = tensor_list_map_.find(tensor_name);
  105. if (iter != tensor_list_map_.end()) return iter->second;
  106. return nullptr;
  107. }
  108. std::shared_ptr<TensorData> GetPrevTensor(const std::string &tensor_name) {
  109. if (tensor_list_map_.find(tensor_name + ":prev") != tensor_list_map_.end()) {
  110. return tensor_list_map_[tensor_name + ":prev"];
  111. }
  112. return nullptr;
  113. }
  114. void SearchTensors(const std::vector<std::string> &search_list,
  115. std::vector<std::tuple<std::string, std::shared_ptr<TensorData>>> *result_list) {
  116. for (auto i : search_list) {
  117. std::map<std::string, std::shared_ptr<TensorData>>::iterator iter;
  118. iter = tensor_list_map_.find(i);
  119. if (iter != tensor_list_map_.end()) {
  120. result_list->push_back(std::make_tuple(i, iter->second));
  121. } else {
  122. result_list->push_back(std::make_tuple(i, nullptr));
  123. }
  124. }
  125. }
  126. void EmptyTensor() {
  127. std::lock_guard<std::mutex> lg(lock_);
  128. prev_tensor_list_map_.clear();
  129. tensor_list_map_.swap(prev_tensor_list_map_);
  130. }
  131. void EmptyCurrentTensor() { tensor_list_map_.clear(); }
  132. bool EnableMemoryControl() { return mem_total_ > 0; }
  133. void AppendToCacheEvictQueue(const std::string &tensor_name) {
  134. std::lock_guard<std::mutex> lk(mem_lock_);
  135. if (std::find(cache_evict_queue_.begin(), cache_evict_queue_.end(), tensor_name) == cache_evict_queue_.end()) {
  136. cache_evict_queue_.push_back(tensor_name);
  137. evict_cond.notify_one();
  138. }
  139. }
  140. bool CheckMemoryAvailable(const std::string &backend_name, const uint64_t data_size) {
  141. // 1. Check if the tensor can fit in the entire limit. If not, don't attempt any read or evictions and generate
  142. // warning.
  143. if (data_size > mem_total_) {
  144. MS_LOG(ERROR) << "Failed to load data of tensor " << backend_name << " because the its data size (" << data_size
  145. << ") exceeds the maximum memory limit (" << mem_total_ << ").";
  146. return false;
  147. }
  148. // 2. Check if there's is enough cache space available for current tensor. If not, try evict cache.
  149. bool ret = CheckAndEvictTensorCache(data_size);
  150. return ret;
  151. }
  152. bool CheckAndEvictTensorCache(const uint64_t data_size) {
  153. std::string candidate_name;
  154. uint64_t candidates_size;
  155. std::unique_lock<std::mutex> lk(mem_lock_);
  156. while (data_size > mem_total_ - mem_usage_) {
  157. // wait until there is any not-in-use candidate to be evicted from cache
  158. evict_cond.wait(lk, [&] { return !cache_evict_queue_.empty(); });
  159. candidate_name = cache_evict_queue_.front();
  160. candidates_size = tensor_list_map_[candidate_name]->GetByteSize();
  161. // evict candidate tensor
  162. lock_.lock();
  163. tensor_list_map_.erase(candidate_name);
  164. lock_.unlock();
  165. cache_evict_queue_.pop_front();
  166. mem_usage_ = std::max(uint64_t(0), mem_usage_ - candidates_size);
  167. MS_LOG(INFO) << "Evict tensor: " << candidate_name;
  168. }
  169. // Reserve space for the current target tensor.
  170. mem_usage_ = std::min(mem_total_, mem_usage_ + data_size);
  171. return true;
  172. }
  173. void SetMemTotal(uint64_t total_mem_size) { this->mem_total_ = total_mem_size; }
  174. #ifdef ONLINE_DBG_MODE
  175. bool DumpTensorToFile(const std::string &tensor_name, bool trans_flag, const std::string &filepath,
  176. const std::string &host_fmt, const std::vector<int64_t> &host_shape, TypeId host_type,
  177. TypeId device_type, const std::string &addr_format, size_t slot) {
  178. if (filepath.empty()) {
  179. MS_LOG(ERROR) << "Dump file path is null!";
  180. return false;
  181. }
  182. std::string path = "";
  183. if (trans_flag) {
  184. path = filepath + '.' + host_fmt;
  185. } else {
  186. path = filepath + '.' + addr_format;
  187. }
  188. MS_LOG(INFO) << "Dump path is " << path;
  189. std::string tensor_loader_name = tensor_name + ":" + std::to_string(slot);
  190. auto iter = tensor_list_map_.find(tensor_loader_name);
  191. if (iter != tensor_list_map_.end()) {
  192. std::shared_ptr<TensorData> node = iter->second;
  193. size_t host_size = node->GetByteSize();
  194. return DumpJsonParser::DumpToFile(path, node->GetDataPtr(), host_size, host_shape, host_type);
  195. }
  196. MS_LOG(INFO) << "Tensor name:" << tensor_name << " not found in tensor_list_map_";
  197. return false;
  198. }
  199. #endif
  200. private:
  201. // the pair is (device_id, iteration)
  202. std::map<std::string, std::shared_ptr<TensorData>> tensor_list_map_;
  203. std::map<std::string, std::shared_ptr<TensorData>> prev_tensor_list_map_;
  204. #ifndef __APPLE__
  205. uint32_t iter_num_;
  206. #endif
  207. std::mutex lock_;
  208. std::mutex mem_lock_;
  209. uint64_t mem_total_;
  210. uint64_t mem_usage_;
  211. std::deque<std::string> cache_evict_queue_;
  212. std::condition_variable evict_cond;
  213. };
  214. #ifdef ONLINE_DBG_MODE
  215. } // namespace mindspore
  216. #endif
  217. #endif // MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_