You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.h 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #pragma once
  2. #pragma GCC diagnostic ignored "-Wmissing-field-initializers"
  3. #include <variant>
  4. #include <string>
  5. #include <unordered_map>
  6. #include "megbrain/imperative/interpreter.h"
  7. #include "pybind11/pybind11.h"
  8. #include "./pyext17.h"
  9. #include "megbrain/imperative/dispatch.h"
  10. #include "megbrain/imperative/transformations/scalar.h"
  11. #include "megbrain/imperative/transformations/symbol.h"
  12. #include "megbrain/imperative/utils/span.h"
  13. namespace mgb::imperative::python {
  14. template <typename T, typename B = pybind11::object>
  15. struct ObjectPtr : B {
  16. using B::B;
  17. T& operator*() { return reinterpret_cast<T&>(*B::ptr()); }
  18. T* operator->() { return reinterpret_cast<T*>(B::ptr()); }
  19. };
  20. } // namespace mgb::imperative::python
  21. namespace mgb::imperative::python {
  22. extern interpreter::Interpreter::Channel* interpreter_for_py;
  23. extern PyTypeObject* py_tensor_type;
  24. extern PyTypeObject* py_varnode_type;
  25. extern pybind11::handle py_device_type;
  26. extern PyObject* cpp_use_symbolic_shape;
  27. extern PyObject* cpp_astensor1d;
  28. struct Tensor {
  29. private:
  30. ValueRef m_data;
  31. std::string m_name;
  32. public:
  33. using Handle = interpreter::Interpreter::Handle;
  34. inline explicit Tensor(ValueRef data) : m_data{data} {}
  35. ~Tensor() = default;
  36. inline Tensor copy() { return *this; }
  37. inline DType dtype() { return *data().dtype(); }
  38. inline CompNode comp_node() { return *data().device(); }
  39. inline std::optional<ValueShape> shape() {
  40. auto shape = data().shape();
  41. if (!shape) {
  42. return {};
  43. }
  44. return *shape;
  45. }
  46. inline HostValue::ref_t numpy() { return data().numpy(); }
  47. inline void reset(ValueRef value) {
  48. m_data = value;
  49. if (!m_name.empty()) {
  50. set_name(m_name);
  51. }
  52. }
  53. inline ValueRef data() const { return m_data.unwrap(); }
  54. bool is_scalar() { return data().is_scalar(); }
  55. inline std::string name() { return m_name; }
  56. inline void set_name(std::string name) {
  57. m_name = name;
  58. if (!name.empty()) {
  59. auto output = imperative::apply(RenameValue(name), m_data)[0];
  60. m_data = output;
  61. }
  62. }
  63. };
  64. struct TensorWrapper {
  65. public:
  66. std::optional<Tensor> m_tensor;
  67. inline TensorWrapper(ValueRef value) { m_tensor.emplace(value); }
  68. TensorWrapper(PyObject* args, PyObject* kwargs);
  69. ~TensorWrapper() = default;
  70. static constexpr auto tp_name = pybind11::detail::_("Tensor");
  71. using wrap_t = pyext17::wrap<TensorWrapper>;
  72. friend wrap_t;
  73. inline static TensorWrapper* cast(PyObject* obj) {
  74. return reinterpret_cast<wrap_t*>(obj)->inst();
  75. }
  76. inline static TensorWrapper* try_cast(PyObject* obj) {
  77. if (!wrap_t::type().isinstance(obj))
  78. return nullptr;
  79. return cast(obj);
  80. }
  81. inline ObjectPtr<TensorWrapper, pybind11::handle> self() {
  82. return wrap_t::pycast(this);
  83. }
  84. template <typename... Args>
  85. static ObjectPtr<Tensor> make(Args&&... args) {
  86. auto* op = wrap_t::cnew(std::forward<Args>(args)...);
  87. return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
  88. }
  89. template <typename... Args>
  90. static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) {
  91. auto* op = wrap_t::cnew_with_type(pytype, std::forward<Args>(args)...);
  92. return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
  93. }
  94. PyObject* shape();
  95. PyObject* dtype();
  96. PyObject* device();
  97. PyObject* numpy();
  98. void reset(PyObject*);
  99. PyObject* detach();
  100. PyObject* isscalar();
  101. PyObject* _dev_tensor();
  102. void _drop();
  103. PyObject* varnode();
  104. PyObject* recording();
  105. PyObject* copied();
  106. PyObject* module_trace_info();
  107. void set_module_trace_info(PyObject*);
  108. void _set_name(PyObject*);
  109. PyObject* _detail();
  110. PyObject* _var();
  111. PyObject* _graph();
  112. void _watch();
  113. };
  114. PyObject* py_apply(
  115. PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */);
  116. void init_tensor(pybind11::module);
  117. extern PyObject* cpp_apply_module_trace;
  118. } // namespace mgb::imperative::python
  119. namespace pybind11::detail {
  120. template <>
  121. struct type_caster<mgb::imperative::python::TensorWrapper>
  122. : mgb::imperative::python::TensorWrapper::wrap_t::caster {};
  123. } // namespace pybind11::detail