You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_utility.cpp 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. /**
  2. * \file imperative/src/impl/opr_utility.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/opr_utility.h"
  12. #include "./mgb_cg_impl.h"
  13. // FIXME; setup_config_cn is copied from src/opr/impl/utility.cpp
  14. namespace {
  15. mgb::OperatorNodeConfig setup_config_cn(const mgb::OperatorNodeConfig& config_,
  16. const mgb::CompNode& cn) {
  17. auto prev_cn = config_.get_single_comp_node();
  18. mgb_assert(!prev_cn.valid() || cn == prev_cn);
  19. auto config = config_;
  20. config.comp_node(cn);
  21. return config;
  22. }
  23. } // namespace
  24. namespace mgb {
  25. namespace opr {
  26. /* ================ InputCallback ================== */
  27. MGB_DYN_TYPE_OBJ_FINAL_IMPL(InputCallback);
  28. InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback,
  29. const VarNodeArray& inputs,
  30. const TensorShape& output_shape,
  31. const OperatorNodeConfig& config)
  32. : Super(&graph, config, "input_callback", inputs),
  33. m_output_shape(output_shape), m_callback(callback) {
  34. for (VarNode* i : inputs) {
  35. add_input({i});
  36. }
  37. DType dt = config.output_dtype();
  38. mgb_assert(dt.valid());
  39. add_output(None)->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC).dtype(dt);
  40. add_output(None)
  41. ->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
  42. .add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
  43. .dtype(DType::from_enum(DTypeEnum::Byte));
  44. }
  45. SymbolVarArray InputCallback::make(cg::ComputingGraph& graph,
  46. callback_t callback, CompNode comp_node,
  47. DType dtype, const TensorShape& shape,
  48. const SymbolVarArray& inputs) {
  49. mgb_assert(comp_node.valid());
  50. mgb_assert(dtype.valid());
  51. OperatorNodeConfig config;
  52. config.comp_node(comp_node);
  53. config.output_dtype(dtype);
  54. auto vinputs = to_var_node_array(inputs);
  55. auto opr = graph.insert_opr(
  56. std::make_unique<InputCallback>(graph, callback, vinputs, shape, config));
  57. return to_symbol_var_array(opr->output());
  58. }
  59. void InputCallback::init_output_static_infer_desc() {
  60. if (m_output_shape.ndim) {
  61. // Write this shape to static infer manager. The effect is
  62. // that infer_shape_fallible() will return a non-empty shape
  63. // while get_infer_type() remains NO_DESC. Most places check
  64. // infer type before relying on inferred shape so things
  65. // won't break. Memory optimizer however, deliberately omits
  66. // infer type check so it will be able to use this shape for hint.
  67. using namespace cg::static_infer;
  68. auto* var = output(0);
  69. var->shape(m_output_shape);
  70. auto&& mgr = cg::ComputingGraphImpl::downcast(owner_graph())->static_infer_manager_impl();
  71. auto* handle = mgr.get_tag_handler_for_shape(var);
  72. handle->sync_from_var();
  73. }
  74. }
  75. cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const {
  76. NodeProp* prop = Super::do_make_node_prop();
  77. prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP);
  78. SmallVector<NodeProp::DepType> dep_types(input().size(),
  79. NodeProp::DepType::DEV_COMP_ORDER);
  80. prop->reset_dep_type(input(), dep_types);
  81. return prop;
  82. }
  83. void InputCallback::scn_do_execute() {
  84. auto dev_tensor = m_callback();
  85. output(0)->reset_dev_tensor_from_tensor(dev_tensor);
  86. }
  87. cg::OperatorNodeBase* InputCallback::shallow_copy(
  88. const serialization::OprShallowCopyContext &ctx,
  89. const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
  90. const OperatorNodeConfig &config) {
  91. auto &&opr = opr_.cast_final_safe<InputCallback>();
  92. auto* graph = ctx.owner_graph(opr, inputs);
  93. return graph->insert_opr(std::make_unique<InputCallback>(*graph, opr.m_callback, inputs, opr.m_output_shape, config));
  94. }
  95. MGB_REG_OPR_SHALLOW_COPY(InputCallback, InputCallback::shallow_copy);
  96. /* ================ OutputCallback ================== */
  97. MGB_DYN_TYPE_OBJ_FINAL_IMPL(OutputCallback);
  98. OutputCallback::OutputCallback(Param param, const VarNodeArray& inputs,
  99. const OperatorNodeConfig& config)
  100. : Super(inputs[0]->owner_graph(),
  101. setup_config_cn(config, inputs[0]->comp_node()),
  102. "output_callback", inputs),
  103. m_param(std::move(param)) {
  104. for (VarNode* i : inputs) {
  105. add_input({i});
  106. }
  107. if (!m_param.borrow) {
  108. input(0)->add_flag(VarNode::Flag::NO_SYS_STATIC_MEM_ALLOC);
  109. }
  110. add_output(None)
  111. ->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
  112. .add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
  113. .dtype(DType::from_enum(DTypeEnum::Byte));
  114. add_equivalence_component<ScalarHash<void*>>(this);
  115. }
  116. SymbolVar OutputCallback::make(Param param, const SymbolVarArray& inputs) {
  117. mgb_assert(inputs.size() >= 1);
  118. auto vinputs = to_var_node_array(inputs);
  119. OperatorNodeConfig config;
  120. return inputs[0].insert_single_output_opr<OutputCallback>(std::move(param),
  121. vinputs, config);
  122. }
  123. void OutputCallback::init_output_static_infer_desc() {}
  124. cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const {
  125. NodeProp* prop = Super::do_make_node_prop();
  126. prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP);
  127. SmallVector<NodeProp::DepType> dep_types(input().size(),
  128. NodeProp::DepType::DEV_COMP_ORDER);
  129. dep_types[0] = NodeProp::DepType::DEV_VALUE;
  130. prop->reset_dep_type(input(), dep_types);
  131. return prop;
  132. }
  133. void OutputCallback::scn_do_execute() {
  134. m_param.callback(input(0)->dev_tensor());
  135. }
  136. cg::OperatorNodeBase* OutputCallback::shallow_copy(
  137. const serialization::OprShallowCopyContext &ctx,
  138. const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
  139. const OperatorNodeConfig &config) {
  140. auto &&opr = opr_.cast_final_safe<OutputCallback>();
  141. auto* graph = ctx.owner_graph(opr, inputs);
  142. return graph->insert_opr(std::make_unique<OutputCallback>(opr.m_param, inputs, config));
  143. }
  144. MGB_REG_OPR_SHALLOW_COPY(OutputCallback, OutputCallback::shallow_copy);
  145. /* ================ NopCallback ================== */
  146. MGB_DYN_TYPE_OBJ_FINAL_IMPL(NopCallback);
  147. NopCallback::NopCallback(cg::ComputingGraph& graph, callback_t callback,
  148. const VarNodeArray& inputs,
  149. const OperatorNodeConfig& config)
  150. : Super(&graph, config, "nop_callback", inputs), m_callback(callback) {
  151. for (VarNode* i : inputs) {
  152. add_input({i});
  153. }
  154. add_output(None)
  155. ->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
  156. .add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
  157. .dtype(DType::from_enum(DTypeEnum::Byte));
  158. }
  159. SymbolVar NopCallback::make(cg::ComputingGraph& graph, callback_t callback,
  160. CompNode comp_node, const SymbolVarArray& inputs) {
  161. mgb_assert(comp_node.valid());
  162. OperatorNodeConfig config;
  163. config.comp_node(comp_node);
  164. auto vinputs = to_var_node_array(inputs);
  165. auto opr = graph.insert_opr(
  166. std::make_unique<NopCallback>(graph, callback, vinputs, config));
  167. return opr->output(0);
  168. }
  169. void NopCallback::init_output_static_infer_desc() {}
  170. void NopCallback::on_output_comp_node_stream_changed() {}
  171. void NopCallback::init_output_comp_node() {
  172. auto cn = config().get_single_comp_node();
  173. mgb_assert(cn.valid());
  174. output(0)->comp_node(cn);
  175. }
  176. cg::OperatorNodeBase::NodeProp* NopCallback::do_make_node_prop() const {
  177. NodeProp* prop = Super::do_make_node_prop();
  178. SmallVector<NodeProp::DepType> dep_types(input().size(),
  179. NodeProp::DepType::DEV_COMP_ORDER);
  180. prop->reset_dep_type(input(), dep_types);
  181. prop->add_flag(
  182. cg::OperatorNodeBase::NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  183. return prop;
  184. }
  185. void NopCallback::do_execute(ExecEnv& env) {
  186. auto cn = output(0)->comp_node();
  187. auto runner = [this, cn] {
  188. owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this,
  189. cn);
  190. cn.activate();
  191. m_callback();
  192. owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn);
  193. };
  194. env.dispatch_on_comp_node(cn, runner);
  195. }
  196. } // namespace opr
  197. } // namespace mgb
  198. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台