You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing.cpp 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. #include "../dnn_op_helper.h"
  2. #include "megbrain/imperative/ops/autogen.h"
  3. #include "../op_trait.h"
  4. #include "megbrain/opr/indexing.h"
  5. #include "megdnn/oprs/general.h"
  6. namespace mgb {
  7. namespace imperative {
  8. namespace {
  9. namespace indexing_one_hot {
  10. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  11. const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
  12. auto&& op = def.cast_final_safe<IndexingOneHot>();
  13. mgb_assert(input_descs.size() == 2, "IndexingOneHot expects two inputs");
  14. auto comp_node = input_descs[0].comp_node;
  15. TensorLayout src = input_descs[0].layout, index = input_descs[1].layout;
  16. mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32");
  17. if (!src.ndim) {
  18. return {{{{{}, src.dtype}, comp_node}}, false};
  19. }
  20. mgb_assert(src.ndim >= 2, "src ndim must be at least 2");
  21. mgb_assert(src.is_contiguous(), "src should be contiguous");
  22. mgb_assert(
  23. -static_cast<int>(src.ndim) <= op.axis &&
  24. op.axis < static_cast<int>(src.ndim),
  25. "axis %d not exists in src", op.axis);
  26. int real_axis = static_cast<int>(op.axis);
  27. if (real_axis < 0) {
  28. real_axis += static_cast<int>(src.ndim);
  29. }
  30. TensorLayout dst = src;
  31. dst.shape[real_axis] = 1;
  32. dst.init_contiguous_stride();
  33. if (!index.ndim) {
  34. return {{{dst, comp_node}}, false};
  35. }
  36. mgb_assert(index.is_contiguous(), "index should be all contiguous");
  37. mgb_assert(
  38. index.eq_shape(src.remove_axis(real_axis)),
  39. "index shape doesn't match src");
  40. return {{{dst, comp_node}}, true};
  41. }
  42. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  43. auto&& op = def.cast_final_safe<IndexingOneHot>();
  44. mgb_assert(inputs.size() == 2);
  45. int real_axis = static_cast<int>(op.axis);
  46. if (real_axis < 0) {
  47. real_axis += static_cast<int>(op.ndim);
  48. }
  49. OperatorNodeConfig config{op.make_name()};
  50. return opr::IndexingOneHot::make(inputs[0], inputs[1], real_axis, config);
  51. }
  52. SmallVector<TensorPtr> apply_on_physical_tensor(
  53. const OpDef& def, SmallVector<TensorPtr> inputs,
  54. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  55. auto&& op = def.cast_final_safe<IndexingOneHot>();
  56. auto&& inp = inputs[0];
  57. auto&& index = inputs[1];
  58. TensorLayout layout = inp->layout();
  59. TensorLayout index_layout = index->layout();
  60. DnnOprCaller<megdnn::IndexingOneHot> dnn_op(inp->comp_node());
  61. auto&& indexing_one_hot_param = dnn_op.op->param();
  62. int real_axis = static_cast<int>(op.axis);
  63. if (real_axis < 0) {
  64. real_axis += static_cast<int>(layout.ndim);
  65. }
  66. mgb_assert(
  67. 0 <= real_axis && real_axis < static_cast<int>(layout.ndim),
  68. "Dimension out of range (expected to be in range of [%d, %d], but got %d)",
  69. 0, static_cast<int>(layout.ndim) - 1, op.axis);
  70. indexing_one_hot_param = real_axis;
  71. TensorLayout tlayout;
  72. dnn_op.op->deduce_layout(layout, index_layout, tlayout);
  73. TensorPtr out = Tensor::make(tlayout, inp->comp_node());
  74. megdnn::TensorND in = inp->dnn_tensor();
  75. megdnn::TensorND ind = index->dnn_tensor();
  76. TensorLayout m_layout(
  77. {dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)},
  78. dtype::Byte());
  79. auto dnn_workspace = dnn_op.create_workspace(m_layout);
  80. dnn_op.op->exec(in, ind, out->dnn_tensor(), dnn_workspace);
  81. return {out};
  82. }
  83. OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
  84. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  85. .apply_on_var_node(apply_on_var_node)
  86. .apply_on_physical_tensor(apply_on_physical_tensor)
  87. .fallback();
  88. } // namespace indexing_one_hot
  89. namespace indexing_set_one_hot {
  90. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  91. const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
  92. mgb_assert(input_descs.size() == 3, "IndexingSetOneHot expects three inputs");
  93. auto comp_node = input_descs[0].comp_node;
  94. TensorLayout src = input_descs[0].layout, index = input_descs[1].layout;
  95. mgb_assert(index.dtype == dtype::Int32(), "index dtype must be int32");
  96. if (!src.ndim) {
  97. return {{{{{}, src.dtype}, comp_node}}, false};
  98. }
  99. mgb_assert(src.is_contiguous(), "src should be contiguous");
  100. return {{input_descs[0]}, true};
  101. }
  102. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  103. auto&& op = static_cast<const IndexingSetOneHot&>(def);
  104. mgb_assert(inputs.size() == 3);
  105. int real_axis = static_cast<int>(op.axis);
  106. if (real_axis < 0) {
  107. real_axis += static_cast<int>(op.ndim);
  108. }
  109. OperatorNodeConfig config{op.make_name()};
  110. return opr::IndexingSetOneHot::make(
  111. inputs[0], inputs[1], inputs[2], real_axis, config);
  112. }
  113. SmallVector<TensorPtr> apply_on_physical_tensor(
  114. const OpDef& def, SmallVector<TensorPtr> inputs,
  115. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  116. auto&& op = def.cast_final_safe<IndexingSetOneHot>();
  117. auto&& inp = inputs[0];
  118. auto&& index = inputs[1];
  119. auto&& sub = inputs[2];
  120. TensorLayout layout = inp->layout();
  121. TensorLayout index_layout = index->layout();
  122. TensorLayout tlayout = sub->layout();
  123. mgb_assert(layout.is_contiguous());
  124. DnnOprCaller<megdnn::IndexingSetOneHot> dnn_op(inp->comp_node());
  125. auto&& indexing_one_hot_param = dnn_op.op->param();
  126. int real_axis = static_cast<int>(op.axis);
  127. if (real_axis < 0) {
  128. real_axis += static_cast<int>(layout.ndim);
  129. }
  130. indexing_one_hot_param = real_axis;
  131. TensorPtr out = Tensor::make(layout, inp->comp_node());
  132. out->dev_tensor().copy_from_fixlayout(inp->dev_tensor());
  133. megdnn::TensorND in = inp->dnn_tensor();
  134. megdnn::TensorND ind = index->dnn_tensor();
  135. megdnn::TensorND su = sub->dnn_tensor();
  136. TensorLayout m_layout(
  137. {dnn_op.op->get_workspace_in_bytes(layout, index_layout, tlayout)},
  138. dtype::Byte());
  139. auto dnn_workspace = dnn_op.create_workspace(m_layout);
  140. dnn_op.op->exec(out->dnn_tensor(), ind, su, dnn_workspace);
  141. return {out};
  142. }
  143. OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
  144. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  145. .apply_on_var_node(apply_on_var_node)
  146. .apply_on_physical_tensor(apply_on_physical_tensor)
  147. .fallback();
  148. } // namespace indexing_set_one_hot
  149. } // anonymous namespace
  150. } // namespace imperative
  151. } // namespace mgb
  152. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}