You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

indexing_multi_axis_vec.cpp 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. /**
  2. * \file dnn/src/common/indexing_multi_axis_vec.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs.h"
  12. #include "src/common/utils.h"
  13. using namespace megdnn;
  14. namespace {
  15. size_t get_index_size_for_workspace(
  16. const TensorShape &shp, const size_t *axes, size_t nr_axes) {
  17. size_t idx_axis = axes[0];
  18. megdnn_assert(shp.ndim && nr_axes);
  19. for (size_t i = 1; i < nr_axes; ++ i) {
  20. megdnn_assert(axes[i] > axes[i - 1]);
  21. if (axes[i] != axes[i - 1] + 1) {
  22. idx_axis = 0;
  23. break;
  24. }
  25. }
  26. megdnn_assert(shp.ndim > idx_axis,
  27. "index on the %zuth axis; but shape is %s",
  28. idx_axis, shp.to_string().c_str());
  29. return shp.shape[idx_axis];
  30. }
  31. } // anonymous namespace
  32. IndexingMultiAxisVecBase::IndexDescLayoutOnly
  33. IndexingMultiAxisVecBase::extract_index_layout(const IndexDesc &index) {
  34. IndexDescLayoutOnly ret(index.size());
  35. for (size_t i = 0; i < index.size(); ++ i) {
  36. ret[i].layout = index[i].vec.layout;
  37. ret[i].axis = index[i].axis;
  38. }
  39. return ret;
  40. }
  41. size_t IndexingMultiAxisVecBase::deduce_layout_fwd(
  42. const TensorLayout &data,
  43. const IndexDescLayoutOnly &index,
  44. TensorLayout &dst) {
  45. megdnn_assert(!index.empty());
  46. megdnn_assert(data.ndim >= index.size());
  47. dst.ndim = data.ndim - index.size() + 1;
  48. dst.shape[0] = 1;
  49. dst.dtype = data.dtype;
  50. auto brdcast = [&](const TensorLayout &ly) {
  51. if (ly.ndim != 1)
  52. return false;
  53. if (dst.shape[0] == ly.shape[0])
  54. return true;
  55. if (dst.shape[0] == 1) {
  56. dst.shape[0] = ly.shape[0];
  57. return true;
  58. }
  59. return ly.shape[0] == 1;
  60. };
  61. size_t dst_axis = 1;
  62. ptrdiff_t prev_axis = -1;
  63. for (size_t axis = 0; axis < index.size(); ++ axis) {
  64. auto &&idx = index[axis];
  65. megdnn_assert(idx.layout.dtype == dtype::Int32(),
  66. "invalid index dtype: %s", idx.layout.dtype.name());
  67. megdnn_assert(idx.axis < data.ndim &&
  68. static_cast<ptrdiff_t>(idx.axis) > prev_axis,
  69. "index %zu requests invalid axis %zu", axis, idx.axis);
  70. auto brd_succ = brdcast(idx.layout);
  71. megdnn_assert(brd_succ, "invalid layout at index %zu: %s",
  72. axis, idx.layout.to_string().c_str());
  73. for (size_t i = prev_axis + 1; i < idx.axis; ++ i) {
  74. dst.shape[dst_axis ++] = data.shape[i];
  75. }
  76. prev_axis = idx.axis;
  77. }
  78. for (size_t i = prev_axis + 1; i < data.ndim; ++ i) {
  79. dst.shape[dst_axis ++] = data.shape[i];
  80. }
  81. megdnn_assert(dst_axis == dst.ndim);
  82. size_t idx_axis = 0;
  83. {
  84. // fix idx_axis if index contains consecutive axes
  85. bool contig_idx = true;
  86. for (size_t i = 1; i < index.size(); ++ i) {
  87. if (index[i].axis != index[i - 1].axis + 1) {
  88. contig_idx = false;
  89. break;
  90. }
  91. }
  92. if (contig_idx) {
  93. auto shp0 = dst.shape[0];
  94. idx_axis = index[0].axis;
  95. for (size_t i = 0; i < idx_axis; ++ i) {
  96. dst.shape[i] = dst.shape[i + 1];
  97. }
  98. dst.shape[idx_axis] = shp0;
  99. }
  100. }
  101. dst.init_contiguous_stride();
  102. return idx_axis;
  103. }
  104. size_t IndexingMultiAxisVecBase::get_nonindex_axes(
  105. size_t src_ndim, const IndexDesc &index, size_t *out) {
  106. auto iter = index.begin();
  107. size_t nr = 0;
  108. for (size_t i = 0; i < src_ndim; ++ i) {
  109. if (iter != index.end() && i == iter->axis) {
  110. ++ iter;
  111. } else {
  112. out[nr ++] = i;
  113. }
  114. }
  115. megdnn_assert(nr + index.size() == src_ndim && iter == index.end());
  116. return nr;
  117. }
  118. IndexingMultiAxisVecBase::ExecInfo
  119. IndexingMultiAxisVecBase::check_exec_noworkspace(
  120. const TensorLayout &data, const TensorLayout &value,
  121. const IndexDesc &index, IndexDescLayoutOnly &index_layout) {
  122. ExecInfo ret;
  123. index_layout = extract_index_layout(index);
  124. TensorLayout value_expect;
  125. ret.idx_axis = deduce_layout_fwd(data, index_layout, value_expect);
  126. megdnn_assert_eq_shape(value_expect, value);
  127. auto value_contig = value.collapse_contiguous();
  128. megdnn_assert(value_contig.ndim == 1,
  129. "value layout must be 1-dim contiguous; got %s",
  130. value.to_string().c_str());
  131. ret.value_stride = value_contig.stride[0];
  132. return ret;
  133. }
  134. std::pair<TensorLayout, size_t>
  135. IndexingMultiAxisVecBase::get_value_iter_optimized_layout(
  136. const TensorLayout &data, const TensorLayout &value,
  137. const IndexDesc &index, size_t idx_axis) {
  138. size_t data_axes[TensorLayout::MAX_NDIM],
  139. nr_axes = get_nonindex_axes(data.ndim, index, data_axes);
  140. megdnn_assert(nr_axes == value.ndim - 1 && idx_axis < value.ndim &&
  141. nr_axes + index.size() == data.ndim);
  142. TensorLayout ret;
  143. if (idx_axis) {
  144. ret.ndim = idx_axis;
  145. for (size_t i = 0; i < idx_axis; ++ i) {
  146. ret.shape[i] = data.shape[data_axes[i]];
  147. ret.stride[i] = data.stride[data_axes[i]];
  148. }
  149. ret = ret.collapse_contiguous();
  150. }
  151. ret.shape[ret.ndim] = value.shape[idx_axis];
  152. ret.stride[ret.ndim] = 0;
  153. size_t ret_idx_axis = ret.ndim;
  154. ++ ret.ndim;
  155. if (idx_axis < nr_axes) {
  156. TensorLayout tail;
  157. tail.ndim = nr_axes - idx_axis;
  158. for (size_t i = idx_axis; i < nr_axes; ++ i) {
  159. tail.shape[i - idx_axis] = data.shape[data_axes[i]];
  160. tail.stride[i - idx_axis] = data.stride[data_axes[i]];
  161. }
  162. tail = tail.collapse_contiguous();
  163. for (size_t i = 0; i < tail.ndim; ++ i) {
  164. ret.shape[ret.ndim] = tail.shape[i];
  165. ret.stride[ret.ndim] = tail.stride[i];
  166. ++ ret.ndim;
  167. }
  168. }
  169. return {ret, ret_idx_axis};
  170. }
  171. size_t IndexingMultiAxisVec::get_workspace_in_bytes(
  172. const TensorShape &dst, const size_t *axes, size_t nr_axes) {
  173. return get_workspace_in_bytes(
  174. get_index_size_for_workspace(dst, axes, nr_axes));
  175. }
  176. IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec(
  177. const TensorLayout &src, const IndexDesc &index,
  178. const TensorLayout &dst, size_t workspace_in_bytes) {
  179. IndexDescLayoutOnly index_layout;
  180. auto ret = check_exec_noworkspace(src, dst, index, index_layout);
  181. megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(
  182. dst.shape[ret.idx_axis]));
  183. megdnn_assert(ret.value_stride, "dst must be non-overlapping");
  184. return ret;
  185. }
  186. size_t IndexingModifyMultiAxisVecBase::get_workspace_in_bytes(
  187. const TensorShape &value, const size_t *axes, size_t nr_axes) {
  188. return get_workspace_in_bytes(
  189. get_index_size_for_workspace(value, axes, nr_axes));
  190. }
  191. IndexingModifyMultiAxisVecBase::ExecInfo
  192. IndexingModifyMultiAxisVecBase::check_exec(
  193. const TensorLayout &data, const TensorLayout &value,
  194. const IndexDesc &index, size_t workspace_in_bytes) {
  195. megdnn_assert(data.is_non_overlapping_strong(),
  196. "data layout should not overlap: %s", data.to_string().c_str());
  197. IndexDescLayoutOnly index_layout;
  198. auto ret = check_exec_noworkspace(data, value, index, index_layout);
  199. megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(
  200. value.shape[ret.idx_axis]));
  201. return ret;
  202. }
  203. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台