You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data.h 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MIINDSPORE_CCSRC_DISTRIBUTED_PERSISTENT_DATA_H_
  17. #define MIINDSPORE_CCSRC_DISTRIBUTED_PERSISTENT_DATA_H_
  18. #include <map>
  19. #include <memory>
  20. #include <vector>
  21. #include <string>
  22. #include <thread>
  23. #include <utility>
  24. #include "distributed/persistent/storage/local_file.h"
  25. #include "utils/log_adapter.h"
  26. namespace mindspore {
  27. namespace distributed {
  28. namespace persistent {
  29. // The data class is used to save and manage the tensor in memory, and provides
  30. // interfaces for persistence and disaster recovery.
  31. template <typename T>
  32. class Data {
  33. public:
  34. explicit Data(const std::shared_ptr<std::vector<T>> &data, const std::shared_ptr<std::vector<int>> &shape = nullptr)
  35. : data_(data), shape_(shape) {}
  36. virtual ~Data() = default;
  37. // Get the memory data of Data
  38. T *data() const { return data_->data(); }
  39. // Get the mutable memory data of Data
  40. std::shared_ptr<std::vector<T>> MutableData() const { return data_; }
  41. // Get the element number of Data
  42. size_t size() const { return data_->size(); }
  43. // Get the dimension information of Data.
  44. std::shared_ptr<std::vector<int>> shape() const { return shape_; }
  45. protected:
  46. // Container used to store continuous memory buffer of Data.
  47. std::shared_ptr<std::vector<T>> data_;
  48. // Container used to record the dimension information of Data which persists a tensor.
  49. std::shared_ptr<std::vector<int>> shape_;
  50. };
  51. // Implementation of the class Data to complete the function of persistence and disaster tolerance.
  52. template <typename T>
  53. class PersistentData : public Data<T> {
  54. public:
  55. explicit PersistentData(const std::shared_ptr<std::vector<T>> &data,
  56. const std::shared_ptr<std::vector<int>> &shape = nullptr)
  57. : Data<T>(data, shape) {}
  58. ~PersistentData() override = default;
  59. // Initialize storage module.
  60. void Initialize(const std::map<std::string, std::string> &storage_config);
  61. // In disaster recovery mode, memory of tensor need to be saved into disk file periodically.
  62. void Persist(const storage::DirtyInfo &dirty_info) const;
  63. // In disaster recovery mode, server node or worker node need to restore persistent data when restart.
  64. void Restore() const;
  65. private:
  66. // The following variables are used in disaster recovery mode:
  67. // The threads used to execute persistence task.
  68. std::thread persist_thread_;
  69. // The file storage handle used to persist data.
  70. std::shared_ptr<storage::StorageBase> storage_;
  71. };
  72. template <typename T>
  73. void PersistentData<T>::Initialize(const std::map<std::string, std::string> &storage_config) {
  74. storage_ = std::make_shared<storage::LocalFile>(storage_config);
  75. }
  76. template <typename T>
  77. void PersistentData<T>::Persist(const storage::DirtyInfo &dirty_info) const {
  78. MS_EXCEPTION_IF_NULL(storage_);
  79. storage::InputData input = std::make_tuple(*shape_, data(), size() * sizeof(T));
  80. storage_->Write(input, dirty_info);
  81. }
  82. template <typename T>
  83. void PersistentData<T>::Restore() const {
  84. storage::OutputData output = std::make_pair(data(), size() * sizeof(T));
  85. MS_EXCEPTION_IF_NULL(storage_);
  86. storage_->Read(output);
  87. }
  88. } // namespace persistent
  89. } // namespace distributed
  90. } // namespace mindspore
  91. #endif // MIINDSPORE_CCSRC_DISTRIBUTED_PERSISTENT_DATA_H_