You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tile_repeat_helper.cpp 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. #include "src/common/tile_repeat_helper.h"
  2. #include <cstring>
  3. #include "src/common/utils.h"
  4. namespace megdnn {
  5. // Tile (m, n) to (m, n*times) or Repeat (m, n) to (m*times, n)
  6. template <typename T>
  7. void tile_or_repeat_single_axis(
  8. const T* __restrict src, T* __restrict dst, const size_t m, const size_t n,
  9. const size_t times) {
  10. rep(i, m) {
  11. // copy Ts of length n to dst
  12. std::memcpy(dst, src, sizeof(T) * n);
  13. size_t k = 1u;
  14. while (k * 2 <= times) {
  15. std::memcpy(dst + k * n, dst, sizeof(T) * (k * n));
  16. k *= 2;
  17. }
  18. if (k < times) {
  19. std::memcpy(dst + k * n, dst, sizeof(T) * (times - k) * n);
  20. }
  21. src += n;
  22. dst += n * times;
  23. }
  24. }
  25. template <typename T>
  26. void init_tile_repeat_state(
  27. const T* src, T* dst, T* workspace0, T* /* workspace1 */, T*& current, T*& next,
  28. size_t& state, size_t nr_reduces) {
  29. current = const_cast<T*>(src);
  30. if (nr_reduces == 1) {
  31. next = dst;
  32. } else {
  33. next = workspace0;
  34. }
  35. state = 0;
  36. }
  37. template <typename T>
  38. void update_tile_repeat_state(
  39. const T* /* src */, T* dst, T* workspace0, T* workspace1, T*& current, T*& next,
  40. size_t& state, size_t nr_reduces) {
  41. current = next;
  42. if (nr_reduces == 1) {
  43. next = nullptr;
  44. } else if (nr_reduces == 2) {
  45. if (state == 0) {
  46. next = dst;
  47. } else {
  48. next = nullptr;
  49. }
  50. } else {
  51. if (state == 0) {
  52. next = workspace1;
  53. } else if (state + 1 == nr_reduces) {
  54. next = nullptr;
  55. } else if (state + 2 == nr_reduces) {
  56. next = dst;
  57. } else {
  58. megdnn_assert(
  59. current == workspace0 || current == workspace1,
  60. "Impossible happened; internal bug.");
  61. next = (current == workspace0 ? workspace1 : workspace0);
  62. }
  63. }
  64. ++state;
  65. }
  66. #define INST(T) \
  67. template void tile_or_repeat_single_axis<T>( \
  68. const T*, T*, const size_t, const size_t, const size_t); \
  69. template void init_tile_repeat_state<T>( \
  70. const T*, T*, T*, T*, T*&, T*&, size_t&, size_t); \
  71. template void update_tile_repeat_state<T>( \
  72. const T*, T*, T*, T*, T*&, T*&, size_t&, size_t);
  73. #define INST_DT(d) INST(DTypeTrait<d>::ctype)
  74. MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT)
  75. } // namespace megdnn
  76. // vim: syntax=cpp.doxygen