Browse Source

Merge branch 'master' of gitee.com:mindspore/mindspore

tags/v0.6.0-beta
yanghaitao 5 years ago
parent
commit
5a886794da
100 changed files with 2866 additions and 579 deletions
  1. +1
    -1
      akg
  2. +7
    -6
      mindspore/ccsrc/CMakeLists.txt
  3. +1
    -2
      mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc
  4. +1
    -2
      mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc
  5. +1
    -2
      mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc
  6. +6
    -0
      mindspore/ccsrc/dataset/util/CMakeLists.txt
  7. +87
    -0
      mindspore/ccsrc/dataset/util/allocator.h
  8. +1
    -1
      mindspore/ccsrc/dataset/util/auto_index.h
  9. +388
    -0
      mindspore/ccsrc/dataset/util/buddy.cc
  10. +133
    -0
      mindspore/ccsrc/dataset/util/buddy.h
  11. +202
    -0
      mindspore/ccsrc/dataset/util/cache_pool.cc
  12. +139
    -0
      mindspore/ccsrc/dataset/util/cache_pool.h
  13. +18
    -0
      mindspore/ccsrc/dataset/util/list.h
  14. +0
    -14
      mindspore/ccsrc/dataset/util/memory_pool.h
  15. +115
    -3
      mindspore/ccsrc/dataset/util/path.cc
  16. +14
    -0
      mindspore/ccsrc/dataset/util/path.h
  17. +41
    -0
      mindspore/ccsrc/dataset/util/semaphore.cc
  18. +54
    -0
      mindspore/ccsrc/dataset/util/semaphore.h
  19. +38
    -0
      mindspore/ccsrc/dataset/util/slice.cc
  20. +122
    -0
      mindspore/ccsrc/dataset/util/slice.h
  21. +164
    -0
      mindspore/ccsrc/dataset/util/storage_container.cc
  22. +79
    -0
      mindspore/ccsrc/dataset/util/storage_container.h
  23. +167
    -0
      mindspore/ccsrc/dataset/util/storage_manager.cc
  24. +76
    -0
      mindspore/ccsrc/dataset/util/storage_manager.h
  25. +7
    -0
      mindspore/ccsrc/dataset/util/system_pool.h
  26. +21
    -6
      mindspore/ccsrc/device/kernel_runtime.cc
  27. +2
    -2
      mindspore/ccsrc/device/kernel_runtime.h
  28. +11
    -1
      mindspore/ccsrc/ir/optimizer_caller.h
  29. +7
    -0
      mindspore/ccsrc/kernel/kernel_query.cc
  30. +53
    -27
      mindspore/ccsrc/optimizer/cse.cc
  31. +1
    -1
      mindspore/ccsrc/optimizer/cse.h
  32. +89
    -75
      mindspore/ccsrc/optimizer/irpass.cc
  33. +28
    -31
      mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
  34. +3
    -3
      mindspore/ccsrc/optimizer/irpass/cast_eliminate.h
  35. +16
    -14
      mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h
  36. +15
    -12
      mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h
  37. +16
    -17
      mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h
  38. +2
    -2
      mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
  39. +6
    -5
      mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h
  40. +18
    -18
      mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h
  41. +9
    -10
      mindspore/ccsrc/optimizer/opt.cc
  42. +9
    -15
      mindspore/ccsrc/optimizer/opt.h
  43. +2
    -2
      mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc
  44. +1
    -0
      mindspore/ccsrc/parallel/context.cc
  45. +6
    -0
      mindspore/ccsrc/parallel/context.h
  46. +4
    -0
      mindspore/ccsrc/pipeline/init.cc
  47. +1
    -1
      mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc
  48. +1
    -1
      mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h
  49. +1
    -0
      mindspore/ccsrc/pybind_api/export_flags.cc
  50. +1
    -1
      mindspore/ccsrc/pybind_api/export_flags.h
  51. +32
    -8
      mindspore/ccsrc/session/ascend_control_parser.cc
  52. +3
    -2
      mindspore/ccsrc/session/ascend_control_parser.h
  53. +56
    -2
      mindspore/ccsrc/session/kernel_graph.cc
  54. +73
    -39
      mindspore/ccsrc/session/session.cc
  55. +10
    -2
      mindspore/ccsrc/session/session_basic.cc
  56. +3
    -1
      mindspore/ccsrc/transform/convert.cc
  57. +21
    -11
      mindspore/ccsrc/utils/log_adapter.cc
  58. +1
    -0
      mindspore/ccsrc/utils/utils.h
  59. +8
    -0
      mindspore/common/tensor.py
  60. +3
    -3
      mindspore/dataset/engine/datasets.py
  61. +3
    -3
      mindspore/dataset/transforms/vision/c_transforms.py
  62. +47
    -36
      mindspore/nn/optim/adam.py
  63. +72
    -70
      mindspore/nn/optim/lamb.py
  64. +94
    -4
      mindspore/nn/optim/optimizer.py
  65. +3
    -1
      mindspore/nn/wrap/cell_wrapper.py
  66. +1
    -0
      mindspore/ops/_op_impl/akg/__init__.py
  67. +73
    -0
      mindspore/ops/_op_impl/akg/batchmatmul.py
  68. +2
    -20
      mindspore/ops/_op_impl/tbe/confusion_transpose_d.py
  69. +16
    -0
      mindspore/ops/composite/multitype_ops/setitem_impl.py
  70. +1
    -0
      mindspore/ops/operations/comm_ops.py
  71. +1
    -7
      mindspore/ops/operations/debug_ops.py
  72. +4
    -2
      mindspore/ops/operations/math_ops.py
  73. +1
    -2
      mindspore/ops/operations/other_ops.py
  74. +25
    -3
      mindspore/parallel/_auto_parallel_context.py
  75. +20
    -4
      mindspore/train/callback/_summary_collector.py
  76. +14
    -18
      model_zoo/faster_rcnn/src/dataset.py
  77. +3
    -1
      model_zoo/vgg16/src/config.py
  78. +16
    -10
      model_zoo/vgg16/train.py
  79. +28
    -13
      serving/core/server.cc
  80. +2
    -3
      serving/core/util/file_system_operation.cc
  81. +23
    -17
      serving/core/util/option_parser.cc
  82. +1
    -2
      serving/core/util/option_parser.h
  83. +0
    -1
      serving/core/version_control/model.cc
  84. +6
    -8
      serving/core/version_control/version_controller.cc
  85. +0
    -1
      serving/core/version_control/version_controller.h
  86. +1
    -1
      serving/cpp_example/ms_client.cc
  87. +1
    -1
      serving/scripts/format_source_code.sh
  88. +1
    -0
      setup.py
  89. +3
    -3
      tests/ut/cpp/dataset/btree_test.cc
  90. +4
    -4
      tests/ut/cpp/optimizer/opt_test.cc
  91. +3
    -0
      tests/ut/cpp/parallel/step_parallel_test.cc
  92. +2
    -1
      tests/ut/data/dataset/declient.cfg
  93. BIN
      tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz
  94. BIN
      tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz
  95. BIN
      tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz
  96. BIN
      tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz
  97. BIN
      tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz
  98. BIN
      tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz
  99. BIN
      tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_result.npz
  100. BIN
      tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz

+ 1
- 1
akg

@@ -1 +1 @@
Subproject commit c460176523d039c8995f1d71089753725ebc0792
Subproject commit df57a6cf9450e347d1854687d1fe66a420ee3b35

+ 7
- 6
mindspore/ccsrc/CMakeLists.txt View File

@@ -277,10 +277,11 @@ endif ()

if (USE_GLOG)
target_link_libraries(inference PRIVATE mindspore::glog)
else()
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
target_link_options(inference PRIVATE -Wl,-init,mindspore_log_init)
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
endif ()
endif()

if (CMAKE_SYSTEM_NAME MATCHES "Linux")
target_link_options(inference PRIVATE -Wl,-init,common_log_init)
elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin")
set_target_properties(inference PROPERTIES MACOSX_RPATH ON)
endif ()


+ 1
- 2
mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc View File

@@ -30,8 +30,7 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow
BOUNDING_BOX_CHECK(input);
CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal");

(*output).push_back(nullptr); // init memory for return vector
(*output).push_back(nullptr);
output->resize(2);
(*output)[1] = std::move(input[1]); // move boxes over to output

size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor


+ 1
- 2
mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc View File

@@ -36,8 +36,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output)
int32_t padded_image_h;
int32_t padded_image_w;

(*output).push_back(nullptr);
(*output).push_back(nullptr);
output->resize(2);
(*output)[1] = std::move(input[1]); // since some boxes may be removed

bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches


+ 1
- 2
mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc View File

@@ -45,8 +45,7 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *
RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y));
}

(*output).push_back(nullptr);
(*output).push_back(nullptr);
output->resize(2);
(*output)[1] = std::move(input[1]);

return VerticalFlip(input[0], &(*output)[0]);


+ 6
- 0
mindspore/ccsrc/dataset/util/CMakeLists.txt View File

@@ -2,6 +2,8 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(utils OBJECT
arena.cc
buddy.cc
cache_pool.cc
circular_pool.cc
memory_pool.cc
cond_var.cc
@@ -11,7 +13,11 @@ add_library(utils OBJECT
service.cc
services.cc
lock.cc
semaphore.cc
status.cc
storage_container.cc
storage_manager.cc
slice.cc
path.cc
wait_post.cc
sig_handler.cc)

+ 87
- 0
mindspore/ccsrc/dataset/util/allocator.h View File

@@ -17,8 +17,10 @@
#define DATASET_UTIL_ALLOCATOR_H_

#include <cstdlib>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include "dataset/util/memory_pool.h"

namespace mindspore {
@@ -84,6 +86,91 @@ class Allocator {
private:
std::shared_ptr<MemoryPool> pool_;
};
/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will
/// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator.
/// Default to std::allocator
template <typename T, typename C = std::allocator<T>>
class MemGuard {
public:
using allocator = C;
MemGuard() : n_(0) {}
explicit MemGuard(allocator a) : n_(0), alloc_(a) {}
// There is no copy constructor nor assignment operator because the memory is solely owned by this object.
MemGuard(const MemGuard &) = delete;
MemGuard &operator=(const MemGuard &) = delete;
// On the other hand, We can support move constructor
MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {}
MemGuard &operator=(MemGuard &&lhs) noexcept {
if (this != &lhs) {
this->deallocate();
n_ = lhs.n_;
alloc_ = std::move(lhs.alloc_);
ptr_ = std::move(lhs.ptr_);
}
return *this;
}
/// \brief Explicitly deallocate the memory if allocated
void deallocate() {
if (ptr_) {
auto *p = ptr_.release();
if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) {
for (auto i = 0; i < n_; ++i) {
p[i].~T();
}
}
alloc_.deallocate(p, n_);
n_ = 0;
}
}
/// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is
/// allocated.
/// \param n Number of objects of type T to be allocated
/// \tparam Args Extra arguments pass to the constructor of T
template <typename... Args>
Status allocate(size_t n, Args &&... args) noexcept {
try {
deallocate();
if (n > 0) {
T *data = alloc_.allocate(n);
if (!std::is_arithmetic<T>::value) {
for (auto i = 0; i < n; i++) {
std::allocator_traits<C>::construct(alloc_, &(data[i]), std::forward<Args>(args)...);
}
}
ptr_ = std::unique_ptr<T[]>(data);
n_ = n;
}
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
} catch (std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
~MemGuard() noexcept { deallocate(); }
/// \brief Getter function
/// \return The pointer to the memory allocated
T *GetPointer() const { return ptr_.get(); }
/// \brief Getter function
/// \return The pointer to the memory allocated
T *GetMutablePointer() { return ptr_.get(); }
/// \brief Overload [] operator to access a particular element
/// \param x index to the element. Must be less than number of element allocated.
/// \return pointer to the x-th element
T *operator[](size_t x) { return GetMutablePointer() + x; }
/// \brief Overload [] operator to access a particular element
/// \param x index to the element. Must be less than number of element allocated.
/// \return pointer to the x-th element
T *operator[](size_t x) const { return GetPointer() + x; }
/// \brief Return how many bytes are allocated in total
/// \return Number of bytes allocated in total
size_t GetSizeInBytes() const { return n_ * sizeof(T); }

private:
allocator alloc_;
std::unique_ptr<T[], std::function<void(T *)>> ptr_;
size_t n_;
};
} // namespace dataset
} // namespace mindspore



+ 1
- 1
mindspore/ccsrc/dataset/util/auto_index.h View File

@@ -91,7 +91,7 @@ class AutoIndexObj : public BPlusTree<int64_t, T, A> {
}

private:
static constexpr key_type kMinKey = 1;
static constexpr key_type kMinKey = 0;
std::atomic<key_type> inx_;
};
} // namespace dataset


+ 388
- 0
mindspore/ccsrc/dataset/util/buddy.cc View File

@@ -0,0 +1,388 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/buddy.h"
#include <iomanip>
#include <stdexcept>
#include "dataset/util/de_error.h"
#include "dataset/util/memory_pool.h"
#include "dataset/util/system_pool.h"
#include "./securec.h"

inline uint64_t BitLeftShift(uint64_t v, uint64_t n) { return (v << n); }

inline uint64_t BitRightShift(uint64_t v, uint64_t n) { return (v >> n); }

inline uint64_t BitOr(uint64_t rhs, uint64_t lhs) { return rhs | lhs; }

inline uint64_t BitEx(uint64_t rhs, uint64_t lhs) { return rhs ^ lhs; }

inline uint64_t BitAnd(uint64_t rhs, uint64_t lhs) { return rhs & lhs; }

namespace mindspore {
namespace dataset {
Status BuddySpace::Init() {
if (log_min_ < 0) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"log_min must be positive : " + std::to_string(log_min_));
}
if (num_lvl_ < 3 || num_lvl_ > 18) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"num_lvl must be between 3 and 18 : " + std::to_string(num_lvl_));
}
min_ = BitLeftShift(1, log_min_);
max_ = BitLeftShift(1, log_min_ + num_lvl_ - 1);
size_t offset_1 = sizeof(rel_addr_t) * num_lvl_;
size_t offset_2 = sizeof(int) * num_lvl_ + offset_1;
size_t offset_3 = sizeof(char) * BitLeftShift(1, num_lvl_ - 3) + offset_2;
RETURN_IF_NOT_OK(DeMalloc(offset_3, &ptr_, true));
hint_ = reinterpret_cast<rel_addr_t *>(ptr_);
count_ = reinterpret_cast<int *>((reinterpret_cast<char *>(ptr_) + offset_1));
map_ = reinterpret_cast<char *>(ptr_) + offset_2;
count_[num_lvl_ - 1] = 1;
map_[0] = BitOr(MORE_BIT, num_lvl_ - 3);
return Status::OK();
}

Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) noexcept {
std::lock_guard<std::mutex> lock(mutex_);
addr_t addr = AllocNoLock(sz, desc);
if (addr != NOSPACE) {
*p = addr;
return Status::OK();
} else {
return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore.");
}
}

addr_t BuddySpace::AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept {
DS_ASSERT(sz <= max_);
uint32_t reqSize = SizeToBlock(sz);
rel_addr_t rel_addr = AllocBuddySeg(reqSize);
if (rel_addr != static_cast<rel_addr_t>(NOSPACE)) {
(void)memset_s(desc, sizeof(BSpaceDescriptor), 0, sizeof(BSpaceDescriptor));
desc->sig = static_cast<int>(0xDEADBEEF);
desc->addr = rel_addr;
desc->req_size = reqSize;
desc->blk_size = NextPowerOf2(reqSize);
return static_cast<addr_t>(rel_addr * min_);
} else {
return NOSPACE;
}
}

void BuddySpace::FreeNoLock(const BSpaceDescriptor *desc) {
DS_ASSERT(desc->sig == 0XDEADBEEF);
rel_addr_t rel_addr = desc->addr;
size_t blk_size = desc->blk_size;
size_t req_size = desc->req_size;
FreeBuddySeg(rel_addr, blk_size, req_size);
}

void BuddySpace::Free(const BSpaceDescriptor *desc) {
std::lock_guard<std::mutex> lock(mutex_);
return FreeNoLock(desc);
}

std::ostream &operator<<(std::ostream &os, const BuddySpace &s) {
os << "1 unit = " << s.GetMinSize() << "\n"
<< "Size of buddy space = " << s.GetMaxSize() << "\n"
<< "Number of levels = " << s.num_lvl_ << "\n\n"
<< "Percent free = " << s.PercentFree() << "\n"
<< "Dumping count array : "
<< "\n";
for (int i = 0; i < s.num_lvl_; i++) {
os << "[" << i << "] = " << s.count_[i] << " ";
if (((i + 1) % 4) == 0) {
os << "\n";
}
}
os << "\n";
os << "Dumping allocation info:"
<< "\n";
auto max_addr = static_cast<rel_addr_t>(BitLeftShift(1, s.num_lvl_ - 1));
rel_addr_t addr = 0;
while (addr < max_addr) {
size_t sz = 0;
BuddySpace::STATE st;
s.GetBuddySegState(addr, &sz, &st);
os << "Address : " << std::left << std::setw(8) << addr << " Size : " << std::setw(8) << sz << " State : "
<< ((st == BuddySpace::STATE::kAlloc) ? "ALLOC" : ((st == BuddySpace::STATE::kFree) ? "FREE" : "Unkonwn"))
<< "\n";
addr += sz;
}
return os;
}

void BuddySpace::GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const {
char byte;
int pos;
int offset;
uint64_t val = 0;
int shift;
pos = BitRightShift(rel_addr, 2);
offset = rel_addr % 4;
shift = offset * 2;
byte = map_[pos];
switch (offset) {
case 0:
val = byte;
break;
case 1:
case 3:
if (offset == 1) {
val = BitLeftShift(BitAnd(byte, 0x30), shift);
} else {
val = BitLeftShift(BitAnd(byte, 0x03), shift);
}
break;
case 2:
val = BitLeftShift(BitAnd(byte, 0x0F), shift);
break;
}
if (BitAnd(val, ONE_BIT)) {
*rel_sz = 1;
} else if (BitAnd(val, TWO_BIT)) {
*rel_sz = 2;
} else if (BitAnd(val, MORE_BIT)) {
log_t lg = BitAnd(val, 0x0F);
*rel_sz = BitLeftShift(1, lg + 2);
} else {
*st = STATE::kEmpty;
return;
}
*st = BitAnd(val, ALLOC_BIT) ? STATE::kAlloc : STATE::kFree;
}

void BuddySpace::SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st) {
int clr;
int mask;
int pos;
int offset;
int val = 0;
int shift;
auto log_sz = static_cast<log_t>(Log2(rel_sz));
pos = BitRightShift(rel_addr, 2);
offset = rel_addr % 4;
shift = offset * 2;
if (rel_sz == 1) {
val = ONE_BIT;
mask = 0xC0;
} else if (rel_sz == 2) {
val = TWO_BIT;
mask = 0xF0;
} else {
val = BitOr(log_sz - 2, MORE_BIT);
mask = 0xFF;
}
if (st == STATE::kAlloc) {
val = BitOr(val, ALLOC_BIT);
} else if (st == STATE::kFree) {
val = BitAnd(val, ~(static_cast<uint64_t>(ALLOC_BIT)));
} else if (st == STATE::kEmpty) {
val = 0;
}
clr = static_cast<int>(~(BitRightShift(mask, shift)));
map_[pos] = static_cast<char>(BitAnd(map_[pos], clr));
map_[pos] = static_cast<char>(BitOr(map_[pos], BitRightShift(val, shift)));
if (st == STATE::kAlloc) {
count_[log_sz]--;
} else if (st == STATE::kFree) {
count_[log_sz]++;
if (rel_addr < hint_[log_sz]) {
hint_[log_sz] = rel_addr;
}
}
}

void BuddySpace::JoinBuddySeg(rel_addr_t addr, size_t blk_sz) {
while (blk_sz < BitLeftShift(1, num_lvl_)) {
rel_addr_t buddy = BitEx(addr, blk_sz);
size_t sz = 0;
STATE st;
GetBuddySegState(buddy, &sz, &st);
if (st == STATE::kFree && sz == blk_sz) {
auto log_sz = static_cast<log_t>(Log2(blk_sz));
rel_addr_t left = (buddy < addr) ? buddy : addr;
rel_addr_t right = left + blk_sz;
DS_ASSERT(count_[log_sz] >= 2);
count_[log_sz] -= 2;
SetBuddySegState(right, blk_sz, STATE::kEmpty);
SetBuddySegState(left, BitLeftShift(blk_sz, 1), STATE::kFree);
for (int i = 0; i < log_sz; i++) {
if (hint_[i] == right) {
hint_[i] = left;
}
}
addr = left;
blk_sz <<= 1u;
} else {
break;
}
}
}

void BuddySpace::TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) {
DS_ASSERT(ask_sz < blk_sz);
uint32_t inx = Log2(blk_sz);
size_t remaining_sz = ask_sz;
for (int i = inx; i > 0; i--) {
size_t b_size = BitLeftShift(1, i);
size_t half_sz = BitRightShift(b_size, 1);
count_[i]--;
SetBuddySegState(addr, half_sz, STATE::kFree);
SetBuddySegState(addr + half_sz, half_sz, STATE::kFree);
if (remaining_sz >= half_sz) {
SetBuddySegState(addr, half_sz, STATE::kAlloc);
remaining_sz -= half_sz;
if (remaining_sz == 0) {
break;
}
addr += half_sz;
}
}
}

void BuddySpace::UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) {
DS_ASSERT(ask_sz < blk_sz);
uint32_t inx = Log2(blk_sz);
size_t remaining_sz = ask_sz;
for (int i = inx; i > 0; i--) {
size_t b_size = BitLeftShift(1, i);
size_t half_sz = BitRightShift(b_size, 1);
if (remaining_sz >= half_sz) {
#ifdef DEBUG
{
size_t sz = 0;
STATE st;
GetBuddySegState(addr, &sz, &st);
DS_ASSERT(sz == half_sz && st == STATE::kAlloc);
}
#endif
SetBuddySegState(addr, half_sz, STATE::kFree);
remaining_sz -= half_sz;
if (remaining_sz == 0) {
JoinBuddySeg(addr, half_sz);
break;
}
addr += half_sz;
}
}
}

rel_addr_t BuddySpace::AllocBuddySeg(uint32_t req_size) noexcept {
uint32_t blk_size = NextPowerOf2(req_size);
int start_inx = static_cast<int>(Log2(blk_size));
bool found = false;
rel_addr_t ask_addr = 0;
auto max_addr = static_cast<rel_addr_t>(BitLeftShift(1, num_lvl_ - 1));
STATE st;
size_t sz = 0;
for (int i = start_inx; !found && i < num_lvl_; i++) {
DS_ASSERT(count_[i] >= 0);
if (count_[i] == 0) {
continue;
}
auto blk_sz = static_cast<size_t>(BitLeftShift(1, i));
ask_addr = hint_[i];
while (ask_addr < max_addr && !found) {
GetBuddySegState(ask_addr, &sz, &st);
if (st == STATE::kFree && sz == blk_sz) {
found = true;
} else {
DS_ASSERT(st != STATE::kEmpty);
ask_addr += ((sz > blk_sz) ? sz : blk_sz);
}
}
}
if (found) {
if (sz > req_size) {
TrimBuddySeg(ask_addr, sz, req_size);
} else {
SetBuddySegState(ask_addr, sz, STATE::kAlloc);
hint_[start_inx] = ask_addr;
}
return ask_addr;
} else {
return static_cast<rel_addr_t>(NOSPACE);
}
}

void BuddySpace::FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size) {
if (req_size == blk_size) {
#ifdef DEBUG
{
size_t sz = 0;
STATE st;
GetBuddySegState(addr, &sz, &st);
}
#endif
SetBuddySegState(addr, blk_size, STATE::kFree);
JoinBuddySeg(addr, blk_size);
} else {
UnTrimBuddySeg(addr, blk_size, req_size);
}
}

int BuddySpace::PercentFree() const {
uint64_t total_free_sz = 0;
uint64_t max_sz_in_unit = BitLeftShift(1, num_lvl_ - 1);
// Go through the count array without lock
for (int i = 0; i < num_lvl_; i++) {
int cnt = count_[i];
if (cnt == 0) {
continue;
}
uint64_t blk_sz = BitLeftShift(1, i);
total_free_sz += (blk_sz * cnt);
}
return static_cast<int>(static_cast<float>(total_free_sz) / static_cast<float>(max_sz_in_unit) * 100);
}

BuddySpace::BuddySpace(int log_min, int num_lvl)
: hint_(nullptr),
count_(nullptr),
map_(nullptr),
log_min_(log_min),
num_lvl_(num_lvl),
min_(0),
max_(0),
ptr_(nullptr) {}

BuddySpace::~BuddySpace() {
if (ptr_ != nullptr) {
free(ptr_);
}
hint_ = nullptr;
count_ = nullptr;
map_ = nullptr;
}

Status BuddySpace::CreateBuddySpace(std::unique_ptr<BuddySpace> *out_bs, int log_min, int num_lvl) {
Status rc;
auto bs = new (std::nothrow) BuddySpace(log_min, num_lvl);
if (bs == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
rc = bs->Init();
if (rc.IsOk()) {
(*out_bs).reset(bs);
} else {
delete bs;
}
return rc;
}
} // namespace dataset
} // namespace mindspore

+ 133
- 0
mindspore/ccsrc/dataset/util/buddy.h View File

@@ -0,0 +1,133 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_BUDDY_H_
#define DATASET_UTIL_BUDDY_H_

#include <cstddef>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <memory>
#include <mutex>
#include "dataset/util/status.h"

using addr_t = int64_t;
using rel_addr_t = int32_t;
using log_t = int;
#define ALLOC_BIT 0x80
#define ONE_BIT 0x40
#define TWO_BIT 0x20
#define MORE_BIT 0x10
#define NOSPACE ((addr_t)(-1))
namespace mindspore {
namespace dataset {
struct BSpaceDescriptor {
int32_t sig;
rel_addr_t addr;
size_t req_size;
size_t blk_size;
};

class BuddySpace {
public:
// C++11 feature. Change STATE into a type safe class with
// the keyword. Don't take out the keyword 'class'
enum class STATE { kFree, kAlloc, kEmpty };

BuddySpace(const BuddySpace &) = delete;

BuddySpace &operator=(const BuddySpace &) = delete;

virtual ~BuddySpace();

Status Alloc(uint64_t sz, BSpaceDescriptor *desc, addr_t *) noexcept;

void Free(const BSpaceDescriptor *desc);

uint64_t GetMinSize() const { return min_; }

uint64_t GetMaxSize() const { return max_; }

int PercentFree() const;

friend std::ostream &operator<<(std::ostream &os, const BuddySpace &s);

static uint64_t NextPowerOf2(uint64_t n) {
if (n <= 1) {
return 1;
}
n = n - 1;
while (n & (n - 1)) {
n = n & (n - 1);
}
return n << 1;
}

static uint32_t Log2(uint64_t n) {
uint32_t cnt = 0;
while (n >>= 1) {
cnt++;
}
return cnt;
}

static Status CreateBuddySpace(std::unique_ptr<BuddySpace> *out_bs, int log_min = 15, int num_lvl = 18);

private:
rel_addr_t *hint_;
int *count_;
char *map_;
int log_min_;
int num_lvl_;
uint64_t min_;
uint64_t max_;
void *ptr_;
std::mutex mutex_;

explicit BuddySpace(int log_min = 15, int num_lvl = 18);

Status Init();

addr_t AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept;

void FreeNoLock(const BSpaceDescriptor *desc);

uint32_t SizeToBlock(const uint64_t sz) const {
uint32_t reqSize = (sz / min_);
if (sz % min_) {
reqSize++;
}
return reqSize;
}

void GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const;

void SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st);

void JoinBuddySeg(rel_addr_t addr, size_t blk_sz);

void TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz);

void UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz);

rel_addr_t AllocBuddySeg(uint32_t req_size) noexcept;

void FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size);
};
} // namespace dataset
} // namespace mindspore

#endif // DATASET_UTIL_BUDDY_H_

+ 202
- 0
mindspore/ccsrc/dataset/util/cache_pool.cc View File

@@ -0,0 +1,202 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "common/utils.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/services.h"

namespace mindspore {
namespace dataset {
CachePool::CachePool(const value_allocator &alloc, const std::string &root)
: alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {}

Status CachePool::DoServiceStart() {
tree_ = std::make_shared<data_index>();
// If we are given a disk path, set up the StorageManager
if (!root_.toString().empty()) {
Path spill = GetSpillPath();
RETURN_IF_NOT_OK(spill.CreateDirectories());
sm_ = std::make_shared<StorageManager>(spill);
RETURN_IF_NOT_OK(sm_->ServiceStart());
MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString());
}
return Status::OK();
}
Status CachePool::DoServiceStop() {
Status rc;
Status rc2;
if (sm_ != nullptr) {
rc = sm_->ServiceStop();
if (rc.IsError()) {
rc2 = rc;
}
}
sm_.reset();
for (auto &bl : *tree_) {
if (bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, bl.sz);
}
}
tree_.reset();
if (!root_.toString().empty()) {
Path spill = GetSpillPath();
auto it = Path::DirIterator::OpenDirectory(&spill);
while (it->hasNext()) {
rc = it->next().Remove();
if (rc.IsError() && rc2.IsOk()) {
rc2 = rc;
}
}
rc = spill.Remove();
if (rc.IsError() && rc2.IsOk()) {
rc2 = rc;
}
}
return rc2;
}
CachePool::~CachePool() noexcept { (void)ServiceStop(); }
Status CachePool::Insert(const std::vector<ReadableSlice> &buf, CachePool::key_type *key) {
DataLocator bl;
Status rc;
size_t sz = 0;
// We will consolidate all the slices into one piece.
for (auto &v : buf) {
sz += v.GetSize();
}
bl.sz = sz;
try {
bl.ptr = alloc_.allocate(sz);
// We will do a piecewise copy.
WritableSlice dest(bl.ptr, bl.sz);
size_t pos = 0;
for (auto &v : buf) {
WritableSlice out(dest, pos);
rc = WritableSlice::Copy(&out, v);
if (rc.IsError()) {
break;
}
pos += v.GetSize();
}
if (rc.IsError()) {
alloc_.deallocate(bl.ptr, sz);
bl.ptr = nullptr;
return rc;
}
} catch (std::bad_alloc &e) {
if (sm_ != nullptr) {
RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf));
// We have an assumption 0 is not a valid key from the design of AutoIndexObj.
// Make sure it is not 0.
if (bl.storage_key == 0) {
RETURN_STATUS_UNEXPECTED("Key 0 is returned which is unexpected");
}
} else {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
rc = tree_->insert(bl, key);
if (rc.IsError() && bl.ptr != nullptr) {
alloc_.deallocate(bl.ptr, sz);
}
return rc;
}
Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const {
RETURN_UNEXPECTED_IF_NULL(dest);
auto r = tree_->Search(key);
if (r.second) {
auto &it = r.first;
if (it->ptr != nullptr) {
ReadableSlice src(it->ptr, it->sz);
RETURN_IF_NOT_OK(WritableSlice::Copy(dest, src));
} else if (sm_ != nullptr) {
size_t expectedLength = 0;
RETURN_IF_NOT_OK(sm_->Read(it->storage_key, dest, &expectedLength));
if (expectedLength != it->sz) {
MS_LOG(ERROR) << "Unexpected length. Read " << expectedLength << ". Expected " << it->sz << "."
<< " Internal key: " << key << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
}
if (bytesRead != nullptr) {
*bytesRead = it->sz;
}
} else {
RETURN_STATUS_UNEXPECTED("Key not found");
}
return Status::OK();
}
const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; }
Path CachePool::GetSpillPath() const {
auto spill = Path(root_) / subfolder_;
return spill;
}
CachePool::CacheStat CachePool::GetStat() const {
CacheStat cs{0};
for (auto &it : *tree_) {
if (it.ptr != nullptr) {
++cs.num_mem_cached;
} else {
++cs.num_disk_cached;
}
}
return cs;
}
Status CachePool::Spill(CachePool::DataLocator *dl) {
if (sm_ == nullptr) {
RETURN_STATUS_UNEXPECTED("No disk storage to spill");
}
RETURN_UNEXPECTED_IF_NULL(dl);
RETURN_UNEXPECTED_IF_NULL(dl->ptr);
if (dl->storage_key == 0) {
ReadableSlice data(dl->ptr, dl->sz);
RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data}));
}
alloc_.deallocate(dl->ptr, dl->sz);
dl->ptr = nullptr;
return Status::OK();
}
Status CachePool::Locate(CachePool::DataLocator *dl) {
RETURN_UNEXPECTED_IF_NULL(dl);
if (dl->ptr == nullptr) {
if (sm_ == nullptr) {
RETURN_STATUS_UNEXPECTED("No disk storage to locate the data");
}
try {
dl->ptr = alloc_.allocate(dl->sz);
WritableSlice dest(dl->ptr, dl->sz);
Status rc = Read(dl->storage_key, &dest);
if (rc.IsError()) {
alloc_.deallocate(dl->ptr, dl->sz);
dl->ptr = nullptr;
return rc;
}
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
return Status::OK();
}
size_t CachePool::GetSize(CachePool::key_type key) const {
auto r = tree_->Search(key);
if (r.second) {
auto &it = r.first;
return it->sz;
} else {
return 0;
}
}
} // namespace dataset
} // namespace mindspore

+ 139
- 0
mindspore/ccsrc/dataset/util/cache_pool.h View File

@@ -0,0 +1,139 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_CACHE_POOL_H_
#define DATASET_UTIL_CACHE_POOL_H_

#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "dataset/util/allocator.h"
#include "dataset/util/service.h"
#include "dataset/util/slice.h"
#include "dataset/util/storage_manager.h"
#include "dataset/util/auto_index.h"

namespace mindspore {
namespace dataset {
/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of
/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to
/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to
/// restore the buffer.
/// \see ReadableSlice
class CachePool : public Service {
public:
using base_type = uint8_t;
using pointer = base_type *;
using const_pointer = const base_type *;
using reference = base_type &;
using const_reference = const base_type &;
using value_allocator = Allocator<base_type>;

// An internal class to locate the whereabouts of a backed up buffer which can be either in
class DataLocator {
public:
DataLocator() : ptr(nullptr), sz(0), storage_key(0) {}
~DataLocator() = default;
DataLocator(const DataLocator &other) = default;
DataLocator &operator=(const DataLocator &other) = default;
DataLocator(DataLocator &&other) noexcept {
ptr = other.ptr;
sz = other.sz;
storage_key = other.storage_key;
other.ptr = nullptr;
other.sz = 0;
other.storage_key = 0;
}
DataLocator &operator=(DataLocator &&other) noexcept {
if (&other != this) {
ptr = other.ptr;
sz = other.sz;
storage_key = other.storage_key;
other.ptr = nullptr;
other.sz = 0;
other.storage_key = 0;
}
return *this;
}
pointer ptr;
size_t sz;
StorageManager::key_type storage_key;
};

using data_index = AutoIndexObj<DataLocator>;
using key_type = data_index::key_type;
using bl_alloc_type = typename value_allocator::template rebind<DataLocator>::other;

/// \brief Simple statistics returned from CachePool like how many elements are cached in memory and
/// how many elements are spilled to disk.
struct CacheStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
};

/// \brief Constructor
/// \param alloc Allocator to allocate memory from
/// \param root Optional disk folder to spill
explicit CachePool(const value_allocator &alloc, const std::string &root = "");

CachePool(const CachePool &) = delete;
CachePool(CachePool &&) = delete;
CachePool &operator=(const CachePool &) = delete;
CachePool &operator=(CachePool &&) = delete;
~CachePool() noexcept;

Status DoServiceStart() override;
Status DoServiceStop() override;

Path GetSpillPath() const;

/// \brief Insert a sequence of ReadableSlice objects into the pool.
/// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk.
/// \param[in] buf A sequence of ReadableSlice objects.
/// \param[out] key Generated key
/// \return Error code
Status Insert(const std::vector<ReadableSlice> &buf, key_type *key);
/// \brief Restore a cached buffer (from memory or disk)
/// \param[in] key A previous key returned from Insert
/// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice
/// \param[out] bytesRead Optional. Number of bytes read.
/// \return Error code
Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const;

Status Spill(DataLocator *dl);

Status Locate(DataLocator *dl);

size_t GetSize(key_type key) const;

/// \brief Get statistics.
/// \return CacheStat object
CacheStat GetStat() const;

const value_allocator &get_allocator() const;

std::string MyName() const { return subfolder_; }

private:
value_allocator alloc_;
Path root_;
const std::string subfolder_;
std::shared_ptr<StorageManager> sm_;
std::shared_ptr<data_index> tree_;
};
} // namespace dataset
} // namespace mindspore
#endif

+ 18
- 0
mindspore/ccsrc/dataset/util/list.h View File

@@ -106,6 +106,24 @@ struct List {
++count;
}

// Insert elem2 before elem1 in the list.
virtual void InsertBefore(pointer elem1, pointer elem2) {
DS_ASSERT(elem1 != elem2);
Node<T> &elem1_node = elem1->*node;
Node<T> &elem2_node = elem2->*node;
elem2_node.next = elem1;
elem2_node.prev = elem1_node.prev;
if (elem1_node.prev != nullptr) {
Node<T> &prev_node = elem1_node.prev->*node;
prev_node.next = elem2;
}
elem1_node.prev = elem2;
if (head == elem1) {
head = elem2;
}
++count;
}

// Remove an element in the list
virtual void Remove(pointer elem) noexcept {
Node<T> &elem_node = elem->*node;


+ 0
- 14
mindspore/ccsrc/dataset/util/memory_pool.h View File

@@ -44,20 +44,6 @@ class MemoryPool {
virtual ~MemoryPool() {}
};

// Used by unique_ptr
template <typename T>
class Deleter {
public:
explicit Deleter(std::shared_ptr<MemoryPool> &mp) : mp_(mp) {}

~Deleter() = default;

void operator()(T *ptr) const { mp_->Deallocate(ptr); }

private:
std::shared_ptr<MemoryPool> mp_;
};

Status DeMalloc(std::size_t s, void **p, bool);
} // namespace dataset
} // namespace mindspore


+ 115
- 3
mindspore/ccsrc/dataset/util/path.cc View File

@@ -16,6 +16,8 @@
#include "dataset/util/path.h"

#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <new>
#include <sstream>
#include <utility>
@@ -26,7 +28,7 @@

namespace mindspore {
namespace dataset {
#ifdef _WIN32
#if defined(_WIN32) || defined(_WIN64)
char Path::separator_ = '\\';
#else
char Path::separator_ = '/';
@@ -132,7 +134,7 @@ Status Path::CreateDirectory() {
#if defined(_WIN32) || defined(_WIN64)
int rc = mkdir(common::SafeCStr(path_));
#else
int rc = mkdir(common::SafeCStr(path_), 0700);
int rc = mkdir(common::SafeCStr(path_), S_IRUSR | S_IWUSR | S_IXUSR);
#endif
if (rc) {
std::ostringstream oss;
@@ -182,6 +184,111 @@ Status Path::CreateDirectories() {
return Status::OK();
}

Status Path::Remove() {
if (Exists()) {
if (IsDirectory()) {
errno_t err = rmdir(common::SafeCStr(path_));
if (err == -1) {
std::ostringstream oss;
oss << "Unable to delete directory " << path_ << ". Errno = " << errno;
RETURN_STATUS_UNEXPECTED(oss.str());
}
} else {
errno_t err = unlink(common::SafeCStr(path_));
if (err == -1) {
std::ostringstream oss;
oss << "Unable to delete file " << path_ << ". Errno = " << errno;
RETURN_STATUS_UNEXPECTED(oss.str());
}
}
}
return Status::OK();
}

Status Path::CreateFile(int *file_descriptor) { return OpenFile(file_descriptor, true); }

Status Path::OpenFile(int *file_descriptor, bool create) {
int fd;
if (file_descriptor == nullptr) {
RETURN_STATUS_UNEXPECTED("null pointer");
}
if (IsDirectory()) {
std::ostringstream oss;
oss << "Unable to create file " << path_ << " which is a directory.";
RETURN_STATUS_UNEXPECTED(oss.str());
}
// Convert to canonical form.
if (strlen(common::SafeCStr(path_)) > PATH_MAX) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
char canonical_path[PATH_MAX + 1] = {0x00};
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(canonical_path, common::SafeCStr(path_), PATH_MAX) == nullptr) {
#else
if (realpath(common::SafeCStr(path_), canonical_path) == nullptr) {
#endif
if (errno == ENOENT && create) {
// File doesn't exist and we are to create it. Let's break it down.
auto file_part = Basename();
auto parent_part = ParentPath();
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(canonical_path, common::SafeCStr(parent_part), PATH_MAX) == nullptr) {
#else
if (realpath(common::SafeCStr(parent_part), canonical_path) == nullptr) {
#endif
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
auto cur_inx = strlen(canonical_path);
if ((cur_inx + file_part.length() + 1) > PATH_MAX) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
canonical_path[cur_inx++] = separator_;
if (strncpy_s(canonical_path + cur_inx, PATH_MAX - cur_inx, common::SafeCStr(file_part), file_part.length()) !=
EOK) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
} else {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
}
if (create) {
fd = open(canonical_path, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP);
} else {
fd = open(canonical_path, O_RDWR);
}
if (fd == -1) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
*file_descriptor = fd;
return Status::OK();
}

Status Path::CloseFile(int fd) const {
if (close(fd) < 0) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
return Status::OK();
}

Status Path::TruncateFile(int fd) const {
int rc;
rc = ftruncate(fd, 0);
if (rc == 0) {
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
}

std::string Path::Basename() {
std::size_t found = path_.find_last_of(separator_);
if (found != std::string::npos) {
return path_.substr(found + 1);
} else {
return path_;
}
}

std::shared_ptr<Path::DirIterator> Path::DirIterator::OpenDirectory(Path *f) {
auto it = new (std::nothrow) DirIterator(f);

@@ -208,7 +315,7 @@ Path::DirIterator::~DirIterator() {

Path::DirIterator::DirIterator(Path *f) : dir_(f), dp_(nullptr), entry_(nullptr) {
MS_LOG(DEBUG) << "Open directory " << f->toString() << ".";
dp_ = opendir(common::SafeCStr(f->toString()));
dp_ = opendir(f->toString().c_str());
}

bool Path::DirIterator::hasNext() {
@@ -225,5 +332,10 @@ bool Path::DirIterator::hasNext() {
}

Path Path::DirIterator::next() { return (*(this->dir_) / Path(entry_->d_name)); }

std::ostream &operator<<(std::ostream &os, const Path &s) {
os << s.path_;
return os;
}
} // namespace dataset
} // namespace mindspore

+ 14
- 0
mindspore/ccsrc/dataset/util/path.h View File

@@ -90,6 +90,20 @@ class Path {

std::string ParentPath();

Status Remove();

Status CreateFile(int *fd);

Status OpenFile(int *fd, bool create = false);

Status CloseFile(int fd) const;

Status TruncateFile(int fd) const;

std::string Basename();

friend std::ostream &operator<<(std::ostream &os, const Path &s);

private:
static char separator_;
std::string path_;


+ 41
- 0
mindspore/ccsrc/dataset/util/semaphore.cc View File

@@ -0,0 +1,41 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/semaphore.h"
#include "dataset/util/task_manager.h"

namespace mindspore {
namespace dataset {
Status Semaphore::P() {
std::unique_lock<std::mutex> lck(mutex_);
RETURN_IF_NOT_OK(wait_cond_.Wait(&lck, [this]() { return value_ > 0; }));
--value_;
return Status::OK();
}
void Semaphore::V() {
std::unique_lock<std::mutex> lck(mutex_);
++value_;
wait_cond_.NotifyOne();
}
int Semaphore::Peek() {
std::unique_lock<std::mutex> lck(mutex_);
return value_;
}
Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); }
Status Semaphore::Deregister() { return (wait_cond_.Deregister()); }
void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); }

} // namespace dataset
} // namespace mindspore

+ 54
- 0
mindspore/ccsrc/dataset/util/semaphore.h View File

@@ -0,0 +1,54 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_SEMAPHORE_H_
#define DATASET_UTIL_SEMAPHORE_H_

#include "dataset/util/cond_var.h"

namespace mindspore {
namespace dataset {
class TaskGroup;

/// \brief A counting semaphore. There are two external functions P and V. P decrements the internal count and will be
/// blocked if the count is 0 (zero). V increments the internal count and wake up one of the waiters.
class Semaphore {
public:
/// \brief Constructor
/// \param init Initial value of the internal counter.
explicit Semaphore(int init) : value_(init) {}

virtual ~Semaphore() {}
/// \brief Decrement the internal counter. Will be blocked if the value is 0.
/// \return Error code. Can get interrupt.
Status P();
/// \brief Increment the internal counter. Wakeup on of the watiers if any.
void V();
/// \brief Peek the internal value
/// \return The internal value
int Peek();
Status Register(TaskGroup *vg);
Status Deregister();
void ResetIntrpState();

private:
int value_;

std::mutex mutex_;
CondVar wait_cond_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_SEMAPHORE_H_

+ 38
- 0
mindspore/ccsrc/dataset/util/slice.cc View File

@@ -0,0 +1,38 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/slice.h"

namespace mindspore {
namespace dataset {
WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset, size_t len) : ReadableSlice(src, offset, len) {
mutable_data_ = static_cast<char *>(src.mutable_data_) + offset;
}
WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset)
: WritableSlice(src, offset, src.GetSize() - offset) {}
Status WritableSlice::Copy(WritableSlice *dest, const ReadableSlice &src) {
RETURN_UNEXPECTED_IF_NULL(dest);
RETURN_UNEXPECTED_IF_NULL(dest->GetMutablePointer());
if (dest->GetSize() <= 0) {
RETURN_STATUS_UNEXPECTED("Destination length is non-positive");
}
auto err = memcpy_s(dest->GetMutablePointer(), dest->GetSize(), src.GetPointer(), src.GetSize());
if (err) {
RETURN_STATUS_UNEXPECTED(std::to_string(err));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 122
- 0
mindspore/ccsrc/dataset/util/slice.h View File

@@ -0,0 +1,122 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_SLICE_H_
#define DATASET_UTIL_SLICE_H_

#include <unistd.h>
#include <cstddef>
#include <utility>
#include "./securec.h"
#include "dataset/util/allocator.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
/// \brief A ReadableSlice wraps a const pointer in memory and its size.
/// \see WritableSlice for a non-const version
///
class ReadableSlice {
public:
ReadableSlice() : ptr_(nullptr), sz_(0) {}
ReadableSlice(const void *ptr, size_t sz) : ptr_(ptr), sz_(sz) {}
ReadableSlice(const ReadableSlice &src, off64_t offset, size_t len) {
ptr_ = static_cast<const char *>(src.GetPointer()) + offset;
sz_ = len;
}
ReadableSlice(const ReadableSlice &src, off64_t offset) : ReadableSlice(src, offset, src.sz_ - offset) {}
ReadableSlice(const ReadableSlice &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
}
ReadableSlice &operator=(const ReadableSlice &lhs) {
if (this != &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
}
return *this;
}
ReadableSlice(ReadableSlice &&lhs) noexcept {
if (this != &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
lhs.ptr_ = nullptr;
lhs.sz_ = 0;
}
}
ReadableSlice &operator=(ReadableSlice &&lhs) noexcept {
if (this != &lhs) {
ptr_ = lhs.ptr_;
sz_ = lhs.sz_;
lhs.ptr_ = nullptr;
lhs.sz_ = 0;
}
return *this;
}
/// \brief Getter function
/// \return Const version of the pointer
const void *GetPointer() const { return ptr_; }
/// \brief Getter function
/// \return Size of the slice
size_t GetSize() const { return sz_; }
bool empty() const { return ptr_ == nullptr; }

private:
const void *ptr_;
size_t sz_;
};
/// \brief A WritableSlice inherits from ReadableSlice to allow
/// one to write to the address pointed to by the pointer.
///
class WritableSlice : public ReadableSlice {
public:
friend class StorageContainer;
/// \brief Default constructor
WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {}
/// \brief This form of a constructor takes a pointer and its size.
WritableSlice(void *ptr, size_t sz) : ReadableSlice(ptr, sz), mutable_data_(ptr) {}
WritableSlice(const WritableSlice &src, off64_t offset, size_t len);
WritableSlice(const WritableSlice &src, off64_t offset);
WritableSlice(const WritableSlice &lhs) : ReadableSlice(lhs) { mutable_data_ = lhs.mutable_data_; }
WritableSlice &operator=(const WritableSlice &lhs) {
if (this != &lhs) {
mutable_data_ = lhs.mutable_data_;
ReadableSlice::operator=(lhs);
}
return *this;
}
WritableSlice(WritableSlice &&lhs) noexcept : ReadableSlice(std::move(lhs)) {
if (this != &lhs) {
mutable_data_ = lhs.mutable_data_;
lhs.mutable_data_ = nullptr;
}
}
WritableSlice &operator=(WritableSlice &&lhs) noexcept {
if (this != &lhs) {
mutable_data_ = lhs.mutable_data_;
lhs.mutable_data_ = nullptr;
ReadableSlice::operator=(std::move(lhs));
}
return *this;
}
/// \brief Copy the content from one slice onto another.
static Status Copy(WritableSlice *dest, const ReadableSlice &src);

private:
void *mutable_data_;
void *GetMutablePointer() { return mutable_data_; }
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_UTIL_SLICE_H_

+ 164
- 0
mindspore/ccsrc/dataset/util/storage_container.cc View File

@@ -0,0 +1,164 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/storage_container.h"

#include <fcntl.h>
#include <sys/stat.h>
#include <unistd.h>
#include <vector>
#include "common/utils.h"
#include "dataset/util/de_error.h"
#include "dataset/util/path.h"
#include "dataset/util/status.h"
#include "utils/log_adapter.h"

namespace mindspore {
namespace dataset {
Status StorageContainer::Create() {
RETURN_IF_NOT_OK(BuddySpace::CreateBuddySpace(&bs_));
RETURN_IF_NOT_OK(cont_.CreateFile(&fd_));
is_open_ = true;
MS_LOG(INFO) << "Container " << cont_ << " created";
return Status::OK();
}

Status StorageContainer::Open() noexcept {
std::lock_guard<std::mutex> lck(mutex_);
// Check again
if (!is_open_) {
RETURN_IF_NOT_OK(cont_.OpenFile(&fd_));
is_open_ = true;
}
return Status::OK();
}

Status StorageContainer::Close() noexcept {
if (is_open_) {
std::lock_guard<std::mutex> lck(mutex_);
// Check again
if (is_open_) {
RETURN_IF_NOT_OK(cont_.CloseFile(fd_));
is_open_ = false;
fd_ = -1;
}
}
return Status::OK();
}

Status StorageContainer::Read(WritableSlice *dest, off64_t offset) const noexcept {
DS_ASSERT(is_open_);
RETURN_UNEXPECTED_IF_NULL(dest);
auto sz = dest->GetSize();
#if defined(_WIN32) || defined(_WIN64)
// Doesn't seem there is any pread64 on mingw.
// So we will do a seek and then a read under
// a protection of mutex.
std::lock_guard<std::mutex> lck(mutex_);
auto seek_err = lseek(fd_, offset, SEEK_SET);
if (seek_err < 0) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
auto r_sz = read(fd_, dest->GetMutablePointer(), sz);
#else
auto r_sz = pread64(fd_, dest->GetMutablePointer(), sz, offset);
#endif
if (r_sz != sz) {
errno_t err = (r_sz == 0) ? EOF : errno;
RETURN_STATUS_UNEXPECTED(strerror(err));
}
return Status::OK();
}

Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const noexcept {
DS_ASSERT(is_open_);
auto sz = dest.GetSize();
#if defined(_WIN32) || defined(_WIN64)
// Doesn't seem there is any pwrite64 on mingw.
// So we will do a seek and then a read under
// a protection of mutex.
std::lock_guard<std::mutex> lck(mutex_);
auto seek_err = lseek(fd_, offset, SEEK_SET);
if (seek_err < 0) {
RETURN_STATUS_UNEXPECTED(strerror(errno));
}
auto r_sz = write(fd_, dest.GetPointer(), sz);
#else
auto r_sz = pwrite64(fd_, dest.GetPointer(), sz, offset);
#endif
if (r_sz != sz) {
errno_t err = (r_sz == 0) ? EOF : errno;
RETURN_STATUS_UNEXPECTED(strerror(err));
}
return Status::OK();
}

Status StorageContainer::Insert(const std::vector<ReadableSlice> &buf, off64_t *offset) noexcept {
size_t sz = 0;
for (auto &v : buf) {
sz += v.GetSize();
}
if (sz == 0) {
RETURN_STATUS_UNEXPECTED("Unexpected 0 length");
}
if (sz > bs_->GetMaxSize()) {
RETURN_STATUS_UNEXPECTED("Request size too big");
}
BSpaceDescriptor bspd{0};
addr_t addr = 0;
RETURN_IF_NOT_OK(bs_->Alloc(sz, &bspd, &addr));
*offset = static_cast<off64_t>(addr);
// We will do piecewise copy of the data to disk.
for (auto &v : buf) {
RETURN_IF_NOT_OK(Write(v, addr));
addr += v.GetSize();
}
return Status::OK();
}

Status StorageContainer::Truncate() const noexcept {
if (is_open_) {
RETURN_IF_NOT_OK(cont_.TruncateFile(fd_));
MS_LOG(INFO) << "Container " << cont_ << " truncated";
}
return Status::OK();
}

StorageContainer::~StorageContainer() noexcept {
(void)Truncate();
(void)Close();
}

std::ostream &operator<<(std::ostream &os, const StorageContainer &s) {
os << "File path : " << s.cont_ << "\n" << *(s.bs_.get());
return os;
}

Status StorageContainer::CreateStorageContainer(std::shared_ptr<StorageContainer> *out_sc, const std::string &path) {
Status rc;
auto sc = new (std::nothrow) StorageContainer(path);
if (sc == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
rc = sc->Create();
if (rc.IsOk()) {
(*out_sc).reset(sc);
} else {
delete sc;
}
return rc;
}
} // namespace dataset
} // namespace mindspore

+ 79
- 0
mindspore/ccsrc/dataset/util/storage_container.h View File

@@ -0,0 +1,79 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_STORAGE_CONTAINER_H_
#define DATASET_UTIL_STORAGE_CONTAINER_H_

#include <limits.h>
#include <unistd.h>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "dataset/util/system_pool.h"
#include "dataset/util/buddy.h"
#include "dataset/util/path.h"
#include "dataset/util/slice.h"
#include "dataset/util/status.h"

namespace mindspore {
namespace dataset {
class StorageManager;

class StorageContainer {
public:
friend class StorageManager;

~StorageContainer() noexcept;

StorageContainer(const StorageContainer &) = delete;

StorageContainer &operator=(const StorageContainer &) = delete;

friend std::ostream &operator<<(std::ostream &os, const StorageContainer &s);

Status Open() noexcept;

Status Close() noexcept;

Status Insert(const std::vector<ReadableSlice> &buf, off64_t *offset) noexcept;

Status Write(const ReadableSlice &dest, off64_t offset) const noexcept;

Status Read(WritableSlice *dest, off64_t offset) const noexcept;

Status Truncate() const noexcept;

bool IsOpen() const { return is_open_; }

static Status CreateStorageContainer(std::shared_ptr<StorageContainer> *out_sc, const std::string &path);

private:
mutable std::mutex mutex_;
Path cont_;
int fd_;
bool is_open_;
std::unique_ptr<BuddySpace> bs_;

// Use the default value of BuddySpace
// which can map upto 4G of space.
explicit StorageContainer(const std::string &path) : cont_(path), fd_(-1), is_open_(false), bs_(nullptr) {}

Status Create();
};
} // namespace dataset
} // namespace mindspore

#endif // DATASET_UTIL_STORAGE_CONTAINER_H_

+ 167
- 0
mindspore/ccsrc/dataset/util/storage_manager.cc View File

@@ -0,0 +1,167 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/storage_manager.h"

#include <iomanip>
#include <sstream>
#include <stdexcept>
#include <utility>
#include "common/utils.h"
#include "dataset/util/path.h"
#include "dataset/util/services.h"
#include "dataset/util//de_error.h"
#include "utils/log_adapter.h"

namespace mindspore {
namespace dataset {
std::string StorageManager::GetBaseName(const std::string &prefix, int32_t file_id) {
std::ostringstream oss;
oss << prefix << std::setfill('0') << std::setw(5) << file_id;
return oss.str();
}

std::string StorageManager::ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix) {
std::string base_name = GetBaseName(prefix, file_id);
return (base_name + "." + suffix);
}

Status StorageManager::AddOneContainer() {
const std::string kPrefix = "IMG";
const std::string kSuffix = "LB";
Path container_name = root_ / ConstructFileName(kPrefix, file_id_, kSuffix);
std::shared_ptr<StorageContainer> sc;
RETURN_IF_NOT_OK(StorageContainer::CreateStorageContainer(&sc, container_name.toString()));
containers_.push_back(sc);
file_id_++;
return Status::OK();
}

Status StorageManager::DoServiceStart() {
containers_.reserve(1000);
if (root_.IsDirectory()) {
RETURN_IF_NOT_OK(AddOneContainer());
} else {
RETURN_STATUS_UNEXPECTED("Not a directory");
}
return Status::OK();
}

Status StorageManager::Write(key_type *key, const std::vector<ReadableSlice> &buf) {
RETURN_UNEXPECTED_IF_NULL(key);
size_t sz = 0;
for (auto &v : buf) {
sz += v.GetSize();
}
if (sz == 0) {
RETURN_STATUS_UNEXPECTED("Unexpected 0 length");
}
std::shared_ptr<StorageContainer> cont;
key_type out_key;
value_type out_value;
bool create_new_container = false;
do {
SharedLock lock_s(&rw_lock_);
size_t num_containers = containers_.size();
if (create_new_container) {
// Upgrade to exclusvie lock.
lock_s.Upgrade();
create_new_container = false;
// Check again if someone has already added a
// new container after we got the x lock
if (containers_.size() == num_containers) {
RETURN_IF_NOT_OK(AddOneContainer());
}
// Refresh how many containers there are.
num_containers = containers_.size();
// Downgrade back to shared lock
lock_s.Downgrade();
}
if (num_containers == 0) {
RETURN_STATUS_UNEXPECTED("num_containers is zero");
}
// Go to the last container to insert.
cont = containers_.at(num_containers - 1);
off64_t offset;
Status rc = cont->Insert(buf, &offset);
if (rc.IsNoSpace()) {
create_new_container = true;
} else if (rc.IsOk()) {
out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz));
RETURN_IF_NOT_OK(index_.insert(out_value, &out_key));
*key = out_key;
break;
} else {
return rc;
}
} while (true);
return Status::OK();
}

Status StorageManager::Read(StorageManager::key_type key, WritableSlice *dest, size_t *bytesRead) const {
RETURN_UNEXPECTED_IF_NULL(dest);
auto r = index_.Search(key);
if (r.second) {
auto &it = r.first;
value_type v = *it;
int container_inx = v.first;
off_t offset = v.second.first;
size_t sz = v.second.second;
if (dest->GetSize() < sz) {
std::string errMsg = "Destination buffer too small. Expect at least " + std::to_string(sz) +
" but length = " + std::to_string(dest->GetSize());
RETURN_STATUS_UNEXPECTED(errMsg);
}
if (bytesRead != nullptr) {
*bytesRead = sz;
}
auto cont = containers_.at(container_inx);
RETURN_IF_NOT_OK(cont->Read(dest, offset));
} else {
RETURN_STATUS_UNEXPECTED("Key not found");
}
return Status::OK();
}

Status StorageManager::DoServiceStop() noexcept {
Status rc;
Status rc1;
for (auto const &p : containers_) {
// The destructor of StorageContainer is not called automatically until the use
// count drops to 0. But it is not always the case. We will do it ourselves.
rc = p.get()->Truncate();
if (rc.IsError()) {
rc1 = rc;
}
}
containers_.clear();
file_id_ = 0;
return rc1;
}

StorageManager::StorageManager(const Path &root) : root_(root), file_id_(0), index_() {}

StorageManager::~StorageManager() { (void)StorageManager::DoServiceStop(); }

std::ostream &operator<<(std::ostream &os, const StorageManager &s) {
os << "Dumping all containers ..."
<< "\n";
for (auto const &p : s.containers_) {
os << *(p.get());
}
return os;
}
} // namespace dataset
} // namespace mindspore

+ 76
- 0
mindspore/ccsrc/dataset/util/storage_manager.h View File

@@ -0,0 +1,76 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_UTIL_STORAGE_MANAGER_H_
#define DATASET_UTIL_STORAGE_MANAGER_H_

#include <unistd.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/util/allocator.h"
#include "dataset/util/auto_index.h"
#include "dataset/util/lock.h"
#include "dataset/util/memory_pool.h"
#include "dataset/util/path.h"
#include "dataset/util/service.h"
#include "dataset/util/slice.h"
#include "dataset/util/storage_container.h"

using ListOfContainers = std::vector<std::shared_ptr<mindspore::dataset::StorageContainer>>;
namespace mindspore {
namespace dataset {
class StorageManager : public Service {
public:
using storage_index = AutoIndexObj<std::pair<int, std::pair<off_t, size_t>>>;
using key_type = storage_index::key_type;
using value_type = storage_index::value_type;

explicit StorageManager(const Path &);

~StorageManager() override;

StorageManager(const StorageManager &) = delete;

StorageManager &operator=(const StorageManager &) = delete;

Status Write(key_type *out_key, const std::vector<ReadableSlice> &buf);

Status Read(key_type key, WritableSlice *dest, size_t *bytesRead) const;

Status DoServiceStart() override;

Status DoServiceStop() noexcept override;

friend std::ostream &operator<<(std::ostream &os, const StorageManager &s);

private:
Path root_;
ListOfContainers containers_;
int file_id_;
RWLock rw_lock_;
storage_index index_;

std::string GetBaseName(const std::string &prefix, int32_t file_id);

std::string ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix);

Status AddOneContainer();
};
} // namespace dataset
} // namespace mindspore

#endif // DATASET_UTIL_STORAGE_MANAGER_H_

+ 7
- 0
mindspore/ccsrc/dataset/util/system_pool.h View File

@@ -19,8 +19,10 @@
#include <cstddef>
#include <cstdlib>
#include <limits>
#include <memory>
#include <new>
#include "./securec.h"
#include "dataset/util/allocator.h"
#include "dataset/util/memory_pool.h"

namespace mindspore {
@@ -61,6 +63,11 @@ class SystemPool : public MemoryPool {
uint64_t get_max_size() const override { return std::numeric_limits<uint64_t>::max(); }

int PercentFree() const override { return 100; }

template <typename T>
static Allocator<T> GetAllocator() {
return Allocator<T>(std::make_shared<SystemPool>());
}
};
} // namespace dataset
} // namespace mindspore


+ 21
- 6
mindspore/ccsrc/device/kernel_runtime.cc View File

@@ -30,6 +30,7 @@
#include "kernel/common_utils.h"
#include "kernel/oplib/oplib.h"
#include "ir/value.h"
#include "pre_activate/common/helper.h"
using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr;

@@ -632,7 +633,7 @@ void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) {
}
}

void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
void KernelRuntime::GenLaunchArgs(const session::KernelGraph &graph, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
AddressPtrList *kernel_outputs) {
MS_EXCEPTION_IF_NULL(kernel);
@@ -644,9 +645,15 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
}
auto is_all_nop_node = opt::IsAllNopNode(&graph);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
DeviceAddressPtr device_address;
if (is_all_nop_node) {
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, false);
} else {
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input, true);
}
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input);
@@ -656,8 +663,16 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
kernel_inputs->emplace_back(input);
}

for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetOutputAddr(kernel, i);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
DeviceAddressPtr device_address;
if (is_all_nop_node) {
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
} else {
device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, true);
}
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr output = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(output);
output->addr = device_address->ptr_;
@@ -666,7 +681,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
kernel_outputs->emplace_back(output);
}

for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(workspace);
@@ -721,7 +736,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
GenLaunchArgs(graph, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";


+ 2
- 2
mindspore/ccsrc/device/kernel_runtime.h View File

@@ -96,8 +96,8 @@ class KernelRuntime {

private:
void AssignStaticMemoryOutput(const session::KernelGraph *graph);
void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
void GenLaunchArgs(const session::KernelGraph &graph, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs);
bool LaunchKernelMod(const session::KernelGraph &graph);
void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs);
size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index);


+ 11
- 1
mindspore/ccsrc/ir/optimizer_caller.h View File

@@ -17,13 +17,23 @@
#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_

#include <memory>

#include "ir/anf.h"
#include "optimizer/opt.h"

namespace mindspore {
namespace opt {
class Optimizer;
using OptimizerPtr = std::shared_ptr<Optimizer>;
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;

using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
} // namespace opt

class OptimizerCaller {
public:
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; }
};
using OptimizerCallerPtr = std::shared_ptr<OptimizerCaller>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_

+ 7
- 0
mindspore/ccsrc/kernel/kernel_query.cc View File

@@ -23,6 +23,7 @@
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include "kernel/akg/akg_kernel_metadata.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/context/ms_context.h"

namespace mindspore {
namespace kernel {
@@ -97,6 +98,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel

std::string op_name = AnfAlgo::GetCNodeName(kernel_node);

auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) {
kernel_type = KernelType::AKG_KERNEL;
}

switch (kernel_type) {
case KernelType::AKG_KERNEL:
AkgMetadataInfo(kernel_node, kernel_info_list);


+ 53
- 27
mindspore/ccsrc/optimizer/cse.cc View File

@@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {

return changed;
}

// The op like print, summary, or the op do not has true output, and always as a depend node input.
static bool HasSideEffect(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim == nullptr) {
return false;
}
auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT);
if (side_effect_v != nullptr && side_effect_v->isa<BoolImm>()) {
return GetValue<bool>(side_effect_v);
}
return false;
}
// If true do not merge the node.
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool has_random_effect = false;
auto prim_main = GetCNodePrimitive(main);
auto prim_node = GetCNodePrimitive(node);
if (prim_main == prim_node) {
return false;
}
// if has random effect, when generate by different op (not same object), do not merge.
if (prim_main != nullptr) {
if (prim_main == prim_node) {
return false;
}
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
has_random_effect = GetValue<bool>(effect_val);
@@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons
return has_random_effect;
}

bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);

bool replace = false;
if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
auto main_value = GetValueNode(main);
auto node_value = GetValueNode(node);
replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
} else if (main->isa<CNode>() && node->isa<CNode>()) {
auto c_main = main->cast<CNodePtr>();
auto c_node = node->cast<CNodePtr>();
// When appsame is true, check if has side effect, do not merge.
if (check_side_effect && HasSideEffect(main)) {
return false;
}
const auto &inp1 = c_main->inputs();
const auto &inp2 = c_node->inputs();
if (inp1.size() == inp2.size()) {
bool appsame = true;
for (size_t j = 0; j < inp1.size(); j++) {
MS_EXCEPTION_IF_NULL(inp1[j]);
MS_EXCEPTION_IF_NULL(inp2[j]);
if (!(*inp1[j] == *inp2[j])) {
// Handle the case of two different Tensor, but with the same value
if (IsValueNode<tensor::Tensor>(inp1[j]) && IsValueNode<tensor::Tensor>(inp2[j])) {
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1[j]);
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2[j]);
if (tensor1->ValueEqual(*tensor2)) {
continue;
}
if (inp1.size() != inp2.size()) {
return false;
}
for (size_t j = 0; j < inp1.size(); j++) {
auto inp1_j = inp1[j];
auto inp2_j = inp2[j];
MS_EXCEPTION_IF_NULL(inp1_j);
MS_EXCEPTION_IF_NULL(inp2_j);
if (!(*inp1_j == *inp2_j)) {
// Handle the case of two different Tensor, but with the same value
if (IsValueNode<tensor::Tensor>(inp1_j) && IsValueNode<tensor::Tensor>(inp2_j)) {
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1_j);
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2_j);
if (tensor1->ValueEqual(*tensor2)) {
continue;
}
} else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) {
// When the same side effect node as another two nodes' inputs, we still merge the node.
// Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the
// node.
if (CheckReplace(inp1_j, inp2_j, false)) {
continue;
}
appsame = false;
break;
}
return false;
}
if (CheckRandomEffect(c_main, c_node)) {
appsame = false;
}
replace = appsame;
}
// When appsame is true, check if has random effect do not merge
if (CheckRandomEffect(c_main, c_node)) {
return false;
}
return true;
}
return replace;
// a parameter node.
return false;
}

bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group,


+ 1
- 1
mindspore/ccsrc/optimizer/cse.h View File

@@ -41,7 +41,7 @@ class CSE {
return chg && report_changes_;
}

virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const;
virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const;

virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const;



+ 89
- 75
mindspore/ccsrc/optimizer/irpass.cc View File

@@ -14,140 +14,154 @@
* limitations under the License.
*/

#include "optimizer/irpass.h"

#include <string>

#include "optimizer/irpass/symbol_resolver.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/arithmetic_simplify.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/branch_culling.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/gradient_eliminate.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/inline.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/mark_interface_fusion.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/symbol_resolver.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/opt.h"

namespace mindspore {
namespace opt {
namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul});
arithmetic_simplify2_ =
MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike);
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
zero_like_fill_zero_ =
MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
adjust_all_reduce_mul_add_ =
MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);

// ops eliminate
item_tuple_eliminate_ =
MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape);
transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose);
item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
transpose_eliminate_ =
MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose);
reduce_eliminate_ = MakeSubstitution(
ReduceOneEliminater(), "reduce_eliminate",
std::make_shared<ReduceOneEliminater>(), "reduce_eliminate",
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup);
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape);
check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend);
partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ =
MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);

// Env Item Eliminate
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem);
env_get_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_ =
MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ =
MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
"incorporate_env_getitem_switch", prim::kPrimEnvGetItem);

// Ref eliminate
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate",
make_ref_eliminate_ =
MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate",
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});

replace_refkey_by_param_ =
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam);
replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
// Gradient transforms
expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ);
minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem);
expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ);
minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);

// branch culling
switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch);
float_tuple_getitem_switch_ =
MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem);
switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
"float_tuple_getitem_switch", prim::kPrimTupleGetItem);
float_env_getitem_switch_ =
MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup);
MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
convert_switch_replacement_ =
MakeSubstitution(std::make_shared<ConvertSwitchReplacement>(), "convert_switch_replacement", IsCNodeDup);

// Addn
merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN);
addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN);
merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);

// inline
inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph);
replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode<FuncGraph>);
specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph);
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
replace_applicator_ =
MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
specialize_transform_ =
MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);

// Incorporation
incorporate_getitem_set_ =
MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ =
MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel);
incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup);
MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(),
"incorporate_getitem_from_param", IsCNodeGraphKernel);
incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ =
MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);

// Virtual Dataset
virtual_dataset_eliminate_ =
MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset);
virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
"virtual_dataset_eliminate", prim::kPrimVirtualDataset);

// Convert
print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint);
print_tuple_wrapper_ =
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);

// Unused parameter eliminate
unused_parameter_eliminate_ =
MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel);
unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel);
MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel);
unused_output_eliminate_ =
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);

// AddN eliminate
addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel);
addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);

// Mark interface fusion
mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect);
mark_interface_fusion_ =
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
}

ResolveIRPassLib::ResolveIRPassLib() {
resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr);
resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
}

InferenceOptPrepareLib::InferenceOptPrepareLib() {
grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode);
grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode);
}
} // namespace irpass
} // namespace opt


+ 28
- 31
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h View File

@@ -17,15 +17,16 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_

#include <vector>
#include <memory>
#include <algorithm>
#include <memory>
#include <vector>

#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"

namespace mindspore {
namespace opt {
@@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
FuncGraphPtr all_reduce_fg_{nullptr};
};

class ArithmeticSimplify {
class ArithmeticSimplify : public OptimizerCaller {
public:
ArithmeticSimplify()
: multiply_by_zero_or_one_(),
tensor_multiply_by_one_(),
add_by_zero_(),
tensor_add_by_zero_(),
identity_(prim::kPrimIdentity),
opt_update_zero_tensor_(),
constant_duplicate_mul_(),
power_one_() {
: multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()),
tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()),
add_by_zero_(std::make_shared<AddByZero>()),
tensor_add_by_zero_(std::make_shared<TensorAddByZero>()),
identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)),
opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()),
constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()),
power_one_(std::make_shared<PowerOneEliminate>()) {
eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_one_);
eliminaters_.emplace_back(add_by_zero_);
@@ -761,10 +762,10 @@ class ArithmeticSimplify {
}
~ArithmeticSimplify() = default;

AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@@ -773,15 +774,9 @@ class ArithmeticSimplify {
}

private:
MultiplyByZeroOrOne multiply_by_zero_or_one_;
TensorMultiplyByOne tensor_multiply_by_one_;
AddByZero add_by_zero_;
TensorAddByZero tensor_add_by_zero_;
PrimEliminater identity_;
OptUpdateZeroTensor opt_update_zero_tensor_;
ConstantDuplicateMul constant_duplicate_mul_;
PowerOneEliminate power_one_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_,
opt_update_zero_tensor_, constant_duplicate_mul_, power_one_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};

// Arithmetic Simplifications should be done after step_parallel.
@@ -789,15 +784,17 @@ class ArithmeticSimplify {
// with shape(weight), but after step_parallel, shape of weight may be changed, so the
// shape of the constant tensor should also be changed. So this pass is seperated from
// ArithmeticSimplify and deferred until step_parallel.
class ArithmeticSimplify2 {
class ArithmeticSimplify2 : public OptimizerCaller {
public:
ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); }
ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) {
eliminaters_.emplace_back(tensor_multiply_by_zero_);
}
~ArithmeticSimplify2() = default;

AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@@ -806,8 +803,8 @@ class ArithmeticSimplify2 {
}

private:
TensorMultiplyByZero tensor_multiply_by_zero_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr tensor_multiply_by_zero_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
} // namespace irpass
} // namespace opt


+ 3
- 3
mindspore/ccsrc/optimizer/irpass/cast_eliminate.h View File

@@ -17,9 +17,9 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_

#include "ir/visitor.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"

namespace mindspore {
namespace opt {
@@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr}, t_{nullptr};
};

class CastEliminater {
class CastEliminater : public OptimizerCaller {
public:
CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {}
~CastEliminater() = default;

AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
auto new_node = cast_same_type_eliminater_(optimizer, node);
if (new_node != nullptr) {
return new_node;


+ 16
- 14
mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h View File

@@ -17,18 +17,19 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_

#include <vector>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>

#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "utils/symbolic.h"

namespace mindspore {
@@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor {
bool is_match_{false};
};

class EnvGetItemEliminater {
class EnvGetItemEliminater : public OptimizerCaller {
public:
EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() {
EnvGetItemEliminater()
: new_env_get_item_(std::make_shared<NewEnvGetItem>()),
add_env_get_item_(std::make_shared<AddEnvGetItem>()),
env_get_set_item_(std::make_shared<EnvGetSetItem>()) {
eliminaters_.emplace_back(new_env_get_item_);
eliminaters_.emplace_back(add_env_get_item_);
eliminaters_.emplace_back(env_get_set_item_);
}
~EnvGetItemEliminater() = default;

AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@@ -246,10 +250,8 @@ class EnvGetItemEliminater {
}

private:
NewEnvGetItem new_env_get_item_;
AddEnvGetItem add_env_get_item_;
EnvGetSetItem env_get_set_item_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};

// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}


+ 15
- 12
mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h View File

@@ -17,18 +17,20 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_

#include <vector>
#include <algorithm>
#include <unordered_map>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"

namespace mindspore {
namespace opt {
namespace irpass {
@@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor {
internal::GetitemTransform getitem_transform_;
};

class IncorporateGetitemSet {
class IncorporateGetitemSet : public OptimizerCaller {
public:
IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() {
IncorporateGetitemSet()
: incorporate_getitem_(std::make_shared<IncorporateGetitem>()),
incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) {
eliminaters_.emplace_back(incorporate_getitem_);
eliminaters_.emplace_back(incorporate_getitem_switch_);
}
~IncorporateGetitemSet() = default;

AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@@ -403,9 +407,8 @@ class IncorporateGetitemSet {
}

private:
IncorporateGetitem incorporate_getitem_;
IncorporateGetitemSwitch incorporate_getitem_switch_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
} // namespace irpass
} // namespace opt


+ 16
- 17
mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h View File

@@ -17,13 +17,15 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_

#include <vector>
#include <algorithm>
#include <memory>
#include <vector>

#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"

namespace mindspore {
namespace opt {
@@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
};

class ItemTupleEliminater {
class ItemTupleEliminater : public OptimizerCaller {
public:
ItemTupleEliminater()
: get_item_eliminater_(),
get_item_const_eliminater_(),
set_item_eliminater_(),
get_set_item_eliminater_(),
get_item_depend_reorder_() {
: get_item_eliminater_(std::make_shared<GetitemEliminater>()),
get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
set_item_eliminater_(std::make_shared<SetitemEliminater>()),
get_set_item_eliminater_(std::make_shared<GetSetitemEliminater>()),
get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()) {
eliminaters_.emplace_back(get_item_eliminater_);
eliminaters_.emplace_back(get_item_const_eliminater_);
eliminaters_.emplace_back(set_item_eliminater_);
@@ -277,10 +279,10 @@ class ItemTupleEliminater {
}
~ItemTupleEliminater() = default;

AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@@ -289,12 +291,9 @@ class ItemTupleEliminater {
}

private:
GetitemEliminater get_item_eliminater_;
GetitemConstEliminater get_item_const_eliminater_;
SetitemEliminater set_item_eliminater_;
GetSetitemEliminater get_set_item_eliminater_;
GetitemDependReorder get_item_depend_reorder_;
std::vector<TransformFuncType> eliminaters_{};
OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_,
get_item_depend_reorder_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
} // namespace irpass
} // namespace opt


+ 2
- 2
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h View File

@@ -19,9 +19,9 @@

#include <memory>

#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"

namespace mindspore {
namespace opt {


+ 6
- 5
mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h View File

@@ -19,11 +19,12 @@

#include <vector>

#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "pipeline/static_analysis/dshape.h"

namespace mindspore {
@@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr}, shape_{nullptr};
};

class ReshapeEliminater {
class ReshapeEliminater : public OptimizerCaller {
public:
ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {}
~ReshapeEliminater() = default;

AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
auto new_node = reshape_same_shape_eliminater_(optimizer, node);
if (new_node != nullptr) {
return new_node;


+ 18
- 18
mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h View File

@@ -18,31 +18,31 @@
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_

#include <securec.h>
#include <vector>
#include <memory>
#include <algorithm>
#include <memory>
#include <vector>

#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/optimizer_caller.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"

namespace mindspore {
namespace opt {
namespace irpass {
class SpecialOpEliminater {
class SpecialOpEliminater : public OptimizerCaller {
public:
SpecialOpEliminater()
: insert_gradient_of_(prim::kPrimInsertGradientOf),
stop_gradient_(prim::kPrimStopGradient),
hook_backward_(prim::kPrimHookBackward),
print_shape_type_(prim::kPrimPrintShapeType),
get_ref_value_(prim::kPrimGetRefValue),
mirror_(prim::kPrimMirror),
virtual_div_(prim::kPrimVirtualDiv) {
: insert_gradient_of_(std::make_shared<PrimEliminater>(prim::kPrimInsertGradientOf)),
stop_gradient_(std::make_shared<PrimEliminater>(prim::kPrimStopGradient)),
hook_backward_(std::make_shared<PrimEliminater>(prim::kPrimHookBackward)),
print_shape_type_(std::make_shared<PrimEliminater>(prim::kPrimPrintShapeType)),
get_ref_value_(std::make_shared<PrimEliminater>(prim::kPrimGetRefValue)),
mirror_(std::make_shared<PrimEliminater>(prim::kPrimMirror)),
virtual_div_(std::make_shared<PrimEliminater>(prim::kPrimVirtualDiv)) {
eliminaters_.emplace_back(insert_gradient_of_);
eliminaters_.emplace_back(stop_gradient_);
eliminaters_.emplace_back(hook_backward_);
@@ -53,10 +53,10 @@ class SpecialOpEliminater {
}
~SpecialOpEliminater() = default;

AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
@@ -65,9 +65,9 @@ class SpecialOpEliminater {
}

private:
PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
virtual_div_;
std::vector<TransformFuncType> eliminaters_{};
std::vector<OptimizerCallerPtr> eliminaters_{};
};

// {PrimVirtualDataset, X} -> X


+ 9
- 10
mindspore/ccsrc/optimizer/opt.cc View File

@@ -16,28 +16,27 @@

#include "optimizer/opt.h"

#include <algorithm>
#include <deque>
#include <memory>
#include <unordered_set>
#include <deque>
#include <algorithm>

#include "ir/anf.h"
#include "ir/manager.h"
#include "utils/ordered_set.h"

#include "utils/log_adapter.h"
#include "optimizer/optimizer.h"
#include "utils/log_adapter.h"
#include "utils/ordered_set.h"

namespace mindspore {
/* namespace to support opt */
namespace opt {
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &renorm_action) {
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
}

SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) {
auto fn = [prims](const AnfNodePtr &node) -> bool {
if (!node->isa<CNode>()) {
@@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
}

SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &renorm_action) {
return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
}

AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const {
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
#ifdef ENABLE_PROFILE
double t = GetTime();
#endif
AnfNodePtr result = transform_(optimizer, node);
AnfNodePtr result = (*transform_)(optimizer, node);
#ifdef ENABLE_PROFILE
if (optimizer != nullptr) {
auto time = GetTime();


+ 9
- 15
mindspore/ccsrc/optimizer/opt.h View File

@@ -17,24 +17,18 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_

#include <vector>
#include <string>
#include <memory>
#include <string>
#include <vector>

#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "operator/ops.h"

namespace mindspore {
/* namespace to support opt */
namespace opt {
class Optimizer;

using OptimizerPtr = std::shared_ptr<Optimizer>;
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;

using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>;

// Define the interaction mode between an Optimize pass and Renormalize pass
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
@@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };

class Substitution {
public:
TransformFuncType transform_{nullptr};
OptimizerCallerPtr transform_;
std::string name_;
PredicateFuncType predicate_{nullptr};
// an enum to mark this Substitution relation to renormalize pass
RenormAction renorm_action_;
Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate,
Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action)
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
~Substitution() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node);
};

using SubstitutionPtr = std::shared_ptr<Substitution>;

SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims,
const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);

class SubstitutionList {


+ 2
- 2
mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc View File

@@ -465,7 +465,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, co
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
TensorRedistribution tensor_redistribution;
TensorRedistribution tensor_redistribution(false, true);
if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
}
@@ -503,7 +503,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo> &inp
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
TensorRedistribution tensor_redistribution;
TensorRedistribution tensor_redistribution(false, true);
if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) {
MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
}


+ 1
- 0
mindspore/ccsrc/parallel/context.cc View File

@@ -62,6 +62,7 @@ void ParallelContext::Reset() {
enable_all_reduce_fusion_ = false;
strategy_ckpt_load_file_ = "";
strategy_ckpt_save_file_ = "";
enable_parallel_optimizer_ = false;
}

void ParallelContext::set_device_num(int32_t device_num) {


+ 6
- 0
mindspore/ccsrc/parallel/context.h View File

@@ -100,6 +100,11 @@ class ParallelContext {
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }

void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
enable_parallel_optimizer_ = enable_parallel_optimizer;
}
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }

void Reset();

private:
@@ -123,6 +128,7 @@ class ParallelContext {
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
std::string strategy_ckpt_load_file_;
std::string strategy_ckpt_save_file_;
bool enable_parallel_optimizer_;
};

void ParallelParameterContextInit(const FuncGraphPtr &func_graph);


+ 4
- 0
mindspore/ccsrc/pipeline/init.cc View File

@@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
"Set enable/disable parallel optimizer.")
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
"Get enable/disable parallel optimizer.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");

(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")


+ 1
- 1
mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc View File

@@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) {
}
} // namespace
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);


+ 1
- 1
mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h View File

@@ -31,7 +31,7 @@ class BackendCSE : public CSE {
public:
BackendCSE() = default;
~BackendCSE() override = default;
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const override;
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override;
};
} // namespace opt
} // namespace mindspore


+ 1
- 0
mindspore/ccsrc/pybind_api/export_flags.cc View File

@@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect";

} // namespace mindspore

+ 1
- 1
mindspore/ccsrc/pybind_api/export_flags.h View File

@@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
extern const char GRAPH_FLAG_HAS_EFFECT[];
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
extern const char GRAPH_FLAG_SIDE_EFFECT[];
} // namespace mindspore

#endif // PYBIND_API_EXPORT_FLAGS_H_

+ 32
- 8
mindspore/ccsrc/session/ascend_control_parser.cc View File

@@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;

namespace mindspore {
namespace session {
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
auto &nodes = parent_graph->execution_order();
for (auto &node : nodes) {
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) {
return node;
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) &&
(child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) {
return node;
}
}
MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
return nullptr;
}

static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
const NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(kg.get()) != memo->end()) {
@@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
if (target_graph_iter == graph_id_map.end()) {
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
}
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg),
NOT_NULL(parameter));
}
}
}
@@ -263,7 +279,7 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
}
}
kg->SetExecOrderByDefault();
MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString();
return NOT_NULL(start_label);
}
@@ -433,7 +449,8 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
return {partial_cnode, branch_kg};
}

void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph,
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) {
std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
@@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg
<< to_outputs.size() << "]";
}
for (size_t i = 0; i < from_outputs.size(); i++) {
InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
if (assign_node != nullptr) {
auto jump_node = GetJumpNode(from_graph, to_graph);
if (jump_node != nullptr) {
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
}
}
}
}

void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) {
AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) {
if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
return;
return nullptr;
}
if (from.get() == to.get()) {
return;
return nullptr;
}
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
<< to->DebugString();
@@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
assign_node->set_abstract(to->abstract());
// append the assign at the end of from graph
InsertDependToGraph(kg, NOT_NULL(assign_node));
return assign_node;
}

std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,


+ 3
- 2
mindspore/ccsrc/session/ascend_control_parser.h View File

@@ -52,8 +52,9 @@ class AscendControlParser {
const CNodePtr &last_label);
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);

static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);

// root graph order
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,


+ 56
- 2
mindspore/ccsrc/session/kernel_graph.cc View File

@@ -521,6 +521,47 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
return output_nodes;
}

// Find control_depend real input nodes.
void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(visited);
if (visited->find(anf_node) != visited->end()) {
MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
return;
}
visited->insert(anf_node);
if (AnfAlgo::IsRealKernel(anf_node)) {
result->emplace_back(anf_node);
return;
}
if (!anf_node->isa<CNode>()) {
return;
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
}
auto input0 = cnode->input(0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
GetAllFatherRealNode(cnode->input(i), result, visited);
}
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
if (cnode->inputs().size() != kDependInputSize) {
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
}
}

// update the depend relations of control depend
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
for (const auto &node : depends) {
@@ -551,11 +592,24 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
if (depend_node->isa<Parameter>() && depend_mode == 1) {
depend_nodes = GetOutputNodes(depend_node);
}
for (auto &first_node : prior_nodes) {

std::vector<AnfNodePtr> real_prior_nodes;
std::set<AnfNodePtr> prior_visited;
for (const auto &tmp : prior_nodes) {
GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
}

std::vector<AnfNodePtr> real_depend_nodes;
std::set<AnfNodePtr> depend_visited;
for (const auto &tmp : depend_nodes) {
GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
}

for (auto &first_node : real_prior_nodes) {
if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
continue;
}
for (auto &second_node : depend_nodes) {
for (auto &second_node : real_depend_nodes) {
if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
continue;
}


+ 73
- 39
mindspore/ccsrc/session/session.cc View File

@@ -33,9 +33,14 @@
namespace py = pybind11;
namespace mindspore::inference {
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device) {
inference::Session::RegAllOp();
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
return anf_graph;
try {
inference::Session::RegAllOp();
auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size);
return anf_graph;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference LoadModel failed";
return nullptr;
}
}

void ExitInference() {
@@ -51,12 +56,17 @@ void ExitInference() {
}

std::shared_ptr<MSSession> MSSession::CreateSession(const std::string &device, uint32_t device_id) {
auto session = std::make_shared<inference::Session>();
auto ret = session->Init(device, device_id);
if (ret != 0) {
try {
auto session = std::make_shared<inference::Session>();
auto ret = session->Init(device, device_id);
if (ret != 0) {
return nullptr;
}
return session;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference CreatSession failed";
return nullptr;
}
return session;
}

void Session::RegAllOp() {
@@ -113,47 +123,71 @@ void Session::RegAllOp() {

uint32_t Session::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) {
MS_ASSERT(session_impl_ != nullptr);
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
py::gil_scoped_release gil_release;
return graph_id;
try {
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
py::gil_scoped_release gil_release;
return graph_id;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference CompileGraph failed";
return static_cast<uint32_t>(-1);
}
}

MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) {
std::vector<tensor::TensorPtr> inTensors;
inTensors.resize(inputs.size());
bool has_error = false;
std::transform(inputs.begin(), inputs.end(), inTensors.begin(),
[&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr {
if (tensor_ptr == nullptr) {
MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr";
has_error = true;
return nullptr;
}
auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get());
if (tensor == nullptr) {
MS_LOG(ERROR) << "Can not cast input MSTensor to tensor";
has_error = true;
return nullptr;
}
return tensor->tensor();
});
if (has_error) {
MS_LOG(ERROR) << "Init Tensor failed, returning empty result";
std::vector<std::shared_ptr<inference::MSTensor>> multiTensor;
return multiTensor;
}
VectorRef outputs;
session_impl_->RunGraph(graph_id, inTensors, &outputs);
try {
std::vector<tensor::TensorPtr> inTensors;
inTensors.resize(inputs.size());
bool has_error = false;
std::transform(inputs.begin(), inputs.end(), inTensors.begin(),
[&has_error](const std::shared_ptr<inference::MSTensor> &tensor_ptr) -> tensor::TensorPtr {
if (tensor_ptr == nullptr) {
MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr";
has_error = true;
return nullptr;
}
auto tensor = static_cast<inference::Tensor *>(tensor_ptr.get());
if (tensor == nullptr) {
MS_LOG(ERROR) << "Can not cast input MSTensor to tensor";
has_error = true;
return nullptr;
}
return tensor->tensor();
});
if (has_error) {
MS_LOG(ERROR) << "Init Tensor failed, returning empty result";
std::vector<std::shared_ptr<inference::MSTensor>> multiTensor;
return multiTensor;
}
VectorRef outputs;
session_impl_->RunGraph(graph_id, inTensors, &outputs);

return TransformVectorRefToMultiTensor(outputs);
return TransformVectorRefToMultiTensor(outputs);
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference Rungraph failed";
return MultiTensor();
}
}

namespace {
string AjustTargetName(const std::string &device) {
if (device == kAscendDevice) {
return std::string(kAscendDevice) + "Inference";
} else {
MS_LOG(ERROR) << "Only support device Ascend right now";
return "";
}
}
} // namespace
int Session::Init(const std::string &device, uint32_t device_id) {
RegAllOp();
auto ms_context = MsContext::GetInstance();
ms_context->set_execution_mode(kGraphMode);
ms_context->set_device_target(kAscendDevice);
session_impl_ = session::SessionFactory::Get().Create(device);
ms_context->set_device_id(device_id);
auto ajust_device = AjustTargetName(device);
if (ajust_device == "") {
return -1;
}
ms_context->set_device_target(device);
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
return -1;


+ 10
- 2
mindspore/ccsrc/session/session_basic.cc View File

@@ -81,7 +81,15 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
}
}
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
auto address = AnfAlgo::GetOutputAddr(node, output_index);
DeviceAddressPtr address;
auto is_all_nop_node = opt::IsAllNopNode(&graph);
if (is_all_nop_node) {
// The graph does not remove the nop node.
address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
} else {
// The graph removes the nop node.
address = AnfAlgo::GetMutableOutputAddr(node, output_index, true);
}
MS_EXCEPTION_IF_NULL(address);
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
TypeId type_id = kNumberTypeFloat32;
@@ -93,7 +101,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
tensor->set_device_address(address);
tensor->set_dirty(false);
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {


+ 3
- 1
mindspore/ccsrc/transform/convert.cc View File

@@ -1646,7 +1646,7 @@ bool DfGraphConvertor::GetControlDependList(const CNodePtr &node,
dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end());
}
if (src_ops_list->empty() || dst_ops_list->empty()) {
MS_LOG(WARNING) << "Control depend node's src or dest node is not a apply node, ignore it";
MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it";
error_ = SUCCESS;
}
return true;
@@ -1690,6 +1690,8 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) {
});
} else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) {
control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]});
} else if (src_ops_list->empty() || dst_ops_list->empty()) {
MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it";
} else {
MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size()
<< " -> dst:" << dst_ops_list->size();


+ 21
- 11
mindspore/ccsrc/utils/log_adapter.cc View File

@@ -463,7 +463,7 @@ void InitSubModulesLogLevel() {

// set submodule's log level
auto submodule = GetEnv("MS_SUBMODULE_LOG_v");
MS_LOG(INFO) << "MS_SUBMODULE_LOG_v=`" << submodule << "`";
MS_LOG(DEBUG) << "MS_SUBMODULE_LOG_v=`" << submodule << "`";
LogConfigParser parser(submodule);
auto configs = parser.Parse();
for (const auto &cfg : configs) {
@@ -489,22 +489,14 @@ void InitSubModulesLogLevel() {
} // namespace mindspore

extern "C" {
// shared lib init hook
#if defined(_WIN32) || defined(_WIN64)
__attribute__((constructor)) void mindspore_log_init(void) {
__attribute__((constructor)) void common_log_init(void) {
#else
void mindspore_log_init(void) {
void common_log_init(void) {
#endif
#ifdef USE_GLOG
// do not use glog predefined log prefix
FLAGS_log_prefix = false;
static bool is_glog_initialzed = false;
if (!is_glog_initialzed) {
#if !defined(_WIN32) && !defined(_WIN64)
google::InitGoogleLogging("mindspore");
#endif
is_glog_initialzed = true;
}
// set default log level to WARNING
if (mindspore::GetEnv("GLOG_v").empty()) {
FLAGS_v = mindspore::WARNING;
@@ -525,4 +517,22 @@ void mindspore_log_init(void) {
#endif
mindspore::InitSubModulesLogLevel();
}

// shared lib init hook
#if defined(_WIN32) || defined(_WIN64)
__attribute__((constructor)) void mindspore_log_init(void) {
#else
void mindspore_log_init(void) {
#endif
#ifdef USE_GLOG
static bool is_glog_initialzed = false;
if (!is_glog_initialzed) {
#if !defined(_WIN32) && !defined(_WIN64)
google::InitGoogleLogging("mindspore");
#endif
is_glog_initialzed = true;
}
#endif
common_log_init();
}
}

+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -252,6 +252,7 @@ constexpr auto kControlDependMode = "depend_mode";
// index define of depend
constexpr auto kRealInputIndexInDepend = 1;
constexpr auto kDependAttachNodeIndex = 2;
constexpr auto kDependInputSize = 3;
// format
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";


+ 8
- 0
mindspore/common/tensor.py View File

@@ -22,6 +22,10 @@ from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry

__all__ = ['Tensor', 'MetaTensor']
np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_)



class Tensor(Tensor_):
@@ -54,6 +58,10 @@ class Tensor(Tensor_):
"""

def __init__(self, input_data, dtype=None):
# If input data is numpy number, convert it to np array
if isinstance(input_data, np_types):
input_data = np.array(input_data)

# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
check_type('tensor input_data', input_data, (Tensor_, float, int))
if dtype is not None:


+ 3
- 3
mindspore/dataset/engine/datasets.py View File

@@ -1040,7 +1040,7 @@ class Dataset:

Args:
columns (list[str], optional): List of columns to be used to specify the order of columns
(defaults=None, means all columns).
(default=None, means all columns).

Returns:
Iterator, list of ndarray.
@@ -3382,7 +3382,7 @@ class ManifestDataset(MappableDataset):
class_indexing (dict, optional): A str-to-int mapping from label name to index
(default=None, the folder names will be sorted alphabetically and each
class will be given a unique index starting from 0).
decode (bool, optional): decode the images after reading (defaults=False).
decode (bool, optional): decode the images after reading (default=False).
num_shards (int, optional): Number of shards that the dataset should be divided
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
@@ -4760,7 +4760,7 @@ class _NumpySlicesDataset:

def process_dict(self, input_data):
"""
Convert the dict like data into tuple format, when input is a tuple of dict then compose it into a dict first.
Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
"""
# Convert pandas like dict(has "values" column) into General dict
data_keys = list(input_data.keys())


+ 3
- 3
mindspore/dataset/transforms/vision/c_transforms.py View File

@@ -202,7 +202,7 @@ class RandomHorizontalFlip(cde.RandomHorizontalFlipOp):
Flip the input image horizontally, randomly with a given probability.

Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float, optional): Probability of the image being flipped (default=0.5).
"""

@check_prob
@@ -217,7 +217,7 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp):
Maintains data integrity by also flipping bounding boxes in an object detection pipeline.

Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float, optional): Probability of the image being flipped (default=0.5).
"""

@check_prob
@@ -231,7 +231,7 @@ class RandomVerticalFlip(cde.RandomVerticalFlipOp):
Flip the input image vertically, randomly with a given probability.

Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float, optional): Probability of the image being flipped (default=0.5).
"""

@check_prob


+ 47
- 36
mindspore/nn/optim/adam.py View File

@@ -29,8 +29,9 @@ from .optimizer import Optimizer
_adam_opt = C.MultitypeFuncGraph("adam_opt")


@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter):
"""
Update parameters.

@@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Applies weight decay or not.
optim_filter (bool): Applies parameter update or not.

Returns:
Tensor, the new value of v after updating.
"""
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
if optim_filter:
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()

param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)

next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta1, gradient_fp32)

next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))

update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = op_mul(weight_decay_tensor, param_fp32) + update

update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = op_mul(weight_decay_tensor, param_fp32) + update

next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param))))
next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m))))
next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_v
update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))

next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_param
return gradient


def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
@@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.

Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
tuple[bool], all elements are True.

Examples:
>>> net = Net()
@@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer):

def construct(self, gradients):
lr = self.get_lr()
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)

return updated_velocity
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
return optim_result


class AdamWeightDecayDynamicLR(Optimizer):
@@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.

Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
tuple[bool], all elements are True.

Examples:
>>> net = Net()
@@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer):
warmup_lr = self.start_learning_rate * warmup_percent
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)

optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
added_global_step = self.global_step + self.one
F.control_depend(lr, added_global_step)
self.global_step = added_global_step

return updated_velocity
return optim_result

+ 72
- 70
mindspore/nn/optim/lamb.py View File

@@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32)

_lamb_opt = C.MultitypeFuncGraph("lamb_opt")


@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Bool")
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v,
gradient, decay_flag):
gradient, decay_flag, optim_filter):
"""
Update parameters.

@@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Specifies whether param update with weight decay.
optim_filter(bool): Applies parameter update or not.

Returns:
Tensor, the new value of v after updating.
"""
op_mul = P.Mul()
op_sqrt = P.Sqrt()
op_rsqrt = P.Rsqrt()
op_square = P.Square()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
op_pow = P.Pow()
op_norm = layer.Norm()
op_select = P.Select()
op_greater = P.Greater()
op_fill = P.Fill()
op_dtype = P.DType()

param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)

next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one,
mstype.float32) - beta1, gradient_fp32)

next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one,
mstype.float32) - beta2, op_square(gradient_fp32))

next_mm = next_m / (op_cast(num_one, mstype.float32)
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
next_vv = next_v / (op_cast(num_one, mstype.float32) -
op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient_fp32)

g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(
next_vv + eps)) + weight_decay_tensor * param_fp32)
zeros = F.zeros_like(w_norm)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
trust_ratio = op_select(
op_greater(w_norm, zeros),
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
ones)
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
update = next_mm / (op_sqrt(next_vv) + eps)

if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32)

update_with_lr = op_mul(op_mul(trust_ratio, lr), update)

next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))

next_v = F.depend(next_v, F.assign(param, next_param))
next_v = F.depend(next_v, F.assign(m, next_m))
next_v = F.depend(next_v, F.assign(v, next_v))

return next_v
if optim_filter:
op_mul = P.Mul()
op_sqrt = P.Sqrt()
op_rsqrt = P.Rsqrt()
op_square = P.Square()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
op_pow = P.Pow()
op_norm = layer.Norm()
op_select = P.Select()
op_greater = P.Greater()
op_fill = P.Fill()
op_dtype = P.DType()

param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)

next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)

next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))

next_mm = next_m / (op_cast(num_one, mstype.float32)
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
next_vv = next_v / (op_cast(num_one, mstype.float32) -
op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient_fp32)

g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32)
zeros = F.zeros_like(w_norm)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
trust_ratio = op_select(
op_greater(w_norm, zeros),
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
ones)
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
update = next_mm / (op_sqrt(next_vv) + eps)

if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32)

update_with_lr = op_mul(op_mul(trust_ratio, lr), update)

next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))

next_param = F.depend(next_param, F.assign(param, next_param))
next_param = F.depend(next_param, F.assign(m, next_m))
next_param = F.depend(next_param, F.assign(v, next_v))

return next_param
return gradient


lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel")
@@ -238,7 +237,7 @@ class Lamb(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.

Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
tuple[bool], all elements are True.

Examples:
>>> net = Net()
@@ -311,18 +310,21 @@ class Lamb(Optimizer):
self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
if self.enable_graph_kernel:
updated_velocity = self.hyper_map(F.partial(lamb_opt_graph_kernel,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
else:
updated_velocity = self.hyper_map(F.partial(_lamb_opt,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
optim_result = self.hyper_map(F.partial(_lamb_opt,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)

added_global_step = self.global_step + self.one
F.control_depend(lr, added_global_step)
self.global_step = added_global_step

return updated_velocity
return optim_result

+ 94
- 4
mindspore/nn/optim/optimizer.py View File

@@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor
from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.parallel_utils import ParallelMode

__all__ = ['Optimizer']

@@ -155,6 +158,27 @@ class Optimizer(Cell):
self.param_length = len(self.parameters)
self.map_ = C.Map()

use_parallel = auto_parallel_context().get_enable_parallel_optimizer()
self.use_parallel = use_parallel
if use_parallel:
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL,
ParallelMode.AUTO_PARALLEL]:
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
(_get_parallel_mode()))
self.dev_num = _get_device_num()
if self.dev_num > self.param_length:
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
" less than the number of devices {}".format(self.param_length, self.dev_num))
self.param_rank = self._get_parameter_group_id()
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
self.param_names = []
for param in self.parameters:
self.param_names.append(param.name)
else:
self.optim_filter = (True,) * self.param_length

def decay_weight(self, gradients):
"""
Weight decay.
@@ -219,8 +243,32 @@ class Optimizer(Cell):
raise TypeError("Learning rate should be float, Tensor or Iterable.")
return lr

def _check_group_params(self, parameters):
"""Check group params."""
parse_keys = ['params', 'lr', 'weight_decay', 'order_params']
for group_param in parameters:
invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
if invalid_key:
raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.')

if 'order_params' in group_param.keys():
if len(group_param.keys()) > 1:
raise ValueError("The order params dict in group parameters should "
"only include the 'order_params' key.")
if not isinstance(group_param['order_params'], Iterable):
raise TypeError("The value of 'order_params' should be an Iterable type.")
continue

if not group_param['params']:
raise ValueError("Optimizer got an empty group parameter list.")

for param in group_param['params']:
if not isinstance(param, Parameter):
raise TypeError("The group param should be an iterator of Parameter type.")

def _parse_group_params(self, parameters, learning_rate):
"""Parse group params."""
self._check_group_params(parameters)
if self.dynamic_lr:
dynamic_lr_length = learning_rate.size()
else:
@@ -250,9 +298,6 @@ class Optimizer(Cell):
if dynamic_lr_length not in (lr_length, 0):
raise ValueError("The dynamic learning rate in group should be the same size.")

if not group_param['params']:
raise ValueError("Optimizer got an empty group parameter list.")

dynamic_lr_length = lr_length
self.dynamic_lr_length = dynamic_lr_length

@@ -384,6 +429,51 @@ class Optimizer(Cell):
lr = self.learning_rate
return lr

def _get_parameter_group_id(self):
"""
Get the parameter partition group id, which is less than the number of devices.

Returns:
tuple, the group id tuple of parameters.
"""
rank_list = ()
count = 0
for _ in range(self.param_length):
rank_list = rank_list + (count,)
count = count + 1
if count == self.dev_num:
count = 0
return rank_list

def broadcast_params(self, optim_result):
"""
Apply Broadcast operations in the sequential order of parameter groups.

Returns:
bool, the status flag.
"""
param_group = []
key_group = []
for _ in range(self.dev_num):
param_group.append(F.make_tuple())
key_group.append(F.make_tuple())
for i in range(self.param_length):
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],)
key = P.MakeRefKey(self.param_names[i])()
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
new_param_group = []
for root in range(self.dev_num):
ops = P.Broadcast(root)
next_params = ops(param_group[root])
new_param_group.append(next_params)
for i in range(F.tuple_len(next_params)):
F.assign(key_group[root][i], next_params[i])
status = True
for i in range(self.dev_num - 1):
status = F.control_depend(new_param_group[i][0], new_param_group[i+1])

return status

def construct(self, *hyper_params):
raise NotImplementedError



+ 3
- 1
mindspore/nn/wrap/cell_wrapper.py View File

@@ -220,7 +220,9 @@ class DataWrapper(Cell):

def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())

# Also copy the flag in `network` construct
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
self.add_flags(**flags)
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
self.network = network



+ 1
- 0
mindspore/ops/_op_impl/akg/__init__.py View File

@@ -47,6 +47,7 @@ from .gather_v2 import _gather_v2_akg
from .less import _less_akg
from .log import _log_akg
from .matmul import _matmul_akg
from .batchmatmul import _batchmatmul_akg
from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_akg
from .max_pool_with_argmax import _max_pool_with_argmax_akg
from .max import _max_akg


+ 73
- 0
mindspore/ops/_op_impl/akg/batchmatmul.py View File

@@ -0,0 +1,73 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""BatchMatMul op"""
from mindspore.ops.op_info_register import op_info_register


@op_info_register("""{
"op_name": "BatchMatMul",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"attr": [
{
"name": "transpose_a",
"param_type": "optional",
"type": "bool"
},
{
"name": "transpose_b",
"param_type": "optional",
"type": "bool"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x1"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x2"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "output"
}
]
}""")
def _batchmatmul_akg():
"""BatchMatMul AKG register"""
return

+ 2
- 20
mindspore/ops/_op_impl/tbe/confusion_transpose_d.py View File

@@ -28,26 +28,8 @@ confusion_transpose_d_op_info = TBERegOp("ConfusionTransposeD") \
.attr("transpose_first", "required", "bool", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_FracNZ, DataType.I8_FracNZ) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_FracNZ, DataType.U8_FracNZ) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_FracNZ, DataType.I16_FracNZ) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_FracNZ, DataType.U16_FracNZ) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_FracNZ, DataType.I32_FracNZ) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_FracNZ, DataType.U32_FracNZ) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_FracNZ, DataType.I64_FracNZ) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_FracNZ, DataType.U64_FracNZ) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()




+ 16
- 0
mindspore/ops/composite/multitype_ops/setitem_impl.py View File

@@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value):
return F.list_setitem(data, number_index, value)


@setitem.register("List", "Number", "Tuple")
def _list_setitem_with_Tuple(data, number_index, value):
"""
Assigns value to list.

Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (list): Value given.

Outputs:
list, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)


@setitem.register("Dictionary", "String", "Tensor")
def _dict_setitem_with_tensor(data, key, value):
"""


+ 1
- 0
mindspore/ops/operations/comm_ops.py View File

@@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer):
self.op = op
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('index', 0)

def vm_impl(self, x):
"""Implement by vm mode."""


+ 1
- 7
mindspore/ops/operations/debug_ops.py View File

@@ -309,12 +309,6 @@ class Print(PrimitiveWithInfer):
Output tensor or string to stdout.

Note:
The print operation cannot support the following cases currently.

1. The type of tensor is float64 or bool.

2. The data of tensor is a scalar type.

In pynative mode, please use python print function.

Inputs:
@@ -334,7 +328,7 @@ class Print(PrimitiveWithInfer):

@prim_attr_register
def __init__(self):
pass
self.add_prim_attr("_side_effect", True)

def __call__(self, *args):
for arg in args:


+ 4
- 2
mindspore/ops/operations/math_ops.py View File

@@ -888,7 +888,8 @@ class Neg(PrimitiveWithInfer):
def infer_value(self, input_x):
if input_x is not None:
input_x = input_x.asnumpy()
return Tensor(-input_x)
out = np.array(-input_x, input_x.dtype)
return Tensor(out)

return None

@@ -1667,7 +1668,8 @@ class Div(_MathBinaryOp):
if x is not None and y is not None:
x = x.asnumpy()
y = y.asnumpy()
return Tensor(x / y)
out = np.array(x / y, x.dtype)
return Tensor(out)
return None




+ 1
- 2
mindspore/ops/operations/other_ops.py View File

@@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer):
return variable

def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
# Add a type validation later when we don't have to assign a value to RefKey.
return variable




+ 25
- 3
mindspore/parallel/_auto_parallel_context.py View File

@@ -400,6 +400,23 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_global_rank_is_set()

def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
"""
Set enable/disable parallel optimizer.

Args:
set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
"""
self.check_context_handle()
if not isinstance(enable_parallel_optimizer, bool):
raise TypeError('enable_parallel_optimizer is invalid type')
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)

def get_enable_parallel_optimizer(self):
"""Get parallel optimizer flag."""
self.check_context_handle()
return self._context_handle.get_enable_parallel_optimizer()

def reset(self):
"""Reset all settings."""
self.check_context_handle()
@@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().set_full_batch}
"full_batch": auto_parallel_context().set_full_batch,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer}


_get_auto_parallel_context_func_map = {
@@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().get_full_batch}
"full_batch": auto_parallel_context().get_full_batch,
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}


@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool)
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)

def _set_auto_parallel_context(**kwargs):
"""
Set auto parallel context.
@@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs):
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False.

Raises:
ValueError: If input key is not attribute in auto parallel context.
@@ -535,5 +556,6 @@ def _reset_auto_parallel_context():
- parameter_broadcast: False.
- strategy_ckpt_load_file: ""
- strategy_ckpt_save_file: ""
- enable_parallel_optimizer: False
"""
auto_parallel_context().reset()

+ 20
- 4
mindspore/train/callback/_summary_collector.py View File

@@ -166,8 +166,11 @@ class SummaryCollector(Callback):
self._has_saved_custom_data = False
self._is_parse_loss_success = True
self._first_step = True
self._dataset_sink_mode = True

def __enter__(self):
self._first_step = True
self._dataset_sink_mode = True
self._record = SummaryRecord(log_dir=self._summary_dir)
return self

@@ -279,15 +282,15 @@ class SummaryCollector(Callback):

def step_end(self, run_context):
cb_params = run_context.original_args()
if self._first_step:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num)

if cb_params.mode == ModeEnum.TRAIN.value:

# Make sure the first step data is recorded
if not self._first_step and cb_params.cur_step_num % self._collect_freq:
if not self._is_collect_this_step(cb_params):
return

self._first_step = False

if not self._has_saved_train_network:
self._collect_graphs(cb_params)

@@ -295,6 +298,7 @@ class SummaryCollector(Callback):
self._collect_metric(cb_params)
self._collect_histogram(cb_params)

self._first_step = False
self._record.record(cb_params.cur_step_num)

def end(self, run_context):
@@ -320,6 +324,18 @@ class SummaryCollector(Callback):
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list,"
f"but expected only one {self.__class__.__name__} instance.")

def _is_collect_this_step(self, cb_params):
"""Decide whether to collect data for the current step."""
# Make sure the first step data is recorded
if not self._first_step:
if self._dataset_sink_mode:
if cb_params.cur_epoch_num % self._collect_freq:
return False
else:
if cb_params.cur_step_num % self._collect_freq:
return False
return True

@staticmethod
def _package_custom_lineage_data(custom_lineage_data):
"""


+ 14
- 18
model_zoo/faster_rcnn/src/dataset.py View File

@@ -318,10 +318,6 @@ def preprocess_fn(image, box, is_training):
else:
input_data = resize_column(*input_data)

photo = (np.random.rand() < config.photo_ratio)
if photo:
input_data = photo_crop_column(*input_data)

input_data = image_bgr_rgb(*input_data)

output_data = input_data
@@ -432,19 +428,19 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fast
writer.write_raw_data([row])
writer.commit()


def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0,
is_training=True, num_parallel_workers=8):
is_training=True, num_parallel_workers=4):
"""Creatr FasterRcnn dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
num_parallel_workers=num_parallel_workers, shuffle=is_training)
num_parallel_workers=1, shuffle=is_training)
decode = C.Decode()
ds = ds.map(input_columns=["image"], operations=decode)
ds = ds.map(input_columns=["image"], operations=decode, num_parallel_workers=1)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))

hwc_to_chw = C.HWC2CHW()
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
horizontally_op = C.RandomHorizontalFlip(1)
type_cast0 = CC.TypeCast(mstype.float32)
type_cast1 = CC.TypeCast(mstype.float16)
type_cast2 = CC.TypeCast(mstype.int32)
type_cast3 = CC.TypeCast(mstype.bool_)
@@ -453,17 +449,18 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
columns_order=["image", "image_shape", "box", "label", "valid_num"],
operations=compose_map_func, num_parallel_workers=4)

ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
num_parallel_workers=num_parallel_workers)
operations=compose_map_func, num_parallel_workers=num_parallel_workers)

flip = (np.random.rand() < config.flip_ratio)
if flip:
ds = ds.map(input_columns=["image"], operations=[horizontally_op],
num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=[normalize_op, horizontally_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"],
operations=flipped_generation, num_parallel_workers=4)
operations=flipped_generation, num_parallel_workers=num_parallel_workers)
else:
ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)

else:
ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"],
@@ -471,11 +468,10 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
operations=compose_map_func,
num_parallel_workers=num_parallel_workers)

ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)

# transpose_column from python to c
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
ds = ds.map(input_columns=["box"], operations=[type_cast1])
ds = ds.map(input_columns=["label"], operations=[type_cast2])


+ 3
- 1
model_zoo/vgg16/src/config.py View File

@@ -19,7 +19,9 @@ from easydict import EasyDict as edict

cifar_cfg = edict({
'num_classes': 10,
'lr_init': 0.05,
'lr_init': 0.01,
'lr_max': 0.1,
'warmup_epochs': 5,
'batch_size': 64,
'epoch_size': 70,
'momentum': 0.9,


+ 16
- 10
model_zoo/vgg16/train.py View File

@@ -38,20 +38,25 @@ random.seed(1)
np.random.seed(1)


def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""Set learning rate."""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
warmup_steps = steps_per_epoch * warmup_epochs
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr_each_step.append(lr_max)
elif i < decay_epoch_index[1]:
lr_each_step.append(lr_max * 0.1)
elif i < decay_epoch_index[2]:
lr_each_step.append(lr_max * 0.01)
if i < warmup_steps:
lr_value = float(lr_init) + inc_each_step * float(i)
else:
lr_each_step.append(lr_max * 0.001)
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr_value = float(lr_max) * base * base
if lr_value < 0.0:
lr_value = 0.0
lr_each_step.append(lr_value)

current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
@@ -86,7 +91,8 @@ if __name__ == '__main__':
if args_opt.pre_trained:
load_param_into_net(net, load_checkpoint(args_opt.pre_trained))

lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
lr = lr_steps(0, lr_init=cfg.lr_init, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs,
total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)


+ 28
- 13
serving/core/server.cc View File

@@ -22,6 +22,7 @@
#include <vector>
#include <utility>
#include <memory>
#include <future>

#include "mindspore/ccsrc/utils/log_adapter.h"
#include "serving/ms_service.grpc.pb.h"
@@ -40,7 +41,7 @@ namespace serving {
using MSTensorPtr = std::shared_ptr<inference::MSTensor>;

Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
session_ = inference::MSSession::CreateSession(device + "Inference", device_id);
session_ = inference::MSSession::CreateSession(device, device_id);
if (session_ == nullptr) {
MS_LOG(ERROR) << "Creat Session Failed";
return FAILED;
@@ -67,6 +68,7 @@ Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::Multi
MS_LOG(INFO) << "run Predict";

*outputs = session_->RunGraph(graph_id_, inputs);
MS_LOG(INFO) << "run Predict finished";
return SUCCESS;
}

@@ -80,12 +82,16 @@ Status Session::Warmup(const MindSporeModelPtr model) {
std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
char *graphBuf = ReadFile(file_name.c_str(), &size);
if (graphBuf == nullptr) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
return FAILED;
}
last_graph_ = inference::LoadModel(graphBuf, size, device_type_);
if (last_graph_ == nullptr) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
return FAILED;
}
graph_id_ = session_->CompileGraph(last_graph_);
MS_LOG(INFO) << "Session Warmup";
MS_LOG(INFO) << "Session Warmup finished";
return SUCCESS;
}

@@ -95,6 +101,9 @@ Status Session::Clear() {
}

namespace {
static const uint32_t uint32max = 0x7FFFFFFF;
std::promise<void> exit_requested;

const std::map<ms_serving::DataType, TypeId> type2id_map{
{ms_serving::MS_UNKNOWN, TypeId::kNumberTypeBegin}, {ms_serving::MS_BOOL, TypeId::kNumberTypeBool},
{ms_serving::MS_INT8, TypeId::kNumberTypeInt8}, {ms_serving::MS_UINT8, TypeId::kNumberTypeUInt8},
@@ -141,7 +150,7 @@ MSTensorPtr ServingTensor2MSTensor(const ms_serving::Tensor &tensor) {
}
TypeId type = iter->second;
auto ms_tensor = std::shared_ptr<inference::MSTensor>(inference::MSTensor::CreateTensor(type, shape));
memcpy_s(ms_tensor->MutableData(), tensor.data().size(), tensor.data().data(), tensor.data().size());
memcpy_s(ms_tensor->MutableData(), ms_tensor->Size(), tensor.data().data(), tensor.data().size());
return ms_tensor;
}

@@ -166,10 +175,7 @@ void ClearEnv() {
Session::Instance().Clear();
inference::ExitInference();
}
void HandleSignal(int sig) {
ClearEnv();
exit(0);
}
void HandleSignal(int sig) { exit_requested.set_value(); }

#ifdef ENABLE_D
static rtContext_t g_ctx = nullptr;
@@ -247,6 +253,7 @@ Status Server::BuildAndStart() {
rtError_t rt_ret = rtCtxGetCurrent(&ctx);
if (rt_ret != RT_ERROR_NONE || ctx == nullptr) {
MS_LOG(ERROR) << "the ascend device context is null";
ClearEnv();
return FAILED;
}
g_ctx = ctx;
@@ -258,6 +265,7 @@ Status Server::BuildAndStart() {
auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
grpc::ServerBuilder builder;
builder.SetOption(std::move(option));
builder.SetMaxMessageSize(uint32max);
// Listen on the given address without any authentication mechanism.
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
// Register "service" as the instance through which we'll communicate with
@@ -265,13 +273,20 @@ Status Server::BuildAndStart() {
builder.RegisterService(&service);
// Finally assemble the server.
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
if (server == nullptr) {
MS_LOG(ERROR) << "The serving server create failed";
ClearEnv();
return FAILED;
}
auto grpc_server_run = [&server]() { server->Wait(); };
std::thread serving_thread(grpc_server_run);
MS_LOG(INFO) << "Server listening on " << server_address << std::endl;

// Wait for the server to shutdown. Note that some other thread must be
// responsible for shutting down the server for this call to ever return.
server->Wait();
auto exit_future = exit_requested.get_future();
exit_future.wait();
ClearEnv();
server->Shutdown();
serving_thread.join();
return SUCCESS;
}

} // namespace serving
} // namespace mindspore

+ 2
- 3
serving/core/util/file_system_operation.cc View File

@@ -29,7 +29,6 @@

namespace mindspore {
namespace serving {

char *ReadFile(const char *file, size_t *size) {
if (file == nullptr) {
MS_LOG(ERROR) << "file is nullptr";
@@ -70,8 +69,8 @@ bool DirOrFileExist(const std::string &file_path) {
}

std::vector<std::string> GetAllSubDirs(const std::string &dir_path) {
DIR *dir;
struct dirent *ptr;
DIR *dir = nullptr;
struct dirent *ptr = nullptr;
std::vector<std::string> SubDirs;

if ((dir = opendir(dir_path.c_str())) == NULL) {


+ 23
- 17
serving/core/util/option_parser.cc View File

@@ -36,17 +36,16 @@ bool RemovePrefix(std::string *str, const std::string &prefix) {

bool Option::ParseInt32(std::string *arg) {
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
char extra;
int32_t parsed_value;
if (sscanf(arg->data(), "%d%c", &parsed_value, &extra) != 1) {
std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl;
try {
parsed_value = std::stoi(arg->data());
} catch (std::invalid_argument) {
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
return false;
} else {
*int32_default_ = parsed_value;
}
*int32_default_ = parsed_value;
return true;
}

return false;
}

@@ -76,17 +75,16 @@ bool Option::ParseString(std::string *arg) {

bool Option::ParseFloat(std::string *arg) {
if (RemovePrefix(arg, "--") && RemovePrefix(arg, name_) && RemovePrefix(arg, "=")) {
char extra;
float parsed_value;
if (sscanf(arg->data(), "%f%c", &parsed_value, &extra) != 1) {
std::cout << "Parse " << name_ << "Error for option " << *arg << std::endl;
try {
parsed_value = std::stof(arg->data());
} catch (std::invalid_argument) {
std::cout << "Parse " << name_ << " Error for option " << *arg << std::endl;
return false;
} else {
*float_default_ = parsed_value;
}
*float_default_ = parsed_value;
return true;
}

return false;
}

@@ -159,10 +157,11 @@ Options::Options() : args_(nullptr) { CreateOptions(); }
void Options::CreateOptions() {
args_ = std::make_shared<Arguments>();
std::vector<Option> options = {
Option("port", &args_->grpc_port, "Port to listen on for gRPC API, default is 5500"),
Option("model_name", &args_->model_name, "model name "),
Option("model_path", &args_->model_path, "the path of the model files"),
Option("device_id", &args_->device_id, "the device id, default is 0"),
Option("port", &args_->grpc_port,
"[Optional] Port to listen on for gRPC API, default is 5500, range from 1 to 65535"),
Option("model_name", &args_->model_name, "[Required] model name "),
Option("model_path", &args_->model_path, "[Required] the path of the model files"),
Option("device_id", &args_->device_id, "[Optional] the device id, default is 0, range from 0 to 7"),
};
options_ = options;
}
@@ -176,6 +175,14 @@ bool Options::CheckOptions() {
std::cout << "device_type only support Ascend right now" << std::endl;
return false;
}
if (args_->device_id > 7) {
std::cout << "the device_id should be in [0~7]" << std::endl;
return false;
}
if (args_->grpc_port < 1 || args_->grpc_port > 65535) {
std::cout << "the port should be in [1~65535]" << std::endl;
return false;
}
return true;
}

@@ -238,6 +245,5 @@ void Options::Usage() {
<< option.usage_ << std::endl;
}
}

} // namespace serving
} // namespace mindspore

+ 1
- 2
serving/core/util/option_parser.h View File

@@ -22,7 +22,6 @@

namespace mindspore {
namespace serving {

struct Arguments {
int32_t grpc_port = 5500;
std::string grpc_socket_path;
@@ -40,6 +39,7 @@ class Option {
Option(const std::string &name, bool *default_point, const std::string &usage);
Option(const std::string &name, std::string *default_point, const std::string &usage);
Option(const std::string &name, float *default_point, const std::string &usage);
~Option() = default;

private:
friend class Options;
@@ -77,7 +77,6 @@ class Options {
std::vector<Option> options_;
std::shared_ptr<Arguments> args_;
};

} // namespace serving
} // namespace mindspore



+ 0
- 1
serving/core/version_control/model.cc View File

@@ -19,7 +19,6 @@

namespace mindspore {
namespace serving {

MindSporeModel::MindSporeModel(const std::string &model_name, const std::string &model_path,
const std::string &model_version, const time_t &last_update_time)
: model_name_(model_name),


+ 6
- 8
serving/core/version_control/version_controller.cc View File

@@ -25,7 +25,6 @@

namespace mindspore {
namespace serving {

volatile bool stop_poll = false;

std::string GetVersionFromPath(const std::string &path) {
@@ -102,10 +101,10 @@ Status VersionController::CreateInitModels() {
}
std::vector<std::string> SubDirs = GetAllSubDirs(models_path_);
if (version_control_strategy_ == kLastest) {
auto path = SubDirs.empty() ? models_path_ : SubDirs.back();
std::string model_version = GetVersionFromPath(path);
time_t last_update_time = GetModifyTime(path);
MindSporeModelPtr model_ptr = std::make_shared<MindSporeModel>(model_name_, path, model_version, last_update_time);
std::string model_version = GetVersionFromPath(models_path_);
time_t last_update_time = GetModifyTime(models_path_);
MindSporeModelPtr model_ptr =
std::make_shared<MindSporeModel>(model_name_, models_path_, model_version, last_update_time);
valid_models_.emplace_back(model_ptr);
} else {
for (auto &dir : SubDirs) {
@@ -119,8 +118,8 @@ Status VersionController::CreateInitModels() {
MS_LOG(ERROR) << "There is no valid model for serving";
return FAILED;
}
Session::Instance().Warmup(valid_models_.back());
return SUCCESS;
auto ret = Session::Instance().Warmup(valid_models_.back());
return ret;
}

void VersionController::StartPollModelPeriodic() {
@@ -129,6 +128,5 @@ void VersionController::StartPollModelPeriodic() {
}

void VersionController::StopPollModelPeriodic() {}

} // namespace serving
} // namespace mindspore

+ 0
- 1
serving/core/version_control/version_controller.h View File

@@ -64,7 +64,6 @@ class PeriodicFunction {
VersionController::VersionControllerStrategy version_control_strategy_;
std::vector<MindSporeModelPtr> valid_models_;
};

} // namespace serving
} // namespace mindspore



+ 1
- 1
serving/cpp_example/ms_client.cc View File

@@ -214,6 +214,7 @@ PredictRequest ReadBertInput() {
class MSClient {
public:
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
~MSClient() = default;

std::string Predict(const std::string &type) {
// Data we are sending to the server.
@@ -310,7 +311,6 @@ int main(int argc, char **argv) {
type = "add";
}
}

} else {
target_str = "localhost:5500";
type = "add";


+ 1
- 1
serving/scripts/format_source_code.sh View File

@@ -81,7 +81,7 @@ function checkopts()
checkopts "$@"

# switch to project root path, which contains clang-format config file '.clang-format'
cd "${SCRIPTS_PATH}/.." || exit 1
cd "${SCRIPTS_PATH}/../.." || exit 1

FMT_FILE_LIST='__format_files_list__'



+ 1
- 0
setup.py View File

@@ -161,6 +161,7 @@ setup(
description='MindSpore is a new open source deep learning training/inference '
'framework that could be used for mobile, edge and cloud scenarios.',
long_description="\n\n".join([readme, release]),
long_description_content_type="text/markdown",
packages=find_packages(),
package_data=package_data,
include_package_data=True,


+ 3
- 3
tests/ut/cpp/dataset/btree_test.cc View File

@@ -190,9 +190,9 @@ TEST_F(MindDataTestBPlusTree, Test3) {
EXPECT_TRUE(rc.IsOk());
uint64_t min = ai.min_key();
uint64_t max = ai.max_key();
EXPECT_EQ(min, 1);
EXPECT_EQ(max, 4);
auto r = ai.Search(3);
EXPECT_EQ(min, 0);
EXPECT_EQ(max, 3);
auto r = ai.Search(2);
auto &it = r.first;
EXPECT_EQ(it.value(), "b");
MS_LOG(INFO) << "Dump all the values using [] operator.";


+ 4
- 4
tests/ut/cpp/optimizer/opt_test.cc View File

@@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common {
};

void SetUp() {
elim_Z = MakeSubstitution(irpass::AddByZero(), "elim_Z", prim::kPrimScalarAdd);
elim_R = MakeSubstitution(irpass::PrimEliminater(R), "elim_R", R);
idempotent_P = MakeSubstitution(IdempotentEliminater(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(QctToP(), "Qct_to_P", Q);
elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd);
elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
}

bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {


+ 3
- 0
tests/ut/cpp/parallel/step_parallel_test.cc View File

@@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
} else if (name == "instance_name") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "test");
} else if (name == "index") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "0");
} else {
MS_LOG(EXCEPTION) << "Test failed";
}


+ 2
- 1
tests/ut/data/dataset/declient.cfg View File

@@ -4,6 +4,7 @@
"numParallelWorkers": 4,
"workerConnectorSize": 16,
"opConnectorSize": 16,
"seed": 5489
"seed": 5489,
"monitor_sampling_interval": 15

}

BIN
tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz View File


BIN
tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz View File


BIN
tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz View File


BIN
tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz View File


BIN
tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz View File


BIN
tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz View File


BIN
tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_result.npz View File


BIN
tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz View File


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save