You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. /**
  2. * \file imperative/src/impl/ops/elemwise.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/ops/autogen.h"
  12. #include "megbrain/opr/basic_arith.h"
  13. #include "megbrain/imperative/opr_utility.h"
  14. #include "megbrain/opr/utility.h"
  15. #include "../op_trait.h"
  16. #include "../dnn_op_helper.h"
  17. namespace mgb {
  18. namespace imperative {
  19. namespace {
  20. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  21. auto* node = &node_->cast_final_safe<opr::Elemwise>();
  22. return Elemwise::make(node->param().mode);
  23. }
  24. cg::OperatorNodeBase* apply_on_var_node(
  25. const OpDef& def,
  26. const VarNodeArray& inputs) {
  27. auto&& elemwise_opr = def.cast_final_safe<Elemwise>();
  28. return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr();
  29. }
  30. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  31. const OpDef& def,
  32. const SmallVector<LogicalTensorDesc>& inputs) {
  33. auto&& op_def = def.cast_final_safe<Elemwise>();
  34. auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
  35. mgb_assert(inputs.size() == trait.arity,
  36. "%s expects %u inputs; got %zu actually", trait.name,
  37. trait.arity, inputs.size());
  38. TensorShapeArray inp_shapes;
  39. DType out_dt;
  40. CompNode out_cn;
  41. for (size_t i = 0; i < inputs.size(); ++ i) {
  42. auto &&t = inputs[i];
  43. if (!i) {
  44. out_cn = t.comp_node;
  45. out_dt = t.layout.dtype;
  46. } else {
  47. mgb_assert(t.comp_node == out_cn);
  48. mgb_assert(t.layout.dtype == out_dt);
  49. }
  50. if (t.layout.ndim > 0) {
  51. inp_shapes.push_back(t.layout);
  52. } else {
  53. TensorLayout out_layout;
  54. out_layout.ndim = 0;
  55. out_layout.dtype = out_dt;
  56. return {{{out_layout, out_cn}}, false};
  57. }
  58. }
  59. auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
  60. return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
  61. }
  62. DispatchMode decide_dispatch_mode(
  63. const OpDef& def,
  64. const SmallVector<LogicalTensorDesc>& inputs) {
  65. bool host_computable = true;
  66. constexpr int size_threshhold = TensorShape::MAX_NDIM;
  67. for (auto&& inp : inputs) {
  68. if (inp.value.empty() || inp.value.layout().ndim == 0
  69. || inp.value.layout().total_nr_elems() > size_threshhold) {
  70. host_computable = false;
  71. break;
  72. }
  73. }
  74. return host_computable ? DEFAULT_CPU : KERNEL;
  75. }
  76. void apply_on_device_tensornd(
  77. const OpDef& def,
  78. const SmallVector<DeviceTensorND>& inputs,
  79. SmallVector<DeviceTensorND>* outputs) {
  80. auto&& op_def = def.cast_final_safe<Elemwise>();
  81. auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
  82. mgb_assert(inputs.size() == trait.arity,
  83. "%s expects %u inputs; got %zu actually", trait.name,
  84. trait.arity, inputs.size());
  85. auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0].comp_node());
  86. opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr);
  87. }
  88. SmallVector<TensorPtr> apply_on_physical_tensor(
  89. const OpDef& def,
  90. const SmallVector<TensorPtr>& inputs) {
  91. SmallVector<DeviceTensorND> inp_tensornds(inputs.size());
  92. for (unsigned i = 0; i < inputs.size(); ++i){
  93. inp_tensornds[i] = inputs[i]->dev_tensor();
  94. }
  95. SmallVector<DeviceTensorND> oup_tensornds = {{inp_tensornds[0].comp_node(), inp_tensornds[0].dtype()}};
  96. apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds);
  97. return {Tensor::make(oup_tensornds[0])};
  98. }
  99. MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{
  100. public:
  101. struct Param{
  102. using Mode = megdnn::Elemwise::Param::Mode;
  103. Mode mode;
  104. size_t inplace_index;
  105. };
  106. using Mode = Param::Mode;
  107. ForceInplaceElemwise(const VarNodeArray& inputs, Param param,
  108. OperatorNodeConfig config = {})
  109. : Super(inputs[0]->owner_graph(), config, "device_add_update", inputs), m_param{param} {
  110. for (auto* input: inputs) {
  111. add_input({input});
  112. }
  113. add_output(None)->
  114. set_fwd_in2out_writable_force(input(param.inplace_index)).
  115. add_flag(VarNode::Flag::NO_MEM_RECLAIM);
  116. }
  117. static SymbolVar make(const VarNodeArray& inputs, Param param) {
  118. return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>(
  119. inputs, param);
  120. }
  121. static cg::OperatorNodeBase* shallow_copy(
  122. const serialization::OprShallowCopyContext &ctx,
  123. const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
  124. const OperatorNodeConfig &config);
  125. protected:
  126. NodeProp* do_make_node_prop() const override {
  127. auto ret = Super::do_make_node_prop();
  128. ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
  129. return ret;
  130. }
  131. void create_megdnn_opr() override {
  132. auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node());
  133. opr->param().mode = m_param.mode;
  134. set_megdnn_opr(std::move(opr));
  135. }
  136. void scn_do_execute() override {
  137. auto to_dnnnd = [&](auto* var){ return var->dev_tensor().as_megdnn(); };
  138. megdnn::TensorNDArray inputs_dnnnd;
  139. for (auto* input: input()) {
  140. inputs_dnnnd.push_back(to_dnnnd(input));
  141. }
  142. mgb_assert(input(m_param.inplace_index)->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC),
  143. "ForceInplaceElemwise cannot be applied in internal tensor");
  144. auto* out_dest = output(0);
  145. auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr());
  146. opr->exec(std::move(inputs_dnnnd),
  147. to_dnnnd(out_dest));
  148. }
  149. void init_output_static_infer_desc() override {
  150. using namespace cg::static_infer;
  151. owner_graph()->static_infer_manager().register_shape_infer(
  152. output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index)));
  153. }
  154. private:
  155. Param m_param;
  156. void record_execute_deps(ExecDependencyArray& deps) override {
  157. record_megdnn_opr(deps);
  158. }
  159. };
  160. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise);
  161. cg::OperatorNodeBase* ForceInplaceElemwise::shallow_copy(
  162. const serialization::OprShallowCopyContext &ctx,
  163. const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
  164. const OperatorNodeConfig &config) {
  165. auto &&opr = opr_.cast_final_safe<ForceInplaceElemwise>();
  166. auto* graph = ctx.owner_graph(opr, inputs);
  167. return graph->insert_opr(std::make_unique<ForceInplaceElemwise>(inputs, opr.m_param, config));
  168. }
  169. MGB_REG_OPR_SHALLOW_COPY(ForceInplaceElemwise, ForceInplaceElemwise::shallow_copy);
  170. cg::OperatorNodeBase* apply_inplace_add_on_var_node(
  171. const OpDef& def,
  172. const VarNodeArray& inputs) {
  173. auto dest = inputs[0], delta = inputs[1],
  174. alpha = inputs[2], beta = inputs[3];
  175. auto mode = ForceInplaceElemwise::Param::Mode::FUSE_MUL_ADD4;
  176. return ForceInplaceElemwise::make({alpha, dest, beta, delta}, {mode, 1}).node()->owner_opr();
  177. }
  178. SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
  179. const OpDef& def,
  180. const SmallVector<TensorPtr>& inputs){
  181. auto dest = inputs[0], delta = inputs[1],
  182. alpha = inputs[2], beta = inputs[3];
  183. auto tensor_to_scalar = [](const TensorPtr& tensor) -> float {
  184. return *tensor->get_value().ptr<float>();
  185. };
  186. DnnOprCaller<megdnn::AddUpdate> caller{dest->comp_node()};
  187. caller.op->param() = { tensor_to_scalar(alpha), tensor_to_scalar(beta) };
  188. caller.op->exec(dest->dev_tensor().as_megdnn(), delta->dev_tensor().as_megdnn());
  189. return { std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout()) };
  190. }
  191. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible(
  192. const OpDef& def,
  193. const SmallVector<LogicalTensorDesc>& inputs) {
  194. mgb_assert(inputs.size() == 4, "invalid input number for inplace_add");
  195. CompNode cn;
  196. for (auto&& input: inputs) {
  197. if (!cn.valid()) {
  198. cn = input.comp_node;
  199. } else {
  200. mgb_assert(input.comp_node == cn, "inputs should be in same comp_node");
  201. }
  202. }
  203. auto dest = inputs[0], delta = inputs[1],
  204. alpha = inputs[2], beta = inputs[3];
  205. bool succeed = dest.layout.ndim != 0;
  206. if (succeed) {
  207. mgb_assert(delta.layout.ndim == 0 || dest.layout.eq_shape(delta.layout), "dest and delta must have same shape");
  208. mgb_assert(alpha.layout.ndim == 0 || alpha.layout.eq_shape({1}), "alpha should be scalar");
  209. mgb_assert(beta.layout.ndim == 0 || beta.layout.eq_shape({1}), "beta should be scalar");
  210. }
  211. mgb_assert(alpha.layout.dtype == dtype::Float32(), "alpha should be float32");
  212. mgb_assert(beta.layout.dtype == dtype::Float32(), "beta should be float32");
  213. // inplace op result's desc value is changed
  214. return {{{dest.layout, dest.comp_node}}, succeed};
  215. }
  216. OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
  217. .make_from_op_node(make_from_op_node)
  218. .decide_dispatch_mode(decide_dispatch_mode)
  219. .apply_on_var_node(apply_on_var_node)
  220. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  221. .apply_on_device_tensornd(apply_on_device_tensornd)
  222. .apply_on_physical_tensor(apply_on_physical_tensor)
  223. .fallback();
  224. OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate)
  225. .apply_on_var_node(apply_inplace_add_on_var_node)
  226. .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor)
  227. .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible)
  228. .fallback();
  229. } // anonymous namespace
  230. } // namespace imperative
  231. } // namespace mgb
  232. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台