You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data.h 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MIINDSPORE_CCSRC_DISTRIBUTED_PERSISTENT_DATA_H_
  17. #define MIINDSPORE_CCSRC_DISTRIBUTED_PERSISTENT_DATA_H_
  18. #include <map>
  19. #include <memory>
  20. #include <vector>
  21. #include <string>
  22. #include <thread>
  23. #include <utility>
  24. #include "distributed/persistent/storage/local_file.h"
  25. #include "utils/log_adapter.h"
  26. namespace mindspore {
  27. namespace distributed {
  28. namespace persistent {
  29. // The data class is used to save and manage the tensor in memory, and provides
  30. // interfaces for persistence and disaster recovery.
  31. template <typename T>
  32. class Data {
  33. public:
  34. explicit Data(const std::shared_ptr<std::vector<T>> &data, const std::shared_ptr<std::vector<int>> &shape = nullptr)
  35. : data_(data), shape_(shape) {}
  36. virtual ~Data() = default;
  37. // Get the memory data of Data
  38. T *data() const { return data_->data(); }
  39. // Get the mutable memory data of Data
  40. std::shared_ptr<std::vector<T>> MutableData() const { return data_; }
  41. // Get the element number of Data
  42. size_t size() const { return data_->size(); }
  43. // Get the dimension information of Data.
  44. std::shared_ptr<std::vector<int>> shape() const { return shape_; }
  45. protected:
  46. // Container used to store continuous memory buffer of Data.
  47. std::shared_ptr<std::vector<T>> data_;
  48. // Container used to record the dimension information of Data which persists a tensor.
  49. std::shared_ptr<std::vector<int>> shape_;
  50. };
  51. // Implementation of the class Data to complete the function of persistence and disaster tolerance.
  52. template <typename T>
  53. class PersistentData : public Data<T> {
  54. public:
  55. explicit PersistentData(const std::shared_ptr<std::vector<T>> &data,
  56. const std::shared_ptr<std::vector<int>> &shape = nullptr)
  57. : Data<T>(data, shape) {}
  58. ~PersistentData() override = default;
  59. // Initialize storage module.
  60. // Custom storage config, you can choose different configurations according to different storage forms,
  61. // such as using file storage by configuring the file storage path,
  62. // and config can be like this: std::map<std::string, std::string> config = {{kFileStoragePath, "real_path_of_dir"}};
  63. void Initialize(const std::map<std::string, std::string> &storage_config);
  64. // In disaster recovery mode, memory of tensor need to be saved into disk file periodically.
  65. void Persist(const storage::DirtyInfo &dirty_info) const;
  66. // In disaster recovery mode, server node or worker node need to restore persistent data when restart.
  67. void Restore() const;
  68. private:
  69. // The following variables are used in disaster recovery mode:
  70. // The threads used to execute persistence task.
  71. std::thread persist_thread_;
  72. // The file storage handle used to persist data.
  73. std::shared_ptr<storage::StorageBase> storage_;
  74. };
  75. template <typename T>
  76. void PersistentData<T>::Initialize(const std::map<std::string, std::string> &storage_config) {
  77. storage_ = std::make_shared<storage::LocalFile>(storage_config);
  78. }
  79. template <typename T>
  80. void PersistentData<T>::Persist(const storage::DirtyInfo &dirty_info) const {
  81. MS_EXCEPTION_IF_NULL(storage_);
  82. storage::InputData input = std::make_tuple(*Data<T>::shape_, Data<T>::data(), Data<T>::size() * sizeof(T));
  83. storage_->Write(input, dirty_info);
  84. }
  85. template <typename T>
  86. void PersistentData<T>::Restore() const {
  87. storage::OutputData output = std::make_pair(Data<T>::data(), Data<T>::size() * sizeof(T));
  88. MS_EXCEPTION_IF_NULL(storage_);
  89. storage_->Read(output);
  90. }
  91. } // namespace persistent
  92. } // namespace distributed
  93. } // namespace mindspore
  94. #endif // MIINDSPORE_CCSRC_DISTRIBUTED_PERSISTENT_DATA_H_