You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_ref.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_H_
  17. #define MINDSPORE_CCSRC_UTILS_BASE_REF_H_
  18. #include <type_traits>
  19. #include <algorithm>
  20. #include <vector>
  21. #include <set>
  22. #include <string>
  23. #include <memory>
  24. #include <sstream>
  25. #include <utility>
  26. #include <iterator>
  27. #include "ir/value.h"
  28. namespace mindspore {
  29. class BaseRef;
  30. class VectorRef;
  31. class SetRef;
  32. class RunFunctionRef;
  33. using iterator = std::vector<BaseRef>::iterator;
  34. using const_iterator = std::vector<BaseRef>::const_iterator;
  35. using const_reverse_iterator = std::vector<BaseRef>::const_reverse_iterator;
  36. using RunFunc = std::function<VectorRef(const VectorRef &args)>;
  37. using RunFuncPtr = std::shared_ptr<RunFunc>;
  38. template <typename T>
  39. using remove_reference_t = typename std::remove_reference<T>::type;
  40. template <typename T>
  41. using remove_const_t = typename std::remove_const<T>::type;
  42. template <typename T>
  43. using is_base = std::is_base_of<Base, remove_reference_t<T>>;
  44. template <typename T>
  45. using is_value = std::is_base_of<Value, remove_reference_t<T>>;
  46. template <typename T>
  47. using is_base_ref = std::is_base_of<BaseRef, remove_reference_t<T>>;
  48. iterator ConstIteratorCast(std::vector<BaseRef> *v, const_iterator iter);
  49. inline std::shared_ptr<VectorRef> MakeNode(const std::vector<BaseRef> &elements) {
  50. return std::make_shared<VectorRef>(elements);
  51. }
  52. inline std::shared_ptr<VectorRef> MakeNode(std::initializer_list<BaseRef> elements) {
  53. return std::make_shared<VectorRef>(elements);
  54. }
  55. // Anfnode, Funcgraph and some not value node class
  56. template <typename T,
  57. typename std::enable_if<is_shared_ptr<remove_const_t<T>>::value && is_base<typename T::element_type>::value,
  58. int>::type = 0>
  59. inline BasePtr MakeNode(const T &v) {
  60. return v;
  61. }
  62. template <typename T,
  63. typename std::enable_if<!is_shared_ptr<remove_const_t<T>>::value && !is_base_ref<T>::value, int>::type = 0>
  64. inline BasePtr MakeNode(const T &v) {
  65. return MakeValue(v);
  66. }
  67. inline std::shared_ptr<VectorRef> MakeNode(const VectorRef &a) { return std::make_shared<VectorRef>(std::move(a)); }
  68. inline std::shared_ptr<VectorRef> MakeNode(const AnfNodePtrList &a) {
  69. std::vector<BaseRef> ret;
  70. (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; });
  71. return std::make_shared<VectorRef>(ret);
  72. }
  73. inline std::shared_ptr<SetRef> MakeNode(const SetRef &a) { return std::make_shared<SetRef>(std::move(a)); }
  74. inline std::shared_ptr<RunFunctionRef> MakeNode(const RunFuncPtr &a) { return std::make_shared<RunFunctionRef>(a); }
  75. class BaseRef : public Base {
  76. public:
  77. BaseRef() : m_ptr(nullptr) {}
  78. BaseRef(const BaseRef &other);
  79. virtual std::shared_ptr<Base> copy() const { return m_ptr; }
  80. BaseRef(BaseRef &&other) : Base(other) {
  81. m_ptr = other.m_ptr;
  82. other.m_ptr = nullptr;
  83. }
  84. // right reference constructor
  85. template <class T,
  86. class = typename std::enable_if<!std::is_same<typename std::decay<T>::type, BaseRef>::value, T>::type>
  87. BaseRef(T &&t) { // NOLINT
  88. m_ptr = MakeNode(t);
  89. }
  90. ~BaseRef() override { m_ptr = nullptr; }
  91. MS_DECLARE_PARENT(BaseRef, Base)
  92. bool operator!=(const BaseRef &other) const { return !(operator==(other)); }
  93. virtual bool operator==(const BaseRef &other) const;
  94. // left reference
  95. virtual BaseRef &operator=(const BaseRef &other);
  96. // right reference
  97. virtual BaseRef &operator=(BaseRef &&other);
  98. std::size_t hash() const override {
  99. if (m_ptr == nullptr) {
  100. MS_LOG(ERROR) << "Invalid m_ptr";
  101. return 0;
  102. }
  103. return m_ptr->hash();
  104. }
  105. std::string ToString() const override;
  106. bool is_null() const { return m_ptr == nullptr; }
  107. virtual uint32_t type() const;
  108. BasePtr m_ptr; // point to real data
  109. };
  110. using BaseRefPtr = std::shared_ptr<BaseRef>;
  111. struct BaseRefHash {
  112. std::size_t operator()(const BaseRef &c) const { return c.hash(); }
  113. };
  114. struct BaseRefLess {
  115. bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); }
  116. };
  117. namespace utils {
  118. // judge isa relation
  119. // examples: isa<Int32Imm>(handle), isa<FuncGraph>(handle)
  120. template <typename T, typename std::enable_if<is_base<T>::value && !is_base_ref<T>::value, int>::type = 0>
  121. bool isa(const BaseRef &handle) {
  122. if (!handle.m_ptr) {
  123. return false;
  124. }
  125. return handle.m_ptr->isa<T>();
  126. }
  127. // noderef isa ptr isa<AnfNodePtr>(x) or isa<SeqPtr>()
  128. template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
  129. typename std::enable_if<is_base<U>::value || is_base_ref<U>::value, int>::type = 0>
  130. bool isa(const BaseRef &handle) {
  131. if (handle.m_ptr == nullptr) {
  132. return typeid(handle.m_ptr) == typeid(T);
  133. }
  134. if (handle.m_ptr->isa<U>()) {
  135. return true;
  136. }
  137. // constptr isa<anfnodeptr> can be true
  138. return std::dynamic_pointer_cast<U>(handle.m_ptr) != nullptr;
  139. }
  140. // isa<int32>(handle)
  141. template <typename S, typename U = typename ImmTraits<S>::type::element_type>
  142. bool isa(const BaseRef &handle) {
  143. if (handle.m_ptr == nullptr) {
  144. return false;
  145. }
  146. return handle.m_ptr->isa<U>();
  147. }
  148. // isa<BaseRef>(handle), judge reference or ptr
  149. template <typename T, typename std::enable_if<is_base_ref<T>::value, int>::type = 0>
  150. bool isa(const BaseRef &handle) {
  151. static const uint32_t tid = Base::GetTypeId(typeid(T).name());
  152. return handle.IsFromTypeId(tid) || (handle.m_ptr && handle.m_ptr->isa<T>());
  153. }
  154. // valueref -> C++ type
  155. // cast<int>(handle)
  156. template <typename T, typename std::enable_if<!is_base_ref<T>::value && !is_shared_ptr<T>::value, int>::type = 0>
  157. T cast(const BaseRef &handle) {
  158. T ret = GetValue<T>(std::static_pointer_cast<Value>(handle.m_ptr));
  159. return std::move(ret);
  160. }
  161. // valueref -> valueref type
  162. // cast<VectorRef>(handle)
  163. template <typename T, typename std::enable_if<is_base_ref<T>::value, int>::type = 0>
  164. const T &cast(const BaseRef &handle) {
  165. if (handle.m_ptr) {
  166. return static_cast<const T &>(*handle.m_ptr);
  167. }
  168. return std::move(static_cast<const T &>(handle));
  169. }
  170. // valueref -> nodeptr type
  171. // cast<FuncGraphPtr>(handle)
  172. template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
  173. typename std::enable_if<is_shared_ptr<T>::value && std::is_base_of<Base, typename T::element_type>::value,
  174. int>::type = 0>
  175. T cast(const BaseRef &handle) {
  176. if (!handle.m_ptr) {
  177. MS_LOG(EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null";
  178. }
  179. auto m = handle.m_ptr->cast<T>();
  180. if (nullptr != m) {
  181. return m;
  182. }
  183. return std::static_pointer_cast<U>(handle.m_ptr);
  184. }
  185. } // namespace utils
  186. class VectorRef : public BaseRef {
  187. public:
  188. using value_type = BaseRef;
  189. VectorRef() {}
  190. explicit VectorRef(const std::vector<BaseRef> &elements) : elements_(elements) {}
  191. VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {}
  192. // left reference
  193. virtual VectorRef &operator=(const VectorRef &other);
  194. ~VectorRef() override = default;
  195. std::shared_ptr<Base> copy() const override { return std::make_shared<VectorRef>(elements_); }
  196. bool empty() const { return (elements_.size() == 0); }
  197. std::size_t size() const { return elements_.size(); }
  198. MS_DECLARE_PARENT(VectorRef, BaseRef)
  199. const BaseRef &operator[](const std::size_t &dim) const {
  200. if (dim >= size()) {
  201. MS_LOG(EXCEPTION) << "Out of the size of the tuple.";
  202. }
  203. return elements_[dim];
  204. }
  205. BaseRef &operator[](const std::size_t &dim) {
  206. if (dim >= size()) {
  207. MS_LOG(EXCEPTION) << "Out of the size of the tuple.";
  208. }
  209. return elements_[dim];
  210. }
  211. uint32_t type() const override { return tid(); }
  212. std::string ToString() const override;
  213. std::vector<BaseRef> &elements() { return elements_; }
  214. void clear() { elements_.clear(); }
  215. bool operator==(const BaseRef &other) const override;
  216. bool operator==(const VectorRef &other) const;
  217. void push_back(const BaseRef &value) { elements_.push_back(value); }
  218. void push_back(BaseRef &&value) { elements_.push_back(value); }
  219. void emplace_back(const BaseRef &value) { elements_.emplace_back(value); }
  220. void emplace_back(BaseRef &&value) { elements_.emplace_back(value); }
  221. template <class InputIt>
  222. void insert(const iterator pos, const InputIt first, const InputIt last) {
  223. (void)elements_.insert(pos, first, last);
  224. }
  225. template <class InputIt>
  226. void insert(const const_iterator cpos, const InputIt first, const InputIt last) {
  227. auto pos = ConstIteratorCast(&elements_, cpos);
  228. (void)elements_.insert(pos, first, last);
  229. }
  230. const_iterator begin() const { return elements_.begin(); }
  231. const_iterator end() const { return elements_.end(); }
  232. const_reverse_iterator rbegin() const { return elements_.rbegin(); }
  233. const_reverse_iterator rend() const { return elements_.rend(); }
  234. iterator erase(const const_iterator cpos) {
  235. auto pos = ConstIteratorCast(&elements_, cpos);
  236. return elements_.erase(pos);
  237. }
  238. iterator erase(const const_iterator cfirst, const const_iterator clast) {
  239. auto first = ConstIteratorCast(&elements_, cfirst);
  240. auto last = ConstIteratorCast(&elements_, clast);
  241. return elements_.erase(first, last);
  242. }
  243. std::size_t hash() const override {
  244. std::stringstream buffer;
  245. buffer << ToString();
  246. return std::hash<std::string>()(buffer.str());
  247. }
  248. std::vector<BaseRef> elements_;
  249. };
  250. using VectorRefPtr = std::shared_ptr<VectorRef>;
  251. using set_iterator = std::set<BaseRef, BaseRefLess>::iterator;
  252. using const_set_iterator = std::set<BaseRef, BaseRefLess>::const_iterator;
  253. struct VectorRefHash {
  254. std::size_t operator()(const VectorRef &c) const { return c.hash(); }
  255. };
  256. class SetRef : public BaseRef {
  257. public:
  258. SetRef() {}
  259. explicit SetRef(const std::set<BaseRef, BaseRefLess> &elements) : elements_(elements) {}
  260. SetRef(const std::initializer_list<BaseRef> elements) : elements_(elements.begin(), elements.end()) {}
  261. SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {}
  262. // left reference
  263. virtual SetRef &operator=(const SetRef &other);
  264. bool operator==(const BaseRef &other) const override;
  265. bool operator==(const SetRef &other) const;
  266. ~SetRef() override = default;
  267. std::shared_ptr<Base> copy() const override { return std::make_shared<SetRef>(elements_); }
  268. bool empty() const { return (elements_.size() == 0); }
  269. std::size_t size() const { return elements_.size(); }
  270. MS_DECLARE_PARENT(SetRef, BaseRef)
  271. uint32_t type() const override { return tid(); }
  272. std::string ToString() const override;
  273. std::set<BaseRef, BaseRefLess> &elements() { return elements_; }
  274. void clear() { elements_.clear(); }
  275. void insert(const BaseRef &elem) { (void)elements_.insert(elem); }
  276. const_set_iterator begin() const { return elements_.begin(); }
  277. const_set_iterator end() const { return elements_.end(); }
  278. template <class InputIt>
  279. void insert(const InputIt first, const InputIt last) {
  280. (void)elements_.insert(first, last);
  281. }
  282. std::size_t count(const BaseRef &elem) const { return elements_.count(elem); }
  283. const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); }
  284. std::set<BaseRef, BaseRefLess> elements_;
  285. };
  286. using SetRefPtr = std::shared_ptr<SetRef>;
  287. class RunFunctionRef : public BaseRef {
  288. public:
  289. RunFunctionRef() {}
  290. explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {}
  291. ~RunFunctionRef() override = default;
  292. MS_DECLARE_PARENT(RunFunctionRef, BaseRef)
  293. uint32_t type() const override { return tid(); }
  294. std::string ToString() const override { return std::string("RunFunctionRef"); }
  295. bool operator==(const BaseRef &other) const override;
  296. bool operator==(const RunFunctionRef &other) const;
  297. RunFuncPtr func_;
  298. };
  299. } // namespace mindspore
  300. #endif // MINDSPORE_CCSRC_UTILS_NODE_REF_H_