You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

meta_tensor.cc 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ir/meta_tensor.h"
  17. #include <functional>
  18. #include <numeric>
  19. #include <vector>
  20. #include <sstream>
  21. #include <string>
  22. namespace mindspore {
  23. namespace tensor {
  24. // MetaTensor has default type_id_ which is TypeId::kTypeUnknown.
  25. MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {}
  26. MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int> &shape) : data_type_(data_type), shape_(shape) {}
  27. MetaTensor::MetaTensor(const TypePtr &type_ptr, const std::vector<int> &shape) {
  28. TypeId data_type = TypeId::kTypeUnknown;
  29. if (type_ptr != nullptr) {
  30. data_type = type_ptr->type_id();
  31. }
  32. data_type_ = data_type;
  33. shape_ = shape;
  34. }
  35. MetaTensor::MetaTensor(const MetaTensor &meta_tensor)
  36. : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {}
  37. MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) {
  38. if (&meta_tensor == this) {
  39. return *this;
  40. }
  41. data_type_ = meta_tensor.data_type();
  42. shape_ = meta_tensor.shape();
  43. device_info_ = meta_tensor.device_info();
  44. return *this;
  45. }
  46. bool MetaTensor::operator==(const MetaTensor &meta_tensor) const {
  47. return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape();
  48. }
  49. // Get the size of a given dimension by its index number.
  50. // The given index number should be in [0, shape_.size()).
  51. // param index Dimension index number.
  52. // return The size of the dimension if succeed, or -1 if failed.
  53. int MetaTensor::DimensionSize(const size_t index) const {
  54. int dim_size = -1;
  55. if (index < shape_.size()) {
  56. dim_size = shape_[index];
  57. } else {
  58. MS_LOG(ERROR) << "Dimension index is wrong: " << index;
  59. }
  60. return dim_size;
  61. }
  62. int MetaTensor::ElementsNum() const {
  63. return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies<int>());
  64. }
  65. TypePtr MetaTensor::Dtype() const { return TypeIdToType(data_type_); }
  66. TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) {
  67. if (type_ptr == nullptr) {
  68. MS_LOG(ERROR) << "Dtype to be set is nullptr.";
  69. return nullptr;
  70. }
  71. (void)set_data_type(type_ptr->type_id());
  72. return type_ptr;
  73. }
  74. void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) {
  75. DeviceInfo info(format, data_type);
  76. set_device_info(info);
  77. }
  78. std::string MetaTensor::ToString() const {
  79. std::ostringstream buf;
  80. buf << "MetaTensor shape:[" << shape() << "]";
  81. return buf.str();
  82. }
  83. std::string MetaTensor::DumpText() const {
  84. std::ostringstream oss;
  85. oss << type_name() << "(" << SizeToInt(data_type_) << ")[";
  86. for (size_t i = 0; i < shape_.size(); ++i) {
  87. oss << (i > 0 ? ", " : "") << shape_[i];
  88. }
  89. oss << "]";
  90. return oss.str();
  91. }
  92. } // namespace tensor
  93. } // namespace mindspore