You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. /**
  2. * \file dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./opr_impl.h"
  12. #include "./kern.cuh"
  13. #include "src/cuda/utils.h"
  14. #include "src/common/indexing_multi_axis_vec_kdef.h"
  15. using namespace megdnn;
  16. using namespace cuda;
  17. using namespace indexing_multi_axis_vec;
  18. namespace {
  19. class ExecImplHelper {
  20. template<int nidx>
  21. void dispatch_gen_offset_base_nidx();
  22. void dispatch_gen_offset_base();
  23. protected:
  24. using IndexDesc = IndexingMultiAxisVec::IndexDesc;
  25. using ExecInfo = IndexingMultiAxisVec::ExecInfo;
  26. cudaStream_t m_stream;
  27. const TensorND * const m_data;
  28. const TensorND * const m_value;
  29. const IndexDesc * const m_index;
  30. const ExecInfo* const m_exec_info;
  31. int * const m_offset_base;
  32. TensorLayout m_value_layout_on_data;
  33. size_t m_idx_axis;
  34. int m_value_stride;
  35. public:
  36. ExecImplHelper(const TensorND &data, const TensorND &value,
  37. const IndexDesc &index, const Workspace &workspace,
  38. const ExecInfo &exec_info, cudaStream_t stream);
  39. };
  40. template<class Opr>
  41. class ExecImpl : public ExecImplHelper {
  42. void dispatch_exec();
  43. template<typename ctype>
  44. void dispatch_exec_ctype();
  45. template<typename ctype, int ndim>
  46. void dispatch_exec_ctype_ndim();
  47. public:
  48. using ExecImplHelper::ExecImplHelper;
  49. void operator() () {
  50. dispatch_exec();
  51. after_kernel_launch();
  52. }
  53. };
  54. } // anonymous namespace
  55. ExecImplHelper::ExecImplHelper(const TensorND &data, const TensorND &value,
  56. const IndexDesc &index, const Workspace &workspace,
  57. const ExecInfo &exec_info, cudaStream_t stream):
  58. m_stream{stream}, m_data{&data}, m_value{&value}, m_index{&index},
  59. m_exec_info{&exec_info}, m_offset_base{workspace.ptr<int>()}
  60. {
  61. safe_size_in_kern(data.layout.total_nr_elems());
  62. dispatch_gen_offset_base();
  63. std::tie(m_value_layout_on_data, m_idx_axis) =
  64. IndexingMultiAxisVec::get_value_iter_optimized_layout(
  65. data.layout, value.layout, index, exec_info.idx_axis);
  66. m_value_stride = exec_info.value_stride;
  67. }
  68. template<int nidx>
  69. void ExecImplHelper::dispatch_gen_offset_base_nidx() {
  70. GenOffsetBaseParam<nidx> param;
  71. param.size = m_value->layout.shape[m_exec_info->idx_axis];
  72. param.output = m_offset_base;
  73. param.error_tracker = m_exec_info->error_tracker;
  74. param.error_info = m_exec_info->error_info;
  75. for (int i = 0; i < nidx; ++ i) {
  76. auto &&dst = param.indexer[i];
  77. auto &&src = m_index->operator[](i);
  78. megdnn_assert(src.vec.layout.ndim == 1);
  79. dst.stride = src.vec.layout.stride[0];
  80. if (src.vec.layout.shape[0] == 1) {
  81. dst.stride = 0;
  82. }
  83. dst.ptr = src.vec.ptr<int>();
  84. param.data_shape[i] = m_data->layout.shape[src.axis];
  85. param.data_stride[i] = m_data->layout.stride[src.axis];
  86. }
  87. gen_offset_base(param, m_stream);
  88. }
  89. void ExecImplHelper::dispatch_gen_offset_base() {
  90. switch(m_index->size()) {
  91. #define cb(_n) case _n: return dispatch_gen_offset_base_nidx<_n>();
  92. MEGDNN_FOREACH_TENSOR_NDIM(cb)
  93. #undef cb
  94. }
  95. megdnn_throw("bad index size");
  96. }
  97. template<class Opr>
  98. void ExecImpl<Opr>::dispatch_exec() {
  99. switch (m_data->layout.dtype.enumv()) {
  100. #define cb(_dtype) \
  101. case DTypeTrait<_dtype>::enumv: \
  102. return dispatch_exec_ctype<DTypeTrait<_dtype>::ctype>();
  103. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  104. #undef cb
  105. default:
  106. megdnn_throw("bad dtype");
  107. }
  108. }
  109. template<class Opr>
  110. template<typename ctype>
  111. void ExecImpl<Opr>::dispatch_exec_ctype() {
  112. switch (m_value_layout_on_data.ndim) {
  113. #define cb(_n) \
  114. case _n: return dispatch_exec_ctype_ndim<ctype, _n>();
  115. MEGDNN_FOREACH_TENSOR_NDIM(cb)
  116. #undef cb
  117. default:
  118. megdnn_throw("bad data ndim");
  119. }
  120. }
  121. template<class Opr>
  122. template<typename ctype, int ndim>
  123. void ExecImpl<Opr>::dispatch_exec_ctype_ndim() {
  124. ApplyOprParam<ctype, ndim> param;
  125. param.tot_size = safe_size_in_kern(m_value->layout.total_nr_elems());
  126. param.offset_base = m_offset_base;
  127. param.data = m_data->ptr<ctype>();
  128. param.value = m_value->ptr<ctype>();
  129. param.idx_axis = m_idx_axis;
  130. param.value_stride = m_value_stride;
  131. for (int i = 0; i < ndim; ++ i) {
  132. param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i];
  133. if (i) {
  134. param.value_ly_on_data.shape[i - 1] =
  135. m_value_layout_on_data.shape[i];
  136. }
  137. }
  138. apply_opr<ctype, ndim, Opr>(param, m_stream);
  139. }
  140. size_t IndexingMultiAxisVecImpl::get_workspace_in_bytes(size_t dst_idx_size) {
  141. return dst_idx_size * sizeof(int);
  142. }
  143. void IndexingMultiAxisVecImpl::exec(
  144. _megdnn_tensor_in src, const IndexDesc &index,
  145. _megdnn_tensor_out dst,
  146. _megdnn_workspace workspace) {
  147. auto info = check_exec(src.layout, index, dst.layout, workspace.size);
  148. info.error_tracker = m_error_tracker;
  149. info.error_info = async_error_info(handle());
  150. ExecImpl<indexing_multi_axis_vec_kdef::OprFwd>{
  151. src, dst, index, workspace, info, cuda_stream(handle())}();
  152. }
  153. size_t IndexingSetMultiAxisVecImpl::get_workspace_in_bytes(
  154. size_t value_idx_size) {
  155. return value_idx_size * sizeof(int);
  156. }
  157. void IndexingSetMultiAxisVecImpl::exec(
  158. _megdnn_tensor_inout data, _megdnn_tensor_in value,
  159. const IndexDesc &index, _megdnn_workspace workspace) {
  160. auto info = check_exec(data.layout, value.layout, index, workspace.size);
  161. info.error_tracker = m_error_tracker;
  162. info.error_info = async_error_info(handle());
  163. ExecImpl<indexing_multi_axis_vec_kdef::OprSet>{
  164. data, value, index, workspace, info, cuda_stream(handle())}();
  165. }
  166. size_t IndexingIncrMultiAxisVecImpl::get_workspace_in_bytes(
  167. size_t value_idx_size) {
  168. return value_idx_size * sizeof(int);
  169. }
  170. void IndexingIncrMultiAxisVecImpl::exec(
  171. _megdnn_tensor_inout data, _megdnn_tensor_in value,
  172. const IndexDesc &index, _megdnn_workspace workspace) {
  173. MEGDNN_INC_FLOAT16(
  174. megdnn_assert(data.layout.dtype != dtype::Float16(),
  175. "float16 incr on cuda currently not supported"));
  176. auto info = check_exec(data.layout, value.layout, index, workspace.size);
  177. info.error_tracker = m_error_tracker;
  178. info.error_info = async_error_info(handle());
  179. ExecImpl<OprAtomicIncr>{data, value, index, workspace, info,
  180. cuda_stream(handle())}();
  181. }
  182. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台