You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 9.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /**
  2. * \file imperative/src/impl/ops/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/ops/autogen.h"
  12. #include "megbrain/opr/basic_arith.h"
  13. #include "megbrain/imperative/opr_utility.h"
  14. #include "megbrain/opr/utility.h"
  15. #include "../op_trait.h"
  16. #include "../dnn_op_helper.h"
  17. namespace mgb {
  18. namespace imperative {
  19. namespace {
  20. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  21. auto* node = &node_->cast_final_safe<opr::Elemwise>();
  22. return Elemwise::make(node->param().mode);
  23. }
  24. cg::OperatorNodeBase* apply_on_var_node(
  25. const OpDef& def,
  26. const VarNodeArray& inputs) {
  27. auto&& elemwise_opr = def.cast_final_safe<Elemwise>();
  28. return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr();
  29. }
  30. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  31. const OpDef& def,
  32. const SmallVector<LogicalTensorDesc>& inputs) {
  33. auto&& op_def = def.cast_final_safe<Elemwise>();
  34. auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
  35. mgb_assert(inputs.size() == trait.arity,
  36. "%s expects %u inputs; got %zu actually", trait.name,
  37. trait.arity, inputs.size());
  38. TensorShapeArray inp_shapes;
  39. DType out_dt;
  40. CompNode out_cn;
  41. for (size_t i = 0; i < inputs.size(); ++ i) {
  42. auto &&t = inputs[i];
  43. if (!i) {
  44. out_cn = t.comp_node;
  45. out_dt = t.layout.dtype;
  46. } else {
  47. mgb_assert(t.comp_node == out_cn);
  48. mgb_assert(t.layout.dtype == out_dt);
  49. }
  50. if (t.layout.ndim > 0) {
  51. inp_shapes.push_back(t.layout);
  52. } else {
  53. TensorLayout out_layout;
  54. out_layout.ndim = 0;
  55. out_layout.dtype = out_dt;
  56. return {{{out_layout, out_cn}}, true};
  57. }
  58. }
  59. auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
  60. return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
  61. }
  62. SmallVector<TensorPtr> apply_on_physical_tensor(
  63. const OpDef& def,
  64. const SmallVector<TensorPtr>& inputs) {
  65. auto&& op_def = def.cast_final_safe<Elemwise>();
  66. auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
  67. mgb_assert(inputs.size() == trait.arity,
  68. "%s expects %u inputs; got %zu actually", trait.name,
  69. trait.arity, inputs.size());
  70. DeviceTensorND out;
  71. SmallVector<DeviceTensorND> dt_inputs(inputs.size());
  72. for (unsigned i = 0; i < inputs.size(); ++i){
  73. dt_inputs[i] = inputs[i]->dev_tensor();
  74. }
  75. auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0]->comp_node());
  76. opr::Elemwise::perform(op_def.mode, out, dt_inputs, dnn_opr);
  77. return {Tensor::make(out)};
  78. }
  79. MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{
  80. public:
  81. struct Param{
  82. using Mode = megdnn::Elemwise::Param::Mode;
  83. Mode mode;
  84. size_t inplace_index;
  85. };
  86. using Mode = Param::Mode;
  87. ForceInplaceElemwise(const VarNodeArray& inputs, Param param,
  88. OperatorNodeConfig config = {})
  89. : Super(inputs[0]->owner_graph(), config, "device_add_update", inputs), m_param{param} {
  90. for (auto* input: inputs) {
  91. add_input({input});
  92. }
  93. add_output(None)->
  94. set_fwd_in2out_writable_force(input(param.inplace_index)).
  95. add_flag(VarNode::Flag::NO_MEM_RECLAIM);
  96. }
  97. static SymbolVar make(const VarNodeArray& inputs, Param param) {
  98. return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>(
  99. inputs, param);
  100. }
  101. static cg::OperatorNodeBase* shallow_copy(
  102. const serialization::OprShallowCopyContext &ctx,
  103. const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
  104. const OperatorNodeConfig &config);
  105. protected:
  106. NodeProp* do_make_node_prop() const override {
  107. auto ret = Super::do_make_node_prop();
  108. ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
  109. return ret;
  110. }
  111. void create_megdnn_opr() override {
  112. auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node());
  113. opr->param().mode = m_param.mode;
  114. set_megdnn_opr(std::move(opr));
  115. }
  116. void scn_do_execute() override {
  117. auto to_dnnnd = [&](auto* var){ return var->dev_tensor().as_megdnn(); };
  118. megdnn::TensorNDArray inputs_dnnnd;
  119. for (auto* input: input()) {
  120. inputs_dnnnd.push_back(to_dnnnd(input));
  121. }
  122. mgb_assert(input(m_param.inplace_index)->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC),
  123. "ForceInplaceElemwise cannot be applied in internal tensor");
  124. auto* out_dest = output(0);
  125. auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr());
  126. opr->exec(std::move(inputs_dnnnd),
  127. to_dnnnd(out_dest));
  128. }
  129. void init_output_static_infer_desc() override {
  130. using namespace cg::static_infer;
  131. owner_graph()->static_infer_manager().register_shape_infer(
  132. output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index)));
  133. }
  134. private:
  135. Param m_param;
  136. void record_execute_deps(ExecDependencyArray& deps) override {
  137. record_megdnn_opr(deps);
  138. }
  139. };
  140. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise);
  141. cg::OperatorNodeBase* ForceInplaceElemwise::shallow_copy(
  142. const serialization::OprShallowCopyContext &ctx,
  143. const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
  144. const OperatorNodeConfig &config) {
  145. auto &&opr = opr_.cast_final_safe<ForceInplaceElemwise>();
  146. auto* graph = ctx.owner_graph(opr, inputs);
  147. return graph->insert_opr(std::make_unique<ForceInplaceElemwise>(inputs, opr.m_param, config));
  148. }
  149. MGB_REG_OPR_SHALLOW_COPY(ForceInplaceElemwise, ForceInplaceElemwise::shallow_copy);
  150. cg::OperatorNodeBase* apply_inplace_add_on_var_node(
  151. const OpDef& def,
  152. const VarNodeArray& inputs) {
  153. auto dest = inputs[0], delta = inputs[1],
  154. alpha = inputs[2], beta = inputs[3];
  155. auto mode = ForceInplaceElemwise::Param::Mode::FUSE_MUL_ADD4;
  156. return ForceInplaceElemwise::make({alpha, dest, beta, delta}, {mode, 1}).node()->owner_opr();
  157. }
  158. SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
  159. const OpDef& def,
  160. const SmallVector<TensorPtr>& inputs){
  161. auto dest = inputs[0], delta = inputs[1],
  162. alpha = inputs[2], beta = inputs[3];
  163. auto tensor_to_scalar = [](const TensorPtr& tensor) -> float {
  164. return *tensor->get_value().ptr<float>();
  165. };
  166. DnnOprCaller<megdnn::AddUpdate> caller{dest->comp_node()};
  167. caller.op->param() = { tensor_to_scalar(alpha), tensor_to_scalar(beta) };
  168. caller.op->exec(dest->dev_tensor().as_megdnn(), delta->dev_tensor().as_megdnn());
  169. return { std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout()) };
  170. }
  171. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible(
  172. const OpDef& def,
  173. const SmallVector<LogicalTensorDesc>& inputs) {
  174. mgb_assert(inputs.size() == 4, "invalid input number for inplace_add");
  175. CompNode cn;
  176. for (auto&& input: inputs) {
  177. if (!cn.valid()) {
  178. cn = input.comp_node;
  179. } else {
  180. mgb_assert(input.comp_node == cn, "inputs should be in same comp_node");
  181. }
  182. }
  183. auto dest = inputs[0], delta = inputs[1],
  184. alpha = inputs[2], beta = inputs[3];
  185. bool succeed = dest.layout.ndim != 0;
  186. if (succeed) {
  187. mgb_assert(delta.layout.ndim == 0 || dest.layout.eq_shape(delta.layout), "dest and delta must have same shape");
  188. mgb_assert(alpha.layout.ndim == 0 || alpha.layout.eq_shape({1}), "alpha should be scalar");
  189. mgb_assert(beta.layout.ndim == 0 || beta.layout.eq_shape({1}), "beta should be scalar");
  190. }
  191. mgb_assert(alpha.layout.dtype == dtype::Float32(), "alpha should be float32");
  192. mgb_assert(beta.layout.dtype == dtype::Float32(), "beta should be float32");
  193. // inplace op result's desc value is changed
  194. return {{{dest.layout, dest.comp_node}}, succeed};
  195. }
  196. OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
  197. .make_from_op_node(make_from_op_node)
  198. .apply_on_var_node(apply_on_var_node)
  199. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  200. .apply_on_physical_tensor(apply_on_physical_tensor)
  201. .fallback();
  202. OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate)
  203. .apply_on_var_node(apply_inplace_add_on_var_node)
  204. .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor)
  205. .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible)
  206. .fallback();
  207. } // anonymous namespace
  208. } // namespace imperative
  209. } // namespace mgb
  210. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台