You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cc 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/tensor.h"
  17. #include <securec.h>
  18. #include <functional>
  19. #include <utility>
  20. #include "common/log.h"
  21. namespace mindspore::serving {
  22. Tensor::Tensor() = default;
  23. Tensor::Tensor(DataType type, std::vector<int64_t> shape, const void *data, size_t data_len) {
  24. set_data_type(type);
  25. set_shape(shape);
  26. set_data(data, data_len);
  27. }
  28. const uint8_t *Tensor::data() const {
  29. if (data_size() == 0) {
  30. return nullptr;
  31. }
  32. return data_.data();
  33. }
  34. size_t Tensor::data_size() const { return data_.size(); }
  35. bool Tensor::resize_data(size_t data_len) {
  36. data_.resize(data_len);
  37. return true;
  38. }
  39. uint8_t *Tensor::mutable_data() {
  40. if (data_size() == 0) {
  41. return nullptr;
  42. }
  43. return data_.data();
  44. }
  45. // For kMSI_String and kMSI_Bytes
  46. void Tensor::clear_bytes_data() { bytes_.clear(); }
  47. void Tensor::add_bytes_data(const uint8_t *data, size_t bytes_len) {
  48. std::vector<uint8_t> bytes(bytes_len);
  49. memcpy_s(bytes.data(), bytes.size(), data, bytes_len);
  50. bytes_.push_back(std::move(bytes));
  51. }
  52. size_t Tensor::bytes_data_size() const { return bytes_.size(); }
  53. void Tensor::get_bytes_data(size_t index, const uint8_t **data, size_t *bytes_len) const {
  54. MSI_EXCEPTION_IF_NULL(data);
  55. MSI_EXCEPTION_IF_NULL(bytes_len);
  56. *bytes_len = bytes_[index].size();
  57. if (*bytes_len == 0) {
  58. *data = nullptr;
  59. } else {
  60. *data = bytes_[index].data();
  61. }
  62. }
  63. VectorTensorWrapReply::VectorTensorWrapReply(std::vector<Tensor> *tensor_list) : tensor_list_(tensor_list) {}
  64. VectorTensorWrapReply::~VectorTensorWrapReply() = default;
  65. size_t VectorTensorWrapReply::size() const {
  66. MSI_EXCEPTION_IF_NULL(tensor_list_);
  67. return tensor_list_->size();
  68. }
  69. void VectorTensorWrapReply::clear() {
  70. MSI_EXCEPTION_IF_NULL(tensor_list_);
  71. tensor_list_->clear();
  72. }
  73. TensorBase *VectorTensorWrapReply::operator[](size_t index) {
  74. MSI_EXCEPTION_IF_NULL(tensor_list_);
  75. if (index >= tensor_list_->size()) {
  76. MSI_LOG_EXCEPTION << "visit invalid index " << index << " total size " << tensor_list_->size();
  77. }
  78. return &((*tensor_list_)[index]);
  79. }
  80. const TensorBase *VectorTensorWrapReply::operator[](size_t index) const {
  81. MSI_EXCEPTION_IF_NULL(tensor_list_);
  82. if (index >= tensor_list_->size()) {
  83. MSI_LOG_EXCEPTION << "visit invalid index " << index << " total size " << tensor_list_->size();
  84. }
  85. return &((*tensor_list_)[index]);
  86. }
  87. TensorBase *VectorTensorWrapReply::add() {
  88. MSI_EXCEPTION_IF_NULL(tensor_list_);
  89. tensor_list_->push_back(Tensor());
  90. return &(tensor_list_->back());
  91. }
  92. const TensorBase *VectorTensorWrapRequest::operator[](size_t index) const {
  93. if (index >= tensor_list_.size()) {
  94. MSI_LOG_EXCEPTION << "visit invalid index " << index << " total size " << tensor_list_.size();
  95. }
  96. return &(tensor_list_[index]);
  97. }
  98. VectorTensorPtrWrapReply::VectorTensorPtrWrapReply(std::vector<TensorBasePtr> *tensor_list,
  99. std::function<TensorBasePtr()> create_fun)
  100. : tensor_list_(tensor_list), tensor_create_fun_(create_fun) {}
  101. VectorTensorPtrWrapReply::~VectorTensorPtrWrapReply() = default;
  102. size_t VectorTensorPtrWrapReply::size() const {
  103. MSI_EXCEPTION_IF_NULL(tensor_list_);
  104. return tensor_list_->size();
  105. }
  106. void VectorTensorPtrWrapReply::clear() {
  107. MSI_EXCEPTION_IF_NULL(tensor_list_);
  108. tensor_list_->clear();
  109. }
  110. TensorBase *VectorTensorPtrWrapReply::operator[](size_t index) {
  111. MSI_EXCEPTION_IF_NULL(tensor_list_);
  112. if (index >= tensor_list_->size()) {
  113. MSI_LOG_EXCEPTION << "visit invalid index " << index << " total size " << tensor_list_->size();
  114. }
  115. return (*tensor_list_)[index].get();
  116. }
  117. const TensorBase *VectorTensorPtrWrapReply::operator[](size_t index) const {
  118. MSI_EXCEPTION_IF_NULL(tensor_list_);
  119. if (index >= tensor_list_->size()) {
  120. MSI_LOG_EXCEPTION << "visit invalid index " << index << " total size " << tensor_list_->size();
  121. }
  122. return (*tensor_list_)[index].get();
  123. }
  124. TensorBase *VectorTensorPtrWrapReply::add() {
  125. MSI_EXCEPTION_IF_NULL(tensor_list_);
  126. MSI_EXCEPTION_IF_NULL(tensor_create_fun_);
  127. auto tensor = tensor_create_fun_();
  128. if (tensor == nullptr) {
  129. MSI_LOG_EXCEPTION << "create tensor failed";
  130. }
  131. tensor_list_->push_back(tensor);
  132. return tensor.get();
  133. }
  134. const TensorBase *VectorTensorPtrWrapRequest::operator[](size_t index) const {
  135. if (index >= tensor_list_.size()) {
  136. MSI_LOG_EXCEPTION << "visit invalid index " << index << " total size " << tensor_list_.size();
  137. }
  138. return tensor_list_[index].get();
  139. }
  140. } // namespace mindspore::serving

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.