You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. #include "./common.h"
  2. #include <pybind11/operators.h>
  3. #include <pybind11/pytypes.h>
  4. #include "./helper.h"
  5. #include "./numpy_dtypes.h"
  6. #include "megbrain/comp_node.h"
  7. #include "megbrain/graph.h"
  8. #include "megbrain/imperative/physical_tensor.h"
  9. #include "megbrain/version.h"
  10. #if MGB_ENABLE_OPR_MM
  11. #include "megbrain/opr/mm_handler.h"
  12. #endif
  13. #if MEGDNN_WITH_CUDA
  14. #include "cuda_sm_gen.h"
  15. #endif
  16. namespace py = pybind11;
  17. using namespace mgb;
  18. using namespace imperative;
  19. namespace {
  20. template <typename XTensorND>
  21. auto def_TensorND(py::object parent, const char* name) {
  22. return py::class_<XTensorND>(parent, name)
  23. .def_property_readonly(
  24. "shape", py::overload_cast<>(&XTensorND::shape, py::const_))
  25. .def_property_readonly(
  26. "dtype", py::overload_cast<>(&XTensorND::dtype, py::const_))
  27. .def_property_readonly(
  28. "comp_node", py::overload_cast<>(&XTensorND::comp_node, py::const_))
  29. .def("copy_from", &XTensorND::template copy_from<DeviceTensorStorage>)
  30. .def("copy_from", &XTensorND::template copy_from<HostTensorStorage>)
  31. .def("copy_from_fixlayout",
  32. py::overload_cast<const DeviceTensorND&>(
  33. &XTensorND::template copy_from_fixlayout<DeviceTensorStorage>))
  34. .def("copy_from_fixlayout",
  35. py::overload_cast<const HostTensorND&>(
  36. &XTensorND::template copy_from_fixlayout<HostTensorStorage>));
  37. }
  38. std::string default_device = "xpux";
  39. } // namespace
  40. void set_default_device(const std::string& device) {
  41. default_device = device;
  42. }
  43. void init_nccl_env(const std::string& ip, int port, int nranks, int rank, int root) {
  44. #if MGB_ENABLE_OPR_MM
  45. auto&& help = mgb::opr::BatchSendRecvHelper::getInstance();
  46. bool res = help->init(nranks, rank, ip, port, root);
  47. auto p = help->get(std::string("init_all_cards"));
  48. #else
  49. mgb_throw(
  50. MegBrainError,
  51. "MegEngine compiled without MM opr, doesn't support init_nccl_env");
  52. #endif
  53. }
  54. std::string get_default_device() {
  55. return default_device;
  56. }
  57. py::handle py_comp_node_type;
  58. void init_common(py::module m) {
  59. auto PyCompNode =
  60. py::class_<CompNode>(m, "CompNode")
  61. .def(py::init())
  62. .def(py::init(
  63. py::overload_cast<const std::string&>(&CompNode::load)))
  64. .def_property_readonly(
  65. "logical_name",
  66. [](const CompNode& cn) { return cn.to_string_logical(); })
  67. .def_property_readonly(
  68. "physical_name",
  69. [](const CompNode& cn) { return cn.to_string_physical(); })
  70. .def_property_readonly(
  71. "get_mem_status_bytes",
  72. [](const CompNode& cn) {
  73. return cn.get_mem_status_bytes();
  74. })
  75. .def_property_readonly(
  76. "get_used_memory",
  77. [](const CompNode& cn) { return cn.get_used_memory(); })
  78. .def_property_readonly(
  79. "get_max_used_memory",
  80. [](const CompNode& cn) { return cn.get_max_used_memory(); })
  81. .def_property_readonly(
  82. "get_reserved_memory",
  83. [](const CompNode& cn) { return cn.get_reserved_memory(); })
  84. .def_property_readonly(
  85. "get_max_reserved_memory",
  86. [](const CompNode& cn) {
  87. return cn.get_max_reserved_memory();
  88. })
  89. .def_static(
  90. "reset_max_memory_stats",
  91. [](const CompNode& cn) {
  92. cn.reset_max_used_memory();
  93. cn.reset_max_reserved_memory();
  94. })
  95. .def("create_event", &CompNode::create_event,
  96. py::arg("flags") = 0ul)
  97. .def_static("_set_default_device", &set_default_device)
  98. .def_static("_get_default_device", &get_default_device)
  99. .def("__str__", &CompNode::to_string_logical)
  100. .def("__repr__",
  101. [](const CompNode& cn) {
  102. return mgb::ssprintf(
  103. "CompNode(\"%s\" from \"%s\")",
  104. cn.to_string_physical().c_str(),
  105. cn.to_string_logical().c_str());
  106. })
  107. .def("__hash__", [](CompNode cn) { return mgb::hash(cn); })
  108. .def_static("_sync_all", &CompNode::sync_all)
  109. .def(py::self == py::self)
  110. .def_static(
  111. "_get_device_count", &CompNode::get_device_count,
  112. "Get total number of specific devices on this system")
  113. .def(py::pickle(
  114. [](const CompNode& cn) {
  115. return py::str(cn.to_string_logical());
  116. },
  117. [](py::str cn) { return CompNode::load(cn); }));
  118. py_comp_node_type = PyCompNode.inc_ref();
  119. py::class_<CompNode::Event, std::shared_ptr<CompNode::Event>>(PyCompNode, "Event")
  120. .def("record", &CompNode::Event::record)
  121. .def("wait", &CompNode::Event::host_wait);
  122. py::implicitly_convertible<std::string, CompNode>();
  123. py::class_<CompNode::DeviceProperties>(m, "DeviceProperties")
  124. .def(py::init())
  125. .def_property_readonly(
  126. "name",
  127. [](const CompNode::DeviceProperties prop) { return prop.name; })
  128. .def_property_readonly(
  129. "total_memory",
  130. [](const CompNode::DeviceProperties prop) {
  131. return prop.total_memory;
  132. })
  133. .def_property_readonly(
  134. "major",
  135. [](const CompNode::DeviceProperties prop) { return prop.major; })
  136. .def_property_readonly("minor", [](const CompNode::DeviceProperties prop) {
  137. return prop.minor;
  138. });
  139. def_TensorND<DeviceTensorND>(m, "DeviceTensorND")
  140. .def("numpy", [](const DeviceTensorND& self) {
  141. HostTensorND hv;
  142. hv.copy_from(self).sync();
  143. return py::reinterpret_steal<py::object>(
  144. npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
  145. });
  146. def_TensorND<HostTensorND>(m, "HostTensorND")
  147. .def(py::init([](py::array data, CompNode cn, DType dtype) {
  148. if (!cn.valid()) {
  149. throw py::type_error("device must not be None");
  150. }
  151. return npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
  152. }))
  153. .def("numpy", [](const HostTensorND& self) {
  154. return py::reinterpret_steal<py::object>(
  155. npy::ndarray_from_tensor(self, npy::ShareType::TRY_SHARE));
  156. });
  157. py::class_<cg::OperatorNodeConfig>(m, "OperatorNodeConfig")
  158. .def(py::init())
  159. .def_property(
  160. "name",
  161. [](const OperatorNodeConfig& config) -> py::object {
  162. auto name = config.name();
  163. if (name.valid()) {
  164. return py::str(name.val());
  165. } else {
  166. return py::none();
  167. }
  168. },
  169. [](OperatorNodeConfig& config, std::string name) {
  170. config.name(std::move(name));
  171. })
  172. .def_property(
  173. "dtype",
  174. [](const OperatorNodeConfig& config) {
  175. return config.output_dtype();
  176. },
  177. [](OperatorNodeConfig& config, DType dtype) {
  178. config.output_dtype(dtype);
  179. })
  180. .def_property(
  181. "comp_node_arr",
  182. [](const OperatorNodeConfig& config) -> py::tuple {
  183. auto arr = config.comp_node();
  184. std::vector<CompNode> tmp(arr.begin(), arr.end());
  185. return py::cast(tmp);
  186. },
  187. [](OperatorNodeConfig& config, std::vector<CompNode> cns) {
  188. config.comp_node_arr({cns.begin(), cns.end()});
  189. })
  190. .def_property(
  191. "comp_node",
  192. [](const OperatorNodeConfig& config) {
  193. auto arr = config.comp_node();
  194. if (arr.size() != 1) {
  195. throw py::value_error("invalid number of comp_node");
  196. }
  197. return arr[0];
  198. },
  199. [](OperatorNodeConfig& config, CompNode cn) {
  200. OperatorNodeConfig::CompNodeArray arr{cn};
  201. config.comp_node_arr(arr);
  202. });
  203. py::class_<LogicalTensorDesc>(m, "TensorAttr")
  204. .def(py::init())
  205. .def(py::init([](const TensorShape& shape, const DType& dtype,
  206. const CompNode& comp_node) {
  207. return LogicalTensorDesc{TensorLayout{shape, dtype}, comp_node};
  208. }))
  209. .def_property(
  210. "shape",
  211. [](const LogicalTensorDesc& desc) {
  212. return static_cast<TensorShape>(desc.layout);
  213. },
  214. [](LogicalTensorDesc& desc, TensorShape shape) {})
  215. .def_property(
  216. "dtype",
  217. [](const LogicalTensorDesc& desc) { return desc.layout.dtype; },
  218. [](LogicalTensorDesc& desc, DType dtype) {
  219. desc.layout.dtype = dtype;
  220. })
  221. .def_readwrite("comp_node", &LogicalTensorDesc::comp_node);
  222. py::enum_<CompNode::DeviceType>(m, "DeviceType")
  223. .value("UNSPEC", CompNode::DeviceType::UNSPEC)
  224. .value("CUDA", CompNode::DeviceType::CUDA)
  225. .value("ROCM", CompNode::DeviceType::ROCM)
  226. .value("CPU", CompNode::DeviceType::CPU)
  227. .value("CAMBRICON", CompNode::DeviceType::CAMBRICON)
  228. .value("ATLAS", CompNode::DeviceType::ATLAS)
  229. .value("MULTITHREAD", CompNode::DeviceType::MULTITHREAD)
  230. .value("MAX_DEVICE_ID", CompNode::DeviceType::MAX_DEVICE_ID);
  231. m.def("set_prealloc_config", &CompNode::set_prealloc_config,
  232. "specifies how to pre-allocate from raw dev allocator");
  233. m.def("get_device_prop", &CompNode::get_device_prop);
  234. m.def("get_supported_sm_versions", []() {
  235. #if MEGDNN_WITH_CUDA
  236. static const char* mge_gen_code = MGE_CUDA_GENCODE;
  237. #else
  238. static const char* mge_gen_code = "-1";
  239. #endif
  240. return mge_gen_code;
  241. });
  242. m.def("get_cuda_version", []() { return mgb::get_cuda_version(); });
  243. m.def("get_cudnn_version", []() { return mgb::get_cudnn_version(); });
  244. m.def("get_tensorrt_version", []() { return mgb::get_tensorrt_version(); });
  245. m.def("what_is_xpu",
  246. [] { return CompNode::Locator::parse("xpux").to_physical().type; });
  247. m.def("init_nccl_env", &init_nccl_env);
  248. init_npy_num_bfloat16(m);
  249. init_npy_num_intbx(m);
  250. init_dtypes(m);
  251. }