You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

misc.i 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. /*
  2. * $File: misc.i
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * $Copyright: Copyright (c) 2014-2017 Megvii Inc. All rights reserved.
  7. */
  8. %{
  9. #include "megbrain/utils/persistent_cache.h"
  10. #include "megbrain/serialization/helper.h"
  11. #include "megbrain/gopt/inference.h"
  12. #include "megbrain/plugin/opr_footprint.h"
  13. using _PyStackExtracter = PyStackExtracter;
  14. using _PersistentCache = mgb::PersistentCache;
  15. using _PersistentCacheBlob = _PersistentCache::Blob;
  16. using _MaybePersistentCacheBlob = mgb::Maybe<_PersistentCacheBlob>;
  17. using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
  18. %}
  19. %feature("director") _PyStackExtracter;
  20. class _PyStackExtracter {
  21. public:
  22. virtual ~_PyStackExtracter() = default;
  23. virtual std::string extract() = 0;
  24. static void reg(_PyStackExtracter *p);
  25. };
  26. // from Blob to python bytes
  27. %typemap(in) const _PersistentCacheBlob& {
  28. mgb_assert(PyBytes_Check($input));
  29. $1->ptr = PyBytes_AsString($input);
  30. $1->size = PyBytes_Size($input);
  31. }
  32. %typemap(directorin) const _PersistentCacheBlob& {
  33. $input = PyBytes_FromStringAndSize(
  34. static_cast<const char*>($1.ptr), $1.size);
  35. }
  36. %typemap(directorout) _MaybePersistentCacheBlob {
  37. mgb_assert($1->ob_refcnt >= 2, "persistent cache result refcnt too small");
  38. if ($1 == Py_None) {
  39. $result = mgb::None;
  40. } else {
  41. mgb_assert(PyBytes_Check($input));
  42. _PersistentCacheBlob blob;
  43. blob.ptr = PyBytes_AsString($1);
  44. blob.size = PyBytes_Size($1);
  45. $result = blob;
  46. }
  47. }
  48. %feature("director") _PersistentCache;
  49. class _PersistentCache {
  50. public:
  51. virtual ~_PersistentCache() = default;
  52. virtual void put(const std::string &category,
  53. const _PersistentCacheBlob &key,
  54. const _PersistentCacheBlob &value) = 0;
  55. virtual _MaybePersistentCacheBlob get(
  56. const std::string &category,
  57. const _PersistentCacheBlob &key) = 0;
  58. %extend {
  59. static void reg(_PersistentCache *p) {
  60. _PersistentCache::set_impl({p, [](_PersistentCache*){}});
  61. }
  62. }
  63. };
  64. struct _OptimizeForInferenceOptions {
  65. #define SET(n) void enable_##n();
  66. SET(f16_io_f32_comp);
  67. SET(f16_io_comp);
  68. SET(fuse_conv_bias_nonlinearity);
  69. SET(fuse_conv_bias_with_z);
  70. #undef SET
  71. #define SET(_trans, _trans_capital) \
  72. void enable_##_trans(); \
  73. SET(nchw2nhwcd4, NCHW2NHWCD4);
  74. SET(nchw2nchw88, NCHW2NCHW88);
  75. SET(nchw2nchw44, NCHW2NCHW44);
  76. SET(nchw2nchw32, NCHW2NCHW32);
  77. #undef SET
  78. };
  79. %inline {
  80. static SymbolVarArray _optimize_for_inference(
  81. const SymbolVarArray& dest_vars,
  82. const _OptimizeForInferenceOptions& opt) {
  83. return mgb::gopt::optimize_for_inference(dest_vars, opt);
  84. }
  85. // defined in function_replace.cpp
  86. void _register_logger(PyObject *logger);
  87. void _timed_func_set_fork_exec_path(const char *arg0, const char *arg1);
  88. void _timed_func_exec_cb(const char *user_data);
  89. // defined in megbrain_wrap.cpp
  90. void _mgb_global_finalize();
  91. std::vector<size_t> _get_mgb_version();
  92. SymbolVarArray _grad(SymbolVar target, SymbolVarArray wrts,
  93. bool warn_mid_wrt, int use_virtual_grad,
  94. bool return_zero_for_nodep);
  95. SymbolVar _inter_graph_trans_var(
  96. CompGraph &dest_graph, SymbolVar src);
  97. SymbolVar _get_graph_optimizer_replaced_var(SymbolVar src);
  98. void _add_update_fastpath(SharedND& dest, SharedND& delta,
  99. float alpha, float beta, float bias);
  100. void _add_update_fastpath(SharedND& dest,
  101. CompGraphCallbackValueProxy& delta,
  102. float alpha, float beta, float bias);
  103. static SymbolVar _current_grad_target(CompGraph &graph) {
  104. return mgb::cg::current_grad_target(graph.get());
  105. }
  106. uint32_t _get_dtype_num(PyObject *dtype) {
  107. return static_cast<uint32_t>(npy::dtype_np2mgb(dtype).enumv());
  108. }
  109. PyObject* _get_serialized_dtype(PyObject *dtype) {
  110. std::string sdtype;
  111. auto write = [&sdtype](const void* data, size_t size) {
  112. auto pos = sdtype.size();
  113. sdtype.resize(pos + size);
  114. memcpy(&sdtype[pos], data, size);
  115. };
  116. mgb::serialization::serialize_dtype(npy::dtype_np2mgb(dtype), write);
  117. return PyBytes_FromStringAndSize(sdtype.data(), sdtype.size());
  118. }
  119. size_t max_size_t() {
  120. return std::numeric_limits<size_t>::max();
  121. }
  122. std::string _get_opr_fp_graph_exec(
  123. CompGraph& cg, const SymbolVarArray& outputs) {
  124. auto json = mgb::OprFootprint::get_opr_fp_graph_exec(cg.get(), outputs);
  125. return json->to_string();
  126. }
  127. }
  128. // vim: ft=swig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台