| @@ -1 +1 @@ | |||
| Subproject commit c460176523d039c8995f1d71089753725ebc0792 | |||
| Subproject commit df57a6cf9450e347d1854687d1fe66a420ee3b35 | |||
| @@ -277,10 +277,11 @@ endif () | |||
| if (USE_GLOG) | |||
| target_link_libraries(inference PRIVATE mindspore::glog) | |||
| else() | |||
| if (CMAKE_SYSTEM_NAME MATCHES "Linux") | |||
| target_link_options(inference PRIVATE -Wl,-init,mindspore_log_init) | |||
| elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| set_target_properties(inference PROPERTIES MACOSX_RPATH ON) | |||
| endif () | |||
| endif() | |||
| if (CMAKE_SYSTEM_NAME MATCHES "Linux") | |||
| target_link_options(inference PRIVATE -Wl,-init,common_log_init) | |||
| elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| set_target_properties(inference PROPERTIES MACOSX_RPATH ON) | |||
| endif () | |||
| @@ -30,8 +30,7 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow | |||
| BOUNDING_BOX_CHECK(input); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal"); | |||
| (*output).push_back(nullptr); // init memory for return vector | |||
| (*output).push_back(nullptr); | |||
| output->resize(2); | |||
| (*output)[1] = std::move(input[1]); // move boxes over to output | |||
| size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor | |||
| @@ -36,8 +36,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) | |||
| int32_t padded_image_h; | |||
| int32_t padded_image_w; | |||
| (*output).push_back(nullptr); | |||
| (*output).push_back(nullptr); | |||
| output->resize(2); | |||
| (*output)[1] = std::move(input[1]); // since some boxes may be removed | |||
| bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches | |||
| @@ -45,8 +45,7 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow * | |||
| RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y)); | |||
| } | |||
| (*output).push_back(nullptr); | |||
| (*output).push_back(nullptr); | |||
| output->resize(2); | |||
| (*output)[1] = std::move(input[1]); | |||
| return VerticalFlip(input[0], &(*output)[0]); | |||
| @@ -2,6 +2,8 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(utils OBJECT | |||
| arena.cc | |||
| buddy.cc | |||
| cache_pool.cc | |||
| circular_pool.cc | |||
| memory_pool.cc | |||
| cond_var.cc | |||
| @@ -11,7 +13,11 @@ add_library(utils OBJECT | |||
| service.cc | |||
| services.cc | |||
| lock.cc | |||
| semaphore.cc | |||
| status.cc | |||
| storage_container.cc | |||
| storage_manager.cc | |||
| slice.cc | |||
| path.cc | |||
| wait_post.cc | |||
| sig_handler.cc) | |||
| @@ -17,8 +17,10 @@ | |||
| #define DATASET_UTIL_ALLOCATOR_H_ | |||
| #include <cstdlib> | |||
| #include <functional> | |||
| #include <memory> | |||
| #include <type_traits> | |||
| #include <utility> | |||
| #include "dataset/util/memory_pool.h" | |||
| namespace mindspore { | |||
| @@ -84,6 +86,91 @@ class Allocator { | |||
| private: | |||
| std::shared_ptr<MemoryPool> pool_; | |||
| }; | |||
| /// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will | |||
| /// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator. | |||
| /// Default to std::allocator | |||
| template <typename T, typename C = std::allocator<T>> | |||
| class MemGuard { | |||
| public: | |||
| using allocator = C; | |||
| MemGuard() : n_(0) {} | |||
| explicit MemGuard(allocator a) : n_(0), alloc_(a) {} | |||
| // There is no copy constructor nor assignment operator because the memory is solely owned by this object. | |||
| MemGuard(const MemGuard &) = delete; | |||
| MemGuard &operator=(const MemGuard &) = delete; | |||
| // On the other hand, We can support move constructor | |||
| MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {} | |||
| MemGuard &operator=(MemGuard &&lhs) noexcept { | |||
| if (this != &lhs) { | |||
| this->deallocate(); | |||
| n_ = lhs.n_; | |||
| alloc_ = std::move(lhs.alloc_); | |||
| ptr_ = std::move(lhs.ptr_); | |||
| } | |||
| return *this; | |||
| } | |||
| /// \brief Explicitly deallocate the memory if allocated | |||
| void deallocate() { | |||
| if (ptr_) { | |||
| auto *p = ptr_.release(); | |||
| if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) { | |||
| for (auto i = 0; i < n_; ++i) { | |||
| p[i].~T(); | |||
| } | |||
| } | |||
| alloc_.deallocate(p, n_); | |||
| n_ = 0; | |||
| } | |||
| } | |||
| /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is | |||
| /// allocated. | |||
| /// \param n Number of objects of type T to be allocated | |||
| /// \tparam Args Extra arguments pass to the constructor of T | |||
| template <typename... Args> | |||
| Status allocate(size_t n, Args &&... args) noexcept { | |||
| try { | |||
| deallocate(); | |||
| if (n > 0) { | |||
| T *data = alloc_.allocate(n); | |||
| if (!std::is_arithmetic<T>::value) { | |||
| for (auto i = 0; i < n; i++) { | |||
| std::allocator_traits<C>::construct(alloc_, &(data[i]), std::forward<Args>(args)...); | |||
| } | |||
| } | |||
| ptr_ = std::unique_ptr<T[]>(data); | |||
| n_ = n; | |||
| } | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } catch (std::exception &e) { | |||
| RETURN_STATUS_UNEXPECTED(e.what()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| ~MemGuard() noexcept { deallocate(); } | |||
| /// \brief Getter function | |||
| /// \return The pointer to the memory allocated | |||
| T *GetPointer() const { return ptr_.get(); } | |||
| /// \brief Getter function | |||
| /// \return The pointer to the memory allocated | |||
| T *GetMutablePointer() { return ptr_.get(); } | |||
| /// \brief Overload [] operator to access a particular element | |||
| /// \param x index to the element. Must be less than number of element allocated. | |||
| /// \return pointer to the x-th element | |||
| T *operator[](size_t x) { return GetMutablePointer() + x; } | |||
| /// \brief Overload [] operator to access a particular element | |||
| /// \param x index to the element. Must be less than number of element allocated. | |||
| /// \return pointer to the x-th element | |||
| T *operator[](size_t x) const { return GetPointer() + x; } | |||
| /// \brief Return how many bytes are allocated in total | |||
| /// \return Number of bytes allocated in total | |||
| size_t GetSizeInBytes() const { return n_ * sizeof(T); } | |||
| private: | |||
| allocator alloc_; | |||
| std::unique_ptr<T[], std::function<void(T *)>> ptr_; | |||
| size_t n_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -91,7 +91,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> { | |||
| } | |||
| private: | |||
| static constexpr key_type kMinKey = 1; | |||
| static constexpr key_type kMinKey = 0; | |||
| std::atomic<key_type> inx_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -0,0 +1,388 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/util/buddy.h" | |||
| #include <iomanip> | |||
| #include <stdexcept> | |||
| #include "dataset/util/de_error.h" | |||
| #include "dataset/util/memory_pool.h" | |||
| #include "dataset/util/system_pool.h" | |||
| #include "./securec.h" | |||
| inline uint64_t BitLeftShift(uint64_t v, uint64_t n) { return (v << n); } | |||
| inline uint64_t BitRightShift(uint64_t v, uint64_t n) { return (v >> n); } | |||
| inline uint64_t BitOr(uint64_t rhs, uint64_t lhs) { return rhs | lhs; } | |||
| inline uint64_t BitEx(uint64_t rhs, uint64_t lhs) { return rhs ^ lhs; } | |||
| inline uint64_t BitAnd(uint64_t rhs, uint64_t lhs) { return rhs & lhs; } | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status BuddySpace::Init() { | |||
| if (log_min_ < 0) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||
| "log_min must be positive : " + std::to_string(log_min_)); | |||
| } | |||
| if (num_lvl_ < 3 || num_lvl_ > 18) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||
| "num_lvl must be between 3 and 18 : " + std::to_string(num_lvl_)); | |||
| } | |||
| min_ = BitLeftShift(1, log_min_); | |||
| max_ = BitLeftShift(1, log_min_ + num_lvl_ - 1); | |||
| size_t offset_1 = sizeof(rel_addr_t) * num_lvl_; | |||
| size_t offset_2 = sizeof(int) * num_lvl_ + offset_1; | |||
| size_t offset_3 = sizeof(char) * BitLeftShift(1, num_lvl_ - 3) + offset_2; | |||
| RETURN_IF_NOT_OK(DeMalloc(offset_3, &ptr_, true)); | |||
| hint_ = reinterpret_cast<rel_addr_t *>(ptr_); | |||
| count_ = reinterpret_cast<int *>((reinterpret_cast<char *>(ptr_) + offset_1)); | |||
| map_ = reinterpret_cast<char *>(ptr_) + offset_2; | |||
| count_[num_lvl_ - 1] = 1; | |||
| map_[0] = BitOr(MORE_BIT, num_lvl_ - 3); | |||
| return Status::OK(); | |||
| } | |||
| Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) noexcept { | |||
| std::lock_guard<std::mutex> lock(mutex_); | |||
| addr_t addr = AllocNoLock(sz, desc); | |||
| if (addr != NOSPACE) { | |||
| *p = addr; | |||
| return Status::OK(); | |||
| } else { | |||
| return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore."); | |||
| } | |||
| } | |||
| addr_t BuddySpace::AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept { | |||
| DS_ASSERT(sz <= max_); | |||
| uint32_t reqSize = SizeToBlock(sz); | |||
| rel_addr_t rel_addr = AllocBuddySeg(reqSize); | |||
| if (rel_addr != static_cast<rel_addr_t>(NOSPACE)) { | |||
| (void)memset_s(desc, sizeof(BSpaceDescriptor), 0, sizeof(BSpaceDescriptor)); | |||
| desc->sig = static_cast<int>(0xDEADBEEF); | |||
| desc->addr = rel_addr; | |||
| desc->req_size = reqSize; | |||
| desc->blk_size = NextPowerOf2(reqSize); | |||
| return static_cast<addr_t>(rel_addr * min_); | |||
| } else { | |||
| return NOSPACE; | |||
| } | |||
| } | |||
| void BuddySpace::FreeNoLock(const BSpaceDescriptor *desc) { | |||
| DS_ASSERT(desc->sig == 0XDEADBEEF); | |||
| rel_addr_t rel_addr = desc->addr; | |||
| size_t blk_size = desc->blk_size; | |||
| size_t req_size = desc->req_size; | |||
| FreeBuddySeg(rel_addr, blk_size, req_size); | |||
| } | |||
| void BuddySpace::Free(const BSpaceDescriptor *desc) { | |||
| std::lock_guard<std::mutex> lock(mutex_); | |||
| return FreeNoLock(desc); | |||
| } | |||
| std::ostream &operator<<(std::ostream &os, const BuddySpace &s) { | |||
| os << "1 unit = " << s.GetMinSize() << "\n" | |||
| << "Size of buddy space = " << s.GetMaxSize() << "\n" | |||
| << "Number of levels = " << s.num_lvl_ << "\n\n" | |||
| << "Percent free = " << s.PercentFree() << "\n" | |||
| << "Dumping count array : " | |||
| << "\n"; | |||
| for (int i = 0; i < s.num_lvl_; i++) { | |||
| os << "[" << i << "] = " << s.count_[i] << " "; | |||
| if (((i + 1) % 4) == 0) { | |||
| os << "\n"; | |||
| } | |||
| } | |||
| os << "\n"; | |||
| os << "Dumping allocation info:" | |||
| << "\n"; | |||
| auto max_addr = static_cast<rel_addr_t>(BitLeftShift(1, s.num_lvl_ - 1)); | |||
| rel_addr_t addr = 0; | |||
| while (addr < max_addr) { | |||
| size_t sz = 0; | |||
| BuddySpace::STATE st; | |||
| s.GetBuddySegState(addr, &sz, &st); | |||
| os << "Address : " << std::left << std::setw(8) << addr << " Size : " << std::setw(8) << sz << " State : " | |||
| << ((st == BuddySpace::STATE::kAlloc) ? "ALLOC" : ((st == BuddySpace::STATE::kFree) ? "FREE" : "Unkonwn")) | |||
| << "\n"; | |||
| addr += sz; | |||
| } | |||
| return os; | |||
| } | |||
| void BuddySpace::GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const { | |||
| char byte; | |||
| int pos; | |||
| int offset; | |||
| uint64_t val = 0; | |||
| int shift; | |||
| pos = BitRightShift(rel_addr, 2); | |||
| offset = rel_addr % 4; | |||
| shift = offset * 2; | |||
| byte = map_[pos]; | |||
| switch (offset) { | |||
| case 0: | |||
| val = byte; | |||
| break; | |||
| case 1: | |||
| case 3: | |||
| if (offset == 1) { | |||
| val = BitLeftShift(BitAnd(byte, 0x30), shift); | |||
| } else { | |||
| val = BitLeftShift(BitAnd(byte, 0x03), shift); | |||
| } | |||
| break; | |||
| case 2: | |||
| val = BitLeftShift(BitAnd(byte, 0x0F), shift); | |||
| break; | |||
| } | |||
| if (BitAnd(val, ONE_BIT)) { | |||
| *rel_sz = 1; | |||
| } else if (BitAnd(val, TWO_BIT)) { | |||
| *rel_sz = 2; | |||
| } else if (BitAnd(val, MORE_BIT)) { | |||
| log_t lg = BitAnd(val, 0x0F); | |||
| *rel_sz = BitLeftShift(1, lg + 2); | |||
| } else { | |||
| *st = STATE::kEmpty; | |||
| return; | |||
| } | |||
| *st = BitAnd(val, ALLOC_BIT) ? STATE::kAlloc : STATE::kFree; | |||
| } | |||
| void BuddySpace::SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st) { | |||
| int clr; | |||
| int mask; | |||
| int pos; | |||
| int offset; | |||
| int val = 0; | |||
| int shift; | |||
| auto log_sz = static_cast<log_t>(Log2(rel_sz)); | |||
| pos = BitRightShift(rel_addr, 2); | |||
| offset = rel_addr % 4; | |||
| shift = offset * 2; | |||
| if (rel_sz == 1) { | |||
| val = ONE_BIT; | |||
| mask = 0xC0; | |||
| } else if (rel_sz == 2) { | |||
| val = TWO_BIT; | |||
| mask = 0xF0; | |||
| } else { | |||
| val = BitOr(log_sz - 2, MORE_BIT); | |||
| mask = 0xFF; | |||
| } | |||
| if (st == STATE::kAlloc) { | |||
| val = BitOr(val, ALLOC_BIT); | |||
| } else if (st == STATE::kFree) { | |||
| val = BitAnd(val, ~(static_cast<uint64_t>(ALLOC_BIT))); | |||
| } else if (st == STATE::kEmpty) { | |||
| val = 0; | |||
| } | |||
| clr = static_cast<int>(~(BitRightShift(mask, shift))); | |||
| map_[pos] = static_cast<char>(BitAnd(map_[pos], clr)); | |||
| map_[pos] = static_cast<char>(BitOr(map_[pos], BitRightShift(val, shift))); | |||
| if (st == STATE::kAlloc) { | |||
| count_[log_sz]--; | |||
| } else if (st == STATE::kFree) { | |||
| count_[log_sz]++; | |||
| if (rel_addr < hint_[log_sz]) { | |||
| hint_[log_sz] = rel_addr; | |||
| } | |||
| } | |||
| } | |||
| void BuddySpace::JoinBuddySeg(rel_addr_t addr, size_t blk_sz) { | |||
| while (blk_sz < BitLeftShift(1, num_lvl_)) { | |||
| rel_addr_t buddy = BitEx(addr, blk_sz); | |||
| size_t sz = 0; | |||
| STATE st; | |||
| GetBuddySegState(buddy, &sz, &st); | |||
| if (st == STATE::kFree && sz == blk_sz) { | |||
| auto log_sz = static_cast<log_t>(Log2(blk_sz)); | |||
| rel_addr_t left = (buddy < addr) ? buddy : addr; | |||
| rel_addr_t right = left + blk_sz; | |||
| DS_ASSERT(count_[log_sz] >= 2); | |||
| count_[log_sz] -= 2; | |||
| SetBuddySegState(right, blk_sz, STATE::kEmpty); | |||
| SetBuddySegState(left, BitLeftShift(blk_sz, 1), STATE::kFree); | |||
| for (int i = 0; i < log_sz; i++) { | |||
| if (hint_[i] == right) { | |||
| hint_[i] = left; | |||
| } | |||
| } | |||
| addr = left; | |||
| blk_sz <<= 1u; | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| void BuddySpace::TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { | |||
| DS_ASSERT(ask_sz < blk_sz); | |||
| uint32_t inx = Log2(blk_sz); | |||
| size_t remaining_sz = ask_sz; | |||
| for (int i = inx; i > 0; i--) { | |||
| size_t b_size = BitLeftShift(1, i); | |||
| size_t half_sz = BitRightShift(b_size, 1); | |||
| count_[i]--; | |||
| SetBuddySegState(addr, half_sz, STATE::kFree); | |||
| SetBuddySegState(addr + half_sz, half_sz, STATE::kFree); | |||
| if (remaining_sz >= half_sz) { | |||
| SetBuddySegState(addr, half_sz, STATE::kAlloc); | |||
| remaining_sz -= half_sz; | |||
| if (remaining_sz == 0) { | |||
| break; | |||
| } | |||
| addr += half_sz; | |||
| } | |||
| } | |||
| } | |||
| void BuddySpace::UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { | |||
| DS_ASSERT(ask_sz < blk_sz); | |||
| uint32_t inx = Log2(blk_sz); | |||
| size_t remaining_sz = ask_sz; | |||
| for (int i = inx; i > 0; i--) { | |||
| size_t b_size = BitLeftShift(1, i); | |||
| size_t half_sz = BitRightShift(b_size, 1); | |||
| if (remaining_sz >= half_sz) { | |||
| #ifdef DEBUG | |||
| { | |||
| size_t sz = 0; | |||
| STATE st; | |||
| GetBuddySegState(addr, &sz, &st); | |||
| DS_ASSERT(sz == half_sz && st == STATE::kAlloc); | |||
| } | |||
| #endif | |||
| SetBuddySegState(addr, half_sz, STATE::kFree); | |||
| remaining_sz -= half_sz; | |||
| if (remaining_sz == 0) { | |||
| JoinBuddySeg(addr, half_sz); | |||
| break; | |||
| } | |||
| addr += half_sz; | |||
| } | |||
| } | |||
| } | |||
| rel_addr_t BuddySpace::AllocBuddySeg(uint32_t req_size) noexcept { | |||
| uint32_t blk_size = NextPowerOf2(req_size); | |||
| int start_inx = static_cast<int>(Log2(blk_size)); | |||
| bool found = false; | |||
| rel_addr_t ask_addr = 0; | |||
| auto max_addr = static_cast<rel_addr_t>(BitLeftShift(1, num_lvl_ - 1)); | |||
| STATE st; | |||
| size_t sz = 0; | |||
| for (int i = start_inx; !found && i < num_lvl_; i++) { | |||
| DS_ASSERT(count_[i] >= 0); | |||
| if (count_[i] == 0) { | |||
| continue; | |||
| } | |||
| auto blk_sz = static_cast<size_t>(BitLeftShift(1, i)); | |||
| ask_addr = hint_[i]; | |||
| while (ask_addr < max_addr && !found) { | |||
| GetBuddySegState(ask_addr, &sz, &st); | |||
| if (st == STATE::kFree && sz == blk_sz) { | |||
| found = true; | |||
| } else { | |||
| DS_ASSERT(st != STATE::kEmpty); | |||
| ask_addr += ((sz > blk_sz) ? sz : blk_sz); | |||
| } | |||
| } | |||
| } | |||
| if (found) { | |||
| if (sz > req_size) { | |||
| TrimBuddySeg(ask_addr, sz, req_size); | |||
| } else { | |||
| SetBuddySegState(ask_addr, sz, STATE::kAlloc); | |||
| hint_[start_inx] = ask_addr; | |||
| } | |||
| return ask_addr; | |||
| } else { | |||
| return static_cast<rel_addr_t>(NOSPACE); | |||
| } | |||
| } | |||
| void BuddySpace::FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size) { | |||
| if (req_size == blk_size) { | |||
| #ifdef DEBUG | |||
| { | |||
| size_t sz = 0; | |||
| STATE st; | |||
| GetBuddySegState(addr, &sz, &st); | |||
| } | |||
| #endif | |||
| SetBuddySegState(addr, blk_size, STATE::kFree); | |||
| JoinBuddySeg(addr, blk_size); | |||
| } else { | |||
| UnTrimBuddySeg(addr, blk_size, req_size); | |||
| } | |||
| } | |||
| int BuddySpace::PercentFree() const { | |||
| uint64_t total_free_sz = 0; | |||
| uint64_t max_sz_in_unit = BitLeftShift(1, num_lvl_ - 1); | |||
| // Go through the count array without lock | |||
| for (int i = 0; i < num_lvl_; i++) { | |||
| int cnt = count_[i]; | |||
| if (cnt == 0) { | |||
| continue; | |||
| } | |||
| uint64_t blk_sz = BitLeftShift(1, i); | |||
| total_free_sz += (blk_sz * cnt); | |||
| } | |||
| return static_cast<int>(static_cast<float>(total_free_sz) / static_cast<float>(max_sz_in_unit) * 100); | |||
| } | |||
| BuddySpace::BuddySpace(int log_min, int num_lvl) | |||
| : hint_(nullptr), | |||
| count_(nullptr), | |||
| map_(nullptr), | |||
| log_min_(log_min), | |||
| num_lvl_(num_lvl), | |||
| min_(0), | |||
| max_(0), | |||
| ptr_(nullptr) {} | |||
| BuddySpace::~BuddySpace() { | |||
| if (ptr_ != nullptr) { | |||
| free(ptr_); | |||
| } | |||
| hint_ = nullptr; | |||
| count_ = nullptr; | |||
| map_ = nullptr; | |||
| } | |||
| Status BuddySpace::CreateBuddySpace(std::unique_ptr<BuddySpace> *out_bs, int log_min, int num_lvl) { | |||
| Status rc; | |||
| auto bs = new (std::nothrow) BuddySpace(log_min, num_lvl); | |||
| if (bs == nullptr) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| rc = bs->Init(); | |||
| if (rc.IsOk()) { | |||
| (*out_bs).reset(bs); | |||
| } else { | |||
| delete bs; | |||
| } | |||
| return rc; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,133 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_UTIL_BUDDY_H_ | |||
| #define DATASET_UTIL_BUDDY_H_ | |||
| #include <cstddef> | |||
| #include <cstdint> | |||
| #include <cstring> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include "dataset/util/status.h" | |||
| using addr_t = int64_t; | |||
| using rel_addr_t = int32_t; | |||
| using log_t = int; | |||
| #define ALLOC_BIT 0x80 | |||
| #define ONE_BIT 0x40 | |||
| #define TWO_BIT 0x20 | |||
| #define MORE_BIT 0x10 | |||
| #define NOSPACE ((addr_t)(-1)) | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| struct BSpaceDescriptor { | |||
| int32_t sig; | |||
| rel_addr_t addr; | |||
| size_t req_size; | |||
| size_t blk_size; | |||
| }; | |||
| class BuddySpace { | |||
| public: | |||
| // C++11 feature. Change STATE into a type safe class with | |||
| // the keyword. Don't take out the keyword 'class' | |||
| enum class STATE { kFree, kAlloc, kEmpty }; | |||
| BuddySpace(const BuddySpace &) = delete; | |||
| BuddySpace &operator=(const BuddySpace &) = delete; | |||
| virtual ~BuddySpace(); | |||
| Status Alloc(uint64_t sz, BSpaceDescriptor *desc, addr_t *) noexcept; | |||
| void Free(const BSpaceDescriptor *desc); | |||
| uint64_t GetMinSize() const { return min_; } | |||
| uint64_t GetMaxSize() const { return max_; } | |||
| int PercentFree() const; | |||
| friend std::ostream &operator<<(std::ostream &os, const BuddySpace &s); | |||
| static uint64_t NextPowerOf2(uint64_t n) { | |||
| if (n <= 1) { | |||
| return 1; | |||
| } | |||
| n = n - 1; | |||
| while (n & (n - 1)) { | |||
| n = n & (n - 1); | |||
| } | |||
| return n << 1; | |||
| } | |||
| static uint32_t Log2(uint64_t n) { | |||
| uint32_t cnt = 0; | |||
| while (n >>= 1) { | |||
| cnt++; | |||
| } | |||
| return cnt; | |||
| } | |||
| static Status CreateBuddySpace(std::unique_ptr<BuddySpace> *out_bs, int log_min = 15, int num_lvl = 18); | |||
| private: | |||
| rel_addr_t *hint_; | |||
| int *count_; | |||
| char *map_; | |||
| int log_min_; | |||
| int num_lvl_; | |||
| uint64_t min_; | |||
| uint64_t max_; | |||
| void *ptr_; | |||
| std::mutex mutex_; | |||
| explicit BuddySpace(int log_min = 15, int num_lvl = 18); | |||
| Status Init(); | |||
| addr_t AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept; | |||
| void FreeNoLock(const BSpaceDescriptor *desc); | |||
| uint32_t SizeToBlock(const uint64_t sz) const { | |||
| uint32_t reqSize = (sz / min_); | |||
| if (sz % min_) { | |||
| reqSize++; | |||
| } | |||
| return reqSize; | |||
| } | |||
| void GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const; | |||
| void SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st); | |||
| void JoinBuddySeg(rel_addr_t addr, size_t blk_sz); | |||
| void TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); | |||
| void UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); | |||
| rel_addr_t AllocBuddySeg(uint32_t req_size) noexcept; | |||
| void FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_UTIL_BUDDY_H_ | |||
| @@ -0,0 +1,202 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include "common/utils.h" | |||
| #include "dataset/util/cache_pool.h" | |||
| #include "dataset/util/services.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CachePool::CachePool(const value_allocator &alloc, const std::string &root) | |||
| : alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} | |||
| Status CachePool::DoServiceStart() { | |||
| tree_ = std::make_shared<data_index>(); | |||
| // If we are given a disk path, set up the StorageManager | |||
| if (!root_.toString().empty()) { | |||
| Path spill = GetSpillPath(); | |||
| RETURN_IF_NOT_OK(spill.CreateDirectories()); | |||
| sm_ = std::make_shared<StorageManager>(spill); | |||
| RETURN_IF_NOT_OK(sm_->ServiceStart()); | |||
| MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CachePool::DoServiceStop() { | |||
| Status rc; | |||
| Status rc2; | |||
| if (sm_ != nullptr) { | |||
| rc = sm_->ServiceStop(); | |||
| if (rc.IsError()) { | |||
| rc2 = rc; | |||
| } | |||
| } | |||
| sm_.reset(); | |||
| for (auto &bl : *tree_) { | |||
| if (bl.ptr != nullptr) { | |||
| alloc_.deallocate(bl.ptr, bl.sz); | |||
| } | |||
| } | |||
| tree_.reset(); | |||
| if (!root_.toString().empty()) { | |||
| Path spill = GetSpillPath(); | |||
| auto it = Path::DirIterator::OpenDirectory(&spill); | |||
| while (it->hasNext()) { | |||
| rc = it->next().Remove(); | |||
| if (rc.IsError() && rc2.IsOk()) { | |||
| rc2 = rc; | |||
| } | |||
| } | |||
| rc = spill.Remove(); | |||
| if (rc.IsError() && rc2.IsOk()) { | |||
| rc2 = rc; | |||
| } | |||
| } | |||
| return rc2; | |||
| } | |||
| CachePool::~CachePool() noexcept { (void)ServiceStop(); } | |||
| Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_type *key) { | |||
| DataLocator bl; | |||
| Status rc; | |||
| size_t sz = 0; | |||
| // We will consolidate all the slices into one piece. | |||
| for (auto &v : buf) { | |||
| sz += v.GetSize(); | |||
| } | |||
| bl.sz = sz; | |||
| try { | |||
| bl.ptr = alloc_.allocate(sz); | |||
| // We will do a piecewise copy. | |||
| WritableSlice dest(bl.ptr, bl.sz); | |||
| size_t pos = 0; | |||
| for (auto &v : buf) { | |||
| WritableSlice out(dest, pos); | |||
| rc = WritableSlice::Copy(&out, v); | |||
| if (rc.IsError()) { | |||
| break; | |||
| } | |||
| pos += v.GetSize(); | |||
| } | |||
| if (rc.IsError()) { | |||
| alloc_.deallocate(bl.ptr, sz); | |||
| bl.ptr = nullptr; | |||
| return rc; | |||
| } | |||
| } catch (std::bad_alloc &e) { | |||
| if (sm_ != nullptr) { | |||
| RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); | |||
| // We have an assumption 0 is not a valid key from the design of AutoIndexObj. | |||
| // Make sure it is not 0. | |||
| if (bl.storage_key == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Key 0 is returned which is unexpected"); | |||
| } | |||
| } else { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| rc = tree_->insert(bl, key); | |||
| if (rc.IsError() && bl.ptr != nullptr) { | |||
| alloc_.deallocate(bl.ptr, sz); | |||
| } | |||
| return rc; | |||
| } | |||
| Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const { | |||
| RETURN_UNEXPECTED_IF_NULL(dest); | |||
| auto r = tree_->Search(key); | |||
| if (r.second) { | |||
| auto &it = r.first; | |||
| if (it->ptr != nullptr) { | |||
| ReadableSlice src(it->ptr, it->sz); | |||
| RETURN_IF_NOT_OK(WritableSlice::Copy(dest, src)); | |||
| } else if (sm_ != nullptr) { | |||
| size_t expectedLength = 0; | |||
| RETURN_IF_NOT_OK(sm_->Read(it->storage_key, dest, &expectedLength)); | |||
| if (expectedLength != it->sz) { | |||
| MS_LOG(ERROR) << "Unexpected length. Read " << expectedLength << ". Expected " << it->sz << "." | |||
| << " Internal key: " << key << "\n"; | |||
| RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); | |||
| } | |||
| } | |||
| if (bytesRead != nullptr) { | |||
| *bytesRead = it->sz; | |||
| } | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Key not found"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; } | |||
| Path CachePool::GetSpillPath() const { | |||
| auto spill = Path(root_) / subfolder_; | |||
| return spill; | |||
| } | |||
| CachePool::CacheStat CachePool::GetStat() const { | |||
| CacheStat cs{0}; | |||
| for (auto &it : *tree_) { | |||
| if (it.ptr != nullptr) { | |||
| ++cs.num_mem_cached; | |||
| } else { | |||
| ++cs.num_disk_cached; | |||
| } | |||
| } | |||
| return cs; | |||
| } | |||
| Status CachePool::Spill(CachePool::DataLocator *dl) { | |||
| if (sm_ == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("No disk storage to spill"); | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(dl); | |||
| RETURN_UNEXPECTED_IF_NULL(dl->ptr); | |||
| if (dl->storage_key == 0) { | |||
| ReadableSlice data(dl->ptr, dl->sz); | |||
| RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data})); | |||
| } | |||
| alloc_.deallocate(dl->ptr, dl->sz); | |||
| dl->ptr = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| Status CachePool::Locate(CachePool::DataLocator *dl) { | |||
| RETURN_UNEXPECTED_IF_NULL(dl); | |||
| if (dl->ptr == nullptr) { | |||
| if (sm_ == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("No disk storage to locate the data"); | |||
| } | |||
| try { | |||
| dl->ptr = alloc_.allocate(dl->sz); | |||
| WritableSlice dest(dl->ptr, dl->sz); | |||
| Status rc = Read(dl->storage_key, &dest); | |||
| if (rc.IsError()) { | |||
| alloc_.deallocate(dl->ptr, dl->sz); | |||
| dl->ptr = nullptr; | |||
| return rc; | |||
| } | |||
| } catch (const std::bad_alloc &e) { | |||
| return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| size_t CachePool::GetSize(CachePool::key_type key) const { | |||
| auto r = tree_->Search(key); | |||
| if (r.second) { | |||
| auto &it = r.first; | |||
| return it->sz; | |||
| } else { | |||
| return 0; | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,139 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_UTIL_CACHE_POOL_H_ | |||
| #define DATASET_UTIL_CACHE_POOL_H_ | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "dataset/util/allocator.h" | |||
| #include "dataset/util/service.h" | |||
| #include "dataset/util/slice.h" | |||
| #include "dataset/util/storage_manager.h" | |||
| #include "dataset/util/auto_index.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of | |||
| /// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to | |||
| /// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to | |||
| /// restore the buffer. | |||
| /// \see ReadableSlice | |||
| class CachePool : public Service { | |||
| public: | |||
| using base_type = uint8_t; | |||
| using pointer = base_type *; | |||
| using const_pointer = const base_type *; | |||
| using reference = base_type &; | |||
| using const_reference = const base_type &; | |||
| using value_allocator = Allocator<base_type>; | |||
| // An internal class to locate the whereabouts of a backed up buffer which can be either in | |||
| class DataLocator { | |||
| public: | |||
| DataLocator() : ptr(nullptr), sz(0), storage_key(0) {} | |||
| ~DataLocator() = default; | |||
| DataLocator(const DataLocator &other) = default; | |||
| DataLocator &operator=(const DataLocator &other) = default; | |||
| DataLocator(DataLocator &&other) noexcept { | |||
| ptr = other.ptr; | |||
| sz = other.sz; | |||
| storage_key = other.storage_key; | |||
| other.ptr = nullptr; | |||
| other.sz = 0; | |||
| other.storage_key = 0; | |||
| } | |||
| DataLocator &operator=(DataLocator &&other) noexcept { | |||
| if (&other != this) { | |||
| ptr = other.ptr; | |||
| sz = other.sz; | |||
| storage_key = other.storage_key; | |||
| other.ptr = nullptr; | |||
| other.sz = 0; | |||
| other.storage_key = 0; | |||
| } | |||
| return *this; | |||
| } | |||
| pointer ptr; | |||
| size_t sz; | |||
| StorageManager::key_type storage_key; | |||
| }; | |||
| using data_index = AutoIndexObj<DataLocator>; | |||
| using key_type = data_index::key_type; | |||
| using bl_alloc_type = typename value_allocator::template rebind<DataLocator>::other; | |||
| /// \brief Simple statistics returned from CachePool like how many elements are cached in memory and | |||
| /// how many elements are spilled to disk. | |||
| struct CacheStat { | |||
| int64_t num_mem_cached; | |||
| int64_t num_disk_cached; | |||
| }; | |||
| /// \brief Constructor | |||
| /// \param alloc Allocator to allocate memory from | |||
| /// \param root Optional disk folder to spill | |||
| explicit CachePool(const value_allocator &alloc, const std::string &root = ""); | |||
| CachePool(const CachePool &) = delete; | |||
| CachePool(CachePool &&) = delete; | |||
| CachePool &operator=(const CachePool &) = delete; | |||
| CachePool &operator=(CachePool &&) = delete; | |||
| ~CachePool() noexcept; | |||
| Status DoServiceStart() override; | |||
| Status DoServiceStop() override; | |||
| Path GetSpillPath() const; | |||
| /// \brief Insert a sequence of ReadableSlice objects into the pool. | |||
| /// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk. | |||
| /// \param[in] buf A sequence of ReadableSlice objects. | |||
| /// \param[out] key Generated key | |||
| /// \return Error code | |||
| Status Insert(const std::vector<ReadableSlice> &buf, key_type *key); | |||
| /// \brief Restore a cached buffer (from memory or disk) | |||
| /// \param[in] key A previous key returned from Insert | |||
| /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice | |||
| /// \param[out] bytesRead Optional. Number of bytes read. | |||
| /// \return Error code | |||
| Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const; | |||
| Status Spill(DataLocator *dl); | |||
| Status Locate(DataLocator *dl); | |||
| size_t GetSize(key_type key) const; | |||
| /// \brief Get statistics. | |||
| /// \return CacheStat object | |||
| CacheStat GetStat() const; | |||
| const value_allocator &get_allocator() const; | |||
| std::string MyName() const { return subfolder_; } | |||
| private: | |||
| value_allocator alloc_; | |||
| Path root_; | |||
| const std::string subfolder_; | |||
| std::shared_ptr<StorageManager> sm_; | |||
| std::shared_ptr<data_index> tree_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif | |||
| @@ -106,6 +106,24 @@ struct List { | |||
| ++count; | |||
| } | |||
| // Insert elem2 before elem1 in the list. | |||
| virtual void InsertBefore(pointer elem1, pointer elem2) { | |||
| DS_ASSERT(elem1 != elem2); | |||
| Node<T> &elem1_node = elem1->*node; | |||
| Node<T> &elem2_node = elem2->*node; | |||
| elem2_node.next = elem1; | |||
| elem2_node.prev = elem1_node.prev; | |||
| if (elem1_node.prev != nullptr) { | |||
| Node<T> &prev_node = elem1_node.prev->*node; | |||
| prev_node.next = elem2; | |||
| } | |||
| elem1_node.prev = elem2; | |||
| if (head == elem1) { | |||
| head = elem2; | |||
| } | |||
| ++count; | |||
| } | |||
| // Remove an element in the list | |||
| virtual void Remove(pointer elem) noexcept { | |||
| Node<T> &elem_node = elem->*node; | |||
| @@ -44,20 +44,6 @@ class MemoryPool { | |||
| virtual ~MemoryPool() {} | |||
| }; | |||
| // Used by unique_ptr | |||
| template <typename T> | |||
| class Deleter { | |||
| public: | |||
| explicit Deleter(std::shared_ptr<MemoryPool> &mp) : mp_(mp) {} | |||
| ~Deleter() = default; | |||
| void operator()(T *ptr) const { mp_->Deallocate(ptr); } | |||
| private: | |||
| std::shared_ptr<MemoryPool> mp_; | |||
| }; | |||
| Status DeMalloc(std::size_t s, void **p, bool); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -16,6 +16,8 @@ | |||
| #include "dataset/util/path.h" | |||
| #include <sys/stat.h> | |||
| #include <fcntl.h> | |||
| #include <unistd.h> | |||
| #include <new> | |||
| #include <sstream> | |||
| #include <utility> | |||
| @@ -26,7 +28,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| #ifdef _WIN32 | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| char Path::separator_ = '\\'; | |||
| #else | |||
| char Path::separator_ = '/'; | |||
| @@ -132,7 +134,7 @@ Status Path::CreateDirectory() { | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| int rc = mkdir(common::SafeCStr(path_)); | |||
| #else | |||
| int rc = mkdir(common::SafeCStr(path_), 0700); | |||
| int rc = mkdir(common::SafeCStr(path_), S_IRUSR | S_IWUSR | S_IXUSR); | |||
| #endif | |||
| if (rc) { | |||
| std::ostringstream oss; | |||
| @@ -182,6 +184,111 @@ Status Path::CreateDirectories() { | |||
| return Status::OK(); | |||
| } | |||
| Status Path::Remove() { | |||
| if (Exists()) { | |||
| if (IsDirectory()) { | |||
| errno_t err = rmdir(common::SafeCStr(path_)); | |||
| if (err == -1) { | |||
| std::ostringstream oss; | |||
| oss << "Unable to delete directory " << path_ << ". Errno = " << errno; | |||
| RETURN_STATUS_UNEXPECTED(oss.str()); | |||
| } | |||
| } else { | |||
| errno_t err = unlink(common::SafeCStr(path_)); | |||
| if (err == -1) { | |||
| std::ostringstream oss; | |||
| oss << "Unable to delete file " << path_ << ". Errno = " << errno; | |||
| RETURN_STATUS_UNEXPECTED(oss.str()); | |||
| } | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Path::CreateFile(int *file_descriptor) { return OpenFile(file_descriptor, true); } | |||
| Status Path::OpenFile(int *file_descriptor, bool create) { | |||
| int fd; | |||
| if (file_descriptor == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("null pointer"); | |||
| } | |||
| if (IsDirectory()) { | |||
| std::ostringstream oss; | |||
| oss << "Unable to create file " << path_ << " which is a directory."; | |||
| RETURN_STATUS_UNEXPECTED(oss.str()); | |||
| } | |||
| // Convert to canonical form. | |||
| if (strlen(common::SafeCStr(path_)) > PATH_MAX) { | |||
| RETURN_STATUS_UNEXPECTED(strerror(errno)); | |||
| } | |||
| char canonical_path[PATH_MAX + 1] = {0x00}; | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| if (_fullpath(canonical_path, common::SafeCStr(path_), PATH_MAX) == nullptr) { | |||
| #else | |||
| if (realpath(common::SafeCStr(path_), canonical_path) == nullptr) { | |||
| #endif | |||
| if (errno == ENOENT && create) { | |||
| // File doesn't exist and we are to create it. Let's break it down. | |||
| auto file_part = Basename(); | |||
| auto parent_part = ParentPath(); | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| if (_fullpath(canonical_path, common::SafeCStr(parent_part), PATH_MAX) == nullptr) { | |||
| #else | |||
| if (realpath(common::SafeCStr(parent_part), canonical_path) == nullptr) { | |||
| #endif | |||
| RETURN_STATUS_UNEXPECTED(strerror(errno)); | |||
| } | |||
| auto cur_inx = strlen(canonical_path); | |||
| if ((cur_inx + file_part.length() + 1) > PATH_MAX) { | |||
| RETURN_STATUS_UNEXPECTED(strerror(errno)); | |||
| } | |||
| canonical_path[cur_inx++] = separator_; | |||
| if (strncpy_s(canonical_path + cur_inx, PATH_MAX - cur_inx, common::SafeCStr(file_part), file_part.length()) != | |||
| EOK) { | |||
| RETURN_STATUS_UNEXPECTED(strerror(errno)); | |||
| } | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED(strerror(errno)); | |||
| } | |||
| } | |||
| if (create) { | |||
| fd = open(canonical_path, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP); | |||
| } else { | |||
| fd = open(canonical_path, O_RDWR); | |||
| } | |||
| if (fd == -1) { | |||
| RETURN_STATUS_UNEXPECTED(strerror(errno)); | |||
| } | |||
| *file_descriptor = fd; | |||
| return Status::OK(); | |||
| } | |||
| Status Path::CloseFile(int fd) const { | |||
| if (close(fd) < 0) { | |||
| RETURN_STATUS_UNEXPECTED(strerror(errno)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status Path::TruncateFile(int fd) const { | |||
| int rc; | |||
| rc = ftruncate(fd, 0); | |||
| if (rc == 0) { | |||
| return Status::OK(); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED(strerror(errno)); | |||
| } | |||
| } | |||
| std::string Path::Basename() { | |||
| std::size_t found = path_.find_last_of(separator_); | |||
| if (found != std::string::npos) { | |||
| return path_.substr(found + 1); | |||
| } else { | |||
| return path_; | |||
| } | |||
| } | |||
| std::shared_ptr<Path::DirIterator> Path::DirIterator::OpenDirectory(Path *f) { | |||
| auto it = new (std::nothrow) DirIterator(f); | |||
| @@ -208,7 +315,7 @@ Path::DirIterator::~DirIterator() { | |||
| Path::DirIterator::DirIterator(Path *f) : dir_(f), dp_(nullptr), entry_(nullptr) { | |||
| MS_LOG(DEBUG) << "Open directory " << f->toString() << "."; | |||
| dp_ = opendir(common::SafeCStr(f->toString())); | |||
| dp_ = opendir(f->toString().c_str()); | |||
| } | |||
| bool Path::DirIterator::hasNext() { | |||
| @@ -225,5 +332,10 @@ bool Path::DirIterator::hasNext() { | |||
| } | |||
| Path Path::DirIterator::next() { return (*(this->dir_) / Path(entry_->d_name)); } | |||
| std::ostream &operator<<(std::ostream &os, const Path &s) { | |||
| os << s.path_; | |||
| return os; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -90,6 +90,20 @@ class Path { | |||
| std::string ParentPath(); | |||
| Status Remove(); | |||
| Status CreateFile(int *fd); | |||
| Status OpenFile(int *fd, bool create = false); | |||
| Status CloseFile(int fd) const; | |||
| Status TruncateFile(int fd) const; | |||
| std::string Basename(); | |||
| friend std::ostream &operator<<(std::ostream &os, const Path &s); | |||
| private: | |||
| static char separator_; | |||
| std::string path_; | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/util/semaphore.h" | |||
| #include "dataset/util/task_manager.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status Semaphore::P() { | |||
| std::unique_lock<std::mutex> lck(mutex_); | |||
| RETURN_IF_NOT_OK(wait_cond_.Wait(&lck, [this]() { return value_ > 0; })); | |||
| --value_; | |||
| return Status::OK(); | |||
| } | |||
| void Semaphore::V() { | |||
| std::unique_lock<std::mutex> lck(mutex_); | |||
| ++value_; | |||
| wait_cond_.NotifyOne(); | |||
| } | |||
| int Semaphore::Peek() { | |||
| std::unique_lock<std::mutex> lck(mutex_); | |||
| return value_; | |||
| } | |||
| Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } | |||
| Status Semaphore::Deregister() { return (wait_cond_.Deregister()); } | |||
| void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_UTIL_SEMAPHORE_H_ | |||
| #define DATASET_UTIL_SEMAPHORE_H_ | |||
| #include "dataset/util/cond_var.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class TaskGroup; | |||
| /// \brief A counting semaphore. There are two external functions P and V. P decrements the internal count and will be | |||
| /// blocked if the count is 0 (zero). V increments the internal count and wake up one of the waiters. | |||
| class Semaphore { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param init Initial value of the internal counter. | |||
| explicit Semaphore(int init) : value_(init) {} | |||
| virtual ~Semaphore() {} | |||
| /// \brief Decrement the internal counter. Will be blocked if the value is 0. | |||
| /// \return Error code. Can get interrupt. | |||
| Status P(); | |||
| /// \brief Increment the internal counter. Wakeup on of the watiers if any. | |||
| void V(); | |||
| /// \brief Peek the internal value | |||
| /// \return The internal value | |||
| int Peek(); | |||
| Status Register(TaskGroup *vg); | |||
| Status Deregister(); | |||
| void ResetIntrpState(); | |||
| private: | |||
| int value_; | |||
| std::mutex mutex_; | |||
| CondVar wait_cond_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_UTIL_SEMAPHORE_H_ | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/util/slice.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset, size_t len) : ReadableSlice(src, offset, len) { | |||
| mutable_data_ = static_cast<char *>(src.mutable_data_) + offset; | |||
| } | |||
| WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset) | |||
| : WritableSlice(src, offset, src.GetSize() - offset) {} | |||
| Status WritableSlice::Copy(WritableSlice *dest, const ReadableSlice &src) { | |||
| RETURN_UNEXPECTED_IF_NULL(dest); | |||
| RETURN_UNEXPECTED_IF_NULL(dest->GetMutablePointer()); | |||
| if (dest->GetSize() <= 0) { | |||
| RETURN_STATUS_UNEXPECTED("Destination length is non-positive"); | |||
| } | |||
| auto err = memcpy_s(dest->GetMutablePointer(), dest->GetSize(), src.GetPointer(), src.GetSize()); | |||
| if (err) { | |||
| RETURN_STATUS_UNEXPECTED(std::to_string(err)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,122 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_UTIL_SLICE_H_ | |||
| #define DATASET_UTIL_SLICE_H_ | |||
| #include <unistd.h> | |||
| #include <cstddef> | |||
| #include <utility> | |||
| #include "./securec.h" | |||
| #include "dataset/util/allocator.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \brief A ReadableSlice wraps a const pointer in memory and its size. | |||
| /// \see WritableSlice for a non-const version | |||
| /// | |||
| class ReadableSlice { | |||
| public: | |||
| ReadableSlice() : ptr_(nullptr), sz_(0) {} | |||
| ReadableSlice(const void *ptr, size_t sz) : ptr_(ptr), sz_(sz) {} | |||
| ReadableSlice(const ReadableSlice &src, off64_t offset, size_t len) { | |||
| ptr_ = static_cast<const char *>(src.GetPointer()) + offset; | |||
| sz_ = len; | |||
| } | |||
| ReadableSlice(const ReadableSlice &src, off64_t offset) : ReadableSlice(src, offset, src.sz_ - offset) {} | |||
| ReadableSlice(const ReadableSlice &lhs) { | |||
| ptr_ = lhs.ptr_; | |||
| sz_ = lhs.sz_; | |||
| } | |||
| ReadableSlice &operator=(const ReadableSlice &lhs) { | |||
| if (this != &lhs) { | |||
| ptr_ = lhs.ptr_; | |||
| sz_ = lhs.sz_; | |||
| } | |||
| return *this; | |||
| } | |||
| ReadableSlice(ReadableSlice &&lhs) noexcept { | |||
| if (this != &lhs) { | |||
| ptr_ = lhs.ptr_; | |||
| sz_ = lhs.sz_; | |||
| lhs.ptr_ = nullptr; | |||
| lhs.sz_ = 0; | |||
| } | |||
| } | |||
| ReadableSlice &operator=(ReadableSlice &&lhs) noexcept { | |||
| if (this != &lhs) { | |||
| ptr_ = lhs.ptr_; | |||
| sz_ = lhs.sz_; | |||
| lhs.ptr_ = nullptr; | |||
| lhs.sz_ = 0; | |||
| } | |||
| return *this; | |||
| } | |||
| /// \brief Getter function | |||
| /// \return Const version of the pointer | |||
| const void *GetPointer() const { return ptr_; } | |||
| /// \brief Getter function | |||
| /// \return Size of the slice | |||
| size_t GetSize() const { return sz_; } | |||
| bool empty() const { return ptr_ == nullptr; } | |||
| private: | |||
| const void *ptr_; | |||
| size_t sz_; | |||
| }; | |||
| /// \brief A WritableSlice inherits from ReadableSlice to allow | |||
| /// one to write to the address pointed to by the pointer. | |||
| /// | |||
| class WritableSlice : public ReadableSlice { | |||
| public: | |||
| friend class StorageContainer; | |||
| /// \brief Default constructor | |||
| WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} | |||
| /// \brief This form of a constructor takes a pointer and its size. | |||
| WritableSlice(void *ptr, size_t sz) : ReadableSlice(ptr, sz), mutable_data_(ptr) {} | |||
| WritableSlice(const WritableSlice &src, off64_t offset, size_t len); | |||
| WritableSlice(const WritableSlice &src, off64_t offset); | |||
| WritableSlice(const WritableSlice &lhs) : ReadableSlice(lhs) { mutable_data_ = lhs.mutable_data_; } | |||
| WritableSlice &operator=(const WritableSlice &lhs) { | |||
| if (this != &lhs) { | |||
| mutable_data_ = lhs.mutable_data_; | |||
| ReadableSlice::operator=(lhs); | |||
| } | |||
| return *this; | |||
| } | |||
| WritableSlice(WritableSlice &&lhs) noexcept : ReadableSlice(std::move(lhs)) { | |||
| if (this != &lhs) { | |||
| mutable_data_ = lhs.mutable_data_; | |||
| lhs.mutable_data_ = nullptr; | |||
| } | |||
| } | |||
| WritableSlice &operator=(WritableSlice &&lhs) noexcept { | |||
| if (this != &lhs) { | |||
| mutable_data_ = lhs.mutable_data_; | |||
| lhs.mutable_data_ = nullptr; | |||
| ReadableSlice::operator=(std::move(lhs)); | |||
| } | |||
| return *this; | |||
| } | |||
| /// \brief Copy the content from one slice onto another. | |||
| static Status Copy(WritableSlice *dest, const ReadableSlice &src); | |||
| private: | |||
| void *mutable_data_; | |||
| void *GetMutablePointer() { return mutable_data_; } | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_UTIL_SLICE_H_ | |||
| @@ -0,0 +1,164 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/util/storage_container.h" | |||
| #include <fcntl.h> | |||
| #include <sys/stat.h> | |||
| #include <unistd.h> | |||
| #include <vector> | |||
| #include "common/utils.h" | |||
| #include "dataset/util/de_error.h" | |||
| #include "dataset/util/path.h" | |||
| #include "dataset/util/status.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status StorageContainer::Create() { | |||
| RETURN_IF_NOT_OK(BuddySpace::CreateBuddySpace(&bs_)); | |||
| RETURN_IF_NOT_OK(cont_.CreateFile(&fd_)); | |||
| is_open_ = true; | |||
| MS_LOG(INFO) << "Container " << cont_ << " created"; | |||
| return Status::OK(); | |||
| } | |||
| Status StorageContainer::Open() noexcept { | |||
| std::lock_guard<std::mutex> lck(mutex_); | |||
| // Check again | |||
| if (!is_open_) { | |||
| RETURN_IF_NOT_OK(cont_.OpenFile(&fd_)); | |||
| is_open_ = true; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status StorageContainer::Close() noexcept { | |||
| if (is_open_) { | |||
| std::lock_guard<std::mutex> lck(mutex_); | |||
| // Check again | |||
| if (is_open_) { | |||
| RETURN_IF_NOT_OK(cont_.CloseFile(fd_)); | |||
| is_open_ = false; | |||
| fd_ = -1; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status StorageContainer::Read(WritableSlice *dest, off64_t offset) const noexcept { | |||
| DS_ASSERT(is_open_); | |||
| RETURN_UNEXPECTED_IF_NULL(dest); | |||
| auto sz = dest->GetSize(); | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| // Doesn't seem there is any pread64 on mingw. | |||
| // So we will do a seek and then a read under | |||
| // a protection of mutex. | |||
| std::lock_guard<std::mutex> lck(mutex_); | |||
| auto seek_err = lseek(fd_, offset, SEEK_SET); | |||
| if (seek_err < 0) { | |||
| RETURN_STATUS_UNEXPECTED(strerror(errno)); | |||
| } | |||
| auto r_sz = read(fd_, dest->GetMutablePointer(), sz); | |||
| #else | |||
| auto r_sz = pread64(fd_, dest->GetMutablePointer(), sz, offset); | |||
| #endif | |||
| if (r_sz != sz) { | |||
| errno_t err = (r_sz == 0) ? EOF : errno; | |||
| RETURN_STATUS_UNEXPECTED(strerror(err)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const noexcept { | |||
| DS_ASSERT(is_open_); | |||
| auto sz = dest.GetSize(); | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| // Doesn't seem there is any pwrite64 on mingw. | |||
| // So we will do a seek and then a read under | |||
| // a protection of mutex. | |||
| std::lock_guard<std::mutex> lck(mutex_); | |||
| auto seek_err = lseek(fd_, offset, SEEK_SET); | |||
| if (seek_err < 0) { | |||
| RETURN_STATUS_UNEXPECTED(strerror(errno)); | |||
| } | |||
| auto r_sz = write(fd_, dest.GetPointer(), sz); | |||
| #else | |||
| auto r_sz = pwrite64(fd_, dest.GetPointer(), sz, offset); | |||
| #endif | |||
| if (r_sz != sz) { | |||
| errno_t err = (r_sz == 0) ? EOF : errno; | |||
| RETURN_STATUS_UNEXPECTED(strerror(err)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status StorageContainer::Insert(const std::vector<ReadableSlice> &buf, off64_t *offset) noexcept { | |||
| size_t sz = 0; | |||
| for (auto &v : buf) { | |||
| sz += v.GetSize(); | |||
| } | |||
| if (sz == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); | |||
| } | |||
| if (sz > bs_->GetMaxSize()) { | |||
| RETURN_STATUS_UNEXPECTED("Request size too big"); | |||
| } | |||
| BSpaceDescriptor bspd{0}; | |||
| addr_t addr = 0; | |||
| RETURN_IF_NOT_OK(bs_->Alloc(sz, &bspd, &addr)); | |||
| *offset = static_cast<off64_t>(addr); | |||
| // We will do piecewise copy of the data to disk. | |||
| for (auto &v : buf) { | |||
| RETURN_IF_NOT_OK(Write(v, addr)); | |||
| addr += v.GetSize(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status StorageContainer::Truncate() const noexcept { | |||
| if (is_open_) { | |||
| RETURN_IF_NOT_OK(cont_.TruncateFile(fd_)); | |||
| MS_LOG(INFO) << "Container " << cont_ << " truncated"; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| StorageContainer::~StorageContainer() noexcept { | |||
| (void)Truncate(); | |||
| (void)Close(); | |||
| } | |||
| std::ostream &operator<<(std::ostream &os, const StorageContainer &s) { | |||
| os << "File path : " << s.cont_ << "\n" << *(s.bs_.get()); | |||
| return os; | |||
| } | |||
| Status StorageContainer::CreateStorageContainer(std::shared_ptr<StorageContainer> *out_sc, const std::string &path) { | |||
| Status rc; | |||
| auto sc = new (std::nothrow) StorageContainer(path); | |||
| if (sc == nullptr) { | |||
| return Status(StatusCode::kOutOfMemory); | |||
| } | |||
| rc = sc->Create(); | |||
| if (rc.IsOk()) { | |||
| (*out_sc).reset(sc); | |||
| } else { | |||
| delete sc; | |||
| } | |||
| return rc; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_UTIL_STORAGE_CONTAINER_H_ | |||
| #define DATASET_UTIL_STORAGE_CONTAINER_H_ | |||
| #include <limits.h> | |||
| #include <unistd.h> | |||
| #include <memory> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "dataset/util/system_pool.h" | |||
| #include "dataset/util/buddy.h" | |||
| #include "dataset/util/path.h" | |||
| #include "dataset/util/slice.h" | |||
| #include "dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class StorageManager; | |||
| class StorageContainer { | |||
| public: | |||
| friend class StorageManager; | |||
| ~StorageContainer() noexcept; | |||
| StorageContainer(const StorageContainer &) = delete; | |||
| StorageContainer &operator=(const StorageContainer &) = delete; | |||
| friend std::ostream &operator<<(std::ostream &os, const StorageContainer &s); | |||
| Status Open() noexcept; | |||
| Status Close() noexcept; | |||
| Status Insert(const std::vector<ReadableSlice> &buf, off64_t *offset) noexcept; | |||
| Status Write(const ReadableSlice &dest, off64_t offset) const noexcept; | |||
| Status Read(WritableSlice *dest, off64_t offset) const noexcept; | |||
| Status Truncate() const noexcept; | |||
| bool IsOpen() const { return is_open_; } | |||
| static Status CreateStorageContainer(std::shared_ptr<StorageContainer> *out_sc, const std::string &path); | |||
| private: | |||
| mutable std::mutex mutex_; | |||
| Path cont_; | |||
| int fd_; | |||
| bool is_open_; | |||
| std::unique_ptr<BuddySpace> bs_; | |||
| // Use the default value of BuddySpace | |||
| // which can map upto 4G of space. | |||
| explicit StorageContainer(const std::string &path) : cont_(path), fd_(-1), is_open_(false), bs_(nullptr) {} | |||
| Status Create(); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_UTIL_STORAGE_CONTAINER_H_ | |||
| @@ -0,0 +1,167 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/util/storage_manager.h" | |||
| #include <iomanip> | |||
| #include <sstream> | |||
| #include <stdexcept> | |||
| #include <utility> | |||
| #include "common/utils.h" | |||
| #include "dataset/util/path.h" | |||
| #include "dataset/util/services.h" | |||
| #include "dataset/util//de_error.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| std::string StorageManager::GetBaseName(const std::string &prefix, int32_t file_id) { | |||
| std::ostringstream oss; | |||
| oss << prefix << std::setfill('0') << std::setw(5) << file_id; | |||
| return oss.str(); | |||
| } | |||
| std::string StorageManager::ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix) { | |||
| std::string base_name = GetBaseName(prefix, file_id); | |||
| return (base_name + "." + suffix); | |||
| } | |||
| Status StorageManager::AddOneContainer() { | |||
| const std::string kPrefix = "IMG"; | |||
| const std::string kSuffix = "LB"; | |||
| Path container_name = root_ / ConstructFileName(kPrefix, file_id_, kSuffix); | |||
| std::shared_ptr<StorageContainer> sc; | |||
| RETURN_IF_NOT_OK(StorageContainer::CreateStorageContainer(&sc, container_name.toString())); | |||
| containers_.push_back(sc); | |||
| file_id_++; | |||
| return Status::OK(); | |||
| } | |||
| Status StorageManager::DoServiceStart() { | |||
| containers_.reserve(1000); | |||
| if (root_.IsDirectory()) { | |||
| RETURN_IF_NOT_OK(AddOneContainer()); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Not a directory"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &buf) { | |||
| RETURN_UNEXPECTED_IF_NULL(key); | |||
| size_t sz = 0; | |||
| for (auto &v : buf) { | |||
| sz += v.GetSize(); | |||
| } | |||
| if (sz == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); | |||
| } | |||
| std::shared_ptr<StorageContainer> cont; | |||
| key_type out_key; | |||
| value_type out_value; | |||
| bool create_new_container = false; | |||
| do { | |||
| SharedLock lock_s(&rw_lock_); | |||
| size_t num_containers = containers_.size(); | |||
| if (create_new_container) { | |||
| // Upgrade to exclusvie lock. | |||
| lock_s.Upgrade(); | |||
| create_new_container = false; | |||
| // Check again if someone has already added a | |||
| // new container after we got the x lock | |||
| if (containers_.size() == num_containers) { | |||
| RETURN_IF_NOT_OK(AddOneContainer()); | |||
| } | |||
| // Refresh how many containers there are. | |||
| num_containers = containers_.size(); | |||
| // Downgrade back to shared lock | |||
| lock_s.Downgrade(); | |||
| } | |||
| if (num_containers == 0) { | |||
| RETURN_STATUS_UNEXPECTED("num_containers is zero"); | |||
| } | |||
| // Go to the last container to insert. | |||
| cont = containers_.at(num_containers - 1); | |||
| off64_t offset; | |||
| Status rc = cont->Insert(buf, &offset); | |||
| if (rc.IsNoSpace()) { | |||
| create_new_container = true; | |||
| } else if (rc.IsOk()) { | |||
| out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz)); | |||
| RETURN_IF_NOT_OK(index_.insert(out_value, &out_key)); | |||
| *key = out_key; | |||
| break; | |||
| } else { | |||
| return rc; | |||
| } | |||
| } while (true); | |||
| return Status::OK(); | |||
| } | |||
| Status StorageManager::Read(StorageManager::key_type key, WritableSlice *dest, size_t *bytesRead) const { | |||
| RETURN_UNEXPECTED_IF_NULL(dest); | |||
| auto r = index_.Search(key); | |||
| if (r.second) { | |||
| auto &it = r.first; | |||
| value_type v = *it; | |||
| int container_inx = v.first; | |||
| off_t offset = v.second.first; | |||
| size_t sz = v.second.second; | |||
| if (dest->GetSize() < sz) { | |||
| std::string errMsg = "Destination buffer too small. Expect at least " + std::to_string(sz) + | |||
| " but length = " + std::to_string(dest->GetSize()); | |||
| RETURN_STATUS_UNEXPECTED(errMsg); | |||
| } | |||
| if (bytesRead != nullptr) { | |||
| *bytesRead = sz; | |||
| } | |||
| auto cont = containers_.at(container_inx); | |||
| RETURN_IF_NOT_OK(cont->Read(dest, offset)); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Key not found"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status StorageManager::DoServiceStop() noexcept { | |||
| Status rc; | |||
| Status rc1; | |||
| for (auto const &p : containers_) { | |||
| // The destructor of StorageContainer is not called automatically until the use | |||
| // count drops to 0. But it is not always the case. We will do it ourselves. | |||
| rc = p.get()->Truncate(); | |||
| if (rc.IsError()) { | |||
| rc1 = rc; | |||
| } | |||
| } | |||
| containers_.clear(); | |||
| file_id_ = 0; | |||
| return rc1; | |||
| } | |||
| StorageManager::StorageManager(const Path &root) : root_(root), file_id_(0), index_() {} | |||
| StorageManager::~StorageManager() { (void)StorageManager::DoServiceStop(); } | |||
| std::ostream &operator<<(std::ostream &os, const StorageManager &s) { | |||
| os << "Dumping all containers ..." | |||
| << "\n"; | |||
| for (auto const &p : s.containers_) { | |||
| os << *(p.get()); | |||
| } | |||
| return os; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_UTIL_STORAGE_MANAGER_H_ | |||
| #define DATASET_UTIL_STORAGE_MANAGER_H_ | |||
| #include <unistd.h> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "dataset/util/allocator.h" | |||
| #include "dataset/util/auto_index.h" | |||
| #include "dataset/util/lock.h" | |||
| #include "dataset/util/memory_pool.h" | |||
| #include "dataset/util/path.h" | |||
| #include "dataset/util/service.h" | |||
| #include "dataset/util/slice.h" | |||
| #include "dataset/util/storage_container.h" | |||
| using ListOfContainers = std::vector<std::shared_ptr<mindspore::dataset::StorageContainer>>; | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class StorageManager : public Service { | |||
| public: | |||
| using storage_index = AutoIndexObj<std::pair<int, std::pair<off_t, size_t>>>; | |||
| using key_type = storage_index::key_type; | |||
| using value_type = storage_index::value_type; | |||
| explicit StorageManager(const Path &); | |||
| ~StorageManager() override; | |||
| StorageManager(const StorageManager &) = delete; | |||
| StorageManager &operator=(const StorageManager &) = delete; | |||
| Status Write(key_type *out_key, const std::vector<ReadableSlice> &buf); | |||
| Status Read(key_type key, WritableSlice *dest, size_t *bytesRead) const; | |||
| Status DoServiceStart() override; | |||
| Status DoServiceStop() noexcept override; | |||
| friend std::ostream &operator<<(std::ostream &os, const StorageManager &s); | |||
| private: | |||
| Path root_; | |||
| ListOfContainers containers_; | |||
| int file_id_; | |||
| RWLock rw_lock_; | |||
| storage_index index_; | |||
| std::string GetBaseName(const std::string &prefix, int32_t file_id); | |||
| std::string ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix); | |||
| Status AddOneContainer(); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_UTIL_STORAGE_MANAGER_H_ | |||
| @@ -19,8 +19,10 @@ | |||
| #include <cstddef> | |||
| #include <cstdlib> | |||
| #include <limits> | |||
| #include <memory> | |||
| #include <new> | |||
| #include "./securec.h" | |||
| #include "dataset/util/allocator.h" | |||
| #include "dataset/util/memory_pool.h" | |||
| namespace mindspore { | |||
| @@ -61,6 +63,11 @@ class SystemPool : public MemoryPool { | |||
| uint64_t get_max_size() const override { return std::numeric_limits<uint64_t>::max(); } | |||
| int PercentFree() const override { return 100; } | |||
| template <typename T> | |||
| static Allocator<T> GetAllocator() { | |||
| return Allocator<T>(std::make_shared<SystemPool>()); | |||
| } | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -30,6 +30,7 @@ | |||
| #include "kernel/common_utils.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "ir/value.h" | |||
| #include "pre_activate/common/helper.h" | |||
| using mindspore::kernel::Address; | |||
| using mindspore::kernel::AddressPtr; | |||
| @@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { | |||
| } | |||
| } | |||
| void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, | |||
| void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel, | |||
| AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, | |||
| AddressPtrList *kernel_outputs) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| @@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod | |||
| if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { | |||
| return GenAddrCleanLaunchArgs(cnode, kernel_inputs); | |||
| } | |||
| auto is_all_nop_node = opt::IsAllNopNode(&graph); | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | |||
| auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); | |||
| auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input); | |||
| DeviceAddressPtr device_address; | |||
| if (is_all_nop_node) { | |||
| device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false); | |||
| } else { | |||
| device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| kernel::AddressPtr input = std::make_shared<kernel::Address>(); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| @@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod | |||
| kernel_inputs->emplace_back(input); | |||
| } | |||
| for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) { | |||
| auto device_address = AnfAlgo::GetOutputAddr(kernel, i); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) { | |||
| DeviceAddressPtr device_address; | |||
| if (is_all_nop_node) { | |||
| device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | |||
| } else { | |||
| device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| kernel::AddressPtr output = std::make_shared<kernel::Address>(); | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| output->addr = device_address->ptr_; | |||
| @@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod | |||
| kernel_outputs->emplace_back(output); | |||
| } | |||
| for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) { | |||
| for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { | |||
| auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); | |||
| kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); | |||
| MS_EXCEPTION_IF_NULL(workspace); | |||
| @@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { | |||
| AddressPtrList kernel_inputs; | |||
| AddressPtrList kernel_workspaces; | |||
| AddressPtrList kernel_outputs; | |||
| GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Launch kernel failed."; | |||
| @@ -96,8 +96,8 @@ class KernelRuntime { | |||
| private: | |||
| void AssignStaticMemoryOutput(const session::KernelGraph *graph); | |||
| void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, | |||
| AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); | |||
| void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs, | |||
| AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); | |||
| bool LaunchKernelMod(const session::KernelGraph &graph); | |||
| void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); | |||
| size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); | |||
| @@ -17,13 +17,23 @@ | |||
| #ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | |||
| #define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "optimizer/opt.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class Optimizer; | |||
| using OptimizerPtr = std::shared_ptr<Optimizer>; | |||
| using OptimizerWeakPtr = std::weak_ptr<Optimizer>; | |||
| using PredicateFuncType = std::function<bool(const AnfNodePtr &)>; | |||
| } // namespace opt | |||
| class OptimizerCaller { | |||
| public: | |||
| virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } | |||
| }; | |||
| using OptimizerCallerPtr = std::shared_ptr<OptimizerCaller>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ | |||
| @@ -23,6 +23,7 @@ | |||
| #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" | |||
| #include "kernel/akg/akg_kernel_metadata.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "utils/context/ms_context.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -97,6 +98,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { | |||
| kernel_type = KernelType::AKG_KERNEL; | |||
| } | |||
| switch (kernel_type) { | |||
| case KernelType::AKG_KERNEL: | |||
| AkgMetadataInfo(kernel_node, kernel_info_list); | |||
| @@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { | |||
| return changed; | |||
| } | |||
| // The op like print, summary, or the op do not has true output, and always as a depend node input. | |||
| static bool HasSideEffect(const AnfNodePtr &node) { | |||
| auto prim = GetCNodePrimitive(node); | |||
| if (prim == nullptr) { | |||
| return false; | |||
| } | |||
| auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT); | |||
| if (side_effect_v != nullptr && side_effect_v->isa<BoolImm>()) { | |||
| return GetValue<bool>(side_effect_v); | |||
| } | |||
| return false; | |||
| } | |||
| // If true do not merge the node. | |||
| bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { | |||
| bool has_random_effect = false; | |||
| auto prim_main = GetCNodePrimitive(main); | |||
| auto prim_node = GetCNodePrimitive(node); | |||
| if (prim_main == prim_node) { | |||
| return false; | |||
| } | |||
| // if has random effect, when generate by different op (not same object), do not merge. | |||
| if (prim_main != nullptr) { | |||
| if (prim_main == prim_node) { | |||
| return false; | |||
| } | |||
| auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); | |||
| if (effect_val != nullptr && effect_val->isa<BoolImm>()) { | |||
| has_random_effect = GetValue<bool>(effect_val); | |||
| @@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons | |||
| return has_random_effect; | |||
| } | |||
| bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { | |||
| bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { | |||
| MS_EXCEPTION_IF_NULL(main); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| bool replace = false; | |||
| if (main->isa<ValueNode>() && node->isa<ValueNode>()) { | |||
| auto main_value = GetValueNode(main); | |||
| auto node_value = GetValueNode(node); | |||
| replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); | |||
| return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); | |||
| } else if (main->isa<CNode>() && node->isa<CNode>()) { | |||
| auto c_main = main->cast<CNodePtr>(); | |||
| auto c_node = node->cast<CNodePtr>(); | |||
| // When appsame is true, check if has side effect, do not merge. | |||
| if (check_side_effect && HasSideEffect(main)) { | |||
| return false; | |||
| } | |||
| const auto &inp1 = c_main->inputs(); | |||
| const auto &inp2 = c_node->inputs(); | |||
| if (inp1.size() == inp2.size()) { | |||
| bool appsame = true; | |||
| for (size_t j = 0; j < inp1.size(); j++) { | |||
| MS_EXCEPTION_IF_NULL(inp1[j]); | |||
| MS_EXCEPTION_IF_NULL(inp2[j]); | |||
| if (!(*inp1[j] == *inp2[j])) { | |||
| // Handle the case of two different Tensor, but with the same value | |||
| if (IsValueNode<tensor::Tensor>(inp1[j]) && IsValueNode<tensor::Tensor>(inp2[j])) { | |||
| auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1[j]); | |||
| auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2[j]); | |||
| if (tensor1->ValueEqual(*tensor2)) { | |||
| continue; | |||
| } | |||
| if (inp1.size() != inp2.size()) { | |||
| return false; | |||
| } | |||
| for (size_t j = 0; j < inp1.size(); j++) { | |||
| auto inp1_j = inp1[j]; | |||
| auto inp2_j = inp2[j]; | |||
| MS_EXCEPTION_IF_NULL(inp1_j); | |||
| MS_EXCEPTION_IF_NULL(inp2_j); | |||
| if (!(*inp1_j == *inp2_j)) { | |||
| // Handle the case of two different Tensor, but with the same value | |||
| if (IsValueNode<tensor::Tensor>(inp1_j) && IsValueNode<tensor::Tensor>(inp2_j)) { | |||
| auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1_j); | |||
| auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2_j); | |||
| if (tensor1->ValueEqual(*tensor2)) { | |||
| continue; | |||
| } | |||
| } else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) { | |||
| // When the same side effect node as another two nodes' inputs, we still merge the node. | |||
| // Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the | |||
| // node. | |||
| if (CheckReplace(inp1_j, inp2_j, false)) { | |||
| continue; | |||
| } | |||
| appsame = false; | |||
| break; | |||
| } | |||
| return false; | |||
| } | |||
| if (CheckRandomEffect(c_main, c_node)) { | |||
| appsame = false; | |||
| } | |||
| replace = appsame; | |||
| } | |||
| // When appsame is true, check if has random effect do not merge | |||
| if (CheckRandomEffect(c_main, c_node)) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| return replace; | |||
| // a parameter node. | |||
| return false; | |||
| } | |||
| bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group, | |||
| @@ -41,7 +41,7 @@ class CSE { | |||
| return chg && report_changes_; | |||
| } | |||
| virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const; | |||
| virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const; | |||
| virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; | |||
| @@ -14,140 +14,154 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "optimizer/irpass.h" | |||
| #include <string> | |||
| #include "optimizer/irpass/symbol_resolver.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/irpass/arithmetic_simplify.h" | |||
| #include "optimizer/irpass/special_op_eliminate.h" | |||
| #include "optimizer/irpass/item_tuple_eliminate.h" | |||
| #include "optimizer/irpass/env_item_eliminate.h" | |||
| #include "optimizer/irpass/tile_eliminate.h" | |||
| #include "optimizer/irpass/cast_eliminate.h" | |||
| #include "optimizer/irpass/reshape_eliminate.h" | |||
| #include "optimizer/irpass/transpose_eliminate.h" | |||
| #include "optimizer/irpass/reduce_eliminate.h" | |||
| #include "optimizer/irpass/partial_eliminate.h" | |||
| #include "optimizer/irpass/ref_eliminate.h" | |||
| #include "optimizer/irpass/merge_addn.h" | |||
| #include "optimizer/irpass/branch_culling.h" | |||
| #include "optimizer/irpass/cast_eliminate.h" | |||
| #include "optimizer/irpass/convert.h" | |||
| #include "optimizer/irpass/env_item_eliminate.h" | |||
| #include "optimizer/irpass/grad_var_prepare.h" | |||
| #include "optimizer/irpass/gradient_eliminate.h" | |||
| #include "optimizer/irpass/minmax_grad.h" | |||
| #include "optimizer/irpass/inline.h" | |||
| #include "optimizer/irpass/convert.h" | |||
| #include "optimizer/irpass/specialize_transform.h" | |||
| #include "optimizer/irpass/incorporate_getitem.h" | |||
| #include "optimizer/irpass/incorporate_call.h" | |||
| #include "optimizer/irpass/grad_var_prepare.h" | |||
| #include "optimizer/irpass/param_replace.h" | |||
| #include "optimizer/irpass/incorporate_getitem.h" | |||
| #include "optimizer/irpass/item_tuple_eliminate.h" | |||
| #include "optimizer/irpass/mark_interface_fusion.h" | |||
| #include "optimizer/irpass/merge_addn.h" | |||
| #include "optimizer/irpass/minmax_grad.h" | |||
| #include "optimizer/irpass/param_replace.h" | |||
| #include "optimizer/irpass/partial_eliminate.h" | |||
| #include "optimizer/irpass/reduce_eliminate.h" | |||
| #include "optimizer/irpass/ref_eliminate.h" | |||
| #include "optimizer/irpass/reshape_eliminate.h" | |||
| #include "optimizer/irpass/special_op_eliminate.h" | |||
| #include "optimizer/irpass/specialize_transform.h" | |||
| #include "optimizer/irpass/symbol_resolver.h" | |||
| #include "optimizer/irpass/tile_eliminate.h" | |||
| #include "optimizer/irpass/transpose_eliminate.h" | |||
| #include "optimizer/opt.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | |||
| arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify", | |||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | |||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); | |||
| arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul}); | |||
| arithmetic_simplify2_ = | |||
| MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul}); | |||
| special_op_eliminate_ = | |||
| MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||
| MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate", | |||
| {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, | |||
| prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||
| zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike); | |||
| adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | |||
| zero_like_fill_zero_ = | |||
| MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike); | |||
| adjust_all_reduce_mul_add_ = | |||
| MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | |||
| // ops eliminate | |||
| item_tuple_eliminate_ = | |||
| MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); | |||
| tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile); | |||
| cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast); | |||
| reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape); | |||
| transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose); | |||
| item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate", | |||
| {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); | |||
| tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile); | |||
| cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast); | |||
| reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape); | |||
| transpose_eliminate_ = | |||
| MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose); | |||
| reduce_eliminate_ = MakeSubstitution( | |||
| ReduceOneEliminater(), "reduce_eliminate", | |||
| std::make_shared<ReduceOneEliminater>(), "reduce_eliminate", | |||
| {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); | |||
| partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); | |||
| same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); | |||
| check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); | |||
| reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); | |||
| depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend); | |||
| partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup); | |||
| same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape); | |||
| check_bprop_eliminate_ = | |||
| MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop); | |||
| reset_defer_inline_ = | |||
| MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>); | |||
| depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend); | |||
| // Env Item Eliminate | |||
| env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); | |||
| new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem); | |||
| env_get_item_eliminate_ = | |||
| MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem); | |||
| new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_ = | |||
| MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_switch_ = | |||
| MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); | |||
| MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem); | |||
| incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(), | |||
| "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); | |||
| // Ref eliminate | |||
| make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); | |||
| get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate", | |||
| make_ref_eliminate_ = | |||
| MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef); | |||
| get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", | |||
| {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | |||
| get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", | |||
| get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate", | |||
| {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); | |||
| replace_refkey_by_param_ = | |||
| MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); | |||
| replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); | |||
| replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param", | |||
| IsValueNode<RefKey>, opt::FORCE_RENORM); | |||
| replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam); | |||
| // Gradient transforms | |||
| expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); | |||
| minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); | |||
| expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ); | |||
| minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem); | |||
| // branch culling | |||
| switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch); | |||
| float_tuple_getitem_switch_ = | |||
| MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem); | |||
| switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch); | |||
| float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(), | |||
| "float_tuple_getitem_switch", prim::kPrimTupleGetItem); | |||
| float_env_getitem_switch_ = | |||
| MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem); | |||
| convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup); | |||
| MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem); | |||
| convert_switch_replacement_ = | |||
| MakeSubstitution(std::make_shared<ConvertSwitchReplacement>(), "convert_switch_replacement", IsCNodeDup); | |||
| // Addn | |||
| merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN); | |||
| addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN); | |||
| merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN); | |||
| addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN); | |||
| // inline | |||
| inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph); | |||
| replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode<FuncGraph>); | |||
| specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph); | |||
| inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph); | |||
| replace_applicator_ = | |||
| MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>); | |||
| specialize_transform_ = | |||
| MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph); | |||
| // Incorporation | |||
| incorporate_getitem_set_ = | |||
| MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); | |||
| incorporate_getitem_from_param_ = | |||
| MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel); | |||
| incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); | |||
| incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); | |||
| MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem); | |||
| incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(), | |||
| "incorporate_getitem_from_param", IsCNodeGraphKernel); | |||
| incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup); | |||
| incorporate_call_switch_ = | |||
| MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup); | |||
| // Virtual Dataset | |||
| virtual_dataset_eliminate_ = | |||
| MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset); | |||
| virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(), | |||
| "virtual_dataset_eliminate", prim::kPrimVirtualDataset); | |||
| // Convert | |||
| print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); | |||
| print_tuple_wrapper_ = | |||
| MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint); | |||
| // Unused parameter eliminate | |||
| unused_parameter_eliminate_ = | |||
| MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel); | |||
| unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel); | |||
| MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel); | |||
| unused_output_eliminate_ = | |||
| MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel); | |||
| // AddN eliminate | |||
| addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel); | |||
| addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel); | |||
| // Mark interface fusion | |||
| mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect); | |||
| mark_interface_fusion_ = | |||
| MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect); | |||
| } | |||
| ResolveIRPassLib::ResolveIRPassLib() { | |||
| resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); | |||
| resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); | |||
| resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve); | |||
| resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr); | |||
| } | |||
| InferenceOptPrepareLib::InferenceOptPrepareLib() { | |||
| grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode); | |||
| grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode); | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -17,15 +17,16 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "optimizer/optimizer.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/irpass/prim_eliminate.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/irpass/prim_eliminate.h" | |||
| #include "optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor { | |||
| FuncGraphPtr all_reduce_fg_{nullptr}; | |||
| }; | |||
| class ArithmeticSimplify { | |||
| class ArithmeticSimplify : public OptimizerCaller { | |||
| public: | |||
| ArithmeticSimplify() | |||
| : multiply_by_zero_or_one_(), | |||
| tensor_multiply_by_one_(), | |||
| add_by_zero_(), | |||
| tensor_add_by_zero_(), | |||
| identity_(prim::kPrimIdentity), | |||
| opt_update_zero_tensor_(), | |||
| constant_duplicate_mul_(), | |||
| power_one_() { | |||
| : multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()), | |||
| tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()), | |||
| add_by_zero_(std::make_shared<AddByZero>()), | |||
| tensor_add_by_zero_(std::make_shared<TensorAddByZero>()), | |||
| identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)), | |||
| opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()), | |||
| constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()), | |||
| power_one_(std::make_shared<PowerOneEliminate>()) { | |||
| eliminaters_.emplace_back(multiply_by_zero_or_one_); | |||
| eliminaters_.emplace_back(tensor_multiply_by_one_); | |||
| eliminaters_.emplace_back(add_by_zero_); | |||
| @@ -761,10 +762,10 @@ class ArithmeticSimplify { | |||
| } | |||
| ~ArithmeticSimplify() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -773,15 +774,9 @@ class ArithmeticSimplify { | |||
| } | |||
| private: | |||
| MultiplyByZeroOrOne multiply_by_zero_or_one_; | |||
| TensorMultiplyByOne tensor_multiply_by_one_; | |||
| AddByZero add_by_zero_; | |||
| TensorAddByZero tensor_add_by_zero_; | |||
| PrimEliminater identity_; | |||
| OptUpdateZeroTensor opt_update_zero_tensor_; | |||
| ConstantDuplicateMul constant_duplicate_mul_; | |||
| PowerOneEliminate power_one_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_, | |||
| opt_update_zero_tensor_, constant_duplicate_mul_, power_one_; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| // Arithmetic Simplifications should be done after step_parallel. | |||
| @@ -789,15 +784,17 @@ class ArithmeticSimplify { | |||
| // with shape(weight), but after step_parallel, shape of weight may be changed, so the | |||
| // shape of the constant tensor should also be changed. So this pass is seperated from | |||
| // ArithmeticSimplify and deferred until step_parallel. | |||
| class ArithmeticSimplify2 { | |||
| class ArithmeticSimplify2 : public OptimizerCaller { | |||
| public: | |||
| ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); } | |||
| ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) { | |||
| eliminaters_.emplace_back(tensor_multiply_by_zero_); | |||
| } | |||
| ~ArithmeticSimplify2() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -806,8 +803,8 @@ class ArithmeticSimplify2 { | |||
| } | |||
| private: | |||
| TensorMultiplyByZero tensor_multiply_by_zero_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| OptimizerCallerPtr tensor_multiply_by_zero_; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -17,9 +17,9 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ | |||
| #include "ir/visitor.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/visitor.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor { | |||
| AnfNodePtr x_{nullptr}, t_{nullptr}; | |||
| }; | |||
| class CastEliminater { | |||
| class CastEliminater : public OptimizerCaller { | |||
| public: | |||
| CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} | |||
| ~CastEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| auto new_node = cast_same_type_eliminater_(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| @@ -17,18 +17,19 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/visitor.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "utils/symbolic.h" | |||
| namespace mindspore { | |||
| @@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor { | |||
| bool is_match_{false}; | |||
| }; | |||
| class EnvGetItemEliminater { | |||
| class EnvGetItemEliminater : public OptimizerCaller { | |||
| public: | |||
| EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() { | |||
| EnvGetItemEliminater() | |||
| : new_env_get_item_(std::make_shared<NewEnvGetItem>()), | |||
| add_env_get_item_(std::make_shared<AddEnvGetItem>()), | |||
| env_get_set_item_(std::make_shared<EnvGetSetItem>()) { | |||
| eliminaters_.emplace_back(new_env_get_item_); | |||
| eliminaters_.emplace_back(add_env_get_item_); | |||
| eliminaters_.emplace_back(env_get_set_item_); | |||
| } | |||
| ~EnvGetItemEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -246,10 +250,8 @@ class EnvGetItemEliminater { | |||
| } | |||
| private: | |||
| NewEnvGetItem new_env_get_item_; | |||
| AddEnvGetItem add_env_get_item_; | |||
| EnvGetSetItem env_get_set_item_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} | |||
| @@ -17,18 +17,20 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <vector> | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/visitor.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| @@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor { | |||
| internal::GetitemTransform getitem_transform_; | |||
| }; | |||
| class IncorporateGetitemSet { | |||
| class IncorporateGetitemSet : public OptimizerCaller { | |||
| public: | |||
| IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() { | |||
| IncorporateGetitemSet() | |||
| : incorporate_getitem_(std::make_shared<IncorporateGetitem>()), | |||
| incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) { | |||
| eliminaters_.emplace_back(incorporate_getitem_); | |||
| eliminaters_.emplace_back(incorporate_getitem_switch_); | |||
| } | |||
| ~IncorporateGetitemSet() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -403,9 +407,8 @@ class IncorporateGetitemSet { | |||
| } | |||
| private: | |||
| IncorporateGetitem incorporate_getitem_; | |||
| IncorporateGetitemSwitch incorporate_getitem_switch_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -17,13 +17,15 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor { | |||
| AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; | |||
| }; | |||
| class ItemTupleEliminater { | |||
| class ItemTupleEliminater : public OptimizerCaller { | |||
| public: | |||
| ItemTupleEliminater() | |||
| : get_item_eliminater_(), | |||
| get_item_const_eliminater_(), | |||
| set_item_eliminater_(), | |||
| get_set_item_eliminater_(), | |||
| get_item_depend_reorder_() { | |||
| : get_item_eliminater_(std::make_shared<GetitemEliminater>()), | |||
| get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()), | |||
| set_item_eliminater_(std::make_shared<SetitemEliminater>()), | |||
| get_set_item_eliminater_(std::make_shared<GetSetitemEliminater>()), | |||
| get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()) { | |||
| eliminaters_.emplace_back(get_item_eliminater_); | |||
| eliminaters_.emplace_back(get_item_const_eliminater_); | |||
| eliminaters_.emplace_back(set_item_eliminater_); | |||
| @@ -277,10 +279,10 @@ class ItemTupleEliminater { | |||
| } | |||
| ~ItemTupleEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -289,12 +291,9 @@ class ItemTupleEliminater { | |||
| } | |||
| private: | |||
| GetitemEliminater get_item_eliminater_; | |||
| GetitemConstEliminater get_item_const_eliminater_; | |||
| SetitemEliminater set_item_eliminater_; | |||
| GetSetitemEliminater get_set_item_eliminater_; | |||
| GetitemDependReorder get_item_depend_reorder_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, | |||
| get_item_depend_reorder_; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| @@ -19,9 +19,9 @@ | |||
| #include <memory> | |||
| #include "optimizer/optimizer.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "ir/pattern_matcher.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -19,11 +19,12 @@ | |||
| #include <vector> | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "ir/visitor.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "pipeline/static_analysis/dshape.h" | |||
| namespace mindspore { | |||
| @@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor { | |||
| AnfNodePtr x_{nullptr}, shape_{nullptr}; | |||
| }; | |||
| class ReshapeEliminater { | |||
| class ReshapeEliminater : public OptimizerCaller { | |||
| public: | |||
| ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} | |||
| ~ReshapeEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| auto new_node = reshape_same_shape_eliminater_(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| @@ -18,31 +18,31 @@ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ | |||
| #include <securec.h> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "optimizer/optimizer.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "optimizer/irpass/prim_eliminate.h" | |||
| #include "ir/pattern_matcher.h" | |||
| #include "ir/visitor.h" | |||
| #include "operator/ops.h" | |||
| #include "ir/pattern_matcher.h" | |||
| #include "optimizer/irpass.h" | |||
| #include "optimizer/irpass/prim_eliminate.h" | |||
| #include "optimizer/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| class SpecialOpEliminater { | |||
| class SpecialOpEliminater : public OptimizerCaller { | |||
| public: | |||
| SpecialOpEliminater() | |||
| : insert_gradient_of_(prim::kPrimInsertGradientOf), | |||
| stop_gradient_(prim::kPrimStopGradient), | |||
| hook_backward_(prim::kPrimHookBackward), | |||
| print_shape_type_(prim::kPrimPrintShapeType), | |||
| get_ref_value_(prim::kPrimGetRefValue), | |||
| mirror_(prim::kPrimMirror), | |||
| virtual_div_(prim::kPrimVirtualDiv) { | |||
| : insert_gradient_of_(std::make_shared<PrimEliminater>(prim::kPrimInsertGradientOf)), | |||
| stop_gradient_(std::make_shared<PrimEliminater>(prim::kPrimStopGradient)), | |||
| hook_backward_(std::make_shared<PrimEliminater>(prim::kPrimHookBackward)), | |||
| print_shape_type_(std::make_shared<PrimEliminater>(prim::kPrimPrintShapeType)), | |||
| get_ref_value_(std::make_shared<PrimEliminater>(prim::kPrimGetRefValue)), | |||
| mirror_(std::make_shared<PrimEliminater>(prim::kPrimMirror)), | |||
| virtual_div_(std::make_shared<PrimEliminater>(prim::kPrimVirtualDiv)) { | |||
| eliminaters_.emplace_back(insert_gradient_of_); | |||
| eliminaters_.emplace_back(stop_gradient_); | |||
| eliminaters_.emplace_back(hook_backward_); | |||
| @@ -53,10 +53,10 @@ class SpecialOpEliminater { | |||
| } | |||
| ~SpecialOpEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = eliminater(optimizer, node); | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| return new_node; | |||
| } | |||
| @@ -65,9 +65,9 @@ class SpecialOpEliminater { | |||
| } | |||
| private: | |||
| PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, | |||
| OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, | |||
| virtual_div_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||
| }; | |||
| // {PrimVirtualDataset, X} -> X | |||
| @@ -16,28 +16,27 @@ | |||
| #include "optimizer/opt.h" | |||
| #include <algorithm> | |||
| #include <deque> | |||
| #include <memory> | |||
| #include <unordered_set> | |||
| #include <deque> | |||
| #include <algorithm> | |||
| #include "ir/anf.h" | |||
| #include "ir/manager.h" | |||
| #include "utils/ordered_set.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ordered_set.h" | |||
| namespace mindspore { | |||
| /* namespace to support opt */ | |||
| namespace opt { | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, | |||
| const RenormAction &renorm_action) { | |||
| auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; | |||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | |||
| } | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||
| const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) { | |||
| auto fn = [prims](const AnfNodePtr &node) -> bool { | |||
| if (!node->isa<CNode>()) { | |||
| @@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std:: | |||
| return std::make_shared<Substitution>(transform, name, fn, renorm_action); | |||
| } | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||
| const PredicateFuncType &predicate, const RenormAction &renorm_action) { | |||
| return std::make_shared<Substitution>(transform, name, predicate, renorm_action); | |||
| } | |||
| AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { | |||
| AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||
| #ifdef ENABLE_PROFILE | |||
| double t = GetTime(); | |||
| #endif | |||
| AnfNodePtr result = transform_(optimizer, node); | |||
| AnfNodePtr result = (*transform_)(optimizer, node); | |||
| #ifdef ENABLE_PROFILE | |||
| if (optimizer != nullptr) { | |||
| auto time = GetTime(); | |||
| @@ -17,24 +17,18 @@ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/optimizer_caller.h" | |||
| #include "operator/ops.h" | |||
| namespace mindspore { | |||
| /* namespace to support opt */ | |||
| namespace opt { | |||
| class Optimizer; | |||
| using OptimizerPtr = std::shared_ptr<Optimizer>; | |||
| using OptimizerWeakPtr = std::weak_ptr<Optimizer>; | |||
| using PredicateFuncType = std::function<bool(const AnfNodePtr &)>; | |||
| using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>; | |||
| // Define the interaction mode between an Optimize pass and Renormalize pass | |||
| // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed | |||
| @@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; | |||
| class Substitution { | |||
| public: | |||
| TransformFuncType transform_{nullptr}; | |||
| OptimizerCallerPtr transform_; | |||
| std::string name_; | |||
| PredicateFuncType predicate_{nullptr}; | |||
| // an enum to mark this Substitution relation to renormalize pass | |||
| RenormAction renorm_action_; | |||
| Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, | |||
| Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, | |||
| const RenormAction &renorm_action) | |||
| : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} | |||
| ~Substitution() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); | |||
| }; | |||
| using SubstitutionPtr = std::shared_ptr<Substitution>; | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, | |||
| const RenormAction &action_renorm = CHECK_RENORM); | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||
| const std::vector<PrimitivePtr> &prims, | |||
| const RenormAction &action_renorm = CHECK_RENORM); | |||
| SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||
| const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); | |||
| class SubstitutionList { | |||
| @@ -465,7 +465,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, co | |||
| CheckGlobalDeviceManager(); | |||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); | |||
| TensorRedistribution tensor_redistribution; | |||
| TensorRedistribution tensor_redistribution(false, true); | |||
| if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; | |||
| } | |||
| @@ -503,7 +503,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inp | |||
| CheckGlobalDeviceManager(); | |||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); | |||
| TensorRedistribution tensor_redistribution; | |||
| TensorRedistribution tensor_redistribution(false, true); | |||
| if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; | |||
| } | |||
| @@ -62,6 +62,7 @@ void ParallelContext::Reset() { | |||
| enable_all_reduce_fusion_ = false; | |||
| strategy_ckpt_load_file_ = ""; | |||
| strategy_ckpt_save_file_ = ""; | |||
| enable_parallel_optimizer_ = false; | |||
| } | |||
| void ParallelContext::set_device_num(int32_t device_num) { | |||
| @@ -100,6 +100,11 @@ class ParallelContext { | |||
| void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); | |||
| std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } | |||
| void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { | |||
| enable_parallel_optimizer_ = enable_parallel_optimizer; | |||
| } | |||
| bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } | |||
| void Reset(); | |||
| private: | |||
| @@ -123,6 +128,7 @@ class ParallelContext { | |||
| std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_; | |||
| std::string strategy_ckpt_load_file_; | |||
| std::string strategy_ckpt_save_file_; | |||
| bool enable_parallel_optimizer_; | |||
| }; | |||
| void ParallelParameterContextInit(const FuncGraphPtr &func_graph); | |||
| @@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") | |||
| .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") | |||
| .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") | |||
| .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, | |||
| "Set enable/disable parallel optimizer.") | |||
| .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, | |||
| "Get enable/disable parallel optimizer.") | |||
| .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); | |||
| (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") | |||
| @@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { | |||
| } | |||
| } // namespace | |||
| bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const { | |||
| bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const { | |||
| MS_EXCEPTION_IF_NULL(main); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| @@ -31,7 +31,7 @@ class BackendCSE : public CSE { | |||
| public: | |||
| BackendCSE() = default; | |||
| ~BackendCSE() override = default; | |||
| bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const override; | |||
| bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll"; | |||
| const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; | |||
| const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; | |||
| const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; | |||
| const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; | |||
| } // namespace mindspore | |||
| @@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[]; | |||
| extern const char GRAPH_FLAG_HAS_EFFECT[]; | |||
| extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; | |||
| extern const char GRAPH_FLAG_RANDOM_EFFECT[]; | |||
| extern const char GRAPH_FLAG_SIDE_EFFECT[]; | |||
| } // namespace mindspore | |||
| #endif // PYBIND_API_EXPORT_FLAGS_H_ | |||
| @@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3; | |||
| namespace mindspore { | |||
| namespace session { | |||
| static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) { | |||
| auto &nodes = parent_graph->execution_order(); | |||
| for (auto &node : nodes) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) { | |||
| return node; | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) && | |||
| (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || | |||
| child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) { | |||
| return node; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); | |||
| return nullptr; | |||
| } | |||
| static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| if (memo->find(kg.get()) != memo->end()) { | |||
| @@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr | |||
| if (target_graph_iter == graph_id_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; | |||
| } | |||
| InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter)); | |||
| InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg), | |||
| NOT_NULL(parameter)); | |||
| } | |||
| } | |||
| } | |||
| @@ -263,7 +279,7 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr | |||
| RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | |||
| } | |||
| } | |||
| kg->SetExecOrderByDefault(); | |||
| MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); | |||
| return NOT_NULL(start_label); | |||
| } | |||
| @@ -433,7 +449,8 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A | |||
| return {partial_cnode, branch_kg}; | |||
| } | |||
| void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, | |||
| void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, | |||
| NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from, | |||
| NotNull<AnfNodePtr> to) { | |||
| std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); | |||
| std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); | |||
| @@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg | |||
| << to_outputs.size() << "]"; | |||
| } | |||
| for (size_t i = 0; i < from_outputs.size(); i++) { | |||
| InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); | |||
| auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); | |||
| if (assign_node != nullptr) { | |||
| auto jump_node = GetJumpNode(from_graph, to_graph); | |||
| if (jump_node != nullptr) { | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, | |||
| NotNull<AnfNodePtr> to) { | |||
| AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, | |||
| NotNull<AnfNodePtr> to) { | |||
| if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && | |||
| AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { | |||
| return; | |||
| return nullptr; | |||
| } | |||
| if (from.get() == to.get()) { | |||
| return; | |||
| return nullptr; | |||
| } | |||
| MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " | |||
| << to->DebugString(); | |||
| @@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul | |||
| assign_node->set_abstract(to->abstract()); | |||
| // append the assign at the end of from graph | |||
| InsertDependToGraph(kg, NOT_NULL(assign_node)); | |||
| return assign_node; | |||
| } | |||
| std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph, | |||
| @@ -52,8 +52,9 @@ class AscendControlParser { | |||
| const CNodePtr &last_label); | |||
| static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node); | |||
| static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||
| static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||
| static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph, | |||
| NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||
| static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||
| // root graph order | |||
| static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, | |||
| @@ -521,6 +521,47 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) { | |||
| return output_nodes; | |||
| } | |||
| // Find control_depend real input nodes. | |||
| void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(result); | |||
| MS_EXCEPTION_IF_NULL(visited); | |||
| if (visited->find(anf_node) != visited->end()) { | |||
| MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; | |||
| return; | |||
| } | |||
| visited->insert(anf_node); | |||
| if (AnfAlgo::IsRealKernel(anf_node)) { | |||
| result->emplace_back(anf_node); | |||
| return; | |||
| } | |||
| if (!anf_node->isa<CNode>()) { | |||
| return; | |||
| } | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->inputs().empty()) { | |||
| MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString(); | |||
| } | |||
| auto input0 = cnode->input(0); | |||
| if (IsPrimitive(input0, prim::kPrimMakeTuple)) { | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| GetAllFatherRealNode(cnode->input(i), result, visited); | |||
| } | |||
| } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { | |||
| if (cnode->inputs().size() != kTupleGetItemInputSize) { | |||
| MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; | |||
| } | |||
| GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited); | |||
| } else if (IsPrimitive(input0, prim::kPrimDepend)) { | |||
| if (cnode->inputs().size() != kDependInputSize) { | |||
| MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!"; | |||
| } | |||
| GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited); | |||
| GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited); | |||
| } | |||
| } | |||
| // update the depend relations of control depend | |||
| void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) { | |||
| for (const auto &node : depends) { | |||
| @@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de | |||
| if (depend_node->isa<Parameter>() && depend_mode == 1) { | |||
| depend_nodes = GetOutputNodes(depend_node); | |||
| } | |||
| for (auto &first_node : prior_nodes) { | |||
| std::vector<AnfNodePtr> real_prior_nodes; | |||
| std::set<AnfNodePtr> prior_visited; | |||
| for (const auto &tmp : prior_nodes) { | |||
| GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); | |||
| } | |||
| std::vector<AnfNodePtr> real_depend_nodes; | |||
| std::set<AnfNodePtr> depend_visited; | |||
| for (const auto &tmp : depend_nodes) { | |||
| GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); | |||
| } | |||
| for (auto &first_node : real_prior_nodes) { | |||
| if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| for (auto &second_node : depend_nodes) { | |||
| for (auto &second_node : real_depend_nodes) { | |||
| if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| @@ -33,9 +33,14 @@ | |||
| namespace py = pybind11; | |||
| namespace mindspore::inference { | |||
| std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device) { | |||
| inference::Session::RegAllOp(); | |||
| auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); | |||
| return anf_graph; | |||
| try { | |||
| inference::Session::RegAllOp(); | |||
| auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); | |||
| return anf_graph; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference LoadModel failed"; | |||
| return nullptr; | |||
| } | |||
| } | |||
| void ExitInference() { | |||
| @@ -51,12 +56,17 @@ void ExitInference() { | |||
| } | |||
| std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) { | |||
| auto session = std::make_shared<inference::Session>(); | |||
| auto ret = session->Init(device, device_id); | |||
| if (ret != 0) { | |||
| try { | |||
| auto session = std::make_shared<inference::Session>(); | |||
| auto ret = session->Init(device, device_id); | |||
| if (ret != 0) { | |||
| return nullptr; | |||
| } | |||
| return session; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference CreatSession failed"; | |||
| return nullptr; | |||
| } | |||
| return session; | |||
| } | |||
| void Session::RegAllOp() { | |||
| @@ -113,47 +123,71 @@ void Session::RegAllOp() { | |||
| uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||
| py::gil_scoped_release gil_release; | |||
| return graph_id; | |||
| try { | |||
| auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||
| py::gil_scoped_release gil_release; | |||
| return graph_id; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference CompileGraph failed"; | |||
| return static_cast<uint32_t>(-1); | |||
| } | |||
| } | |||
| MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) { | |||
| std::vector<tensor::TensorPtr> inTensors; | |||
| inTensors.resize(inputs.size()); | |||
| bool has_error = false; | |||
| std::transform(inputs.begin(), inputs.end(), inTensors.begin(), | |||
| [&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr { | |||
| if (tensor_ptr == nullptr) { | |||
| MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; | |||
| has_error = true; | |||
| return nullptr; | |||
| } | |||
| auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get()); | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; | |||
| has_error = true; | |||
| return nullptr; | |||
| } | |||
| return tensor->tensor(); | |||
| }); | |||
| if (has_error) { | |||
| MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; | |||
| std::vector<std::shared_ptr<inference::MSTensor>> multiTensor; | |||
| return multiTensor; | |||
| } | |||
| VectorRef outputs; | |||
| session_impl_->RunGraph(graph_id, inTensors, &outputs); | |||
| try { | |||
| std::vector<tensor::TensorPtr> inTensors; | |||
| inTensors.resize(inputs.size()); | |||
| bool has_error = false; | |||
| std::transform(inputs.begin(), inputs.end(), inTensors.begin(), | |||
| [&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr { | |||
| if (tensor_ptr == nullptr) { | |||
| MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; | |||
| has_error = true; | |||
| return nullptr; | |||
| } | |||
| auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get()); | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; | |||
| has_error = true; | |||
| return nullptr; | |||
| } | |||
| return tensor->tensor(); | |||
| }); | |||
| if (has_error) { | |||
| MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; | |||
| std::vector<std::shared_ptr<inference::MSTensor>> multiTensor; | |||
| return multiTensor; | |||
| } | |||
| VectorRef outputs; | |||
| session_impl_->RunGraph(graph_id, inTensors, &outputs); | |||
| return TransformVectorRefToMultiTensor(outputs); | |||
| return TransformVectorRefToMultiTensor(outputs); | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "Inference Rungraph failed"; | |||
| return MultiTensor(); | |||
| } | |||
| } | |||
| namespace { | |||
| string AjustTargetName(const std::string &device) { | |||
| if (device == kAscendDevice) { | |||
| return std::string(kAscendDevice) + "Inference"; | |||
| } else { | |||
| MS_LOG(ERROR) << "Only support device Ascend right now"; | |||
| return ""; | |||
| } | |||
| } | |||
| } // namespace | |||
| int Session::Init(const std::string &device, uint32_t device_id) { | |||
| RegAllOp(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| ms_context->set_execution_mode(kGraphMode); | |||
| ms_context->set_device_target(kAscendDevice); | |||
| session_impl_ = session::SessionFactory::Get().Create(device); | |||
| ms_context->set_device_id(device_id); | |||
| auto ajust_device = AjustTargetName(device); | |||
| if (ajust_device == "") { | |||
| return -1; | |||
| } | |||
| ms_context->set_device_target(device); | |||
| session_impl_ = session::SessionFactory::Get().Create(ajust_device); | |||
| if (session_impl_ == nullptr) { | |||
| MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; | |||
| return -1; | |||
| @@ -81,7 +81,15 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||
| } | |||
| } | |||
| // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) | |||
| auto address = AnfAlgo::GetOutputAddr(node, output_index); | |||
| DeviceAddressPtr address; | |||
| auto is_all_nop_node = opt::IsAllNopNode(&graph); | |||
| if (is_all_nop_node) { | |||
| // The graph does not remove the nop node. | |||
| address = AnfAlgo::GetMutableOutputAddr(node, output_index, false); | |||
| } else { | |||
| // The graph removes the nop node. | |||
| address = AnfAlgo::GetMutableOutputAddr(node, output_index, true); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(address); | |||
| auto shape = AnfAlgo::GetOutputInferShape(node, output_index); | |||
| TypeId type_id = kNumberTypeFloat32; | |||
| @@ -93,7 +101,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { | |||
| tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); | |||
| tensor->set_device_address(address); | |||
| tensor->set_dirty(false); | |||
| } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { | |||
| @@ -1646,7 +1646,7 @@ bool DfGraphConvertor::GetControlDependList(const CNodePtr &node, | |||
| dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); | |||
| } | |||
| if (src_ops_list->empty() || dst_ops_list->empty()) { | |||
| MS_LOG(WARNING) << "Control depend node's src or dest node is not a apply node, ignore it"; | |||
| MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it"; | |||
| error_ = SUCCESS; | |||
| } | |||
| return true; | |||
| @@ -1690,6 +1690,8 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { | |||
| }); | |||
| } else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) { | |||
| control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]}); | |||
| } else if (src_ops_list->empty() || dst_ops_list->empty()) { | |||
| MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it"; | |||
| } else { | |||
| MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size() | |||
| << " -> dst:" << dst_ops_list->size(); | |||
| @@ -463,7 +463,7 @@ void InitSubModulesLogLevel() { | |||
| // set submodule's log level | |||
| auto submodule = GetEnv("MS_SUBMODULE_LOG_v"); | |||
| MS_LOG(INFO) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; | |||
| MS_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`"; | |||
| LogConfigParser parser(submodule); | |||
| auto configs = parser.Parse(); | |||
| for (const auto &cfg : configs) { | |||
| @@ -489,22 +489,14 @@ void InitSubModulesLogLevel() { | |||
| } // namespace mindspore | |||
| extern "C" { | |||
| // shared lib init hook | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| __attribute__((constructor)) void mindspore_log_init(void) { | |||
| __attribute__((constructor)) void common_log_init(void) { | |||
| #else | |||
| void mindspore_log_init(void) { | |||
| void common_log_init(void) { | |||
| #endif | |||
| #ifdef USE_GLOG | |||
| // do not use glog predefined log prefix | |||
| FLAGS_log_prefix = false; | |||
| static bool is_glog_initialzed = false; | |||
| if (!is_glog_initialzed) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| google::InitGoogleLogging("mindspore"); | |||
| #endif | |||
| is_glog_initialzed = true; | |||
| } | |||
| // set default log level to WARNING | |||
| if (mindspore::GetEnv("GLOG_v").empty()) { | |||
| FLAGS_v = mindspore::WARNING; | |||
| @@ -525,4 +517,22 @@ void mindspore_log_init(void) { | |||
| #endif | |||
| mindspore::InitSubModulesLogLevel(); | |||
| } | |||
| // shared lib init hook | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| __attribute__((constructor)) void mindspore_log_init(void) { | |||
| #else | |||
| void mindspore_log_init(void) { | |||
| #endif | |||
| #ifdef USE_GLOG | |||
| static bool is_glog_initialzed = false; | |||
| if (!is_glog_initialzed) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| google::InitGoogleLogging("mindspore"); | |||
| #endif | |||
| is_glog_initialzed = true; | |||
| } | |||
| #endif | |||
| common_log_init(); | |||
| } | |||
| } | |||
| @@ -252,6 +252,7 @@ constexpr auto kControlDependMode = "depend_mode"; | |||
| // index define of depend | |||
| constexpr auto kRealInputIndexInDepend = 1; | |||
| constexpr auto kDependAttachNodeIndex = 2; | |||
| constexpr auto kDependInputSize = 3; | |||
| // format | |||
| constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; | |||
| constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0"; | |||
| @@ -22,6 +22,10 @@ from . import dtype as mstype | |||
| from ._register_for_tensor import tensor_operator_registry | |||
| __all__ = ['Tensor', 'MetaTensor'] | |||
| np_types = (np.int8, np.int16, np.int32, np.int64, | |||
| np.uint8, np.uint16, np.uint32, np.uint64, np.float16, | |||
| np.float32, np.float64, np.bool_) | |||
| class Tensor(Tensor_): | |||
| @@ -54,6 +58,10 @@ class Tensor(Tensor_): | |||
| """ | |||
| def __init__(self, input_data, dtype=None): | |||
| # If input data is numpy number, convert it to np array | |||
| if isinstance(input_data, np_types): | |||
| input_data = np.array(input_data) | |||
| # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. | |||
| check_type('tensor input_data', input_data, (Tensor_, float, int)) | |||
| if dtype is not None: | |||
| @@ -1040,7 +1040,7 @@ class Dataset: | |||
| Args: | |||
| columns (list[str], optional): List of columns to be used to specify the order of columns | |||
| (defaults=None, means all columns). | |||
| (default=None, means all columns). | |||
| Returns: | |||
| Iterator, list of ndarray. | |||
| @@ -3382,7 +3382,7 @@ class ManifestDataset(MappableDataset): | |||
| class_indexing (dict, optional): A str-to-int mapping from label name to index | |||
| (default=None, the folder names will be sorted alphabetically and each | |||
| class will be given a unique index starting from 0). | |||
| decode (bool, optional): decode the images after reading (defaults=False). | |||
| decode (bool, optional): decode the images after reading (default=False). | |||
| num_shards (int, optional): Number of shards that the dataset should be divided | |||
| into (default=None). | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| @@ -4760,7 +4760,7 @@ class _NumpySlicesDataset: | |||
| def process_dict(self, input_data): | |||
| """ | |||
| Convert the dict like data into tuple format, when input is a tuple of dict then compose it into a dict first. | |||
| Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first. | |||
| """ | |||
| # Convert pandas like dict(has "values" column) into General dict | |||
| data_keys = list(input_data.keys()) | |||
| @@ -202,7 +202,7 @@ class RandomHorizontalFlip(cde.RandomHorizontalFlipOp): | |||
| Flip the input image horizontally, randomly with a given probability. | |||
| Args: | |||
| prob (float): Probability of the image being flipped (default=0.5). | |||
| prob (float, optional): Probability of the image being flipped (default=0.5). | |||
| """ | |||
| @check_prob | |||
| @@ -217,7 +217,7 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp): | |||
| Maintains data integrity by also flipping bounding boxes in an object detection pipeline. | |||
| Args: | |||
| prob (float): Probability of the image being flipped (default=0.5). | |||
| prob (float, optional): Probability of the image being flipped (default=0.5). | |||
| """ | |||
| @check_prob | |||
| @@ -231,7 +231,7 @@ class RandomVerticalFlip(cde.RandomVerticalFlipOp): | |||
| Flip the input image vertically, randomly with a given probability. | |||
| Args: | |||
| prob (float): Probability of the image being flipped (default=0.5). | |||
| prob (float, optional): Probability of the image being flipped (default=0.5). | |||
| """ | |||
| @check_prob | |||
| @@ -29,8 +29,9 @@ from .optimizer import Optimizer | |||
| _adam_opt = C.MultitypeFuncGraph("adam_opt") | |||
| @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): | |||
| @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Bool", "Bool") | |||
| def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter): | |||
| """ | |||
| Update parameters. | |||
| @@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad | |||
| m (Tensor): m value of parameters. | |||
| v (Tensor): v value of parameters. | |||
| gradient (Tensor): Gradient of parameters. | |||
| decay_flag (bool): Applies weight decay or not. | |||
| optim_filter (bool): Applies parameter update or not. | |||
| Returns: | |||
| Tensor, the new value of v after updating. | |||
| """ | |||
| op_mul = P.Mul() | |||
| op_square = P.Square() | |||
| op_sqrt = P.Sqrt() | |||
| op_cast = P.Cast() | |||
| op_reshape = P.Reshape() | |||
| op_shape = P.Shape() | |||
| if optim_filter: | |||
| op_mul = P.Mul() | |||
| op_square = P.Square() | |||
| op_sqrt = P.Sqrt() | |||
| op_cast = P.Cast() | |||
| op_reshape = P.Reshape() | |||
| op_shape = P.Shape() | |||
| param_fp32 = op_cast(param, mstype.float32) | |||
| m_fp32 = op_cast(m, mstype.float32) | |||
| v_fp32 = op_cast(v, mstype.float32) | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| param_fp32 = op_cast(param, mstype.float32) | |||
| m_fp32 = op_cast(m, mstype.float32) | |||
| v_fp32 = op_cast(v, mstype.float32) | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) | |||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) | |||
| - beta1, gradient_fp32) | |||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) | |||
| - beta2, op_square(gradient_fp32)) | |||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) | |||
| - beta2, op_square(gradient_fp32)) | |||
| update = next_m / (eps + op_sqrt(next_v)) | |||
| if decay_flag: | |||
| update = op_mul(weight_decay_tensor, param_fp32) + update | |||
| update_with_lr = op_mul(lr, update) | |||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||
| update = next_m / (eps + op_sqrt(next_v)) | |||
| if decay_flag: | |||
| update = op_mul(weight_decay_tensor, param_fp32) + update | |||
| next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param)))) | |||
| next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m)))) | |||
| next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v)))) | |||
| return next_v | |||
| update_with_lr = op_mul(lr, update) | |||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||
| next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) | |||
| next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) | |||
| next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) | |||
| return next_param | |||
| return gradient | |||
| def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||
| @@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer): | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| Outputs: | |||
| tuple[Parameter], the updated velocity value, the shape is the same as `params`. | |||
| tuple[bool], all elements are True. | |||
| Examples: | |||
| >>> net = Net() | |||
| @@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer): | |||
| def construct(self, gradients): | |||
| lr = self.get_lr() | |||
| updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| return updated_velocity | |||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, | |||
| self.decay_flag, self.optim_filter) | |||
| if self.use_parallel: | |||
| optim_result = self.broadcast_params(optim_result) | |||
| return optim_result | |||
| class AdamWeightDecayDynamicLR(Optimizer): | |||
| @@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer): | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| Outputs: | |||
| tuple[Parameter], the updated velocity value, the shape is the same as `params`. | |||
| tuple[bool], all elements are True. | |||
| Examples: | |||
| >>> net = Net() | |||
| @@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer): | |||
| warmup_lr = self.start_learning_rate * warmup_percent | |||
| is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) | |||
| lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr | |||
| updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor), | |||
| self.params, self.moments1, self.moments2, gradients, | |||
| self.decay_flag, self.optim_filter) | |||
| if self.use_parallel: | |||
| optim_result = self.broadcast_params(optim_result) | |||
| added_global_step = self.global_step + self.one | |||
| F.control_depend(lr, added_global_step) | |||
| self.global_step = added_global_step | |||
| return updated_velocity | |||
| return optim_result | |||
| @@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32) | |||
| _lamb_opt = C.MultitypeFuncGraph("lamb_opt") | |||
| @_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| @_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Bool", "Bool") | |||
| def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v, | |||
| gradient, decay_flag): | |||
| gradient, decay_flag, optim_filter): | |||
| """ | |||
| Update parameters. | |||
| @@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para | |||
| v (Tensor): v value of parameters. | |||
| gradient (Tensor): Gradient of parameters. | |||
| decay_flag (bool): Specifies whether param update with weight decay. | |||
| optim_filter(bool): Applies parameter update or not. | |||
| Returns: | |||
| Tensor, the new value of v after updating. | |||
| """ | |||
| op_mul = P.Mul() | |||
| op_sqrt = P.Sqrt() | |||
| op_rsqrt = P.Rsqrt() | |||
| op_square = P.Square() | |||
| op_cast = P.Cast() | |||
| op_reshape = P.Reshape() | |||
| op_shape = P.Shape() | |||
| op_pow = P.Pow() | |||
| op_norm = layer.Norm() | |||
| op_select = P.Select() | |||
| op_greater = P.Greater() | |||
| op_fill = P.Fill() | |||
| op_dtype = P.DType() | |||
| param_fp32 = op_cast(param, mstype.float32) | |||
| m_fp32 = op_cast(m, mstype.float32) | |||
| v_fp32 = op_cast(v, mstype.float32) | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, | |||
| mstype.float32) - beta1, gradient_fp32) | |||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, | |||
| mstype.float32) - beta2, op_square(gradient_fp32)) | |||
| next_mm = next_m / (op_cast(num_one, mstype.float32) | |||
| - op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) | |||
| next_vv = next_v / (op_cast(num_one, mstype.float32) - | |||
| op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) | |||
| w_norm = op_norm(param_fp32) | |||
| g_norm = op_norm(gradient_fp32) | |||
| g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt( | |||
| next_vv + eps)) + weight_decay_tensor * param_fp32) | |||
| zeros = F.zeros_like(w_norm) | |||
| ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) | |||
| trust_ratio = op_select( | |||
| op_greater(w_norm, zeros), | |||
| op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), | |||
| ones) | |||
| tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) | |||
| trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) | |||
| update = next_mm / (op_sqrt(next_vv) + eps) | |||
| if decay_flag: | |||
| update = update + op_mul(weight_decay_tensor, param_fp32) | |||
| update_with_lr = op_mul(op_mul(trust_ratio, lr), update) | |||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||
| next_v = F.depend(next_v, F.assign(param, next_param)) | |||
| next_v = F.depend(next_v, F.assign(m, next_m)) | |||
| next_v = F.depend(next_v, F.assign(v, next_v)) | |||
| return next_v | |||
| if optim_filter: | |||
| op_mul = P.Mul() | |||
| op_sqrt = P.Sqrt() | |||
| op_rsqrt = P.Rsqrt() | |||
| op_square = P.Square() | |||
| op_cast = P.Cast() | |||
| op_reshape = P.Reshape() | |||
| op_shape = P.Shape() | |||
| op_pow = P.Pow() | |||
| op_norm = layer.Norm() | |||
| op_select = P.Select() | |||
| op_greater = P.Greater() | |||
| op_fill = P.Fill() | |||
| op_dtype = P.DType() | |||
| param_fp32 = op_cast(param, mstype.float32) | |||
| m_fp32 = op_cast(m, mstype.float32) | |||
| v_fp32 = op_cast(v, mstype.float32) | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32) | |||
| next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32)) | |||
| next_mm = next_m / (op_cast(num_one, mstype.float32) | |||
| - op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) | |||
| next_vv = next_v / (op_cast(num_one, mstype.float32) - | |||
| op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) | |||
| w_norm = op_norm(param_fp32) | |||
| g_norm = op_norm(gradient_fp32) | |||
| g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32) | |||
| zeros = F.zeros_like(w_norm) | |||
| ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) | |||
| trust_ratio = op_select( | |||
| op_greater(w_norm, zeros), | |||
| op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), | |||
| ones) | |||
| tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) | |||
| trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) | |||
| update = next_mm / (op_sqrt(next_vv) + eps) | |||
| if decay_flag: | |||
| update = update + op_mul(weight_decay_tensor, param_fp32) | |||
| update_with_lr = op_mul(op_mul(trust_ratio, lr), update) | |||
| next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) | |||
| next_param = F.depend(next_param, F.assign(param, next_param)) | |||
| next_param = F.depend(next_param, F.assign(m, next_m)) | |||
| next_param = F.depend(next_param, F.assign(v, next_v)) | |||
| return next_param | |||
| return gradient | |||
| lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") | |||
| @@ -238,7 +237,7 @@ class Lamb(Optimizer): | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| Outputs: | |||
| tuple[Parameter], the updated velocity value, the shape is the same as `params`. | |||
| tuple[bool], all elements are True. | |||
| Examples: | |||
| >>> net = Net() | |||
| @@ -311,18 +310,21 @@ class Lamb(Optimizer): | |||
| self.warmup_steps, self.global_step), mstype.float32) | |||
| lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr | |||
| if self.enable_graph_kernel: | |||
| updated_velocity = self.hyper_map(F.partial(lamb_opt_graph_kernel, | |||
| self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor, self.global_step), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, | |||
| self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor, self.global_step), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| else: | |||
| updated_velocity = self.hyper_map(F.partial(_lamb_opt, | |||
| self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor, self.global_step), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flag) | |||
| optim_result = self.hyper_map(F.partial(_lamb_opt, | |||
| self.beta1, self.beta2, self.eps, lr, | |||
| self.weight_decay_tensor, self.global_step), | |||
| self.params, self.moments1, self.moments2, gradients, | |||
| self.decay_flag, self.optim_filter) | |||
| if self.use_parallel: | |||
| optim_result = self.broadcast_params(optim_result) | |||
| added_global_step = self.global_step + self.one | |||
| F.control_depend(lr, added_global_step) | |||
| self.global_step = added_global_step | |||
| return updated_velocity | |||
| return optim_result | |||
| @@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore import log as logger | |||
| from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| __all__ = ['Optimizer'] | |||
| @@ -155,6 +158,27 @@ class Optimizer(Cell): | |||
| self.param_length = len(self.parameters) | |||
| self.map_ = C.Map() | |||
| use_parallel = auto_parallel_context().get_enable_parallel_optimizer() | |||
| self.use_parallel = use_parallel | |||
| if use_parallel: | |||
| if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]: | |||
| raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name)) | |||
| if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL, | |||
| ParallelMode.AUTO_PARALLEL]: | |||
| raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format | |||
| (_get_parallel_mode())) | |||
| self.dev_num = _get_device_num() | |||
| if self.dev_num > self.param_length: | |||
| raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is" | |||
| " less than the number of devices {}".format(self.param_length, self.dev_num)) | |||
| self.param_rank = self._get_parameter_group_id() | |||
| self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank)) | |||
| self.param_names = [] | |||
| for param in self.parameters: | |||
| self.param_names.append(param.name) | |||
| else: | |||
| self.optim_filter = (True,) * self.param_length | |||
| def decay_weight(self, gradients): | |||
| """ | |||
| Weight decay. | |||
| @@ -219,8 +243,32 @@ class Optimizer(Cell): | |||
| raise TypeError("Learning rate should be float, Tensor or Iterable.") | |||
| return lr | |||
| def _check_group_params(self, parameters): | |||
| """Check group params.""" | |||
| parse_keys = ['params', 'lr', 'weight_decay', 'order_params'] | |||
| for group_param in parameters: | |||
| invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys())) | |||
| if invalid_key: | |||
| raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.') | |||
| if 'order_params' in group_param.keys(): | |||
| if len(group_param.keys()) > 1: | |||
| raise ValueError("The order params dict in group parameters should " | |||
| "only include the 'order_params' key.") | |||
| if not isinstance(group_param['order_params'], Iterable): | |||
| raise TypeError("The value of 'order_params' should be an Iterable type.") | |||
| continue | |||
| if not group_param['params']: | |||
| raise ValueError("Optimizer got an empty group parameter list.") | |||
| for param in group_param['params']: | |||
| if not isinstance(param, Parameter): | |||
| raise TypeError("The group param should be an iterator of Parameter type.") | |||
| def _parse_group_params(self, parameters, learning_rate): | |||
| """Parse group params.""" | |||
| self._check_group_params(parameters) | |||
| if self.dynamic_lr: | |||
| dynamic_lr_length = learning_rate.size() | |||
| else: | |||
| @@ -250,9 +298,6 @@ class Optimizer(Cell): | |||
| if dynamic_lr_length not in (lr_length, 0): | |||
| raise ValueError("The dynamic learning rate in group should be the same size.") | |||
| if not group_param['params']: | |||
| raise ValueError("Optimizer got an empty group parameter list.") | |||
| dynamic_lr_length = lr_length | |||
| self.dynamic_lr_length = dynamic_lr_length | |||
| @@ -384,6 +429,51 @@ class Optimizer(Cell): | |||
| lr = self.learning_rate | |||
| return lr | |||
| def _get_parameter_group_id(self): | |||
| """ | |||
| Get the parameter partition group id, which is less than the number of devices. | |||
| Returns: | |||
| tuple, the group id tuple of parameters. | |||
| """ | |||
| rank_list = () | |||
| count = 0 | |||
| for _ in range(self.param_length): | |||
| rank_list = rank_list + (count,) | |||
| count = count + 1 | |||
| if count == self.dev_num: | |||
| count = 0 | |||
| return rank_list | |||
| def broadcast_params(self, optim_result): | |||
| """ | |||
| Apply Broadcast operations in the sequential order of parameter groups. | |||
| Returns: | |||
| bool, the status flag. | |||
| """ | |||
| param_group = [] | |||
| key_group = [] | |||
| for _ in range(self.dev_num): | |||
| param_group.append(F.make_tuple()) | |||
| key_group.append(F.make_tuple()) | |||
| for i in range(self.param_length): | |||
| param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],) | |||
| key = P.MakeRefKey(self.param_names[i])() | |||
| key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,) | |||
| new_param_group = [] | |||
| for root in range(self.dev_num): | |||
| ops = P.Broadcast(root) | |||
| next_params = ops(param_group[root]) | |||
| new_param_group.append(next_params) | |||
| for i in range(F.tuple_len(next_params)): | |||
| F.assign(key_group[root][i], next_params[i]) | |||
| status = True | |||
| for i in range(self.dev_num - 1): | |||
| status = F.control_depend(new_param_group[i][0], new_param_group[i+1]) | |||
| return status | |||
| def construct(self, *hyper_params): | |||
| raise NotImplementedError | |||
| @@ -220,7 +220,9 @@ class DataWrapper(Cell): | |||
| def __init__(self, network, dataset_types, dataset_shapes, queue_name): | |||
| super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) | |||
| # Also copy the flag in `network` construct | |||
| flags = getattr(network.__class__.construct, "_mindspore_flags", {}) | |||
| self.add_flags(**flags) | |||
| self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) | |||
| self.network = network | |||
| @@ -47,6 +47,7 @@ from .gather_v2 import _gather_v2_akg | |||
| from .less import _less_akg | |||
| from .log import _log_akg | |||
| from .matmul import _matmul_akg | |||
| from .batchmatmul import _batchmatmul_akg | |||
| from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_akg | |||
| from .max_pool_with_argmax import _max_pool_with_argmax_akg | |||
| from .max import _max_akg | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BatchMatMul op""" | |||
| from mindspore.ops.op_info_register import op_info_register | |||
| @op_info_register("""{ | |||
| "op_name": "BatchMatMul", | |||
| "imply_type": "AutoDiff", | |||
| "fusion_type": "OPAQUE", | |||
| "attr": [ | |||
| { | |||
| "name": "transpose_a", | |||
| "param_type": "optional", | |||
| "type": "bool" | |||
| }, | |||
| { | |||
| "name": "transpose_b", | |||
| "param_type": "optional", | |||
| "type": "bool" | |||
| } | |||
| ], | |||
| "inputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ" | |||
| ], | |||
| "name": "x1" | |||
| }, | |||
| { | |||
| "index": 1, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ" | |||
| ], | |||
| "name": "x2" | |||
| } | |||
| ], | |||
| "outputs": [ | |||
| { | |||
| "index": 0, | |||
| "dtype": [ | |||
| "float16" | |||
| ], | |||
| "format": [ | |||
| "FRACTAL_NZ" | |||
| ], | |||
| "name": "output" | |||
| } | |||
| ] | |||
| }""") | |||
| def _batchmatmul_akg(): | |||
| """BatchMatMul AKG register""" | |||
| return | |||
| @@ -28,26 +28,8 @@ confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \ | |||
| .attr("transpose_first", "required", "bool", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I8_FracNZ, DataType.I8_FracNZ) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_FracNZ, DataType.U8_FracNZ) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I16_FracNZ, DataType.I16_FracNZ) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.U16_FracNZ, DataType.U16_FracNZ) \ | |||
| .dtype_format(DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.U32_FracNZ, DataType.U32_FracNZ) \ | |||
| .dtype_format(DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.I64_FracNZ, DataType.I64_FracNZ) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U64_FracNZ, DataType.U64_FracNZ) \ | |||
| .dtype_format(DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.None_None, DataType.None_None) \ | |||
| .get_op_info() | |||
| @@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value): | |||
| return F.list_setitem(data, number_index, value) | |||
| @setitem.register("List", "Number", "Tuple") | |||
| def _list_setitem_with_Tuple(data, number_index, value): | |||
| """ | |||
| Assigns value to list. | |||
| Inputs: | |||
| data (list): Data of type lis. | |||
| number_index (Number): Index of data. | |||
| value (list): Value given. | |||
| Outputs: | |||
| list, type is same as the element type of data. | |||
| """ | |||
| return F.list_setitem(data, number_index, value) | |||
| @setitem.register("Dictionary", "String", "Tensor") | |||
| def _dict_setitem_with_tensor(data, key, value): | |||
| """ | |||
| @@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer): | |||
| self.op = op | |||
| self.add_prim_attr('group', _get_group(group)) | |||
| self.add_prim_attr('fusion', 0) | |||
| self.add_prim_attr('index', 0) | |||
| def vm_impl(self, x): | |||
| """Implement by vm mode.""" | |||
| @@ -309,12 +309,6 @@ class Print(PrimitiveWithInfer): | |||
| Output tensor or string to stdout. | |||
| Note: | |||
| The print operation cannot support the following cases currently. | |||
| 1. The type of tensor is float64 or bool. | |||
| 2. The data of tensor is a scalar type. | |||
| In pynative mode, please use python print function. | |||
| Inputs: | |||
| @@ -334,7 +328,7 @@ class Print(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| pass | |||
| self.add_prim_attr("_side_effect", True) | |||
| def __call__(self, *args): | |||
| for arg in args: | |||
| @@ -888,7 +888,8 @@ class Neg(PrimitiveWithInfer): | |||
| def infer_value(self, input_x): | |||
| if input_x is not None: | |||
| input_x = input_x.asnumpy() | |||
| return Tensor(-input_x) | |||
| out = np.array(-input_x, input_x.dtype) | |||
| return Tensor(out) | |||
| return None | |||
| @@ -1667,7 +1668,8 @@ class Div(_MathBinaryOp): | |||
| if x is not None and y is not None: | |||
| x = x.asnumpy() | |||
| y = y.asnumpy() | |||
| return Tensor(x / y) | |||
| out = np.array(x / y, x.dtype) | |||
| return Tensor(out) | |||
| return None | |||
| @@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer): | |||
| return variable | |||
| def infer_dtype(self, variable, value): | |||
| args = {"variable": variable, "value": value} | |||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| # Add a type validation later when we don't have to assign a value to RefKey. | |||
| return variable | |||
| @@ -400,6 +400,23 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_global_rank_is_set() | |||
| def set_enable_parallel_optimizer(self, enable_parallel_optimizer): | |||
| """ | |||
| Set enable/disable parallel optimizer. | |||
| Args: | |||
| set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer. | |||
| """ | |||
| self.check_context_handle() | |||
| if not isinstance(enable_parallel_optimizer, bool): | |||
| raise TypeError('enable_parallel_optimizer is invalid type') | |||
| self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer) | |||
| def get_enable_parallel_optimizer(self): | |||
| """Get parallel optimizer flag.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_enable_parallel_optimizer() | |||
| def reset(self): | |||
| """Reset all settings.""" | |||
| self.check_context_handle() | |||
| @@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = { | |||
| "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, | |||
| "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, | |||
| "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, | |||
| "full_batch": auto_parallel_context().set_full_batch} | |||
| "full_batch": auto_parallel_context().set_full_batch, | |||
| "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer} | |||
| _get_auto_parallel_context_func_map = { | |||
| @@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = { | |||
| "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, | |||
| "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, | |||
| "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, | |||
| "full_batch": auto_parallel_context().get_full_batch} | |||
| "full_batch": auto_parallel_context().get_full_batch, | |||
| "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} | |||
| @args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, | |||
| loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, | |||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, | |||
| strategy_ckpt_save_file=str, full_batch=bool) | |||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) | |||
| def _set_auto_parallel_context(**kwargs): | |||
| """ | |||
| Set auto parallel context. | |||
| @@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs): | |||
| strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' | |||
| strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' | |||
| full_batch (bool): Whether to load the whole batch on each device. Default: False. | |||
| enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -535,5 +556,6 @@ def _reset_auto_parallel_context(): | |||
| - parameter_broadcast: False. | |||
| - strategy_ckpt_load_file: "" | |||
| - strategy_ckpt_save_file: "" | |||
| - enable_parallel_optimizer: False | |||
| """ | |||
| auto_parallel_context().reset() | |||
| @@ -166,8 +166,11 @@ class SummaryCollector(Callback): | |||
| self._has_saved_custom_data = False | |||
| self._is_parse_loss_success = True | |||
| self._first_step = True | |||
| self._dataset_sink_mode = True | |||
| def __enter__(self): | |||
| self._first_step = True | |||
| self._dataset_sink_mode = True | |||
| self._record = SummaryRecord(log_dir=self._summary_dir) | |||
| return self | |||
| @@ -279,15 +282,15 @@ class SummaryCollector(Callback): | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| if self._first_step: | |||
| # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario | |||
| self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num) | |||
| if cb_params.mode == ModeEnum.TRAIN.value: | |||
| # Make sure the first step data is recorded | |||
| if not self._first_step and cb_params.cur_step_num % self._collect_freq: | |||
| if not self._is_collect_this_step(cb_params): | |||
| return | |||
| self._first_step = False | |||
| if not self._has_saved_train_network: | |||
| self._collect_graphs(cb_params) | |||
| @@ -295,6 +298,7 @@ class SummaryCollector(Callback): | |||
| self._collect_metric(cb_params) | |||
| self._collect_histogram(cb_params) | |||
| self._first_step = False | |||
| self._record.record(cb_params.cur_step_num) | |||
| def end(self, run_context): | |||
| @@ -320,6 +324,18 @@ class SummaryCollector(Callback): | |||
| raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," | |||
| f"but expected only one {self.__class__.__name__} instance.") | |||
| def _is_collect_this_step(self, cb_params): | |||
| """Decide whether to collect data for the current step.""" | |||
| # Make sure the first step data is recorded | |||
| if not self._first_step: | |||
| if self._dataset_sink_mode: | |||
| if cb_params.cur_epoch_num % self._collect_freq: | |||
| return False | |||
| else: | |||
| if cb_params.cur_step_num % self._collect_freq: | |||
| return False | |||
| return True | |||
| @staticmethod | |||
| def _package_custom_lineage_data(custom_lineage_data): | |||
| """ | |||
| @@ -318,10 +318,6 @@ def preprocess_fn(image, box, is_training): | |||
| else: | |||
| input_data = resize_column(*input_data) | |||
| photo = (np.random.rand() < config.photo_ratio) | |||
| if photo: | |||
| input_data = photo_crop_column(*input_data) | |||
| input_data = image_bgr_rgb(*input_data) | |||
| output_data = input_data | |||
| @@ -432,19 +428,19 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fast | |||
| writer.write_raw_data([row]) | |||
| writer.commit() | |||
| def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0, | |||
| is_training=True, num_parallel_workers=8): | |||
| is_training=True, num_parallel_workers=4): | |||
| """Creatr FasterRcnn dataset with MindDataset.""" | |||
| ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id, | |||
| num_parallel_workers=num_parallel_workers, shuffle=is_training) | |||
| num_parallel_workers=1, shuffle=is_training) | |||
| decode = C.Decode() | |||
| ds = ds.map(input_columns=["image"], operations=decode) | |||
| ds = ds.map(input_columns=["image"], operations=decode, num_parallel_workers=1) | |||
| compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) | |||
| hwc_to_chw = C.HWC2CHW() | |||
| normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) | |||
| horizontally_op = C.RandomHorizontalFlip(1) | |||
| type_cast0 = CC.TypeCast(mstype.float32) | |||
| type_cast1 = CC.TypeCast(mstype.float16) | |||
| type_cast2 = CC.TypeCast(mstype.int32) | |||
| type_cast3 = CC.TypeCast(mstype.bool_) | |||
| @@ -453,17 +449,18 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi | |||
| ds = ds.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "image_shape", "box", "label", "valid_num"], | |||
| columns_order=["image", "image_shape", "box", "label", "valid_num"], | |||
| operations=compose_map_func, num_parallel_workers=4) | |||
| ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0], | |||
| num_parallel_workers=num_parallel_workers) | |||
| operations=compose_map_func, num_parallel_workers=num_parallel_workers) | |||
| flip = (np.random.rand() < config.flip_ratio) | |||
| if flip: | |||
| ds = ds.map(input_columns=["image"], operations=[horizontally_op], | |||
| num_parallel_workers=num_parallel_workers) | |||
| ds = ds.map(input_columns=["image"], operations=[normalize_op, horizontally_op, hwc_to_chw, type_cast1], | |||
| num_parallel_workers=24) | |||
| ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"], | |||
| operations=flipped_generation, num_parallel_workers=4) | |||
| operations=flipped_generation, num_parallel_workers=num_parallel_workers) | |||
| else: | |||
| ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1], | |||
| num_parallel_workers=24) | |||
| else: | |||
| ds = ds.map(input_columns=["image", "annotation"], | |||
| output_columns=["image", "image_shape", "box", "label", "valid_num"], | |||
| @@ -471,11 +468,10 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi | |||
| operations=compose_map_func, | |||
| num_parallel_workers=num_parallel_workers) | |||
| ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0], | |||
| num_parallel_workers=num_parallel_workers) | |||
| ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1], | |||
| num_parallel_workers=24) | |||
| # transpose_column from python to c | |||
| ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1]) | |||
| ds = ds.map(input_columns=["image_shape"], operations=[type_cast1]) | |||
| ds = ds.map(input_columns=["box"], operations=[type_cast1]) | |||
| ds = ds.map(input_columns=["label"], operations=[type_cast2]) | |||
| @@ -19,7 +19,9 @@ from easydict import EasyDict as edict | |||
| cifar_cfg = edict({ | |||
| 'num_classes': 10, | |||
| 'lr_init': 0.05, | |||
| 'lr_init': 0.01, | |||
| 'lr_max': 0.1, | |||
| 'warmup_epochs': 5, | |||
| 'batch_size': 64, | |||
| 'epoch_size': 70, | |||
| 'momentum': 0.9, | |||
| @@ -38,20 +38,25 @@ random.seed(1) | |||
| np.random.seed(1) | |||
| def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): | |||
| def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch): | |||
| """Set learning rate.""" | |||
| lr_each_step = [] | |||
| total_steps = steps_per_epoch * total_epochs | |||
| decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] | |||
| warmup_steps = steps_per_epoch * warmup_epochs | |||
| if warmup_steps != 0: | |||
| inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) | |||
| else: | |||
| inc_each_step = 0 | |||
| for i in range(total_steps): | |||
| if i < decay_epoch_index[0]: | |||
| lr_each_step.append(lr_max) | |||
| elif i < decay_epoch_index[1]: | |||
| lr_each_step.append(lr_max * 0.1) | |||
| elif i < decay_epoch_index[2]: | |||
| lr_each_step.append(lr_max * 0.01) | |||
| if i < warmup_steps: | |||
| lr_value = float(lr_init) + inc_each_step * float(i) | |||
| else: | |||
| lr_each_step.append(lr_max * 0.001) | |||
| base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) | |||
| lr_value = float(lr_max) * base * base | |||
| if lr_value < 0.0: | |||
| lr_value = 0.0 | |||
| lr_each_step.append(lr_value) | |||
| current_step = global_step | |||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||
| learning_rate = lr_each_step[current_step:] | |||
| @@ -86,7 +91,8 @@ if __name__ == '__main__': | |||
| if args_opt.pre_trained: | |||
| load_param_into_net(net, load_checkpoint(args_opt.pre_trained)) | |||
| lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) | |||
| lr = lr_steps(0, lr_init=cfg.lr_init, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs, | |||
| total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, | |||
| weight_decay=cfg.weight_decay) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||
| @@ -22,6 +22,7 @@ | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <future> | |||
| #include "mindspore/ccsrc/utils/log_adapter.h" | |||
| #include "serving/ms_service.grpc.pb.h" | |||
| @@ -40,7 +41,7 @@ namespace serving { | |||
| using MSTensorPtr = std::shared_ptr<inference::MSTensor>; | |||
| Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) { | |||
| session_ = inference::MSSession::CreateSession(device + "Inference", device_id); | |||
| session_ = inference::MSSession::CreateSession(device, device_id); | |||
| if (session_ == nullptr) { | |||
| MS_LOG(ERROR) << "Creat Session Failed"; | |||
| return FAILED; | |||
| @@ -67,6 +68,7 @@ Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::Multi | |||
| MS_LOG(INFO) << "run Predict"; | |||
| *outputs = session_->RunGraph(graph_id_, inputs); | |||
| MS_LOG(INFO) << "run Predict finished"; | |||
| return SUCCESS; | |||
| } | |||
| @@ -80,12 +82,16 @@ Status Session::Warmup(const MindSporeModelPtr model) { | |||
| std::string file_name = model->GetModelPath() + '/' + model->GetModelName(); | |||
| char *graphBuf = ReadFile(file_name.c_str(), &size); | |||
| if (graphBuf == nullptr) { | |||
| MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | |||
| MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str(); | |||
| return FAILED; | |||
| } | |||
| last_graph_ = inference::LoadModel(graphBuf, size, device_type_); | |||
| if (last_graph_ == nullptr) { | |||
| MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); | |||
| return FAILED; | |||
| } | |||
| graph_id_ = session_->CompileGraph(last_graph_); | |||
| MS_LOG(INFO) << "Session Warmup"; | |||
| MS_LOG(INFO) << "Session Warmup finished"; | |||
| return SUCCESS; | |||
| } | |||
| @@ -95,6 +101,9 @@ Status Session::Clear() { | |||
| } | |||
| namespace { | |||
| static const uint32_t uint32max = 0x7FFFFFFF; | |||
| std::promise<void> exit_requested; | |||
| const std::map<ms_serving::DataType, TypeId> type2id_map{ | |||
| {ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool}, | |||
| {ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8}, | |||
| @@ -141,7 +150,7 @@ MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) { | |||
| } | |||
| TypeId type = iter->second; | |||
| auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape)); | |||
| memcpy_s(ms_tensor->MutableData(), tensor.data().size(), tensor.data().data(), tensor.data().size()); | |||
| memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size()); | |||
| return ms_tensor; | |||
| } | |||
| @@ -166,10 +175,7 @@ void ClearEnv() { | |||
| Session::Instance().Clear(); | |||
| inference::ExitInference(); | |||
| } | |||
| void HandleSignal(int sig) { | |||
| ClearEnv(); | |||
| exit(0); | |||
| } | |||
| void HandleSignal(int sig) { exit_requested.set_value(); } | |||
| #ifdef ENABLE_D | |||
| static rtContext_t g_ctx = nullptr; | |||
| @@ -247,6 +253,7 @@ Status Server::BuildAndStart() { | |||
| rtError_t rt_ret = rtCtxGetCurrent(&ctx); | |||
| if (rt_ret != RT_ERROR_NONE || ctx == nullptr) { | |||
| MS_LOG(ERROR) << "the ascend device context is null"; | |||
| ClearEnv(); | |||
| return FAILED; | |||
| } | |||
| g_ctx = ctx; | |||
| @@ -258,6 +265,7 @@ Status Server::BuildAndStart() { | |||
| auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0); | |||
| grpc::ServerBuilder builder; | |||
| builder.SetOption(std::move(option)); | |||
| builder.SetMaxMessageSize(uint32max); | |||
| // Listen on the given address without any authentication mechanism. | |||
| builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); | |||
| // Register "service" as the instance through which we'll communicate with | |||
| @@ -265,13 +273,20 @@ Status Server::BuildAndStart() { | |||
| builder.RegisterService(&service); | |||
| // Finally assemble the server. | |||
| std::unique_ptr<grpc::Server> server(builder.BuildAndStart()); | |||
| if (server == nullptr) { | |||
| MS_LOG(ERROR) << "The serving server create failed"; | |||
| ClearEnv(); | |||
| return FAILED; | |||
| } | |||
| auto grpc_server_run = [&server]() { server->Wait(); }; | |||
| std::thread serving_thread(grpc_server_run); | |||
| MS_LOG(INFO) << "Server listening on " << server_address << std::endl; | |||
| // Wait for the server to shutdown. Note that some other thread must be | |||
| // responsible for shutting down the server for this call to ever return. | |||
| server->Wait(); | |||
| auto exit_future = exit_requested.get_future(); | |||
| exit_future.wait(); | |||
| ClearEnv(); | |||
| server->Shutdown(); | |||
| serving_thread.join(); | |||
| return SUCCESS; | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -29,7 +29,6 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| char *ReadFile(const char *file, size_t *size) { | |||
| if (file == nullptr) { | |||
| MS_LOG(ERROR) << "file is nullptr"; | |||
| @@ -70,8 +69,8 @@ bool DirOrFileExist(const std::string &file_path) { | |||
| } | |||
| std::vector<std::string> GetAllSubDirs(const std::string &dir_path) { | |||
| DIR *dir; | |||
| struct dirent *ptr; | |||
| DIR *dir = nullptr; | |||
| struct dirent *ptr = nullptr; | |||
| std::vector<std::string> SubDirs; | |||
| if ((dir = opendir(dir_path.c_str())) == NULL) { | |||
| @@ -36,17 +36,16 @@ bool RemovePrefix(std::string *str, const std::string &prefix) { | |||
| bool Option::ParseInt32(std::string *arg) { | |||
| if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { | |||
| char extra; | |||
| int32_t parsed_value; | |||
| if (sscanf(arg->data(), "%d%c", &parsed_value, &extra) != 1) { | |||
| std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; | |||
| try { | |||
| parsed_value = std::stoi(arg->data()); | |||
| } catch (std::invalid_argument) { | |||
| std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl; | |||
| return false; | |||
| } else { | |||
| *int32_default_ = parsed_value; | |||
| } | |||
| *int32_default_ = parsed_value; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| @@ -76,17 +75,16 @@ bool Option::ParseString(std::string *arg) { | |||
| bool Option::ParseFloat(std::string *arg) { | |||
| if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) { | |||
| char extra; | |||
| float parsed_value; | |||
| if (sscanf(arg->data(), "%f%c", &parsed_value, &extra) != 1) { | |||
| std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl; | |||
| try { | |||
| parsed_value = std::stof(arg->data()); | |||
| } catch (std::invalid_argument) { | |||
| std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl; | |||
| return false; | |||
| } else { | |||
| *float_default_ = parsed_value; | |||
| } | |||
| *float_default_ = parsed_value; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| @@ -159,10 +157,11 @@ Options::Options() : args_(nullptr) { CreateOptions(); } | |||
| void Options::CreateOptions() { | |||
| args_ = std::make_shared<Arguments>(); | |||
| std::vector<Option> options = { | |||
| Option("port", &args_->grpc_port, "Port to listen on for gRPC API, default is 5500"), | |||
| Option("model_name", &args_->model_name, "model name "), | |||
| Option("model_path", &args_->model_path, "the path of the model files"), | |||
| Option("device_id", &args_->device_id, "the device id, default is 0"), | |||
| Option("port", &args_->grpc_port, | |||
| "[Optional] Port to listen on for gRPC API, default is 5500, range from 1 to 65535"), | |||
| Option("model_name", &args_->model_name, "[Required] model name "), | |||
| Option("model_path", &args_->model_path, "[Required] the path of the model files"), | |||
| Option("device_id", &args_->device_id, "[Optional] the device id, default is 0, range from 0 to 7"), | |||
| }; | |||
| options_ = options; | |||
| } | |||
| @@ -176,6 +175,14 @@ bool Options::CheckOptions() { | |||
| std::cout << "device_type only support Ascend right now" << std::endl; | |||
| return false; | |||
| } | |||
| if (args_->device_id > 7) { | |||
| std::cout << "the device_id should be in [0~7]" << std::endl; | |||
| return false; | |||
| } | |||
| if (args_->grpc_port < 1 || args_->grpc_port > 65535) { | |||
| std::cout << "the port should be in [1~65535]" << std::endl; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -238,6 +245,5 @@ void Options::Usage() { | |||
| << option.usage_ << std::endl; | |||
| } | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -22,7 +22,6 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| struct Arguments { | |||
| int32_t grpc_port = 5500; | |||
| std::string grpc_socket_path; | |||
| @@ -40,6 +39,7 @@ class Option { | |||
| Option(const std::string &name, bool *default_point, const std::string &usage); | |||
| Option(const std::string &name, std::string *default_point, const std::string &usage); | |||
| Option(const std::string &name, float *default_point, const std::string &usage); | |||
| ~Option() = default; | |||
| private: | |||
| friend class Options; | |||
| @@ -77,7 +77,6 @@ class Options { | |||
| std::vector<Option> options_; | |||
| std::shared_ptr<Arguments> args_; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -19,7 +19,6 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path, | |||
| const std::string &model_version, const time_t &last_update_time) | |||
| : model_name_(model_name), | |||
| @@ -25,7 +25,6 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| volatile bool stop_poll = false; | |||
| std::string GetVersionFromPath(const std::string &path) { | |||
| @@ -102,10 +101,10 @@ Status VersionController::CreateInitModels() { | |||
| } | |||
| std::vector<std::string> SubDirs = GetAllSubDirs(models_path_); | |||
| if (version_control_strategy_ == kLastest) { | |||
| auto path = SubDirs.empty() ? models_path_ : SubDirs.back(); | |||
| std::string model_version = GetVersionFromPath(path); | |||
| time_t last_update_time = GetModifyTime(path); | |||
| MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(model_name_, path, model_version, last_update_time); | |||
| std::string model_version = GetVersionFromPath(models_path_); | |||
| time_t last_update_time = GetModifyTime(models_path_); | |||
| MindSporeModelPtr model_ptr = | |||
| std::make_shared<MindSporeModel>(model_name_, models_path_, model_version, last_update_time); | |||
| valid_models_.emplace_back(model_ptr); | |||
| } else { | |||
| for (auto &dir : SubDirs) { | |||
| @@ -119,8 +118,8 @@ Status VersionController::CreateInitModels() { | |||
| MS_LOG(ERROR) << "There is no valid model for serving"; | |||
| return FAILED; | |||
| } | |||
| Session::Instance().Warmup(valid_models_.back()); | |||
| return SUCCESS; | |||
| auto ret = Session::Instance().Warmup(valid_models_.back()); | |||
| return ret; | |||
| } | |||
| void VersionController::StartPollModelPeriodic() { | |||
| @@ -129,6 +128,5 @@ void VersionController::StartPollModelPeriodic() { | |||
| } | |||
| void VersionController::StopPollModelPeriodic() {} | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -64,7 +64,6 @@ class PeriodicFunction { | |||
| VersionController::VersionControllerStrategy version_control_strategy_; | |||
| std::vector<MindSporeModelPtr> valid_models_; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -214,6 +214,7 @@ PredictRequest ReadBertInput() { | |||
| class MSClient { | |||
| public: | |||
| explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {} | |||
| ~MSClient() = default; | |||
| std::string Predict(const std::string &type) { | |||
| // Data we are sending to the server. | |||
| @@ -310,7 +311,6 @@ int main(int argc, char **argv) { | |||
| type = "add"; | |||
| } | |||
| } | |||
| } else { | |||
| target_str = "localhost:5500"; | |||
| type = "add"; | |||
| @@ -81,7 +81,7 @@ function checkopts() | |||
| checkopts "$@" | |||
| # switch to project root path, which contains clang-format config file '.clang-format' | |||
| cd "${SCRIPTS_PATH}/.." || exit 1 | |||
| cd "${SCRIPTS_PATH}/../.." || exit 1 | |||
| FMT_FILE_LIST='__format_files_list__' | |||
| @@ -161,6 +161,7 @@ setup( | |||
| description='MindSpore is a new open source deep learning training/inference ' | |||
| 'framework that could be used for mobile, edge and cloud scenarios.', | |||
| long_description="\n\n".join([readme, release]), | |||
| long_description_content_type="text/markdown", | |||
| packages=find_packages(), | |||
| package_data=package_data, | |||
| include_package_data=True, | |||
| @@ -190,9 +190,9 @@ TEST_F(MindDataTestBPlusTree, Test3) { | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| uint64_t min = ai.min_key(); | |||
| uint64_t max = ai.max_key(); | |||
| EXPECT_EQ(min, 1); | |||
| EXPECT_EQ(max, 4); | |||
| auto r = ai.Search(3); | |||
| EXPECT_EQ(min, 0); | |||
| EXPECT_EQ(max, 3); | |||
| auto r = ai.Search(2); | |||
| auto &it = r.first; | |||
| EXPECT_EQ(it.value(), "b"); | |||
| MS_LOG(INFO) << "Dump all the values using [] operator."; | |||
| @@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common { | |||
| }; | |||
| void SetUp() { | |||
| elim_Z = MakeSubstitution(irpass::AddByZero(), "elim_Z", prim::kPrimScalarAdd); | |||
| elim_R = MakeSubstitution(irpass::PrimEliminater(R), "elim_R", R); | |||
| idempotent_P = MakeSubstitution(IdempotentEliminater(), "idempotent_P", P); | |||
| Qct_to_P = MakeSubstitution(QctToP(), "Qct_to_P", Q); | |||
| elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd); | |||
| elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R); | |||
| idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P); | |||
| Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q); | |||
| } | |||
| bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) { | |||
| @@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) { | |||
| } else if (name == "instance_name") { | |||
| parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); | |||
| ASSERT_EQ(converted_ret->ToString(), "test"); | |||
| } else if (name == "index") { | |||
| parse::ConvertData(py::cast<py::object>(item.second), &converted_ret); | |||
| ASSERT_EQ(converted_ret->ToString(), "0"); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Test failed"; | |||
| } | |||
| @@ -4,6 +4,7 @@ | |||
| "numParallelWorkers": 4, | |||
| "workerConnectorSize": 16, | |||
| "opConnectorSize": 16, | |||
| "seed": 5489 | |||
| "seed": 5489, | |||
| "monitor_sampling_interval": 15 | |||
| } | |||