You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_ref.cc 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/base_ref.h"
  17. namespace mindspore {
  18. iterator ConstIteratorCast(std::vector<BaseRef> *v, const const_iterator iter) {
  19. return std::next(v->begin(), std::distance(v->cbegin(), iter));
  20. }
  21. BaseRef::BaseRef(const BaseRef &other) : Base(other), m_ptr(other.m_ptr) {
  22. if (!m_ptr) {
  23. m_ptr = other.copy();
  24. }
  25. }
  26. bool BaseRef::operator==(const BaseRef &other) const {
  27. if (m_ptr == other.m_ptr) {
  28. return true;
  29. }
  30. if (m_ptr == nullptr && other.m_ptr == nullptr) {
  31. return *this == other;
  32. }
  33. if (m_ptr == nullptr || other.m_ptr == nullptr) {
  34. return false;
  35. }
  36. if (type() != other.type()) {
  37. MS_LOG(DEBUG) << "Type mismatch";
  38. return false;
  39. }
  40. if (m_ptr->isa<Value>()) {
  41. return *(m_ptr->cast<ValuePtr>()) == *(other.m_ptr->cast<ValuePtr>());
  42. }
  43. // for noderef equal
  44. if (m_ptr->isa<BaseRef>()) {
  45. return *std::static_pointer_cast<BaseRef>(m_ptr) == *std::static_pointer_cast<BaseRef>(other.m_ptr);
  46. }
  47. // for node equal
  48. return *m_ptr == *other.m_ptr;
  49. }
  50. // left reference
  51. BaseRef &BaseRef::operator=(const BaseRef &other) {
  52. if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) {
  53. return *this;
  54. }
  55. m_ptr = other.copy();
  56. return *this;
  57. }
  58. // right reference
  59. BaseRef &BaseRef::operator=(BaseRef &&other) {
  60. if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) {
  61. return *this;
  62. }
  63. m_ptr = other.copy();
  64. other.m_ptr = nullptr;
  65. return *this;
  66. }
  67. std::string BaseRef::ToString() const {
  68. if (m_ptr != nullptr) {
  69. return std::string(m_ptr->type_name()) + std::string(" value:") + m_ptr->ToString();
  70. }
  71. return std::string();
  72. }
  73. uint32_t BaseRef::type() const {
  74. if (m_ptr != nullptr) {
  75. return m_ptr->tid();
  76. }
  77. return tid();
  78. }
  79. // left reference
  80. SetRef &SetRef::operator=(const SetRef &other) {
  81. if (elements_ == other.elements_ || this == &other) {
  82. return *this;
  83. }
  84. elements_ = other.elements_;
  85. return *this;
  86. }
  87. std::string SetRef::ToString() const {
  88. std::ostringstream buffer;
  89. bool begin = true;
  90. buffer << "set[";
  91. for (auto &attr : elements_) {
  92. if (!begin) {
  93. buffer << ", ";
  94. } else {
  95. begin = false;
  96. }
  97. buffer << attr.ToString();
  98. }
  99. buffer << "]";
  100. return buffer.str();
  101. }
  102. // left reference
  103. VectorRef &VectorRef::operator=(const VectorRef &other) {
  104. if (elements_ == other.elements_ || this == &other) {
  105. return *this;
  106. }
  107. elements_ = other.elements_;
  108. return *this;
  109. }
  110. std::string VectorRef::ToString() const {
  111. std::ostringstream buffer;
  112. bool begin = true;
  113. buffer << "vector[";
  114. for (auto &attr : elements_) {
  115. if (!begin) {
  116. buffer << ", ";
  117. } else {
  118. begin = false;
  119. }
  120. buffer << attr.ToString();
  121. }
  122. buffer << "]";
  123. return buffer.str();
  124. }
  125. bool VectorRef::operator==(const BaseRef &other) const {
  126. if (!utils::isa<VectorRef>(other)) {
  127. return false;
  128. }
  129. return *this == utils::cast<VectorRef>(other);
  130. }
  131. bool VectorRef::operator==(const VectorRef &other) const {
  132. if (elements_.size() != other.elements_.size()) {
  133. return false;
  134. }
  135. for (size_t i = 0; i < elements_.size(); ++i) {
  136. if (elements_[i] != other.elements_[i]) {
  137. return false;
  138. }
  139. }
  140. return true;
  141. }
  142. bool SetRef::operator==(const BaseRef &other) const {
  143. if (!utils::isa<SetRef>(other)) {
  144. return false;
  145. }
  146. return *this == utils::cast<SetRef>(other);
  147. }
  148. bool SetRef::operator==(const SetRef &other) const {
  149. if (elements_.size() != other.elements_.size()) {
  150. return false;
  151. }
  152. auto iter = elements_.begin();
  153. auto oth_iter = other.elements_.begin();
  154. for (; iter != elements_.end(); iter++, oth_iter++) {
  155. if (*iter != *oth_iter) {
  156. return false;
  157. }
  158. }
  159. return true;
  160. }
  161. bool RunFunctionRef::operator==(const BaseRef &other) const {
  162. if (!utils::isa<RunFunctionRef>(other)) {
  163. return false;
  164. }
  165. return *this == utils::cast<RunFunctionRef>(other);
  166. }
  167. bool RunFunctionRef::operator==(const RunFunctionRef &other) const { return func_ == other.func_; }
  168. } // namespace mindspore