You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

allocator.h 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_
  18. #include <cstdlib>
  19. #include <functional>
  20. #include <memory>
  21. #include <type_traits>
  22. #include <utility>
  23. #include "minddata/dataset/util/memory_pool.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. // The following conforms to the requirements of
  27. // std::allocator. Do not rename/change any needed
  28. // requirements, e.g. function names, typedef etc.
  29. template <typename T>
  30. class Allocator {
  31. public:
  32. template <typename U>
  33. friend class Allocator;
  34. using value_type = T;
  35. using pointer = T *;
  36. using const_pointer = const T *;
  37. using reference = T &;
  38. using const_reference = const T &;
  39. using size_type = uint64_t;
  40. template <typename U>
  41. struct rebind {
  42. using other = Allocator<U>;
  43. };
  44. using propagate_on_container_copy_assignment = std::true_type;
  45. using propagate_on_container_move_assignment = std::true_type;
  46. using propagate_on_container_swap = std::true_type;
  47. explicit Allocator(const std::shared_ptr<MemoryPool> &b) : pool_(b) {}
  48. ~Allocator() = default;
  49. template <typename U>
  50. explicit Allocator(Allocator<U> const &rhs) : pool_(rhs.pool_) {}
  51. template <typename U>
  52. bool operator==(Allocator<U> const &rhs) const {
  53. return pool_ == rhs.pool_;
  54. }
  55. template <typename U>
  56. bool operator!=(Allocator<U> const &rhs) const {
  57. return pool_ != rhs.pool_;
  58. }
  59. pointer allocate(std::size_t n) {
  60. void *p;
  61. Status rc = pool_->Allocate(n * sizeof(T), &p);
  62. if (rc.IsOk()) {
  63. return reinterpret_cast<pointer>(p);
  64. } else if (rc.IsOutofMemory()) {
  65. throw std::bad_alloc();
  66. } else {
  67. throw std::exception();
  68. }
  69. }
  70. void deallocate(pointer p, std::size_t n = 0) noexcept { pool_->Deallocate(p); }
  71. size_type max_size() { return pool_->get_max_size(); }
  72. private:
  73. std::shared_ptr<MemoryPool> pool_;
  74. };
  75. /// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will
  76. /// be released when the object goes out of scope
  77. /// \tparam T The type of object to be allocated
  78. /// \tparam C Allocator. Default to std::allocator
  79. template <typename T, typename C = std::allocator<T>>
  80. class MemGuard {
  81. public:
  82. using allocator = C;
  83. MemGuard() : n_(0) {}
  84. explicit MemGuard(allocator a) : n_(0), alloc_(a) {}
  85. // There is no copy constructor nor assignment operator because the memory is solely owned by this object.
  86. MemGuard(const MemGuard &) = delete;
  87. MemGuard &operator=(const MemGuard &) = delete;
  88. // On the other hand, We can support move constructor
  89. MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {}
  90. MemGuard &operator=(MemGuard &&lhs) noexcept {
  91. if (this != &lhs) {
  92. this->deallocate();
  93. n_ = lhs.n_;
  94. alloc_ = std::move(lhs.alloc_);
  95. ptr_ = std::move(lhs.ptr_);
  96. }
  97. return *this;
  98. }
  99. /// \brief Explicitly deallocate the memory if allocated
  100. void deallocate() {
  101. if (ptr_) {
  102. auto *p = ptr_.release();
  103. if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) {
  104. for (auto i = 0; i < n_; ++i) {
  105. p[i].~T();
  106. }
  107. }
  108. alloc_.deallocate(p, n_);
  109. n_ = 0;
  110. }
  111. }
  112. /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is
  113. /// allocated.
  114. /// \param n Number of objects of type T to be allocated
  115. /// \tparam Args Extra arguments pass to the constructor of T
  116. template <typename... Args>
  117. Status allocate(size_t n, Args &&... args) noexcept {
  118. try {
  119. deallocate();
  120. if (n > 0) {
  121. T *data = alloc_.allocate(n);
  122. if (!std::is_arithmetic<T>::value) {
  123. for (auto i = 0; i < n; i++) {
  124. std::allocator_traits<C>::construct(alloc_, &(data[i]), std::forward<Args>(args)...);
  125. }
  126. }
  127. ptr_ = std::unique_ptr<T[]>(data);
  128. n_ = n;
  129. }
  130. } catch (const std::bad_alloc &e) {
  131. return Status(StatusCode::kOutOfMemory);
  132. } catch (std::exception &e) {
  133. RETURN_STATUS_UNEXPECTED(e.what());
  134. }
  135. return Status::OK();
  136. }
  137. ~MemGuard() noexcept { deallocate(); }
  138. /// \brief Getter function
  139. /// \return The pointer to the memory allocated
  140. T *GetPointer() const { return ptr_.get(); }
  141. /// \brief Getter function
  142. /// \return The pointer to the memory allocated
  143. T *GetMutablePointer() { return ptr_.get(); }
  144. /// \brief Overload [] operator to access a particular element
  145. /// \param x index to the element. Must be less than number of element allocated.
  146. /// \return pointer to the x-th element
  147. T *operator[](size_t x) { return GetMutablePointer() + x; }
  148. /// \brief Overload [] operator to access a particular element
  149. /// \param x index to the element. Must be less than number of element allocated.
  150. /// \return pointer to the x-th element
  151. T *operator[](size_t x) const { return GetPointer() + x; }
  152. /// \brief Return how many bytes are allocated in total
  153. /// \return Number of bytes allocated in total
  154. size_t GetSizeInBytes() const { return n_ * sizeof(T); }
  155. private:
  156. allocator alloc_;
  157. std::unique_ptr<T[]> ptr_;
  158. size_t n_;
  159. };
  160. } // namespace dataset
  161. } // namespace mindspore
  162. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_