You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

plugin.h 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. /**
  2. * \file python_module/src/cpp/plugin.h
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \brief helpers for debugging
  7. *
  8. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  9. *
  10. */
  11. #ifndef SWIG
  12. #pragma once
  13. #include "./megbrain_wrap.h"
  14. #include "megbrain/plugin/profiler.h"
  15. #include "megbrain/plugin/infkern_finder.h"
  16. #include "megbrain/plugin/num_range_checker.h"
  17. #include "megbrain/plugin/opr_io_dump.h"
  18. #endif // SWIG
  19. #include <Python.h>
  20. class _CompGraphProfilerImpl {
  21. #ifndef SWIG
  22. std::shared_ptr<mgb::ComputingGraph> m_comp_graph;
  23. mgb::GraphProfiler m_profiler;
  24. #endif
  25. public:
  26. _CompGraphProfilerImpl(CompGraph &cg):
  27. m_comp_graph{cg.get().shared_from_this()},
  28. m_profiler{m_comp_graph.get()}
  29. {
  30. }
  31. std::string _get_result() {
  32. auto json = m_profiler.to_json_full(
  33. m_comp_graph->current_comp_seq());
  34. return json->to_string();
  35. }
  36. };
  37. class _NumRangeCheckerImpl {
  38. #ifndef SWIG
  39. std::shared_ptr<mgb::ComputingGraph> m_comp_graph;
  40. mgb::NumRangeChecker m_checker;
  41. #endif
  42. public:
  43. _NumRangeCheckerImpl(CompGraph &cg, float range):
  44. m_comp_graph{cg.get().shared_from_this()},
  45. m_checker{m_comp_graph.get(), range}
  46. {
  47. }
  48. };
  49. class _TextOprIODumpImpl {
  50. #ifndef SWIG
  51. std::shared_ptr<mgb::ComputingGraph> m_comp_graph;
  52. mgb::TextOprIODump m_dump;
  53. #endif
  54. public:
  55. _TextOprIODumpImpl(CompGraph &cg, const char *fpath):
  56. m_comp_graph{cg.get().shared_from_this()},
  57. m_dump{m_comp_graph.get(), fpath}
  58. {
  59. }
  60. void _print_addr(bool flag) {
  61. m_dump.print_addr(flag);
  62. }
  63. void _max_size(size_t size) {
  64. m_dump.max_size(size);
  65. }
  66. };
  67. class _BinaryOprIODumpImpl {
  68. #ifndef SWIG
  69. std::shared_ptr<mgb::ComputingGraph> m_comp_graph;
  70. mgb::BinaryOprIODump m_dump;
  71. #endif
  72. public:
  73. _BinaryOprIODumpImpl(CompGraph &cg, const char *fpath):
  74. m_comp_graph{cg.get().shared_from_this()},
  75. m_dump{m_comp_graph.get(), fpath}
  76. {
  77. }
  78. };
  79. class _InfkernFinderImpl {
  80. #ifndef SWIG
  81. static size_t sm_id;
  82. const size_t m_id;
  83. std::shared_ptr<mgb::ComputingGraph> m_comp_graph;
  84. mgb::InfkernFinder m_finder;
  85. mgb::InfkernFinder::InputValueRecord::FullRecord m_inp_val;
  86. #endif
  87. public:
  88. _InfkernFinderImpl(CompGraph &cg, bool record_input_value);
  89. size_t _write_to_file(const char *fpath);
  90. size_t _get_input_values_prepare(size_t opr_id);
  91. const char* _get_input_values_var_name(size_t idx);
  92. size_t _get_input_values_var_idx(size_t idx);
  93. size_t _get_input_values_run_id(size_t idx);
  94. CompGraphCallbackValueProxy _get_input_values_val(size_t idx);
  95. std::string __repr__();
  96. };
  97. class _FastSignal {
  98. #ifndef SWIG
  99. class Impl;
  100. static Impl sm_impl;
  101. static void signal_hander(int signum);
  102. #endif
  103. public:
  104. static void register_handler(int signum, PyObject *func);
  105. static void shutdown();
  106. };
  107. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台