You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cc 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <vector>
  17. #include <utility>
  18. #include "ir/lite/tensor.h"
  19. #include "securec/include/securec.h"
  20. namespace mindspore {
  21. namespace tensor {
  22. #define kMaxMallocSize 1024 * 1024 * 100
  23. Tensor::Tensor(const TypeId data_type, const std::vector<int> &shape) : MetaTensor(data_type, shape) {}
  24. Tensor::Tensor(const TypePtr &type_ptr, const std::vector<int> &shape) : MetaTensor(type_ptr, shape) {}
  25. Tensor::Tensor(const Tensor &tensor) : MetaTensor(tensor) {
  26. this->data_type_ = tensor.data_type_;
  27. this->shape_ = tensor.shape_;
  28. auto ret = CopyTensorData(tensor);
  29. if (0 != ret) {
  30. MS_LOG(EXCEPTION) << "CopyTensorData error";
  31. }
  32. }
  33. int Tensor::CopyTensorData(const Tensor &srcTensor) {
  34. if (srcTensor.data_ == nullptr) {
  35. MS_LOG(ERROR) << "data of srcTensor is nullptr";
  36. return -1;
  37. }
  38. size_t data_size = this->Size();
  39. MS_ASSERT(data_size == tensor.Size());
  40. if (this->data_ == nullptr) {
  41. if (data_size > kMaxMallocSize) {
  42. MS_LOG(ERROR) << "Malloc size is too big while coping data, " << data_size << " bytes";
  43. return -1;
  44. }
  45. this->data_ = malloc(data_size);
  46. }
  47. memcpy_s(this->data_, data_size, tensor.data_, tensor.Size());
  48. return 0;
  49. }
  50. Tensor::~Tensor() {
  51. if (nullptr != this->data_) {
  52. free(this->data_);
  53. }
  54. }
  55. Tensor &Tensor::operator=(const Tensor &tensor) {
  56. if (&tensor == this) {
  57. return *this;
  58. }
  59. this->shape_ = tensor.shape_;
  60. this->data_type_ = tensor.data_type_;
  61. auto ret = CopyTensorData(tensor);
  62. if (0 != ret) {
  63. MS_LOG(EXCEPTION) << "CopyTensorData error";
  64. }
  65. return *this;
  66. }
  67. bool Tensor::operator==(const Tensor &tensor) {
  68. return data_ == tensor.data_ && shape_ == tensor.shape_ && data_type_ == tensor.data_type_;
  69. }
  70. bool Tensor::operator==(const Value &other) const {
  71. if (other.isa<Tensor>()) {
  72. auto other_ = static_cast<const Tensor &>(other);
  73. return *this == other_;
  74. } else {
  75. return false;
  76. }
  77. }
  78. } // namespace tensor
  79. namespace inference {
  80. MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector<int> &shape) {
  81. return new Tensor(data_type, shape);
  82. }
  83. Tensor::Tensor() { this->tensor_impl_ = std::make_shared<tensor::Tensor>(); }
  84. Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) {
  85. this->tensor_impl_ = std::make_shared<tensor::Tensor>(data_type, shape);
  86. }
  87. Tensor::Tensor(std::shared_ptr<tensor::Tensor> tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); }
  88. TypeId Tensor::data_type() const {
  89. MS_ASSERT(this->tensor_impl_ != nullptr);
  90. return this->tensor_impl_->data_type();
  91. }
  92. TypeId Tensor::set_data_type(TypeId data_type) {
  93. MS_ASSERT(this->tensor_impl_ != nullptr);
  94. return this->tensor_impl_->set_data_type(data_type);
  95. }
  96. std::vector<int> Tensor::shape() const {
  97. MS_ASSERT(this->tensor_impl_ != nullptr);
  98. return this->tensor_impl_->shape();
  99. }
  100. size_t Tensor::set_shape(const std::vector<int> &shape) {
  101. MS_ASSERT(this->tensor_impl_ != nullptr);
  102. return this->tensor_impl_->set_shape(shape);
  103. }
  104. int Tensor::DimensionSize(size_t index) const {
  105. MS_ASSERT(this->tensor_impl_ != nullptr);
  106. return this->tensor_impl_->DimensionSize(index);
  107. }
  108. int Tensor::ElementsNum() const {
  109. MS_ASSERT(this->tensor_impl_ != nullptr);
  110. return this->tensor_impl_->ElementsNum();
  111. }
  112. std::size_t Tensor::hash() const {
  113. MS_ASSERT(this->tensor_impl_ != nullptr);
  114. return this->tensor_impl_->hash();
  115. }
  116. std::shared_ptr<tensor::Tensor> Tensor::tensor() const {
  117. MS_ASSERT(this->tensor_impl_ != nullptr);
  118. return this->tensor_impl_;
  119. }
  120. size_t Tensor::Size() const {
  121. MS_ASSERT(this->tensor_impl_ != nullptr);
  122. return this->tensor_impl_->Size();
  123. }
  124. void *Tensor::MutableData() const {
  125. MS_ASSERT(this->tensor_impl_ != nullptr);
  126. return this->tensor_impl_->data();
  127. }
  128. } // namespace inference
  129. } // namespace mindspore