You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

allocator.h 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_UTIL_ALLOCATOR_H_
  17. #define DATASET_UTIL_ALLOCATOR_H_
  18. #include <cstdlib>
  19. #include <functional>
  20. #include <memory>
  21. #include <type_traits>
  22. #include <utility>
  23. #include "dataset/util/memory_pool.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. // The following conforms to the requirements of
  27. // std::allocator. Do not rename/change any needed
  28. // requirements, e.g. function names, typedef etc.
  29. template <typename T>
  30. class Allocator {
  31. public:
  32. template <typename U>
  33. friend class Allocator;
  34. using value_type = T;
  35. using pointer = T *;
  36. using const_pointer = const T *;
  37. using reference = T &;
  38. using const_reference = const T &;
  39. using size_type = uint64_t;
  40. template <typename U>
  41. struct rebind {
  42. using other = Allocator<U>;
  43. };
  44. using propagate_on_container_copy_assignment = std::true_type;
  45. using propagate_on_container_move_assignment = std::true_type;
  46. using propagate_on_container_swap = std::true_type;
  47. explicit Allocator(const std::shared_ptr<MemoryPool> &b) : pool_(b) {}
  48. ~Allocator() = default;
  49. template <typename U>
  50. explicit Allocator(Allocator<U> const &rhs) : pool_(rhs.pool_) {}
  51. template <typename U>
  52. bool operator==(Allocator<U> const &rhs) const {
  53. return pool_ == rhs.pool_;
  54. }
  55. template <typename U>
  56. bool operator!=(Allocator<U> const &rhs) const {
  57. return pool_ != rhs.pool_;
  58. }
  59. pointer allocate(std::size_t n) {
  60. void *p;
  61. Status rc = pool_->Allocate(n * sizeof(T), &p);
  62. if (rc.IsOk()) {
  63. return reinterpret_cast<pointer>(p);
  64. } else if (rc.IsOutofMemory()) {
  65. throw std::bad_alloc();
  66. } else {
  67. throw std::exception();
  68. }
  69. }
  70. void deallocate(pointer p, std::size_t n = 0) noexcept { pool_->Deallocate(p); }
  71. size_type max_size() { return pool_->get_max_size(); }
  72. private:
  73. std::shared_ptr<MemoryPool> pool_;
  74. };
  75. /// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will
  76. /// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator.
  77. /// Default to std::allocator
  78. template <typename T, typename C = std::allocator<T>>
  79. class MemGuard {
  80. public:
  81. using allocator = C;
  82. MemGuard() : n_(0) {}
  83. explicit MemGuard(allocator a) : n_(0), alloc_(a) {}
  84. // There is no copy constructor nor assignment operator because the memory is solely owned by this object.
  85. MemGuard(const MemGuard &) = delete;
  86. MemGuard &operator=(const MemGuard &) = delete;
  87. // On the other hand, We can support move constructor
  88. MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {}
  89. MemGuard &operator=(MemGuard &&lhs) noexcept {
  90. if (this != &lhs) {
  91. this->deallocate();
  92. n_ = lhs.n_;
  93. alloc_ = std::move(lhs.alloc_);
  94. ptr_ = std::move(lhs.ptr_);
  95. }
  96. return *this;
  97. }
  98. /// \brief Explicitly deallocate the memory if allocated
  99. void deallocate() {
  100. if (ptr_) {
  101. auto *p = ptr_.release();
  102. if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) {
  103. for (auto i = 0; i < n_; ++i) {
  104. p[i].~T();
  105. }
  106. }
  107. alloc_.deallocate(p, n_);
  108. n_ = 0;
  109. }
  110. }
  111. /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is
  112. /// allocated.
  113. /// \param n Number of objects of type T to be allocated
  114. /// \tparam Args Extra arguments pass to the constructor of T
  115. template <typename... Args>
  116. Status allocate(size_t n, Args &&... args) noexcept {
  117. try {
  118. deallocate();
  119. if (n > 0) {
  120. T *data = alloc_.allocate(n);
  121. if (!std::is_arithmetic<T>::value) {
  122. for (auto i = 0; i < n; i++) {
  123. std::allocator_traits<C>::construct(alloc_, &(data[i]), std::forward<Args>(args)...);
  124. }
  125. }
  126. ptr_ = std::unique_ptr<T[]>(data);
  127. n_ = n;
  128. }
  129. } catch (const std::bad_alloc &e) {
  130. return Status(StatusCode::kOutOfMemory);
  131. } catch (std::exception &e) {
  132. RETURN_STATUS_UNEXPECTED(e.what());
  133. }
  134. return Status::OK();
  135. }
  136. ~MemGuard() noexcept { deallocate(); }
  137. /// \brief Getter function
  138. /// \return The pointer to the memory allocated
  139. T *GetPointer() const { return ptr_.get(); }
  140. /// \brief Getter function
  141. /// \return The pointer to the memory allocated
  142. T *GetMutablePointer() { return ptr_.get(); }
  143. /// \brief Overload [] operator to access a particular element
  144. /// \param x index to the element. Must be less than number of element allocated.
  145. /// \return pointer to the x-th element
  146. T *operator[](size_t x) { return GetMutablePointer() + x; }
  147. /// \brief Overload [] operator to access a particular element
  148. /// \param x index to the element. Must be less than number of element allocated.
  149. /// \return pointer to the x-th element
  150. T *operator[](size_t x) const { return GetPointer() + x; }
  151. /// \brief Return how many bytes are allocated in total
  152. /// \return Number of bytes allocated in total
  153. size_t GetSizeInBytes() const { return n_ * sizeof(T); }
  154. private:
  155. allocator alloc_;
  156. std::unique_ptr<T[], std::function<void(T *)>> ptr_;
  157. size_t n_;
  158. };
  159. } // namespace dataset
  160. } // namespace mindspore
  161. #endif // DATASET_UTIL_ALLOCATOR_H_