You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. /**
  2. * \file dnn/src/naive/argsort/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/naive/argsort/opr_impl.h"
  12. #include "src/naive/handle.h"
  13. #include <cstring>
  14. #include "src/common/utils.h"
  15. using namespace megdnn;
  16. namespace {
  17. template <typename KeyType>
  18. void forward_impl(size_t M, size_t N, const KeyType* sptr, KeyType* dptr,
  19. dt_int32* iptr, bool ascending) {
  20. using KV = std::pair<KeyType, int>;
  21. std::vector<KV> row(N);
  22. rep(m, M) {
  23. rep(n, N) {
  24. row[n].first = sptr[m * N + n];
  25. row[n].second = n;
  26. }
  27. if (ascending) {
  28. std::sort(row.begin(), row.end());
  29. } else {
  30. std::sort(row.begin(), row.end(), std::greater<KV>{});
  31. }
  32. rep(n, N) {
  33. dptr[m * N + n] = row[n].first;
  34. iptr[m * N + n] = row[n].second;
  35. }
  36. }
  37. }
  38. template <typename KeyType>
  39. void backward_impl(size_t dst_h, size_t dst_w, size_t src_w, KeyType* dst,
  40. const KeyType* src_data, const int* src_idx) {
  41. if (src_w != dst_w) {
  42. memset(dst, 0, sizeof(KeyType) * dst_h * dst_w);
  43. }
  44. for (size_t i = 0; i < dst_h; ++i) {
  45. for (size_t j = 0; j < src_w; ++j) {
  46. dst[i * dst_w + src_idx[i * src_w + j]] = src_data[i * src_w + j];
  47. }
  48. }
  49. }
  50. } // anonymous namespace
  51. namespace megdnn {
  52. namespace naive {
  53. void ArgsortForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  54. _megdnn_tensor_out indices,
  55. _megdnn_workspace workspace) {
  56. check_exec(src.layout, dst.layout, indices.layout, workspace.size);
  57. auto M = src.layout.shape[0], N = src.layout.shape[1];
  58. auto iptr = indices.ptr<dt_int32>();
  59. switch (src.layout.dtype.enumv()) {
  60. #define cb(dt) \
  61. case DTypeTrait<dt>::enumv: { \
  62. using ctype = DTypeTrait<dt>::ctype; \
  63. auto sptr = src.ptr<ctype>(); \
  64. auto dptr = dst.ptr<ctype>(); \
  65. MEGDNN_DISPATCH_CPU_KERN_OPR(forward_impl( \
  66. M, N, sptr, dptr, iptr, param().order == Order::ASCENDING)); \
  67. return; \
  68. }
  69. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  70. #undef cb
  71. default:
  72. megdnn_throw("bad dtype");
  73. }
  74. }
  75. void ArgsortBackwardImpl::exec(_megdnn_tensor_in diff,
  76. _megdnn_tensor_in indices,
  77. _megdnn_tensor_out grad,
  78. _megdnn_workspace workspace) {
  79. check_exec(diff.layout, indices.layout, grad.layout, workspace.size);
  80. size_t M = grad.layout.shape[0], N = grad.layout.shape[1],
  81. SRC_W = indices.layout[1];
  82. auto iptr = indices.ptr<dt_int32>();
  83. switch (diff.layout.dtype.enumv()) {
  84. #define cb(dt) \
  85. case DTypeTrait<dt>::enumv: { \
  86. using ctype = DTypeTrait<dt>::ctype; \
  87. auto hptr = diff.ptr<ctype>(); \
  88. auto gptr = grad.ptr<ctype>(); \
  89. MEGDNN_DISPATCH_CPU_KERN_OPR( \
  90. backward_impl(M, N, SRC_W, gptr, hptr, iptr)); \
  91. return; \
  92. }
  93. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  94. #undef cb
  95. default:
  96. megdnn_throw("bad dtype");
  97. }
  98. }
  99. } // namespace naive
  100. } // namespace megdnn
  101. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台