You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.h 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. /**
  2. * \file imperative/python/src/tensor.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <variant>
  13. #include "megbrain/imperative/interpreter.h"
  14. #include "pybind11/pybind11.h"
  15. #include <string>
  16. #include "./pyext17.h"
  17. namespace mgb::imperative::python {
  18. template<typename T, typename B = pybind11::object>
  19. struct ObjectPtr : B {
  20. using B::B;
  21. T& operator*() {return reinterpret_cast<T&>(*B::ptr());}
  22. T* operator->() {return reinterpret_cast<T*>(B::ptr());}
  23. };
  24. } // namespace mgb::imperative::python
  25. #include "./grad_info.h" // for struct GradInfo
  26. #include "./trace_info.h" // for struct TraceInfo
  27. namespace mgb::imperative::python {
  28. extern interpreter::Interpreter::Channel* interpreter_for_py;
  29. class SharedHandle {
  30. using Handle = interpreter::Interpreter::Handle;
  31. static_assert(std::is_pointer_v<Handle>);
  32. std::shared_ptr<std::remove_pointer_t<Handle>> holder;
  33. public:
  34. inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){
  35. if (h) {
  36. interpreter_for_py->del(h);
  37. }
  38. }) {}
  39. SharedHandle(const SharedHandle&) = default;
  40. SharedHandle& operator=(const SharedHandle&) = default;
  41. SharedHandle(SharedHandle&&) = default;
  42. SharedHandle& operator=(SharedHandle&&) = default;
  43. inline Handle get() {return holder.get();}
  44. };
  45. struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
  46. using flags_t = uint64_t;
  47. struct Flags {
  48. static constexpr flags_t SCALAR = 1;
  49. static constexpr flags_t GRAD = 1 << 1;
  50. static constexpr flags_t TRACE = 1 << 2;
  51. };
  52. flags_t m_flags = 0;
  53. GradInfo m_grad_info;
  54. TraceInfo m_trace_info;
  55. SharedHandle m_handle;
  56. std::string user_custom_name;
  57. std::string automatic_name;
  58. cg::VarNode* m_var;
  59. using Handle = interpreter::Interpreter::Handle;
  60. inline Tensor() : m_handle(nullptr), m_var(nullptr) {}
  61. inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
  62. inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {}
  63. inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {}
  64. ~Tensor() = default;
  65. inline std::shared_ptr<Tensor> copy() {
  66. auto ret = std::make_shared<Tensor>(m_handle);
  67. ret->m_flags = m_flags;
  68. ret->m_grad_info = m_grad_info;
  69. ret->m_trace_info = m_trace_info;
  70. ret->m_var = m_var;
  71. return ret;
  72. }
  73. inline DType dtype() {
  74. if (m_var) {
  75. return m_var->dtype();
  76. }
  77. return interpreter_for_py->get_dtype(m_handle.get());
  78. }
  79. inline CompNode comp_node() {
  80. if (m_var) {
  81. return m_var->comp_node();
  82. }
  83. return interpreter_for_py->get_device(m_handle.get());
  84. }
  85. inline TensorShape shape() {
  86. if (m_var) {
  87. return m_var->shape();
  88. }
  89. return interpreter_for_py->get_shape(m_handle.get());
  90. }
  91. };
  92. struct TensorWrapper {
  93. std::shared_ptr<Tensor> m_tensor;
  94. inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) : m_tensor(std::move(tensor)) {}
  95. TensorWrapper(PyObject* args, PyObject* kwargs);
  96. ~TensorWrapper() = default;
  97. static constexpr auto tp_name = pybind11::detail::_("Tensor");
  98. using wrap_t = pyext17::wrap<TensorWrapper>;
  99. friend wrap_t;
  100. inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();}
  101. inline static TensorWrapper* try_cast(PyObject* op) {
  102. if (!wrap_t::type().isinstance(op)) return nullptr;
  103. return cast(op);
  104. }
  105. inline ObjectPtr<TensorWrapper, pybind11::handle> self() {return wrap_t::pycast(this);}
  106. template <typename... Args>
  107. static ObjectPtr<Tensor> make(Args&&... args) {
  108. auto* op = wrap_t::cnew(std::forward<Args>(args)...);
  109. return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
  110. }
  111. template <typename... Args>
  112. static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) {
  113. auto* op = wrap_t::cnew_with_type(pytype,std::forward<Args>(args)...);
  114. return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
  115. }
  116. PyObject* shape();
  117. PyObject* dtype();
  118. PyObject* device();
  119. PyObject* numpy();
  120. void reset(PyObject*);
  121. PyObject* detach();
  122. PyObject* isscalar();
  123. void setscalar();
  124. PyObject* _dev_tensor();
  125. void _swap_in();
  126. void _swap_out();
  127. void _drop();
  128. PyObject* varnode();
  129. void reset_varnode();
  130. PyObject* handle();
  131. void set_handle(PyObject *);
  132. PyObject* mixin_handle();
  133. PyObject* recording();
  134. PyObject* copied();
  135. void set_mixin_handle(PyObject*);
  136. void set_recording(PyObject*);
  137. PyObject* compiled_info();
  138. void set_compiled_info(PyObject *);
  139. PyObject* trace_mixin_info();
  140. void set_trace_mixin_info(PyObject *);
  141. PyObject* user_custom_name();
  142. void set_user_custom_name(PyObject *);
  143. PyObject* automatic_name();
  144. void set_automatic_name(PyObject *);
  145. PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
  146. };
  147. PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */);
  148. struct ApplyContext {
  149. static Tensor::flags_t global_disable;
  150. Tensor::flags_t flags;
  151. std::shared_ptr<OpDef> op;
  152. Tensor*const* args;
  153. size_t nargs;
  154. PyTypeObject* pytype = nullptr;
  155. bool backward = false;
  156. class scoped_disable : NonCopyableObj {
  157. Tensor::flags_t saved_flags;
  158. public:
  159. scoped_disable(Tensor::flags_t flags) : saved_flags(ApplyContext::global_disable) {
  160. ApplyContext::global_disable |= flags;
  161. }
  162. ~scoped_disable() {
  163. ApplyContext::global_disable = saved_flags;
  164. }
  165. };
  166. };
  167. using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;
  168. apply_result_t apply(ApplyContext& ctx);
  169. template <typename T>
  170. decltype(auto) resolve_arrow(T&& p) {
  171. if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) {
  172. auto* ret = p;
  173. return ret;
  174. } else {
  175. auto probe = [](auto&& p) -> decltype(p.operator->()) {};
  176. if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) {
  177. return resolve_arrow(p.operator->());
  178. } else {
  179. return std::forward<T>(p);
  180. }
  181. }
  182. }
  183. template <typename... Args>
  184. constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>);
  185. extern bool is_tracing; // FIXME: should use ApplyContext::global_enable
  186. extern bool is_compiled;
  187. template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0>
  188. apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
  189. ApplyContext ctx;
  190. Tensor* arg_arr[] = {resolve_arrow(args)...};
  191. ctx.flags = (0 | ... | args->m_flags);
  192. ctx.flags |= is_tracing ? Tensor::Flags::TRACE : 0;
  193. ctx.args = arg_arr;
  194. ctx.nargs = sizeof...(args);
  195. ctx.op = std::move(op);
  196. return apply(ctx);
  197. }
  198. template <typename T>
  199. auto apply(std::shared_ptr<OpDef> op, T&& tensors)
  200. -> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>,
  201. apply_result_t> {
  202. ApplyContext ctx;
  203. ctx.op = std::move(op);
  204. ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0;
  205. ctx.nargs = tensors.size();
  206. Tensor* args[ctx.nargs];
  207. ctx.args = args;
  208. for (size_t i = 0; i < ctx.nargs; ++i) {
  209. args[i] = resolve_arrow(tensors[i]);
  210. ctx.flags |= args[i]->m_flags;
  211. }
  212. return apply(ctx);
  213. }
  214. inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) {
  215. ApplyContext ctx;
  216. ctx.op = std::move(op);
  217. ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0;
  218. ctx.nargs = nargs;
  219. ctx.args = args;
  220. for (size_t i = 0; i < nargs; ++i) {
  221. ctx.flags |= args[i]->m_flags;
  222. }
  223. return apply(ctx);
  224. }
  225. void init_tensor(pybind11::module);
  226. extern PyObject *cpp_apply_with_tracing, *cpp_apply_compiled_mode;
  227. extern PyObject *cpp_apply_backward_varnode;
  228. } // namespace mgb::imperative::python
  229. namespace pybind11::detail {
  230. template<> struct type_caster<mgb::imperative::python::TensorWrapper> : mgb::imperative::python::TensorWrapper::wrap_t::caster {};
  231. } // namespace pybind11::detail

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台