You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_defs.h 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. /**
  2. * \file python_module/src/cpp/opr_defs.h
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \brief extra opr definitions
  7. *
  8. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  9. *
  10. */
  11. #ifndef SWIG
  12. #pragma once
  13. #include "./megbrain_wrap.h"
  14. #include "./opr_helper.h"
  15. #if MGB_ENABLE_OPR_MM
  16. #include "megbrain/opr/collective_comm.h"
  17. #endif
  18. #include "megbrain/opr/basic_arith.h"
  19. #include "megbrain/opr/tensor_manip.h"
  20. using mgb::SymbolVar;
  21. using mgb::SymbolVarArray;
  22. using mgb::OperatorNodeConfig;
  23. #endif
  24. class _Opr {
  25. public:
  26. // basic arith
  27. static SymbolVar add_update(SymbolVar dest, SymbolVar delta,
  28. const SharedScalar &alpha, const SharedScalar &beta,
  29. const SharedScalar &bias, const SharedScalar &disable,
  30. const OperatorNodeConfig &config) {
  31. return mgb::opr::AddUpdate::make(dest, delta,
  32. {alpha.get_val(), beta.get_val(), bias.get_val(), disable.get_val()},
  33. config);
  34. }
  35. // tensor manip
  36. static SymbolVarArray param_pack_split(
  37. SymbolVar src, const std::vector<std::vector<size_t>>& shapes,
  38. const OperatorNodeConfig& config);
  39. static SymbolVar dimshuffle(SymbolVar src,
  40. const std::vector<int> &pattern, size_t ndim,
  41. const OperatorNodeConfig &config) {
  42. return mgb::opr::Dimshuffle::make(src, pattern, ndim, config);
  43. }
  44. static SymbolVar _axis_add_remove(SymbolVar src,
  45. const std::vector<int>& axis, bool is_add,
  46. const OperatorNodeConfig &config);
  47. static SymbolVar callback_injector(SymbolVar src, _CompGraphCallback &callback,
  48. const OperatorNodeConfig &config) {
  49. return mgb::opr::CallbackInjector::make(src, callback.make_callback());
  50. }
  51. static SymbolVar callback_injector(SymbolVarArray src, _CompGraphCallback &callback,
  52. const OperatorNodeConfig &config) {
  53. return mgb::opr::CallbackInjector::make(src, callback.make_multi_input_callback());
  54. }
  55. static SymbolVar set_grad(SymbolVar src, _SetGradCallback &grad_getter,
  56. const OperatorNodeConfig &config) {
  57. return mgb::opr::SetGrad::make(src, grad_getter.make_callback(), config);
  58. }
  59. // multi machine
  60. static SymbolVar lock_acquire(SymbolVar var, size_t lock_id, size_t group_id,
  61. const OperatorNodeConfig &config);
  62. static SymbolVar lock_release(SymbolVar var, size_t lock_id, size_t group_id,
  63. const OperatorNodeConfig &config);
  64. static SymbolVar remote_send(
  65. const std::string& server_addr, const int port,
  66. const std::string& key, SymbolVar var,
  67. const bool is_grad,
  68. const OperatorNodeConfig& config);
  69. static SymbolVar remote_recv(const std::string& server_addr, const int port,
  70. const std::string& key,
  71. CompGraph& graph,
  72. const std::vector<size_t>& shape, PyObject* dtype,
  73. const OperatorNodeConfig& config);
  74. static SymbolVar collective_comm_with_input(
  75. SymbolVar inpvar, const std::string& key, const size_t nr_devices,
  76. const bool is_root, const int rank, const bool local_grad,
  77. const std::string& server_addr, const int port, PyObject* params,
  78. PyObject* dtype, const std::string& backend, SharedND* output_buf,
  79. const OperatorNodeConfig& config, const SharedScalar& disable);
  80. static SymbolVar collective_comm_without_input(
  81. CompGraph& graph, const std::string& key, const size_t nr_devices,
  82. const bool is_root, const int rank, const bool local_grad,
  83. const std::string& server_addr, const int port, PyObject* params,
  84. PyObject* dtype, const std::string& backend, SharedND* output_buf,
  85. const OperatorNodeConfig& config, const SharedScalar& disable);
  86. // misc
  87. static SymbolVarArray extern_c_opr_placeholder(
  88. const SymbolVarArray& inputs,
  89. const std::vector<std::vector<size_t>>& output_shapes,
  90. PyObject* dtypes,
  91. const char* dump_name, PyObject* data_bytes,
  92. const OperatorNodeConfig& config);
  93. static SymbolVarArray tensor_rt_runtime(const SymbolVarArray& inputs,
  94. PyObject* data_bytes,
  95. const OperatorNodeConfig& config);
  96. static SymbolVar timestamp(SymbolVar input, PyObject* dest, size_t dest_off,
  97. const OperatorNodeConfig& config);
  98. static SymbolVar virtual_loss(const SymbolVarArray& ys,
  99. const SymbolVarArray& y_grads,
  100. const OperatorNodeConfig& config);
  101. static SymbolVar virtual_dep(const SymbolVarArray& symvars,
  102. const OperatorNodeConfig& config);
  103. #ifdef SWIG
  104. %pythoncode {
  105. @classmethod
  106. def _make_axis_vec(cls, axis):
  107. ret = _VectorInt()
  108. if isinstance(axis, collections.Iterable):
  109. for i in axis:
  110. ret.push_back(i)
  111. else:
  112. ret.push_back(axis)
  113. return ret
  114. @classmethod
  115. def add_axis(cls, src, axis, config):
  116. return cls._axis_add_remove(src, cls._make_axis_vec(axis), True, config)
  117. @classmethod
  118. def remove_axis(cls, src, axis, config):
  119. return cls._axis_add_remove(src, cls._make_axis_vec(axis), False, config)
  120. } // %pythoncode
  121. #endif // SWIG
  122. };
  123. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台