You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.h 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_IR_BASE_H_
  17. #define MINDSPORE_CCSRC_IR_BASE_H_
  18. #include <atomic>
  19. #include <iostream>
  20. #include <memory>
  21. #include <mutex>
  22. #include <string>
  23. #include <typeinfo>
  24. #include <unordered_map>
  25. #include <vector>
  26. #include <utility>
  27. #include "utils/visible.h"
  28. #include "utils/log_adapter.h"
  29. #include "utils/ordered_set.h"
  30. namespace mindspore {
  31. template <typename T>
  32. struct is_shared_ptr : public std::false_type {};
  33. template <typename T>
  34. struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {};
  35. class Base : public std::enable_shared_from_this<Base> {
  36. public:
  37. constexpr Base() = default;
  38. Base(const Base &other) : std::enable_shared_from_this<Base>(other) {}
  39. virtual bool operator==(const Base &rhs) {
  40. if (this == &rhs) {
  41. return true;
  42. }
  43. return false;
  44. }
  45. virtual Base &operator=(const Base &) { return *this; }
  46. virtual ~Base() = default;
  47. virtual std::size_t hash() const { return tid(); }
  48. virtual std::string ToString() const { return type_name(); }
  49. virtual void dump() const { std::cout << ToString() << std::endl; }
  50. virtual std::string DumpText() const { return ToString(); }
  51. virtual const bool IsFromTypeId(uint32_t tid) const;
  52. virtual std::string type_name() const { return "Base"; }
  53. static uint32_t GetTypeId(const char *const type_key);
  54. virtual uint32_t tid() const {
  55. static const uint32_t tid = GetTypeId(typeid(Base).name());
  56. return tid;
  57. }
  58. template <typename T,
  59. typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type * = nullptr>
  60. inline bool isa() const {
  61. static const uint32_t tid = GetTypeId(typeid(T).name());
  62. return this->IsFromTypeId(tid);
  63. }
  64. template <typename T, typename U = typename std::enable_if<is_shared_ptr<T>::value, typename T::element_type>::type>
  65. inline T cast() {
  66. if (isa<U>()) {
  67. return std::static_pointer_cast<U>(shared_from_this());
  68. } else {
  69. return nullptr;
  70. }
  71. }
  72. protected:
  73. template <typename Derived>
  74. std::shared_ptr<Derived> shared_from_base() {
  75. return std::static_pointer_cast<Derived>(shared_from_this());
  76. }
  77. };
  78. using BasePtr = std::shared_ptr<Base>;
  79. using BaseWeakPtr = std::weak_ptr<Base>;
  80. template <typename T, typename U>
  81. inline T *cast(U *source) {
  82. if (source != nullptr && source->template isa<T>()) {
  83. return static_cast<T *>(source);
  84. } else {
  85. return nullptr;
  86. }
  87. }
  88. template <
  89. typename T, typename U,
  90. typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type * = nullptr>
  91. inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) {
  92. if (r != nullptr && r->template isa<T>()) {
  93. return std::static_pointer_cast<T>(r);
  94. } else {
  95. return std::shared_ptr<T>();
  96. }
  97. }
  98. #define MS_DECLARE_PARENT(current_t, parent_t) \
  99. uint32_t tid() const override { \
  100. static const uint32_t tid = GetTypeId(typeid(current_t).name()); \
  101. return tid; \
  102. } \
  103. const bool IsFromTypeId(uint32_t from_tid) const override { \
  104. static const uint32_t tid = Base::GetTypeId(typeid(current_t).name()); \
  105. if (tid == from_tid) { \
  106. return true; \
  107. } \
  108. return parent_t::IsFromTypeId(from_tid); \
  109. } \
  110. std::string type_name() const override { return #current_t; }
  111. class Type;
  112. using TypePtr = std::shared_ptr<Type>;
  113. class AnfNode;
  114. using AnfNodePtr = std::shared_ptr<AnfNode>;
  115. using AnfNodePtrList = std::vector<AnfNodePtr>;
  116. using AnfNodeSet = OrderedSet<AnfNodePtr>;
  117. namespace abstract {
  118. class AbstractBase;
  119. using AbstractBasePtr = std::shared_ptr<AbstractBase>;
  120. using AbstractAttribute = std::pair<std::string, AbstractBasePtr>;
  121. class AnalysisContext;
  122. using AnalysisContextPtr = std::shared_ptr<AnalysisContext>;
  123. } // namespace abstract
  124. struct MS_EXPORT TypeIdManager {
  125. std::mutex mutex;
  126. std::atomic<uint32_t> type_counter{0};
  127. std::unordered_map<std::string, uint32_t> map;
  128. static TypeIdManager *Get();
  129. TypeIdManager() : mutex(), type_counter(0), map() {}
  130. };
  131. } // namespace mindspore
  132. #endif // MINDSPORE_CCSRC_IR_BASE_H_