You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_ref.h 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_H_
  17. #define MINDSPORE_CCSRC_UTILS_BASE_REF_H_
  18. #include <type_traits>
  19. #include <algorithm>
  20. #include <vector>
  21. #include <set>
  22. #include <string>
  23. #include <memory>
  24. #include <sstream>
  25. #include <utility>
  26. #include <iterator>
  27. #include "pybind11/pybind11.h"
  28. #include "ir/value.h"
  29. namespace py = pybind11;
  30. namespace mindspore {
  31. class BaseRef;
  32. class VectorRef;
  33. class SetRef;
  34. class PyObjectRef;
  35. class RunFunctionRef;
  36. using iterator = std::vector<BaseRef>::iterator;
  37. using const_iterator = std::vector<BaseRef>::const_iterator;
  38. using const_reverse_iterator = std::vector<BaseRef>::const_reverse_iterator;
  39. using RunFunc = std::function<VectorRef(const VectorRef &args)>;
  40. using RunFuncPtr = std::shared_ptr<RunFunc>;
  41. template <typename T>
  42. using remove_reference_t = typename std::remove_reference<T>::type;
  43. template <typename T>
  44. using remove_const_t = typename std::remove_const<T>::type;
  45. template <typename T>
  46. using is_base = std::is_base_of<Base, remove_reference_t<T>>;
  47. template <typename T>
  48. using is_value = std::is_base_of<Value, remove_reference_t<T>>;
  49. template <typename T>
  50. using is_base_ref = std::is_base_of<BaseRef, remove_reference_t<T>>;
  51. iterator ConstIteratorCast(std::vector<BaseRef> *v, const_iterator iter);
  52. inline std::shared_ptr<VectorRef> MakeNode(const std::vector<BaseRef> &elements) {
  53. return std::make_shared<VectorRef>(elements);
  54. }
  55. inline std::shared_ptr<VectorRef> MakeNode(std::initializer_list<BaseRef> elements) {
  56. return std::make_shared<VectorRef>(elements);
  57. }
  58. // Anfnode, Funcgraph and some not value node class
  59. template <typename T,
  60. typename std::enable_if<is_shared_ptr<remove_const_t<T>>::value && is_base<typename T::element_type>::value,
  61. int>::type = 0>
  62. inline BasePtr MakeNode(const T &v) {
  63. return v;
  64. }
  65. template <typename T,
  66. typename std::enable_if<!is_shared_ptr<remove_const_t<T>>::value && !is_base_ref<T>::value, int>::type = 0>
  67. inline BasePtr MakeNode(const T &v) {
  68. return MakeValue(v);
  69. }
  70. inline std::shared_ptr<VectorRef> MakeNode(const VectorRef &a) { return std::make_shared<VectorRef>(std::move(a)); }
  71. inline std::shared_ptr<VectorRef> MakeNode(const AnfNodePtrList &a) {
  72. std::vector<BaseRef> ret;
  73. (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; });
  74. return std::make_shared<VectorRef>(ret);
  75. }
  76. inline std::shared_ptr<SetRef> MakeNode(const SetRef &a) { return std::make_shared<SetRef>(std::move(a)); }
  77. inline std::shared_ptr<RunFunctionRef> MakeNode(const RunFuncPtr &a) { return std::make_shared<RunFunctionRef>(a); }
  78. inline std::shared_ptr<PyObjectRef> MakeNode(const py::object &a) { return std::make_shared<PyObjectRef>(a); }
  79. inline std::shared_ptr<PyObjectRef> MakeNode(const py::tuple &a) { return std::make_shared<PyObjectRef>(a); }
  80. class BaseRef : public Base {
  81. public:
  82. BaseRef() : m_ptr(nullptr) {}
  83. BaseRef(const BaseRef &other);
  84. virtual std::shared_ptr<Base> copy() const { return m_ptr; }
  85. BaseRef(BaseRef &&other) : Base(other) {
  86. m_ptr = other.m_ptr;
  87. other.m_ptr = nullptr;
  88. }
  89. // right reference constructor
  90. template <class T,
  91. class = typename std::enable_if<!std::is_same<typename std::decay<T>::type, BaseRef>::value, T>::type>
  92. BaseRef(T &&t) { // NOLINT
  93. m_ptr = MakeNode(t);
  94. }
  95. ~BaseRef() override { m_ptr = nullptr; }
  96. MS_DECLARE_PARENT(BaseRef, Base)
  97. bool operator!=(const BaseRef &other) const { return !(operator==(other)); }
  98. virtual bool operator==(const BaseRef &other) const;
  99. // left reference
  100. virtual BaseRef &operator=(const BaseRef &other);
  101. // right reference
  102. virtual BaseRef &operator=(BaseRef &&other);
  103. std::size_t hash() const override {
  104. if (m_ptr == nullptr) {
  105. MS_LOG(ERROR) << "Invalid m_ptr";
  106. return 0;
  107. }
  108. return m_ptr->hash();
  109. }
  110. std::string ToString() const override;
  111. bool is_null() const { return m_ptr == nullptr; }
  112. virtual uint32_t type() const;
  113. BasePtr m_ptr; // point to real data
  114. };
  115. using BaseRefPtr = std::shared_ptr<BaseRef>;
  116. struct BaseRefHash {
  117. std::size_t operator()(const BaseRef &c) const { return c.hash(); }
  118. };
  119. struct BaseRefLess {
  120. bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); }
  121. };
  122. namespace utils {
  123. // judge isa relation
  124. // examples: isa<Int32Imm>(handle), isa<FuncGraph>(handle)
  125. template <typename T, typename std::enable_if<is_base<T>::value && !is_base_ref<T>::value, int>::type = 0>
  126. bool isa(const BaseRef &handle) {
  127. if (!handle.m_ptr) {
  128. return false;
  129. }
  130. return handle.m_ptr->isa<T>();
  131. }
  132. // noderef isa ptr isa<AnfNodePtr>(x) or isa<SeqPtr>()
  133. template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
  134. typename std::enable_if<is_base<U>::value || is_base_ref<U>::value, int>::type = 0>
  135. bool isa(const BaseRef &handle) {
  136. if (handle.m_ptr == nullptr) {
  137. return typeid(handle.m_ptr) == typeid(T);
  138. }
  139. if (handle.m_ptr->isa<U>()) {
  140. return true;
  141. }
  142. // constptr isa<anfnodeptr> can be true
  143. return std::dynamic_pointer_cast<U>(handle.m_ptr) != nullptr;
  144. }
  145. // isa<int32>(handle)
  146. template <typename S, typename U = typename ImmTraits<S>::type::element_type>
  147. bool isa(const BaseRef &handle) {
  148. if (handle.m_ptr == nullptr) {
  149. return false;
  150. }
  151. return handle.m_ptr->isa<U>();
  152. }
  153. // isa<BaseRef>(handle), judge reference or ptr
  154. template <typename T, typename std::enable_if<is_base_ref<T>::value, int>::type = 0>
  155. bool isa(const BaseRef &handle) {
  156. static const uint32_t tid = Base::GetTypeId(typeid(T).name());
  157. return handle.IsFromTypeId(tid) || (handle.m_ptr && handle.m_ptr->isa<T>());
  158. }
  159. // valueref -> C++ type
  160. // cast<int>(handle)
  161. template <typename T, typename std::enable_if<!is_base_ref<T>::value && !is_shared_ptr<T>::value, int>::type = 0>
  162. T cast(const BaseRef &handle) {
  163. T ret = GetValue<T>(std::static_pointer_cast<Value>(handle.m_ptr));
  164. return std::move(ret);
  165. }
  166. // valueref -> valueref type
  167. // cast<VectorRef>(handle)
  168. template <typename T, typename std::enable_if<is_base_ref<T>::value, int>::type = 0>
  169. const T &cast(const BaseRef &handle) {
  170. if (handle.m_ptr) {
  171. return static_cast<const T &>(*handle.m_ptr);
  172. }
  173. return std::move(static_cast<const T &>(handle));
  174. }
  175. // valueref -> nodeptr type
  176. // cast<FuncGraphPtr>(handle)
  177. template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type,
  178. typename std::enable_if<is_shared_ptr<T>::value && std::is_base_of<Base, typename T::element_type>::value,
  179. int>::type = 0>
  180. T cast(const BaseRef &handle) {
  181. if (!handle.m_ptr) {
  182. MS_LOG(EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null";
  183. }
  184. auto m = handle.m_ptr->cast<T>();
  185. if (nullptr != m) {
  186. return m;
  187. }
  188. return std::static_pointer_cast<U>(handle.m_ptr);
  189. }
  190. } // namespace utils
  191. class VectorRef : public BaseRef {
  192. public:
  193. using value_type = BaseRef;
  194. VectorRef() {}
  195. explicit VectorRef(const std::vector<BaseRef> &elements) : elements_(elements) {}
  196. VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {}
  197. // left reference
  198. virtual VectorRef &operator=(const VectorRef &other);
  199. ~VectorRef() override = default;
  200. std::shared_ptr<Base> copy() const override { return std::make_shared<VectorRef>(elements_); }
  201. bool empty() const { return (elements_.size() == 0); }
  202. std::size_t size() const { return elements_.size(); }
  203. MS_DECLARE_PARENT(VectorRef, BaseRef)
  204. const BaseRef &operator[](const std::size_t &dim) const {
  205. if (dim >= size()) {
  206. MS_LOG(EXCEPTION) << "Out of the size of the tuple.";
  207. }
  208. return elements_[dim];
  209. }
  210. BaseRef &operator[](const std::size_t &dim) {
  211. if (dim >= size()) {
  212. MS_LOG(EXCEPTION) << "Out of the size of the tuple.";
  213. }
  214. return elements_[dim];
  215. }
  216. uint32_t type() const override { return tid(); }
  217. std::string ToString() const override;
  218. std::vector<BaseRef> &elements() { return elements_; }
  219. void clear() { elements_.clear(); }
  220. bool operator==(const BaseRef &other) const override;
  221. bool operator==(const VectorRef &other) const;
  222. void push_back(const BaseRef &value) { elements_.push_back(value); }
  223. void push_back(BaseRef &&value) { elements_.push_back(value); }
  224. void emplace_back(const BaseRef &value) { elements_.emplace_back(value); }
  225. void emplace_back(BaseRef &&value) { elements_.emplace_back(value); }
  226. template <class InputIt>
  227. void insert(const iterator pos, const InputIt first, const InputIt last) {
  228. (void)elements_.insert(pos, first, last);
  229. }
  230. template <class InputIt>
  231. void insert(const const_iterator cpos, const InputIt first, const InputIt last) {
  232. auto pos = ConstIteratorCast(&elements_, cpos);
  233. (void)elements_.insert(pos, first, last);
  234. }
  235. const_iterator begin() const { return elements_.begin(); }
  236. const_iterator end() const { return elements_.end(); }
  237. const_reverse_iterator rbegin() const { return elements_.rbegin(); }
  238. const_reverse_iterator rend() const { return elements_.rend(); }
  239. iterator erase(const const_iterator cpos) {
  240. auto pos = ConstIteratorCast(&elements_, cpos);
  241. return elements_.erase(pos);
  242. }
  243. iterator erase(const const_iterator cfirst, const const_iterator clast) {
  244. auto first = ConstIteratorCast(&elements_, cfirst);
  245. auto last = ConstIteratorCast(&elements_, clast);
  246. return elements_.erase(first, last);
  247. }
  248. std::size_t hash() const override {
  249. std::stringstream buffer;
  250. buffer << ToString();
  251. return std::hash<std::string>()(buffer.str());
  252. }
  253. std::vector<BaseRef> elements_;
  254. };
  255. using VectorRefPtr = std::shared_ptr<VectorRef>;
  256. using set_iterator = std::set<BaseRef, BaseRefLess>::iterator;
  257. using const_set_iterator = std::set<BaseRef, BaseRefLess>::const_iterator;
  258. struct VectorRefHash {
  259. std::size_t operator()(const VectorRef &c) const { return c.hash(); }
  260. };
  261. class SetRef : public BaseRef {
  262. public:
  263. SetRef() {}
  264. explicit SetRef(const std::set<BaseRef, BaseRefLess> &elements) : elements_(elements) {}
  265. SetRef(const std::initializer_list<BaseRef> elements) : elements_(elements.begin(), elements.end()) {}
  266. SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {}
  267. // left reference
  268. virtual SetRef &operator=(const SetRef &other);
  269. bool operator==(const BaseRef &other) const override;
  270. bool operator==(const SetRef &other) const;
  271. ~SetRef() override = default;
  272. std::shared_ptr<Base> copy() const override { return std::make_shared<SetRef>(elements_); }
  273. bool empty() const { return (elements_.size() == 0); }
  274. std::size_t size() const { return elements_.size(); }
  275. MS_DECLARE_PARENT(SetRef, BaseRef)
  276. uint32_t type() const override { return tid(); }
  277. std::string ToString() const override;
  278. std::set<BaseRef, BaseRefLess> &elements() { return elements_; }
  279. void clear() { elements_.clear(); }
  280. void insert(const BaseRef &elem) { (void)elements_.insert(elem); }
  281. const_set_iterator begin() const { return elements_.begin(); }
  282. const_set_iterator end() const { return elements_.end(); }
  283. template <class InputIt>
  284. void insert(const InputIt first, const InputIt last) {
  285. (void)elements_.insert(first, last);
  286. }
  287. std::size_t count(const BaseRef &elem) const { return elements_.count(elem); }
  288. const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); }
  289. std::set<BaseRef, BaseRefLess> elements_;
  290. };
  291. using SetRefPtr = std::shared_ptr<SetRef>;
  292. class PyObjectRef : public BaseRef {
  293. public:
  294. explicit PyObjectRef(const py::object &py_object) : object_(py_object) {}
  295. explicit PyObjectRef(const py::tuple &tuple_obj) : object_(tuple_obj) {}
  296. ~PyObjectRef() override = default;
  297. std::shared_ptr<Base> copy() const override { return std::make_shared<PyObjectRef>(object_); }
  298. MS_DECLARE_PARENT(PyObjectRef, BaseRef)
  299. uint32_t type() const override { return tid(); }
  300. std::string ToString() const override { return py::str(object_); }
  301. bool operator==(const BaseRef &other) const override;
  302. bool operator==(const PyObjectRef &other) const;
  303. py::object object_;
  304. };
  305. class RunFunctionRef : public BaseRef {
  306. public:
  307. RunFunctionRef() {}
  308. explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {}
  309. ~RunFunctionRef() override = default;
  310. MS_DECLARE_PARENT(RunFunctionRef, BaseRef)
  311. uint32_t type() const override { return tid(); }
  312. std::string ToString() const override { return std::string("RunFunctionRef"); }
  313. bool operator==(const BaseRef &other) const override;
  314. bool operator==(const RunFunctionRef &other) const;
  315. RunFuncPtr func_;
  316. };
  317. } // namespace mindspore
  318. #endif // MINDSPORE_CCSRC_UTILS_NODE_REF_H_