You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

allocator.h 6.4 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_
  18. #include <cstdlib>
  19. #include <functional>
  20. #include <memory>
  21. #include <type_traits>
  22. #include <utility>
  23. #include "minddata/dataset/util/memory_pool.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. // The following conforms to the requirements of
  27. // std::allocator. Do not rename/change any needed
  28. // requirements, e.g. function names, typedef etc.
  29. template <typename T>
  30. class Allocator {
  31. public:
  32. template <typename U>
  33. friend class Allocator;
  34. using value_type = T;
  35. using pointer = T *;
  36. using const_pointer = const T *;
  37. using reference = T &;
  38. using const_reference = const T &;
  39. using size_type = uint64_t;
  40. using difference_type = std::ptrdiff_t;
  41. template <typename U>
  42. struct rebind {
  43. using other = Allocator<U>;
  44. };
  45. using propagate_on_container_copy_assignment = std::true_type;
  46. using propagate_on_container_move_assignment = std::true_type;
  47. using propagate_on_container_swap = std::true_type;
  48. explicit Allocator(const std::shared_ptr<MemoryPool> &b) : pool_(b) {}
  49. ~Allocator() = default;
  50. template <typename U>
  51. explicit Allocator(Allocator<U> const &rhs) : pool_(rhs.pool_) {}
  52. template <typename U>
  53. bool operator==(Allocator<U> const &rhs) const {
  54. return pool_ == rhs.pool_;
  55. }
  56. template <typename U>
  57. bool operator!=(Allocator<U> const &rhs) const {
  58. return pool_ != rhs.pool_;
  59. }
  60. pointer allocate(std::size_t n) {
  61. void *p;
  62. Status rc = pool_->Allocate(n * sizeof(T), &p);
  63. if (rc.IsOk()) {
  64. return reinterpret_cast<pointer>(p);
  65. } else if (rc.IsOutofMemory()) {
  66. throw std::bad_alloc();
  67. } else {
  68. throw std::exception();
  69. }
  70. }
  71. void deallocate(pointer p, std::size_t n = 0) noexcept { pool_->Deallocate(p); }
  72. size_type max_size() { return pool_->get_max_size(); }
  73. private:
  74. std::shared_ptr<MemoryPool> pool_;
  75. };
  76. /// \brief It is a wrapper of unique_ptr with a custom Allocator class defined above
  77. template <typename T, typename C = std::allocator<T>, typename... Args>
  78. Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, C alloc, size_t n, Args &&... args) {
  79. RETURN_UNEXPECTED_IF_NULL(out);
  80. CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive");
  81. try {
  82. T *data = alloc.allocate(n);
  83. if (!std::is_arithmetic<T>::value) {
  84. for (auto i = 0; i < n; i++) {
  85. std::allocator_traits<C>::construct(alloc, &(data[i]), std::forward<Args>(args)...);
  86. }
  87. }
  88. auto deleter = [](T *p, C f_alloc, size_t f_n) {
  89. if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) {
  90. for (auto i = 0; i < f_n; ++i) {
  91. std::allocator_traits<C>::destroy(f_alloc, &p[i]);
  92. }
  93. }
  94. f_alloc.deallocate(p, f_n);
  95. };
  96. *out = std::unique_ptr<T[], std::function<void(T *)>>(data, std::bind(deleter, std::placeholders::_1, alloc, n));
  97. } catch (const std::bad_alloc &e) {
  98. return Status(StatusCode::kOutOfMemory);
  99. } catch (const std::exception &e) {
  100. RETURN_STATUS_UNEXPECTED(e.what());
  101. }
  102. return Status::OK();
  103. }
  104. /// \brief It is a wrapper of the above custom unique_ptr with some additional methods
  105. /// \tparam T The type of object to be allocated
  106. /// \tparam C Allocator. Default to std::allocator
  107. template <typename T, typename C = std::allocator<T>>
  108. class MemGuard {
  109. public:
  110. using allocator = C;
  111. MemGuard() : n_(0) {}
  112. explicit MemGuard(allocator a) : n_(0), alloc_(a) {}
  113. // There is no copy constructor nor assignment operator because the memory is solely owned by this object.
  114. MemGuard(const MemGuard &) = delete;
  115. MemGuard &operator=(const MemGuard &) = delete;
  116. // On the other hand, We can support move constructor
  117. MemGuard(MemGuard &&lhs) noexcept : n_(lhs.n_), alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)) {}
  118. MemGuard &operator=(MemGuard &&lhs) noexcept {
  119. if (this != &lhs) {
  120. this->deallocate();
  121. n_ = lhs.n_;
  122. alloc_ = std::move(lhs.alloc_);
  123. ptr_ = std::move(lhs.ptr_);
  124. }
  125. return *this;
  126. }
  127. /// \brief Explicitly deallocate the memory if allocated
  128. void deallocate() {
  129. if (ptr_) {
  130. ptr_.reset();
  131. }
  132. }
  133. /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is
  134. /// allocated.
  135. /// \param n Number of objects of type T to be allocated
  136. /// \tparam Args Extra arguments pass to the constructor of T
  137. template <typename... Args>
  138. Status allocate(size_t n, Args &&... args) noexcept {
  139. deallocate();
  140. n_ = n;
  141. return MakeUnique(&ptr_, alloc_, n, std::forward<Args>(args)...);
  142. }
  143. ~MemGuard() noexcept { deallocate(); }
  144. /// \brief Getter function
  145. /// \return The pointer to the memory allocated
  146. T *GetPointer() const { return ptr_.get(); }
  147. /// \brief Getter function
  148. /// \return The pointer to the memory allocated
  149. T *GetMutablePointer() { return ptr_.get(); }
  150. /// \brief Overload [] operator to access a particular element
  151. /// \param x index to the element. Must be less than number of element allocated.
  152. /// \return pointer to the x-th element
  153. T *operator[](size_t x) { return GetMutablePointer() + x; }
  154. /// \brief Overload [] operator to access a particular element
  155. /// \param x index to the element. Must be less than number of element allocated.
  156. /// \return pointer to the x-th element
  157. T *operator[](size_t x) const { return GetPointer() + x; }
  158. /// \brief Return how many bytes are allocated in total
  159. /// \return Number of bytes allocated in total
  160. size_t GetSizeInBytes() const { return n_ * sizeof(T); }
  161. private:
  162. size_t n_;
  163. allocator alloc_;
  164. std::unique_ptr<T[], std::function<void(T *)>> ptr_;
  165. };
  166. } // namespace dataset
  167. } // namespace mindspore
  168. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_