You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. #include "megbrain/imperative/opr_utility.h"
  2. #include "megbrain/imperative/ops/autogen.h"
  3. #include "megbrain/opr/basic_arith.h"
  4. #include "megbrain/opr/utility.h"
  5. #include "../blob_manager_impl.h"
  6. #include "../dnn_op_helper.h"
  7. #include "../op_trait.h"
  8. namespace mgb {
  9. namespace imperative {
  10. namespace {
  11. std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
  12. auto* node = &node_->cast_final_safe<opr::Elemwise>();
  13. return Elemwise::make(node->param().mode);
  14. }
  15. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  16. auto&& elemwise_opr = def.cast_final_safe<Elemwise>();
  17. OperatorNodeConfig config{elemwise_opr.make_name()};
  18. return opr::Elemwise::make(inputs, elemwise_opr.mode, config);
  19. }
  20. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  21. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  22. auto&& op_def = def.cast_final_safe<Elemwise>();
  23. auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
  24. mgb_assert(
  25. inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually",
  26. trait.name, trait.arity, inputs.size());
  27. TensorShapeArray inp_shapes;
  28. DType out_dt;
  29. CompNode out_cn;
  30. for (size_t i = 0; i < inputs.size(); ++i) {
  31. auto&& t = inputs[i];
  32. if (!i) {
  33. out_cn = t.comp_node;
  34. out_dt = t.layout.dtype;
  35. } else {
  36. mgb_assert(t.comp_node == out_cn);
  37. mgb_assert(t.layout.dtype == out_dt);
  38. }
  39. if (t.layout.ndim > 0) {
  40. inp_shapes.push_back(t.layout);
  41. } else {
  42. TensorLayout out_layout;
  43. out_layout.ndim = 0;
  44. out_layout.dtype = out_dt;
  45. return {{{out_layout, out_cn}}, false};
  46. }
  47. }
  48. // copy from megdnn::ElemwiseForward::check_dtype
  49. switch (out_dt.category()) {
  50. case DTypeCategory::FLOAT:
  51. mgb_assert(trait.allow_float, "unsupport mode %s for float\n", trait.name);
  52. break;
  53. case DTypeCategory::INT:
  54. mgb_assert(trait.allow_int, "unsupport mode %s for int\n", trait.name);
  55. break;
  56. case DTypeCategory::BOOL:
  57. mgb_assert(trait.allow_bool, "unsupport mode %s for bool\n", trait.name);
  58. break;
  59. default:
  60. // Quantized Dtype could also be handled by this op,
  61. // but scales need to be the same.
  62. break;
  63. }
  64. auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
  65. return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
  66. }
  67. DispatchMode decide_dispatch_mode(
  68. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  69. bool host_computable = true;
  70. constexpr int size_threshhold = TensorShape::MAX_NDIM;
  71. for (auto&& inp : inputs) {
  72. if (inp.value.empty() || inp.value.layout().ndim == 0 ||
  73. inp.value.layout().total_nr_elems() > size_threshhold) {
  74. host_computable = false;
  75. break;
  76. }
  77. }
  78. return host_computable ? DEFAULT_CPU : KERNEL;
  79. }
  80. void apply_on_device_tensornd(
  81. const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
  82. SmallVector<DeviceTensorND>* outputs) {
  83. auto&& op_def = def.cast_final_safe<Elemwise>();
  84. auto&& trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
  85. mgb_assert(
  86. inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually",
  87. trait.name, trait.arity, inputs.size());
  88. DnnOprCaller<megdnn::Elemwise> dnn_opr(inputs[0].comp_node());
  89. opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr.op);
  90. }
  91. SmallVector<TensorPtr> apply_on_physical_tensor(
  92. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  93. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  94. auto comp_node = inputs[0]->comp_node();
  95. using Mode = Elemwise::Mode;
  96. using TensorND = megdnn::TensorND;
  97. auto&& op_def = def.cast_final_safe<Elemwise>();
  98. SmallVector<TensorND> inp_tensornds;
  99. TensorShapeArray inp_shapes(inputs.size());
  100. inp_tensornds.reserve(inputs.size());
  101. TensorLayout layout{inputs[0]->layout().dtype};
  102. bool is_empty = false;
  103. for (unsigned i = 0; i < inputs.size(); ++i) {
  104. if (inputs[i]->layout().is_empty()) {
  105. is_empty = true;
  106. }
  107. inp_tensornds.push_back(inputs[i]->dnn_tensor());
  108. inp_shapes[i] = inputs[i]->layout();
  109. }
  110. megdnn::Elemwise::deduce_shape(inp_shapes, layout);
  111. layout.init_contiguous_stride();
  112. DeviceTensorND out =
  113. BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout);
  114. if (is_empty) {
  115. return {Tensor::make(out)};
  116. }
  117. DnnOprCaller<megdnn::Elemwise> dnn_opr(comp_node);
  118. dnn_opr.op->param() = op_def.param();
  119. if (dnn_opr.op->param().mode == Mode::FUSE_MUL_ADD3 ||
  120. dnn_opr.op->param().mode == Mode::FUSE_MUL_ADD4 ||
  121. (inp_tensornds.size() &&
  122. inp_tensornds[0].layout.dtype.category() == DTypeCategory::QUANTIZED)) {
  123. opr::Elemwise::perform_dnn(comp_node, out, inp_tensornds, dnn_opr.op);
  124. } else {
  125. dnn_opr.op->exec(inp_tensornds, out.as_megdnn());
  126. }
  127. return {Tensor::make(out)};
  128. }
  129. MGB_DEFINE_OPR_CLASS(
  130. ForceInplaceElemwise,
  131. cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) // {
  132. public:
  133. struct Param {
  134. using Mode = megdnn::Elemwise::Param::Mode;
  135. Mode mode;
  136. size_t inplace_index;
  137. };
  138. using Mode = Param::Mode;
  139. ForceInplaceElemwise(
  140. const VarNodeArray& inputs, Param param, OperatorNodeConfig config = {})
  141. : Super(inputs[0]->owner_graph(), config, "device_add_update", inputs),
  142. m_param{param} {
  143. for (auto* input : inputs) {
  144. add_input({input});
  145. }
  146. add_output(None)
  147. ->set_fwd_in2out_writable_force(input(param.inplace_index))
  148. .add_flag(VarNode::Flag::NO_MEM_RECLAIM);
  149. }
  150. static SymbolVar make(const VarNodeArray& inputs, Param param) {
  151. return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>(
  152. inputs, param);
  153. }
  154. static cg::OperatorNodeBase* shallow_copy(
  155. const serialization::OprShallowCopyContext& ctx,
  156. const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
  157. const OperatorNodeConfig& config);
  158. protected:
  159. NodeProp* do_make_node_prop() const override {
  160. auto ret = Super::do_make_node_prop();
  161. ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
  162. return ret;
  163. }
  164. void create_megdnn_opr() override {
  165. auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node());
  166. opr->param().mode = m_param.mode;
  167. set_megdnn_opr(std::move(opr));
  168. }
  169. void scn_do_execute() override {
  170. auto to_dnnnd = [&](auto* var) { return var->dev_tensor().as_megdnn(); };
  171. megdnn::TensorNDArray inputs_dnnnd;
  172. for (auto* input : input()) {
  173. inputs_dnnnd.push_back(to_dnnnd(input));
  174. }
  175. mgb_assert(
  176. input(m_param.inplace_index)
  177. ->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC),
  178. "ForceInplaceElemwise cannot be applied in internal tensor");
  179. auto* out_dest = output(0);
  180. auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr());
  181. opr->exec(std::move(inputs_dnnnd), to_dnnnd(out_dest));
  182. }
  183. void init_output_static_infer_desc() override {
  184. using namespace cg::static_infer;
  185. owner_graph()->static_infer_manager().register_shape_infer(
  186. output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index)));
  187. }
  188. private:
  189. Param m_param;
  190. void record_execute_deps(ExecDependencyArray& deps) override {
  191. record_megdnn_opr(deps);
  192. }
  193. };
  194. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise);
  195. cg::OperatorNodeBase* ForceInplaceElemwise::shallow_copy(
  196. const serialization::OprShallowCopyContext& ctx,
  197. const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
  198. const OperatorNodeConfig& config) {
  199. auto&& opr = opr_.cast_final_safe<ForceInplaceElemwise>();
  200. auto* graph = ctx.owner_graph(opr, inputs);
  201. return graph->insert_opr(
  202. std::make_unique<ForceInplaceElemwise>(inputs, opr.m_param, config));
  203. }
  204. MGB_REG_OPR_SHALLOW_COPY(ForceInplaceElemwise, ForceInplaceElemwise::shallow_copy);
  205. cg::OperatorNodeBase* apply_inplace_add_on_var_node(
  206. const OpDef& def, const VarNodeArray& inputs) {
  207. auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3];
  208. auto mode = ForceInplaceElemwise::Param::Mode::FUSE_MUL_ADD4;
  209. return ForceInplaceElemwise::make({alpha, dest, beta, delta}, {mode, 1})
  210. .node()
  211. ->owner_opr();
  212. }
  213. SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
  214. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  215. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  216. auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3];
  217. if (!inputs[0]->storage_is_unique()) {
  218. mgb_log_warn(
  219. "This inplace modification may change the elements of other tensors. "
  220. "Fallback to non-inplace update.");
  221. DeviceTensorStorage storage;
  222. storage.reset(dest->comp_node(), dest->blob()->size(), dest->blob()->storage());
  223. storage = storage.sub(dest->offset());
  224. DeviceTensorND dv;
  225. dv.reset(storage, dest->layout());
  226. DeviceTensorND dv_new;
  227. dv_new.copy_from(dv);
  228. dest = Tensor::make(dv_new);
  229. }
  230. auto tensor_to_scalar = [](const TensorPtr& tensor) -> float {
  231. return *tensor->get_value().ptr<float>();
  232. };
  233. DnnOprCaller<megdnn::AddUpdate> caller{dest->comp_node()};
  234. caller.op->param() = {tensor_to_scalar(alpha), tensor_to_scalar(beta)};
  235. caller.op->exec(dest->dev_tensor().as_megdnn(), delta->dev_tensor().as_megdnn());
  236. return {std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout())};
  237. }
  238. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible(
  239. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  240. mgb_assert(inputs.size() == 4, "invalid input number for inplace_add");
  241. CompNode cn;
  242. for (auto&& input : inputs) {
  243. if (!cn.valid()) {
  244. cn = input.comp_node;
  245. } else {
  246. mgb_assert(input.comp_node == cn, "inputs should be in same comp_node");
  247. }
  248. }
  249. auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3];
  250. bool succeed = dest.layout.ndim != 0;
  251. if (succeed) {
  252. mgb_assert(
  253. delta.layout.ndim == 0 || dest.layout.eq_shape(delta.layout),
  254. "dest and delta must have same shape");
  255. mgb_assert(
  256. alpha.layout.ndim == 0 || alpha.layout.eq_shape({1}),
  257. "alpha should be scalar");
  258. mgb_assert(
  259. beta.layout.ndim == 0 || beta.layout.eq_shape({1}),
  260. "beta should be scalar");
  261. }
  262. mgb_assert(alpha.layout.dtype == dtype::Float32(), "alpha should be float32");
  263. mgb_assert(beta.layout.dtype == dtype::Float32(), "beta should be float32");
  264. // inplace op result's desc value is changed
  265. return {{{dest.layout, dest.comp_node}}, succeed};
  266. }
  267. OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
  268. .make_from_op_node(make_from_op_node)
  269. .decide_dispatch_mode(decide_dispatch_mode)
  270. .apply_on_var_node(apply_on_var_node)
  271. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  272. .apply_on_device_tensornd(apply_on_device_tensornd)
  273. .apply_on_physical_tensor(apply_on_physical_tensor)
  274. .fallback();
  275. OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate)
  276. .apply_on_var_node(apply_inplace_add_on_var_node)
  277. .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor)
  278. .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible)
  279. .fallback();
  280. } // anonymous namespace
  281. } // namespace imperative
  282. } // namespace mgb
  283. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}