You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

value.cpp 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. #include "megbrain/imperative/value.h"
  2. #include "megbrain/imperative/basic_operators.h"
  3. #include "megbrain/imperative/dispatch.h"
  4. #include "megbrain/imperative/utils/map.h"
  5. namespace mgb {
  6. namespace imperative {
  7. namespace {
  8. static /*thread_local*/ size_t nr_watched_values = 0;
  9. static /*thread_local*/ uint64_t nr_values = 0;
  10. static /*thread_local*/ bool recording_values = false;
  11. static /*thread_local*/ std::vector<ValueWeakRef> recorded_values;
  12. static WeakValueMap<uint64_t, ValueWeakRef> registered_values;
  13. } // namespace
  14. ValueRef::storage_t& ValueRef::storage() const {
  15. if (mgb_likely(!m_storage->m_successor.m_storage)) {
  16. return m_storage;
  17. }
  18. while (m_storage->m_successor.m_storage) {
  19. m_storage = m_storage->m_successor.m_storage;
  20. }
  21. return m_storage;
  22. }
  23. const Value* ValueRef::as(size_t typecode) const {
  24. auto&& storage = this->storage();
  25. if (storage->m_typecode != typecode) {
  26. return nullptr;
  27. }
  28. return static_cast<Value*>(storage.get());
  29. }
  30. bool ValueRef::is(size_t typecode) const {
  31. return this->storage()->m_typecode == typecode;
  32. }
  33. TypedValueRef<DeviceValue> ValueRef::dev_tensor() const {
  34. return imperative::apply(GetAttr(GetAttr::Data), *this)[0].cast_ref<DeviceValue>();
  35. }
  36. TypedValueRef<HostValue> ValueRef::numpy() const {
  37. return imperative::apply(GetAttr(GetAttr::Value), *this)[0].cast_ref<HostValue>();
  38. }
  39. TypedValueRef<CompNodeValue> ValueRef::device() const {
  40. return imperative::apply(GetAttr(GetAttr::Device), *this)[0]
  41. .cast_ref<CompNodeValue>();
  42. }
  43. TypedValueRef<ShapeValue> ValueRef::shape() const {
  44. return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].cast_ref<ShapeValue>();
  45. }
  46. TypedValueRef<DTypeValue> ValueRef::dtype() const {
  47. return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref<DTypeValue>();
  48. }
  49. TypedValueRef<StringValue> ValueRef::name() const {
  50. return imperative::apply(GetName(), *this)[0].cast_ref<StringValue>();
  51. }
  52. bool ValueRef::is_scalar() const {
  53. return imperative::apply(IsScalar(), *this)[0].cast<BoolValue>();
  54. }
  55. void ValueRef::watch() const {
  56. mgb_assert(m_storage);
  57. storage()->m_watching++;
  58. nr_watched_values++;
  59. storage()->on_watch();
  60. // TODO:
  61. // imperative::apply(Watch(), this);
  62. }
  63. void ValueRef::unwatch() const {
  64. mgb_assert(m_storage);
  65. storage()->m_watching--;
  66. nr_watched_values--;
  67. storage()->on_unwatch();
  68. }
  69. ValueRef ValueRef::unwrap() const {
  70. auto& context = Transformation::get_context();
  71. if (mgb_unlikely(context.next_transformation)) {
  72. ValueRef value = *this;
  73. for (size_t i = 0; i < context.next_transformation; ++i) {
  74. value = context.transformations[i]->unwrap(value);
  75. }
  76. return value;
  77. }
  78. return *this;
  79. }
  80. std::string ValueRef::to_string() const {
  81. if (!m_storage) {
  82. return "<empty value>";
  83. }
  84. return ssprintf(
  85. "(%zu:%zu) %s", id(), storage()->m_id, storage()->to_string().c_str());
  86. }
  87. std::string ValueRef::raw_type() const {
  88. if (!m_storage) {
  89. return "null";
  90. }
  91. auto& types = Value::registered_types();
  92. mgb_assert(types.size() > m_storage->m_typecode);
  93. return types[m_storage->m_typecode].name();
  94. }
  95. bool ValueRef::watching() const {
  96. if (!m_storage) {
  97. return false;
  98. }
  99. return this->storage()->m_watching;
  100. }
  101. ValueRef ValueRef::make(ValueRef::storage_t storage) {
  102. if (recording_values) {
  103. recorded_values.push_back({storage});
  104. }
  105. return {storage};
  106. }
  107. bool ValueRef::any_watching() {
  108. return nr_watched_values != 0;
  109. }
  110. ValueRef ValueWeakRef::lock() {
  111. auto strong_storage = m_storage.lock();
  112. if ((!strong_storage) || strong_storage->m_successor) {
  113. return {};
  114. }
  115. return {strong_storage};
  116. }
  117. Value::Value(size_t typecode) : m_typecode{typecode} {
  118. m_id = nr_values++;
  119. }
  120. Value::~Value() {
  121. if (m_watching) {
  122. debug::notify_event("dtor");
  123. }
  124. }
  125. size_t Value::register_type(std::type_index type) {
  126. auto& types = const_cast<std::vector<std::type_index>&>(registered_types());
  127. types.push_back(type);
  128. return types.size() - 1;
  129. }
  130. const std::vector<std::type_index>& Value::registered_types() {
  131. static std::vector<std::type_index> sm_registered_types;
  132. return sm_registered_types;
  133. }
  134. void Value::register_value(ValueRef value) {
  135. registered_values[value.id()] = ValueWeakRef(value);
  136. }
  137. ValueRef Value::get_value_by_id(uint64_t id) {
  138. auto& weak_value = registered_values[id];
  139. if (auto value = weak_value.lock()) {
  140. return value;
  141. }
  142. return {};
  143. }
  144. void Value::begin_record_values() {
  145. mgb_assert(!recording_values);
  146. recording_values = true;
  147. recorded_values.clear();
  148. }
  149. std::vector<ValueRef> Value::end_record_values() {
  150. recording_values = false;
  151. std::vector<ValueRef> recorded_strong_values;
  152. for (auto&& weak_value : recorded_values) {
  153. if (auto value = weak_value.lock()) {
  154. recorded_strong_values.push_back(value);
  155. }
  156. }
  157. return recorded_strong_values;
  158. }
  159. void Value::try_rethrow() {
  160. if (m_typecode == ErrorValue::TYPE_CODE) {
  161. auto message = static_cast<ErrorValue*>(this)->message();
  162. mgb_throw(MegBrainError, "invalid value: %s", message.c_str());
  163. }
  164. }
  165. inline void ValueRefList::init(size_t nr_elems) {
  166. m_size = nr_elems;
  167. if (m_size > 0) {
  168. if (m_size == 1) {
  169. m_data = inline_storage();
  170. } else {
  171. auto& context = Transformation::get_context();
  172. m_data = context.allocator.allocate(m_size);
  173. }
  174. for (size_t i = 0; i < m_size; ++i) {
  175. new (m_data + i) ValueRef();
  176. }
  177. } else {
  178. m_data = nullptr;
  179. }
  180. }
  181. ValueRefList::ValueRefList(size_t nr_elems) {
  182. init(nr_elems);
  183. }
  184. /*ValueRefList::ValueRefList(std::initializer_list<ValueRef> values)
  185. : ValueRefList(values.begin(), values.end()) {}*/
  186. ValueRefList::ValueRefList(const ValueRefList& rhs)
  187. : ValueRefList(rhs.cbegin(), rhs.cend()) {}
  188. ValueRefList::ValueRefList(ValueRefList&& rhs) : ValueRefList() {
  189. m_size = rhs.m_size;
  190. if (rhs.m_data == rhs.inline_storage()) {
  191. m_data = inline_storage();
  192. new (m_data) ValueRef();
  193. m_data[0] = std::move(rhs.m_data[0]);
  194. } else {
  195. m_data = rhs.m_data;
  196. rhs.m_data = nullptr;
  197. rhs.m_size = 0;
  198. }
  199. }
  200. ValueRefList& ValueRefList::operator=(const ValueRefList& rhs) {
  201. if (this == &rhs) {
  202. return *this;
  203. }
  204. clear();
  205. init(rhs.m_size);
  206. for (size_t i = 0; i < m_size; ++i) {
  207. m_data[i] = rhs.m_data[i];
  208. }
  209. return *this;
  210. }
  211. ValueRefList& ValueRefList::operator=(ValueRefList&& rhs) {
  212. if (this == &rhs) {
  213. return *this;
  214. }
  215. clear();
  216. if (rhs.m_data == rhs.inline_storage()) {
  217. m_data = inline_storage();
  218. new (m_data) ValueRef();
  219. m_data[0] = rhs.m_data[0];
  220. m_size = 1;
  221. rhs.clear();
  222. } else {
  223. m_data = rhs.m_data;
  224. m_size = rhs.m_size;
  225. rhs.m_data = nullptr;
  226. rhs.m_size = 0;
  227. }
  228. return *this;
  229. }
  230. ValueRefList::~ValueRefList() {
  231. clear();
  232. }
  233. void ValueRefList::clear() {
  234. for (size_t i = 0; i < m_size; ++i) {
  235. m_data[i].~ValueRef();
  236. }
  237. if (m_data) {
  238. if (m_size != 1) {
  239. Transformation::get_context().allocator.deallocate(m_data, m_size);
  240. } else {
  241. mgb_assert(m_data == inline_storage());
  242. }
  243. }
  244. m_data = nullptr;
  245. m_size = 0;
  246. }
  247. } // namespace imperative
  248. } // namespace mgb