You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

craniotome.h 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. /**
  2. * \file python_module/src/cpp/craniotome.h
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \brief extend megbrain operators in python
  7. *
  8. * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  9. *
  10. */
  11. #pragma once
  12. #include "megbrain/graph/operator_node.h"
  13. #include "./megbrain_wrap.h"
  14. using TensorShapeVec = std::vector<std::vector<size_t>>;
  15. using SymbolVarArray = mgb::SymbolVarArray;
  16. namespace mgb {
  17. namespace opr {
  18. class Craniotome;
  19. } // namespace opr
  20. } // namespace mgb
  21. class CraniotomeDesc: public mgb::Hashable {
  22. MGB_DYN_TYPE_OBJ_FINAL_DECL;
  23. mutable PyObject *m_py_self = nullptr;
  24. bool is_same_st(const mgb::Hashable &rhs) const override;
  25. size_t hash() const override;
  26. public:
  27. struct NodeFlag {
  28. static constexpr uint32_t
  29. DYNAMIC_OUTPUT_SHAPE = 1 << 0,
  30. DISALLOW_DUPLICATE = 1 << 1,
  31. ALLOW_EMPTY_OUTPUT = 1 << 2,
  32. DISABLE_SYS_MEM_ALLOC = 1 << 3;
  33. };
  34. virtual ~CraniotomeDesc() = default;
  35. mgb::opr::Craniotome* owner_opr = nullptr;
  36. //! get final py object that implements this interface
  37. PyObject* py_self() const ;
  38. //! store self in \p result which is a list
  39. virtual void _setup_self(PyObject *result) const = 0;
  40. virtual bool _is_same(PyObject *rhs) const = 0;
  41. virtual uint32_t _node_flag() const = 0;
  42. virtual size_t _hash() const = 0;
  43. virtual std::string _get_opr_type_name() = 0;
  44. virtual size_t _get_nr_outputs() = 0;
  45. virtual void _execute(
  46. const std::vector<CompGraphCallbackValueProxy> &inputs,
  47. std::vector<SharedND> &outputs) = 0;
  48. /*!
  49. * \brief infer output shape if DYNAMIC_OUTPUT_SHAPE is not set
  50. */
  51. virtual TensorShapeVec _infer_shape(
  52. const TensorShapeVec &inp_shape) = 0;
  53. virtual SymbolVarArray _grad(
  54. size_t wrt_idx,
  55. const SymbolVarArray &inputs,
  56. const SymbolVarArray &outputs,
  57. const SymbolVarArray &out_grad) = 0;
  58. virtual size_t _get_nr_dev_comp_order_deps() = 0;
  59. mgb::thin_function<SymbolVarArray()> _get_all_io_vars;
  60. /*!
  61. * \brief get output dtypes from input dtypes
  62. * \param[in] input_dtypes python list of input
  63. * \param[out] result initialized as an empty python list, and should
  64. * be filled with output dtypes
  65. * \return whether user has set the dtype
  66. */
  67. virtual bool _init_output_dtype(
  68. PyObject *input_dtypes, PyObject *result) = 0;
  69. /*!
  70. * \brief get computing graph when no input var is provided
  71. */
  72. virtual CompGraph _get_comp_graph() = 0;
  73. /*!
  74. * \brief copy this CraniotomeDesc
  75. *
  76. * The implementation must call _set_copy_result() to return the result;
  77. * this is used to bypass some swig issues.
  78. */
  79. virtual void _copy() const = 0;
  80. mutable mgb::thin_function<void(CraniotomeDesc*)> _set_copy_result;
  81. /*!
  82. * \brief setup params for serialization
  83. * \param output an allocated list. One or two elements should be
  84. * inserted in it after this function returns: the first element
  85. * should be a string, indicating the id to be passed to
  86. * opr_maker_loader; the second element, if exists, must be a byte
  87. * object containing extra param that should be written to file.
  88. */
  89. virtual void _setup_serialize_params(PyObject *output) const = 0;
  90. /*!
  91. * \brief callback invoked when the graph is compiled or when func is
  92. * destructed
  93. *
  94. * If the graph is compiled but not executed, this function might not be
  95. * called
  96. *
  97. * \param used_outputs an array indices indicating the used output vars;
  98. * this argument being empty means that the previously compiled
  99. * func is destructed
  100. */
  101. virtual void _on_graph_compile_or_func_del(
  102. const std::vector<size_t>& used_outputs) = 0;
  103. };
  104. namespace mgb {
  105. namespace opr {
  106. MGB_DEFINE_OPR_CLASS(Craniotome, cg::SingleCNOutshapePureByInshapeOprBase) // {
  107. class FuncDelCallbackInvoker;
  108. using NodeFlag = CraniotomeDesc::NodeFlag;
  109. bool m_on_graph_compile_called = false;
  110. const uint32_t m_node_flag;
  111. //! DEV_COMP_ORDER inputs are at the tail of input array; this is the
  112. //! number of DEV_VALUE inputs, and also the index of the first
  113. //! DEV_COMP_ORDER input
  114. size_t m_nr_dev_value_inp;
  115. std::unique_ptr<CraniotomeDesc> m_desc;
  116. //! previously inferred shape; used when there is no input and
  117. //! m_is_dynamic_output_shape is set to true
  118. Maybe<TensorShapeArray> m_prev_inferred_shape;
  119. void scn_do_execute() override;
  120. void get_output_var_shape(const TensorShapeArray &inp_shape,
  121. TensorShapeArray &out_shape) const override;
  122. void add_input_layout_constraint() override;
  123. void init_output_static_infer_desc() override;
  124. void init_output_dtype() override;
  125. NodeProp* do_make_node_prop() const override;
  126. bool output_no_sys_mem_alloc() const {
  127. return m_node_flag & (NodeFlag::DYNAMIC_OUTPUT_SHAPE |
  128. NodeFlag::DISABLE_SYS_MEM_ALLOC);
  129. }
  130. public:
  131. Craniotome(mgb::ComputingGraph *graph,
  132. std::unique_ptr<CraniotomeDesc> desc,
  133. const VarNodeArray &inputs, const OperatorNodeConfig &config);
  134. ~Craniotome() noexcept;
  135. static SymbolVarArray make(
  136. std::unique_ptr<CraniotomeDesc> desc,
  137. const SymbolVarArray &inputs,
  138. const OperatorNodeConfig &config = {});
  139. const CraniotomeDesc& desc() const {
  140. return *m_desc;
  141. }
  142. size_t nr_dev_value_inp() const {
  143. return m_nr_dev_value_inp;
  144. }
  145. };
  146. } // namespace opr
  147. } // namespace mgb
  148. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台