You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_array.cc 4.8 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/device/tensor_array.h"
  17. namespace mindspore {
  18. namespace device {
  19. bool TensorArray::CheckValue(const TypeId &dtype, const std::vector<size_t> &shape) {
  20. MS_LOG(DEBUG) << "Check the data shape and type for " << name_;
  21. if (dtype != dtype_->type_id()) {
  22. MS_LOG(ERROR) << "Invalid data type " << TypeIdLabel(dtype) << " for " << name_ << ", the origin type is "
  23. << TypeIdLabel(dtype_->type_id());
  24. return false;
  25. }
  26. if (shape != shapes_) {
  27. MS_LOG(ERROR) << "Invalid data shape " << shape << " for " << name_ << ", the origin shape is " << shapes_;
  28. return false;
  29. }
  30. return true;
  31. }
  32. bool TensorArray::CheckReadIndexLogical(const int64_t index) {
  33. if (LongToSize(index) >= valid_size_) {
  34. MS_LOG(ERROR) << "Index " << index << " out of range " << valid_size_ << ", " << name_;
  35. return false;
  36. }
  37. return true;
  38. }
  39. // Function Read() can get the tensors in the scope of tensors_.
  40. mindspore::kernel::AddressPtr TensorArray::Read(const int64_t index) {
  41. if (LongToSize(index) >= tensors_.size()) {
  42. MS_LOG(EXCEPTION) << "Index " << index << " out of range " << tensors_.size() << ", " << name_;
  43. }
  44. MS_LOG(DEBUG) << "Read tensor index = " << index << ", addr = " << tensors_[LongToSize(index)]->addr;
  45. return tensors_[LongToSize(index)];
  46. }
  47. // Add tensor to the TensorArray and increase the size.
  48. // Cast 1: is_dynamic = False and index > max_size_, error.
  49. // Case 2: index > valid_size, fill the rest dev_value with zeros, and set valid_size to index + 1.
  50. // Case 3: index == tensors_.size(), we need to increase both real tensors_ size and valid size, and add
  51. // the new dev_value to tensors_.
  52. // Case 4: tensors_size() > index > valid_size, we can reuse the memory in tensors_[index], so
  53. // only increase the valid_size.
  54. bool TensorArray::Write(const int64_t index, const mindspore::kernel::AddressPtr &dev_value) {
  55. MS_LOG(DEBUG) << "Write dev_value to " << name_;
  56. if (!is_dynamic_ && (index >= max_size_)) {
  57. MS_LOG(ERROR) << name_ << " is not in dynamic size, the max_size is " << max_size_ << ", but get index " << index;
  58. return false;
  59. }
  60. if (LongToSize(index) > valid_size_) {
  61. // Create/reuse (index - valid_size) size dev_value with zeros.
  62. // 1 create new mem : index > real_size ? index - real_size : 0
  63. // 2 reuse old mem : index > real_size ? real_size - valid_size : index - valid_size
  64. // 3 fill zeros : index - valid_size
  65. size_t create_size = (LongToSize(index) > tensors_.size()) ? (LongToSize(index) - tensors_.size()) : 0;
  66. for (size_t i = 0; i < create_size; i++) {
  67. kernel::AddressPtr create_dev = std::make_shared<kernel::Address>();
  68. create_dev->addr = CreateMemory(dev_value->size);
  69. create_dev->size = dev_value->size;
  70. tensors_.push_back(create_dev);
  71. }
  72. tensors_.push_back(dev_value);
  73. for (size_t i = valid_size_; i < LongToSize(index); i++) {
  74. auto tensor_size = tensors_[i]->size;
  75. ClearMemory(tensors_[i]->addr, tensor_size);
  76. }
  77. valid_size_ = LongToSize(index) + 1;
  78. } else if (LongToSize(index) == tensors_.size()) {
  79. MS_LOG(DEBUG) << "Write to index " << index << ", increase tensors' size to " << (tensors_.size() + 1);
  80. tensors_.push_back(dev_value);
  81. valid_size_++;
  82. } else {
  83. MS_LOG(DEBUG) << "Reuse tensors in position " << index << ", tensors size is " << tensors_.size();
  84. if (LongToSize(index) == valid_size_) valid_size_++;
  85. }
  86. return true;
  87. }
  88. void TensorArray::Clear() {
  89. valid_size_ = 0;
  90. return;
  91. }
  92. void TensorArray::Free() {
  93. MS_LOG(DEBUG) << "Free device memory for " << name_;
  94. for (const auto &addr : tensors_) {
  95. if (addr != nullptr) {
  96. ReleaseMemory(static_cast<DeviceMemPtr>(addr->addr));
  97. }
  98. }
  99. }
  100. size_t TensorArray::GetValidSize() const { return valid_size_; }
  101. size_t TensorArray::GetRealSize() const { return tensors_.size(); }
  102. const void *TensorArray::GetTensorAddr(const size_t &index) const { return tensors_[index]->addr; }
  103. void TensorArray::SetMaxSize(const int64_t size, const bool is_dynamic) {
  104. is_dynamic_ = is_dynamic;
  105. if (!is_dynamic_) {
  106. max_size_ = size;
  107. MS_LOG(DEBUG) << name_ << " use fixed size " << max_size_;
  108. }
  109. return;
  110. }
  111. } // namespace device
  112. } // namespace mindspore