You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reduce.cpp 9.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. /**
  2. * \file imperative/src/impl/ops/reduce.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/graph/symbol_var.h"
  12. #include "megbrain/imperative/ops/autogen.h"
  13. #include "megbrain/imperative/proxy_graph_detail.h"
  14. #include "megbrain/opr/basic_arith.h"
  15. #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
  16. #include "megbrain/opr/io.h"
  17. #include "megbrain/opr/tensor_manip.h"
  18. #include "megdnn/dtype.h"
  19. #include "../blob_manager_impl.h"
  20. #include "../dnn_op_helper.h"
  21. #include "../op_trait.h"
  22. namespace mgb {
  23. namespace imperative {
  24. namespace {
  25. namespace reduce {
  26. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  27. auto&& reduce = static_cast<const Reduce&>(def);
  28. auto comp_node = inputs[0]->comp_node();
  29. OperatorNodeConfig config{reduce.make_name(), comp_node, inputs[0]->dtype()};
  30. if (inputs.size() > 1) {
  31. return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config);
  32. }
  33. using Param = megdnn::param::Reduce;
  34. auto param = reduce.param();
  35. if (param.axis < 0) {
  36. param.axis = inputs[0]->shape().ndim + param.axis;
  37. }
  38. SymbolVar target_shape = (cg::VarNode*)nullptr;
  39. if (param.axis == INT_MAX) {
  40. DTypeScalar vi{1};
  41. // auto graph = ComputingGraph::make();
  42. auto graph = inputs[0]->owner_graph();
  43. target_shape = opr::ImmutableTensor::make(*graph, vi, config);
  44. }
  45. auto res = opr::Reduce::make(inputs[0], param, target_shape, config);
  46. if (!reduce.keepdim && param.axis != INT_MAX) {
  47. using Desc = opr::AxisAddRemove::AxisDesc;
  48. std::vector<Desc> remove_param;
  49. remove_param.push_back(Desc::make_remove(param.axis));
  50. OperatorNodeConfig remove_config{
  51. def.make_name(), comp_node, inputs[0]->dtype()};
  52. return opr::AxisAddRemove::make(res, remove_param, remove_config);
  53. }
  54. return res;
  55. }
  56. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  57. auto* node = &node_->cast_final_safe<opr::Reduce>();
  58. return Reduce::make(node->param(), true);
  59. }
  60. // TODO: using this for apply_on_physical_tensor
  61. bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) {
  62. auto&& reduce = static_cast<const Reduce&>(def);
  63. if (reduce.mode != Reduce::Mode::SUM_SQR && inputs.size() == 2) {
  64. auto shape_tensor = inputs[1]->get_value();
  65. TensorShape shape;
  66. cg::copy_tensor_value_to_shape(shape, shape_tensor.proxy_to_default_cpu());
  67. if (shape.eq_shape(inputs[0]->shape())) {
  68. return true;
  69. }
  70. }
  71. return false;
  72. }
  73. SmallVector<TensorPtr> apply_on_physical_tensor(
  74. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  75. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  76. if (memory_forward_success(def, inputs)) {
  77. return {Tensor::make(
  78. inputs[0]->blob(), inputs[0]->offset(), inputs[0]->layout())};
  79. }
  80. auto size = inputs.size();
  81. if (size > 1) {
  82. return proxy_graph_detail::apply_on_physical_tensor(
  83. def, inputs, output_descs, validated);
  84. }
  85. auto comp_node = inputs[0]->comp_node();
  86. using TensorND = megdnn::TensorND;
  87. auto&& op_def = def.cast_final_safe<Reduce>();
  88. SmallVector<TensorND> inp_tensornds;
  89. inp_tensornds.reserve(inputs.size());
  90. auto src = inputs[0]->layout();
  91. DnnOprCaller<megdnn::Reduce> dnn_op(comp_node);
  92. dnn_op.op->param() = op_def.param();
  93. auto axis = op_def.param().axis;
  94. auto keepdim = op_def.keepdim;
  95. if (axis < 0) {
  96. axis = inputs[0]->layout().ndim + axis;
  97. }
  98. dnn_op.op->param().axis = axis == INT_MAX ? 0 : axis;
  99. if (axis == INT_MAX) {
  100. src.shape[0] = src.total_nr_elems();
  101. src.ndim = 1;
  102. src.init_contiguous_stride();
  103. }
  104. TensorLayout layout{src.dtype};
  105. dnn_op.op->deduce_layout(src, layout);
  106. if (inputs[0]->layout().is_empty()) {
  107. inputs[0]->dev_tensor().reset(inputs[0]->dev_tensor().storage(), src);
  108. auto mode = op_def.param().mode;
  109. DnnOprCaller<megdnn::Fill> fill_op(comp_node);
  110. if (!keepdim && src.ndim > 1) {
  111. layout.remove_axis_inplace(axis);
  112. layout.init_contiguous_stride();
  113. }
  114. DeviceTensorND out =
  115. BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout);
  116. std::string err_msg;
  117. switch (mode) {
  118. case Reduce::Mode::SUM:
  119. if (!out.empty()) {
  120. fill_op.op->param() = 0;
  121. fill_op.op->exec(out.as_megdnn(), {});
  122. }
  123. break;
  124. case Reduce::Mode::PRODUCT:
  125. if (!out.empty()) {
  126. fill_op.op->param() = 1;
  127. fill_op.op->exec(out.as_megdnn(), {});
  128. }
  129. break;
  130. case Reduce::Mode::MEAN:
  131. err_msg = "mean";
  132. break;
  133. case Reduce::Mode::MIN:
  134. err_msg = "min";
  135. break;
  136. case Reduce::Mode::MAX:
  137. err_msg = "max";
  138. break;
  139. case Reduce::Mode::SUM_SQR:
  140. err_msg = "sum_sqr";
  141. break;
  142. default:
  143. mgb_throw(MegBrainError, "bad reduce mode");
  144. }
  145. if (!err_msg.empty()) {
  146. mgb_throw(
  147. MegBrainError, "empty input is not allowed for reduce mode: %s",
  148. err_msg.c_str());
  149. }
  150. return {Tensor::make(out)};
  151. }
  152. auto dnn_ten = inputs[0]->dnn_tensor();
  153. dnn_ten.layout = src;
  154. inp_tensornds.push_back(dnn_ten);
  155. megdnn::Workspace dnn_wk;
  156. auto wk_size = dnn_op.op->get_workspace_in_bytes(src, layout);
  157. if (wk_size != 0) {
  158. auto wk = Blob::make(comp_node, wk_size);
  159. dnn_wk.raw_ptr = wk->storage().get();
  160. dnn_wk.size = wk_size;
  161. }
  162. DeviceTensorND out =
  163. BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout);
  164. dnn_op.op->exec(inp_tensornds[0], out.as_megdnn(), dnn_wk);
  165. if (!keepdim && src.ndim > 1) {
  166. auto out_layout = out.layout();
  167. out_layout.remove_axis_inplace(axis);
  168. out_layout.init_contiguous_stride();
  169. out.resize(out_layout);
  170. }
  171. return {Tensor::make(out)};
  172. }
  173. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  174. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  175. auto&& op_def = def.cast_final_safe<Reduce>();
  176. auto axis = op_def.param().axis;
  177. auto keepdim = op_def.keepdim;
  178. size_t size = inputs.size();
  179. SmallVector<LogicalTensorDesc> dests(size);
  180. for (size_t i = 0; i < size; i++) {
  181. if (inputs[i].layout.ndim == 0) {
  182. return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}},
  183. false};
  184. }
  185. }
  186. if (size > 1) {
  187. auto [output_descs, validated] =
  188. proxy_graph_detail::infer_output_attrs_fallible(def, inputs);
  189. if (!inputs[1].value.empty()) {
  190. cg::copy_tensor_value_to_shape(output_descs[0].layout, inputs[1].value);
  191. output_descs[0].layout.init_contiguous_stride();
  192. }
  193. return {output_descs, validated};
  194. }
  195. if (axis < 0) {
  196. axis = inputs[0].layout.ndim + axis;
  197. }
  198. if (axis == INT_MAX || inputs[0].layout.ndim == 1) {
  199. TensorLayout layout{inputs[0].layout.dtype};
  200. layout.shape[0] = 1;
  201. layout.ndim = 1;
  202. dests[0].layout = layout;
  203. dests[0].comp_node = inputs[0].comp_node;
  204. } else {
  205. for (size_t i = 0; i < size; ++i) {
  206. dests[i].comp_node = inputs[i].comp_node;
  207. dests[i].layout = inputs[i].layout;
  208. if (not keepdim && dests[i].layout.ndim > 1) {
  209. dests[i].layout.remove_axis_inplace(axis);
  210. } else {
  211. dests[i].layout.shape[axis] = 1;
  212. }
  213. dests[i].layout.init_contiguous_stride();
  214. }
  215. }
  216. return {dests, true};
  217. }
  218. SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
  219. const OpDef& def, const SmallVector<TensorPtr>& inputs) {
  220. SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
  221. layout_checker[0] = [](const TensorLayout& layout) {
  222. return layout.is_contiguous();
  223. };
  224. return layout_checker;
  225. }
  226. OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
  227. .make_from_op_node(make_from_op_node)
  228. .apply_on_var_node(apply_on_var_node)
  229. .apply_on_physical_tensor(apply_on_physical_tensor)
  230. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  231. .get_input_layout_constraint(get_input_layout_constraint)
  232. .fallback();
  233. } // namespace reduce
  234. } // namespace
  235. } // namespace imperative
  236. } // namespace mgb
  237. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}