You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_ref.h 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_H_
  17. #define MINDSPORE_CCSRC_UTILS_BASE_REF_H_
  18. #include <type_traits>
  19. #include <algorithm>
  20. #include <vector>
  21. #include <set>
  22. #include <string>
  23. #include <memory>
  24. #include <sstream>
  25. #include <utility>
  26. #include <iterator>
  27. #include "ir/value.h"
  28. namespace mindspore {
  29. class BaseRef;
  30. class VectorRef;
  31. class SetRef;
  32. class PyObjectRef;
  33. class RunFunctionRef;
  34. using iterator = std::vector<BaseRef>::iterator;
  35. using const_iterator = std::vector<BaseRef>::const_iterator;
  36. using const_reverse_iterator = std::vector<BaseRef>::const_reverse_iterator;
  37. using RunFunc = std::function<VectorRef(const VectorRef &args)>;
  38. using RunFuncPtr = std::shared_ptr<RunFunc>;
  39. template <typename T>
  40. using remove_reference_t = typename std::remove_reference<T>::type;
  41. template <typename T>
  42. using remove_const_t = typename std::remove_const<T>::type;
  43. template <typename T>
  44. using is_base = std::is_base_of<Base, remove_reference_t<T>>;
  45. template <typename T>
  46. using is_value = std::is_base_of<Value, remove_reference_t<T>>;
  47. template <typename T>
  48. using is_base_ref = std::is_base_of<BaseRef, remove_reference_t<T>>;
  49. iterator ConstIteratorCast(std::vector<BaseRef> *v, const_iterator iter);
  50. inline std::shared_ptr<VectorRef> MakeNode(const std::vector<BaseRef> &elements) {
  51. return std::make_shared<VectorRef>(elements);
  52. }
  53. inline std::shared_ptr<VectorRef> MakeNode(std::initializer_list<BaseRef> elements) {
  54. return std::make_shared<VectorRef>(elements);
  55. }
  56. // Anfnode, Funcgraph and some not value node class
  57. template <typename T,
  58. typename std::enable_if<is_shared_ptr<remove_const_t<T>>::value && is_base<typename T::element_type>::value,
  59. int>::type = 0>
  60. inline BasePtr MakeNode(const T &v) {
  61. return v;
  62. }
  63. template <typename T,
  64. typename std::enable_if<!is_shared_ptr<remove_const_t<T>>::value && !is_base_ref<T>::value, int>::type = 0>
  65. inline BasePtr MakeNode(const T &v) {
  66. return MakeValue(v);
  67. }
  68. inline std::shared_ptr<VectorRef> MakeNode(const VectorRef &a) { return std::make_shared<VectorRef>(std::move(a)); }
  69. inline std::shared_ptr<VectorRef> MakeNode(const AnfNodePtrList &a) {
  70. std::vector<BaseRef> ret;
  71. (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; });
  72. return std::make_shared<VectorRef>(ret);
  73. }
  74. inline std::shared_ptr<SetRef> MakeNode(const SetRef &a) { return std::make_shared<SetRef>(std::move(a)); }
  75. inline std::shared_ptr<RunFunctionRef> MakeNode(const RunFuncPtr &a) { return std::make_shared<RunFunctionRef>(a); }
  76. inline std::shared_ptr<PyObjectRef> MakeNode(const py::object &a) { return std::make_shared<PyObjectRef>(a); }
  77. inline std::shared_ptr<PyObjectRef> MakeNode(const py::tuple &a) { return std::make_shared<PyObjectRef>(a); }
  78. class BaseRef : public Base {
  79. public:
  80. BaseRef() : m_ptr(nullptr) {}
  81. BaseRef(const BaseRef &other);
  82. virtual std::shared_ptr<Base> copy() const { return m_ptr; }
  83. BaseRef(BaseRef &&other) : Base(other) {
  84. m_ptr = other.m_ptr;
  85. other.m_ptr = nullptr;
  86. }
  87. // right reference constructor
  88. template <class T,
  89. class = typename std::enable_if<!std::is_same<typename std::decay<T>::type, BaseRef>::value, T>::type>
  90. BaseRef(T &&t) { // NOLINT
  91. m_ptr = MakeNode(t);
  92. }
  93. ~BaseRef() override { m_ptr = nullptr; }
  94. MS_DECLARE_PARENT(BaseRef, Base)
  95. bool operator!=(const BaseRef &other) const { return !(operator==(other)); }
  96. virtual bool operator==(const BaseRef &other) const;
  97. // left reference
  98. virtual BaseRef &operator=(const BaseRef &other);
  99. // right reference
  100. virtual BaseRef &operator=(BaseRef &&other);
  101. std::size_t hash() const override {
  102. if (m_ptr == nullptr) {
  103. MS_LOG(ERROR) << "Invalid m_ptr";
  104. return 0;
  105. }
  106. return m_ptr->hash();
  107. }
  108. std::string ToString() const override;
  109. bool is_null() const { return m_ptr == nullptr; }
  110. virtual uint32_t type() const;
  111. BasePtr m_ptr; // point to real data
  112. };
  113. using BaseRefPtr = std::shared_ptr<BaseRef>;
  114. struct BaseRefHash {
  115. std::size_t operator()(const BaseRef &c) const { return c.hash(); }
  116. };
  117. struct BaseRefLess {
  118. bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); }
  119. };
  120. namespace utils {
  121. // judge isa relation
  122. // examples: isa<Int32Imm>(handle), isa<FuncGraph>(handle)
  123. template <typename T, typename std::enable_if<is_base<T>::value && !is_base_ref<T>::value, int>::type = 0>
  124. bool isa(const BaseRef &handle) {
  125. if (!handle.m_ptr) {
  126. return false;
  127. }
  128. return handle.m_ptr->isa<T>();
  129. }
  130. // noderef isa ptr isa<AnfNodePtr>(x) or isa<SeqPtr>()
  131. template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
  132. typename std::enable_if<is_base<U>::value || is_base_ref<U>::value, int>::type = 0>
  133. bool isa(const BaseRef &handle) {
  134. if (handle.m_ptr == nullptr) {
  135. return typeid(handle.m_ptr) == typeid(T);
  136. }
  137. if (handle.m_ptr->isa<U>()) {
  138. return true;
  139. }
  140. // constptr isa<anfnodeptr> can be true
  141. return std::dynamic_pointer_cast<U>(handle.m_ptr) != nullptr;
  142. }
  143. // isa<int32>(handle)
  144. template <typename S, typename U = typename ImmTraits<S>::type::element_type>
  145. bool isa(const BaseRef &handle) {
  146. if (handle.m_ptr == nullptr) {
  147. return false;
  148. }
  149. return handle.m_ptr->isa<U>();
  150. }
  151. // isa<BaseRef>(handle), judge reference or ptr
  152. template <typename T, typename std::enable_if<is_base_ref<T>::value, int>::type = 0>
  153. bool isa(const BaseRef &handle) {
  154. static const uint32_t tid = Base::GetTypeId(typeid(T).name());
  155. return handle.IsFromTypeId(tid) || (handle.m_ptr && handle.m_ptr->isa<T>());
  156. }
  157. // valueref -> C++ type
  158. // cast<int>(handle)
  159. template <typename T, typename std::enable_if<!is_base_ref<T>::value && !is_shared_ptr<T>::value, int>::type = 0>
  160. T cast(const BaseRef &handle) {
  161. T ret = GetValue<T>(std::static_pointer_cast<Value>(handle.m_ptr));
  162. return std::move(ret);
  163. }
  164. // valueref -> valueref type
  165. // cast<VectorRef>(handle)
  166. template <typename T, typename std::enable_if<is_base_ref<T>::value, int>::type = 0>
  167. const T &cast(const BaseRef &handle) {
  168. if (handle.m_ptr) {
  169. return static_cast<const T &>(*handle.m_ptr);
  170. }
  171. return std::move(static_cast<const T &>(handle));
  172. }
  173. // valueref -> nodeptr type
  174. // cast<FuncGraphPtr>(handle)
  175. template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
  176. typename std::enable_if<is_shared_ptr<T>::value && std::is_base_of<Base, typename T::element_type>::value,
  177. int>::type = 0>
  178. T cast(const BaseRef &handle) {
  179. if (!handle.m_ptr) {
  180. MS_LOG(EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null";
  181. }
  182. auto m = handle.m_ptr->cast<T>();
  183. if (nullptr != m) {
  184. return m;
  185. }
  186. return std::static_pointer_cast<U>(handle.m_ptr);
  187. }
  188. } // namespace utils
  189. class VectorRef : public BaseRef {
  190. public:
  191. using value_type = BaseRef;
  192. VectorRef() {}
  193. explicit VectorRef(const std::vector<BaseRef> &elements) : elements_(elements) {}
  194. VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {}
  195. // left reference
  196. virtual VectorRef &operator=(const VectorRef &other);
  197. ~VectorRef() override = default;
  198. std::shared_ptr<Base> copy() const override { return std::make_shared<VectorRef>(elements_); }
  199. bool empty() const { return (elements_.size() == 0); }
  200. std::size_t size() const { return elements_.size(); }
  201. MS_DECLARE_PARENT(VectorRef, BaseRef)
  202. const BaseRef &operator[](const std::size_t &dim) const {
  203. if (dim >= size()) {
  204. MS_LOG(EXCEPTION) << "Out of the size of the tuple.";
  205. }
  206. return elements_[dim];
  207. }
  208. BaseRef &operator[](const std::size_t &dim) {
  209. if (dim >= size()) {
  210. MS_LOG(EXCEPTION) << "Out of the size of the tuple.";
  211. }
  212. return elements_[dim];
  213. }
  214. uint32_t type() const override { return tid(); }
  215. std::string ToString() const override;
  216. std::vector<BaseRef> &elements() { return elements_; }
  217. void clear() { elements_.clear(); }
  218. bool operator==(const BaseRef &other) const override;
  219. bool operator==(const VectorRef &other) const;
  220. void push_back(const BaseRef &value) { elements_.push_back(value); }
  221. void push_back(BaseRef &&value) { elements_.push_back(value); }
  222. void emplace_back(const BaseRef &value) { elements_.emplace_back(value); }
  223. void emplace_back(BaseRef &&value) { elements_.emplace_back(value); }
  224. template <class InputIt>
  225. void insert(const iterator pos, const InputIt first, const InputIt last) {
  226. (void)elements_.insert(pos, first, last);
  227. }
  228. template <class InputIt>
  229. void insert(const const_iterator cpos, const InputIt first, const InputIt last) {
  230. auto pos = ConstIteratorCast(&elements_, cpos);
  231. (void)elements_.insert(pos, first, last);
  232. }
  233. const_iterator begin() const { return elements_.begin(); }
  234. const_iterator end() const { return elements_.end(); }
  235. const_reverse_iterator rbegin() const { return elements_.rbegin(); }
  236. const_reverse_iterator rend() const { return elements_.rend(); }
  237. iterator erase(const const_iterator cpos) {
  238. auto pos = ConstIteratorCast(&elements_, cpos);
  239. return elements_.erase(pos);
  240. }
  241. iterator erase(const const_iterator cfirst, const const_iterator clast) {
  242. auto first = ConstIteratorCast(&elements_, cfirst);
  243. auto last = ConstIteratorCast(&elements_, clast);
  244. return elements_.erase(first, last);
  245. }
  246. std::size_t hash() const override {
  247. std::stringstream buffer;
  248. buffer << ToString();
  249. return std::hash<std::string>()(buffer.str());
  250. }
  251. std::vector<BaseRef> elements_;
  252. };
  253. using VectorRefPtr = std::shared_ptr<VectorRef>;
  254. using set_iterator = std::set<BaseRef, BaseRefLess>::iterator;
  255. using const_set_iterator = std::set<BaseRef, BaseRefLess>::const_iterator;
  256. struct VectorRefHash {
  257. std::size_t operator()(const VectorRef &c) const { return c.hash(); }
  258. };
  259. class SetRef : public BaseRef {
  260. public:
  261. SetRef() {}
  262. explicit SetRef(const std::set<BaseRef, BaseRefLess> &elements) : elements_(elements) {}
  263. SetRef(const std::initializer_list<BaseRef> elements) : elements_(elements.begin(), elements.end()) {}
  264. SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {}
  265. // left reference
  266. virtual SetRef &operator=(const SetRef &other);
  267. bool operator==(const BaseRef &other) const override;
  268. bool operator==(const SetRef &other) const;
  269. ~SetRef() override = default;
  270. std::shared_ptr<Base> copy() const override { return std::make_shared<SetRef>(elements_); }
  271. bool empty() const { return (elements_.size() == 0); }
  272. std::size_t size() const { return elements_.size(); }
  273. MS_DECLARE_PARENT(SetRef, BaseRef)
  274. uint32_t type() const override { return tid(); }
  275. std::string ToString() const override;
  276. std::set<BaseRef, BaseRefLess> &elements() { return elements_; }
  277. void clear() { elements_.clear(); }
  278. void insert(const BaseRef &elem) { (void)elements_.insert(elem); }
  279. const_set_iterator begin() const { return elements_.begin(); }
  280. const_set_iterator end() const { return elements_.end(); }
  281. template <class InputIt>
  282. void insert(const InputIt first, const InputIt last) {
  283. (void)elements_.insert(first, last);
  284. }
  285. std::size_t count(const BaseRef &elem) const { return elements_.count(elem); }
  286. const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); }
  287. std::set<BaseRef, BaseRefLess> elements_;
  288. };
  289. using SetRefPtr = std::shared_ptr<SetRef>;
  290. class PyObjectRef : public BaseRef {
  291. public:
  292. explicit PyObjectRef(const py::object &py_object) : object_(py_object) {}
  293. explicit PyObjectRef(const py::tuple &tuple_obj) : object_(tuple_obj) {}
  294. ~PyObjectRef() override = default;
  295. std::shared_ptr<Base> copy() const override { return std::make_shared<PyObjectRef>(object_); }
  296. MS_DECLARE_PARENT(PyObjectRef, BaseRef)
  297. uint32_t type() const override { return tid(); }
  298. std::string ToString() const override { return py::str(object_); }
  299. bool operator==(const BaseRef &other) const override;
  300. bool operator==(const PyObjectRef &other) const;
  301. py::object object_;
  302. };
  303. class RunFunctionRef : public BaseRef {
  304. public:
  305. RunFunctionRef() {}
  306. explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {}
  307. ~RunFunctionRef() override = default;
  308. MS_DECLARE_PARENT(RunFunctionRef, BaseRef)
  309. uint32_t type() const override { return tid(); }
  310. std::string ToString() const override { return std::string("RunFunctionRef"); }
  311. bool operator==(const BaseRef &other) const override;
  312. bool operator==(const RunFunctionRef &other) const;
  313. RunFuncPtr func_;
  314. };
  315. } // namespace mindspore
  316. #endif // MINDSPORE_CCSRC_UTILS_NODE_REF_H_