/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_ #include #include #include #include #include #include "minddata/dataset/util/memory_pool.h" namespace mindspore { namespace dataset { // The following conforms to the requirements of // std::allocator. Do not rename/change any needed // requirements, e.g. function names, typedef etc. template class Allocator { public: template friend class Allocator; using value_type = T; using pointer = T *; using const_pointer = const T *; using reference = T &; using const_reference = const T &; using size_type = uint64_t; template struct rebind { using other = Allocator; }; using propagate_on_container_copy_assignment = std::true_type; using propagate_on_container_move_assignment = std::true_type; using propagate_on_container_swap = std::true_type; explicit Allocator(const std::shared_ptr &b) : pool_(b) {} ~Allocator() = default; template explicit Allocator(Allocator const &rhs) : pool_(rhs.pool_) {} template bool operator==(Allocator const &rhs) const { return pool_ == rhs.pool_; } template bool operator!=(Allocator const &rhs) const { return pool_ != rhs.pool_; } pointer allocate(std::size_t n) { void *p; Status rc = pool_->Allocate(n * sizeof(T), &p); if (rc.IsOk()) { return reinterpret_cast(p); } else if (rc.IsOutofMemory()) { throw std::bad_alloc(); } else { throw std::exception(); } } void deallocate(pointer p, std::size_t n = 0) noexcept { pool_->Deallocate(p); } size_type max_size() { return pool_->get_max_size(); } private: std::shared_ptr pool_; }; /// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will /// be released when the object goes out of scope /// \tparam T The type of object to be allocated /// \tparam C Allocator. Default to std::allocator template > class MemGuard { public: using allocator = C; MemGuard() : n_(0) {} explicit MemGuard(allocator a) : n_(0), alloc_(a) {} // There is no copy constructor nor assignment operator because the memory is solely owned by this object. MemGuard(const MemGuard &) = delete; MemGuard &operator=(const MemGuard &) = delete; // On the other hand, We can support move constructor MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {} MemGuard &operator=(MemGuard &&lhs) noexcept { if (this != &lhs) { this->deallocate(); n_ = lhs.n_; alloc_ = std::move(lhs.alloc_); ptr_ = std::move(lhs.ptr_); } return *this; } /// \brief Explicitly deallocate the memory if allocated void deallocate() { if (ptr_) { auto *p = ptr_.release(); if (!std::is_arithmetic::value && std::is_destructible::value) { for (auto i = 0; i < n_; ++i) { p[i].~T(); } } alloc_.deallocate(p, n_); n_ = 0; } } /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is /// allocated. /// \param n Number of objects of type T to be allocated /// \tparam Args Extra arguments pass to the constructor of T template Status allocate(size_t n, Args &&... args) noexcept { try { deallocate(); if (n > 0) { T *data = alloc_.allocate(n); if (!std::is_arithmetic::value) { for (auto i = 0; i < n; i++) { std::allocator_traits::construct(alloc_, &(data[i]), std::forward(args)...); } } ptr_ = std::unique_ptr(data); n_ = n; } } catch (const std::bad_alloc &e) { return Status(StatusCode::kOutOfMemory); } catch (std::exception &e) { RETURN_STATUS_UNEXPECTED(e.what()); } return Status::OK(); } ~MemGuard() noexcept { deallocate(); } /// \brief Getter function /// \return The pointer to the memory allocated T *GetPointer() const { return ptr_.get(); } /// \brief Getter function /// \return The pointer to the memory allocated T *GetMutablePointer() { return ptr_.get(); } /// \brief Overload [] operator to access a particular element /// \param x index to the element. Must be less than number of element allocated. /// \return pointer to the x-th element T *operator[](size_t x) { return GetMutablePointer() + x; } /// \brief Overload [] operator to access a particular element /// \param x index to the element. Must be less than number of element allocated. /// \return pointer to the x-th element T *operator[](size_t x) const { return GetPointer() + x; } /// \brief Return how many bytes are allocated in total /// \return Number of bytes allocated in total size_t GetSizeInBytes() const { return n_ * sizeof(T); } private: allocator alloc_; std::unique_ptr ptr_; size_t n_; }; } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_