You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

misc.i 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. /*
  2. * $File: misc.i
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * $Copyright: Copyright (c) 2014-2017 Megvii Inc. All rights reserved.
  7. */
  8. %{
  9. #include "megbrain/utils/persistent_cache.h"
  10. #include "megbrain/serialization/helper.h"
  11. #include "megbrain/gopt/inference.h"
  12. #include "megbrain/plugin/opr_footprint.h"
  13. using _PyStackExtracter = PyStackExtracter;
  14. using _PersistentCache = mgb::PersistentCache;
  15. using _PersistentCacheBlob = _PersistentCache::Blob;
  16. using _MaybePersistentCacheBlob = mgb::Maybe<_PersistentCacheBlob>;
  17. using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
  18. %}
  19. %feature("director") _PyStackExtracter;
  20. class _PyStackExtracter {
  21. public:
  22. virtual ~_PyStackExtracter() = default;
  23. virtual std::string extract() = 0;
  24. static void reg(_PyStackExtracter *p);
  25. };
  26. // from Blob to python bytes
  27. %typemap(in) const _PersistentCacheBlob& {
  28. mgb_assert(PyBytes_Check($input));
  29. $1->ptr = PyBytes_AsString($input);
  30. $1->size = PyBytes_Size($input);
  31. }
  32. %typemap(directorin) const _PersistentCacheBlob& {
  33. $input = PyBytes_FromStringAndSize(
  34. static_cast<const char*>($1.ptr), $1.size);
  35. }
  36. %typemap(directorout) _MaybePersistentCacheBlob {
  37. mgb_assert($1->ob_refcnt >= 2, "persistent cache result refcnt too small");
  38. if ($1 == Py_None) {
  39. $result = mgb::None;
  40. } else {
  41. mgb_assert(PyBytes_Check($input));
  42. _PersistentCacheBlob blob;
  43. blob.ptr = PyBytes_AsString($1);
  44. blob.size = PyBytes_Size($1);
  45. $result = blob;
  46. }
  47. }
  48. %feature("director") _PersistentCache;
  49. class _PersistentCache {
  50. public:
  51. virtual ~_PersistentCache() = default;
  52. virtual void put(const std::string &category,
  53. const _PersistentCacheBlob &key,
  54. const _PersistentCacheBlob &value) = 0;
  55. virtual _MaybePersistentCacheBlob get(
  56. const std::string &category,
  57. const _PersistentCacheBlob &key) = 0;
  58. %extend {
  59. static void reg(_PersistentCache *p) {
  60. _PersistentCache::set_impl({p, [](_PersistentCache*){}});
  61. }
  62. }
  63. };
  64. struct _OptimizeForInferenceOptions {
  65. #define SET(n) void enable_##n()
  66. SET(f16_io_f32_comp);
  67. SET(f16_io_comp);
  68. SET(fuse_conv_bias_nonlinearity);
  69. SET(use_nhwcd4);
  70. SET(use_tensor_core);
  71. SET(fuse_conv_bias_with_z);
  72. SET(use_nchw88);
  73. SET(use_nchw44);
  74. #undef SET
  75. };
  76. %inline {
  77. static SymbolVarArray _optimize_for_inference(
  78. const SymbolVarArray& dest_vars,
  79. const _OptimizeForInferenceOptions& opt) {
  80. return mgb::gopt::optimize_for_inference(dest_vars, opt);
  81. }
  82. // defined in function_replace.cpp
  83. void _register_logger(PyObject *logger);
  84. void _timed_func_set_fork_exec_path(const char *arg0, const char *arg1);
  85. void _timed_func_exec_cb(const char *user_data);
  86. // defined in megbrain_wrap.cpp
  87. void _mgb_global_finalize();
  88. std::vector<size_t> _get_mgb_version();
  89. SymbolVarArray _grad(SymbolVar target, SymbolVarArray wrts,
  90. bool warn_mid_wrt, int use_virtual_grad,
  91. bool return_zero_for_nodep);
  92. SymbolVar _inter_graph_trans_var(
  93. CompGraph &dest_graph, SymbolVar src);
  94. SymbolVar _get_graph_optimizer_replaced_var(SymbolVar src);
  95. void _add_update_fastpath(SharedND& dest, SharedND& delta,
  96. float alpha, float beta, float bias);
  97. void _add_update_fastpath(SharedND& dest,
  98. CompGraphCallbackValueProxy& delta,
  99. float alpha, float beta, float bias);
  100. static SymbolVar _current_grad_target(CompGraph &graph) {
  101. return mgb::cg::current_grad_target(graph.get());
  102. }
  103. uint32_t _get_dtype_num(PyObject *dtype) {
  104. return static_cast<uint32_t>(npy::dtype_np2mgb(dtype).enumv());
  105. }
  106. PyObject* _get_serialized_dtype(PyObject *dtype) {
  107. std::string sdtype;
  108. auto write = [&sdtype](const void* data, size_t size) {
  109. auto pos = sdtype.size();
  110. sdtype.resize(pos + size);
  111. memcpy(&sdtype[pos], data, size);
  112. };
  113. mgb::serialization::serialize_dtype(npy::dtype_np2mgb(dtype), write);
  114. return PyBytes_FromStringAndSize(sdtype.data(), sdtype.size());
  115. }
  116. size_t max_size_t() {
  117. return std::numeric_limits<size_t>::max();
  118. }
  119. std::string _get_opr_fp_graph_exec(
  120. CompGraph& cg, const SymbolVarArray& outputs) {
  121. auto json = mgb::OprFootprint::get_opr_fp_graph_exec(cg.get(), outputs);
  122. return json->to_string();
  123. }
  124. }
  125. // vim: ft=swig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台