You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bfloat16.cpp 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. /**
  2. * \file python_module/src/cpp/bfloat16.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \brief numpy dtypes for bfloat16
  7. *
  8. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  9. *
  10. */
  11. #include "megbrain/common.h"
  12. #include "megbrain/dtype.h"
  13. #include <Python.h>
  14. #include <structmember.h>
  15. #define NO_IMPORT_ARRAY 1
  16. #include "./numpy_incl.h"
  17. #pragma GCC diagnostic ignored "-Wmissing-field-initializers"
  18. namespace {
  19. struct BFloat16Type {
  20. static int npy_typenum;
  21. mgb::dt_bfloat16 value;
  22. struct PyObj;
  23. struct NpyType;
  24. template <typename S, typename T>
  25. struct NpyCast;
  26. };
  27. int BFloat16Type::npy_typenum;
  28. /* ==================== BFloat16Type::NpyCast ==================== */
  29. template <typename S>
  30. struct BFloat16Type::NpyCast<S, BFloat16Type> {
  31. static void apply(void* from_, void* to_, npy_intp n, void* /*fromarr*/,
  32. void* /*toarr*/) {
  33. auto from = static_cast<S*>(from_);
  34. auto to = static_cast<BFloat16Type*>(to_);
  35. for (npy_intp i = 0; i < n; ++i) {
  36. float cur = static_cast<float>(from[i]);
  37. to[i].value = cur;
  38. }
  39. }
  40. };
  41. template <typename T>
  42. struct BFloat16Type::NpyCast<BFloat16Type, T> {
  43. static void apply(void* from_, void* to_, npy_intp n, void* /*fromarr*/,
  44. void* /*toarr*/) {
  45. auto from = static_cast<BFloat16Type*>(from_);
  46. auto to = static_cast<T*>(to_);
  47. for (npy_intp i = 0; i < n; ++i) {
  48. to[i] = from[i].value;
  49. }
  50. }
  51. };
  52. /* ==================== BFloat16Type::PyObj ==================== */
  53. struct BFloat16Type::PyObj {
  54. PyObject_HEAD BFloat16Type obj;
  55. static PyTypeObject py_type;
  56. static PyObject* from_bfloat16(BFloat16Type val) {
  57. auto p = reinterpret_cast<PyObj*>(py_type.tp_alloc(&py_type, 0));
  58. p->obj.value = val.value;
  59. return reinterpret_cast<PyObject*>(p);
  60. }
  61. static PyObject* py_new(PyTypeObject* type, PyObject* args, PyObject* kwds);
  62. static PyObject* py_repr(PyObject* obj);
  63. static PyObject* py_richcompare(PyObject* a, PyObject* b, int op);
  64. };
  65. PyTypeObject BFloat16Type::PyObj::py_type;
  66. PyObject* BFloat16Type::PyObj::py_new(PyTypeObject* type, PyObject* args,
  67. PyObject* kwds) {
  68. PyObj* self;
  69. Py_ssize_t size;
  70. self = (PyObj*)type->tp_alloc(type, 0);
  71. size = PyTuple_GET_SIZE(args);
  72. if (size > 1) {
  73. PyErr_SetString(PyExc_TypeError, "BFloat16Type Only has 1 parameter");
  74. return NULL;
  75. }
  76. PyObject* x = PyTuple_GET_ITEM(args, 0);
  77. if (PyObject_IsInstance(x, (PyObject*)&py_type)) {
  78. Py_INCREF(x);
  79. return x;
  80. }
  81. if (!PyFloat_Check(x)) {
  82. PyErr_SetString(PyExc_TypeError,
  83. "BFloat16Type must be initialized wit float");
  84. return NULL;
  85. }
  86. const float s = PyFloat_AsDouble(x);
  87. self->obj.value = s;
  88. return (PyObject*)self;
  89. }
  90. PyObject* BFloat16Type::PyObj::py_repr(PyObject* obj) {
  91. float fval = static_cast<float>(((PyObj*)obj)->obj.value);
  92. return PyUnicode_FromString(mgb::ssprintf("%f", fval).c_str());
  93. }
  94. PyObject* BFloat16Type::PyObj::py_richcompare(PyObject* a, PyObject* b,
  95. int op) {
  96. mgb_assert(PyObject_IsInstance(a, (PyObject*)&py_type));
  97. auto bval = PyFloat_AsDouble(b);
  98. if (bval == -1 && PyErr_Occurred()) {
  99. return NULL;
  100. }
  101. double aval = ((PyObj*)a)->obj.value;
  102. #define OP(py, op) \
  103. case py: { \
  104. if (aval op bval) { \
  105. Py_RETURN_TRUE; \
  106. } else { \
  107. Py_RETURN_FALSE; \
  108. } \
  109. }
  110. switch (op) {
  111. OP(Py_LT, <)
  112. OP(Py_LE, <=)
  113. OP(Py_EQ, ==)
  114. OP(Py_NE, !=)
  115. OP(Py_GT, >)
  116. OP(Py_GE, >=)
  117. };
  118. #undef OP
  119. return Py_NotImplemented;
  120. }
  121. /* ==================== BFloat16Type<N>::NpyType ==================== */
  122. struct BFloat16Type::NpyType {
  123. static PyArray_ArrFuncs funcs;
  124. static PyArray_Descr descr;
  125. static bool init();
  126. static void copyswap(void* dst, void* src, int swap, void* /*arr*/) {
  127. if (src) {
  128. mgb_assert(!swap);
  129. memcpy(dst, src, sizeof(BFloat16Type));
  130. }
  131. }
  132. static PyObject* getitem(void* data, void* ap) {
  133. return BFloat16Type::PyObj::from_bfloat16(
  134. *static_cast<BFloat16Type*>(data));
  135. }
  136. static int setitem(PyObject* op, void* ov, void* ap);
  137. };
  138. PyArray_ArrFuncs BFloat16Type::NpyType::funcs;
  139. PyArray_Descr BFloat16Type::NpyType::descr;
  140. int BFloat16Type::NpyType::setitem(PyObject* op, void* ov, void* ap) {
  141. if (PyLong_Check(op)) {
  142. int a = PyLong_AsLong(op);
  143. static_cast<BFloat16Type*>(ov)->value = a;
  144. } else if (PyFloat_Check(op)) {
  145. float a = PyFloat_AsDouble(op);
  146. static_cast<BFloat16Type*>(ov)->value = a;
  147. } else if (PyObject_IsInstance(
  148. op, (PyObject*)(&(BFloat16Type::PyObj::py_type)))) {
  149. static_cast<BFloat16Type*>(ov)->value = ((PyObj*)op)->obj.value;
  150. } else {
  151. PyErr_SetString(PyExc_ValueError,
  152. "input type must be int/float/bfloat16");
  153. return -1;
  154. }
  155. return 0;
  156. }
  157. bool BFloat16Type::NpyType::init() {
  158. descr = {PyObject_HEAD_INIT(0) & BFloat16Type::PyObj::py_type,
  159. 'V', // kind
  160. 'f', // type
  161. '=', // byteorder
  162. NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,
  163. 1, // type num
  164. sizeof(BFloat16Type),
  165. alignof(BFloat16Type),
  166. NULL,
  167. NULL,
  168. NULL,
  169. &funcs};
  170. Py_TYPE(&descr) = &PyArrayDescr_Type;
  171. PyArray_InitArrFuncs(&funcs);
  172. funcs.copyswap = copyswap;
  173. funcs.getitem = getitem;
  174. funcs.setitem = setitem;
  175. npy_typenum = PyArray_RegisterDataType(&descr);
  176. #define REGISTER_CAST(From, To, From_descr, To_typenum, safe) \
  177. { \
  178. PyArray_Descr* from_descr = (From_descr); \
  179. if (PyArray_RegisterCastFunc(from_descr, (To_typenum), \
  180. NpyCast<From, To>::apply) < 0) { \
  181. return false; \
  182. } \
  183. if (safe && PyArray_RegisterCanCast(from_descr, (To_typenum), \
  184. NPY_NOSCALAR) < 0) { \
  185. return false; \
  186. } \
  187. }
  188. #define REGISTER_INT_CASTS(bits) \
  189. REGISTER_CAST(npy_int##bits, BFloat16Type, \
  190. PyArray_DescrFromType(NPY_INT##bits), \
  191. BFloat16Type::npy_typenum, 1) \
  192. REGISTER_CAST(BFloat16Type, npy_int##bits, &descr, NPY_INT##bits, 0) \
  193. REGISTER_CAST(npy_uint##bits, BFloat16Type, \
  194. PyArray_DescrFromType(NPY_UINT##bits), \
  195. BFloat16Type::npy_typenum, 1) \
  196. REGISTER_CAST(BFloat16Type, npy_uint##bits, &descr, NPY_UINT##bits, 0)
  197. REGISTER_INT_CASTS(8)
  198. REGISTER_INT_CASTS(16)
  199. REGISTER_INT_CASTS(32)
  200. REGISTER_INT_CASTS(64)
  201. REGISTER_CAST(BFloat16Type, float, &descr, NPY_FLOAT, 0)
  202. REGISTER_CAST(float, BFloat16Type, PyArray_DescrFromType(NPY_FLOAT),
  203. BFloat16Type::npy_typenum, 0)
  204. REGISTER_CAST(BFloat16Type, double, &descr, NPY_DOUBLE, 1)
  205. REGISTER_CAST(double, BFloat16Type, PyArray_DescrFromType(NPY_DOUBLE),
  206. BFloat16Type::npy_typenum, 0)
  207. return true;
  208. }
  209. } // anonymous namespace
  210. bool init_pytype_bfloat16() {
  211. auto& py_type = BFloat16Type::PyObj::py_type;
  212. py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
  213. py_type.tp_name = "megbrain._mgb.pybfloat16";
  214. py_type.tp_basicsize = sizeof(BFloat16Type::PyObj);
  215. py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
  216. py_type.tp_doc = "bfloat16 type";
  217. py_type.tp_new = BFloat16Type::PyObj::py_new;
  218. py_type.tp_str = BFloat16Type::PyObj::py_repr;
  219. py_type.tp_repr = BFloat16Type::PyObj::py_repr;
  220. py_type.tp_richcompare = BFloat16Type::PyObj::py_richcompare;
  221. py_type.tp_base = &PyGenericArrType_Type;
  222. return PyType_Ready(&py_type) >= 0;
  223. }
  224. void register_pytype_bfloat16(PyObject* d, PyObject* m) {
  225. Py_INCREF(&BFloat16Type::PyObj::py_type);
  226. PyDict_SetItemString(d, "bfloat16_pytype",
  227. (PyObject*)&BFloat16Type::PyObj::py_type);
  228. PyModule_AddObject(m, "bfloat16_pytype",
  229. (PyObject*)&BFloat16Type::PyObj::py_type);
  230. }
  231. //! called from swig init
  232. void _init_bfloat16_types(PyObject* m) {
  233. if (m == NULL)
  234. return;
  235. PyObject* d = PyModule_GetDict(m);
  236. PyArray_Descr* dtype;
  237. if (!init_pytype_bfloat16())
  238. return;
  239. if (!BFloat16Type::NpyType::init())
  240. return;
  241. dtype = PyArray_DescrFromType(BFloat16Type::npy_typenum);
  242. if (!dtype)
  243. return;
  244. {
  245. PyObject* pytype = (PyObject*)(&BFloat16Type::PyObj::py_type);
  246. Py_INCREF(pytype);
  247. PyDict_SetItemString(d, "pybfloat16", pytype);
  248. }
  249. Py_INCREF(dtype);
  250. PyDict_SetItemString(d, "bfloat16", (PyObject*)dtype);
  251. register_pytype_bfloat16(d, m);
  252. return;
  253. }
  254. int mgb::npy_num_bfloat16() {
  255. return BFloat16Type::npy_typenum;
  256. }
  257. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台