You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

argmxx_helper.h 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. /**
  2. * \file dnn/src/common/argmxx_helper.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/dtype.h"
  13. #if MEGDNN_CC_HOST
  14. #include "megdnn/basic_types.h"
  15. #endif
  16. namespace megdnn {
  17. namespace argmxx {
  18. template <typename stype_, bool is_max>
  19. struct ArgmxxOp {
  20. struct wtype {
  21. stype_ key;
  22. dt_int32 val;
  23. MEGDNN_HOST MEGDNN_DEVICE wtype()
  24. {}
  25. MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val):
  26. key(key), val(val)
  27. {}
  28. MEGDNN_HOST MEGDNN_DEVICE wtype(wtype &rhs):
  29. key(rhs.key),
  30. val(rhs.val)
  31. {}
  32. MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype &rhs):
  33. key(rhs.key),
  34. val(rhs.val)
  35. {}
  36. MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype &rhs):
  37. key(rhs.key),
  38. val(rhs.val)
  39. {}
  40. MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype &rhs):
  41. key(rhs.key),
  42. val(rhs.val)
  43. {}
  44. MEGDNN_HOST MEGDNN_DEVICE volatile wtype &operator=(const wtype &rhs) volatile
  45. {
  46. this->key = rhs.key;
  47. this->val = rhs.val;
  48. return *this;
  49. }
  50. };
  51. MEGDNN_HOST MEGDNN_DEVICE
  52. ArgmxxOp(stype_ *src, dt_int32 *dst, uint32_t A, uint32_t B, uint32_t C):
  53. src(src), dst(dst), A(A), B(B), C(C),
  54. INIT(wtype(is_max ? DTypeTrait<stype_>::min() :
  55. DTypeTrait<stype_>::max(), -1))
  56. {
  57. }
  58. MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx)
  59. {
  60. wtype res;
  61. res.key = src[idx];
  62. res.val = idx / C % B;
  63. return res;
  64. }
  65. MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val)
  66. {
  67. dst[idx] = val.val;
  68. }
  69. static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs)
  70. {
  71. if (is_max) {
  72. if (lhs.key > rhs.key) return lhs; else return rhs;
  73. } else {
  74. if (lhs.key < rhs.key) return lhs; else return rhs;
  75. }
  76. }
  77. stype_ *src;
  78. dt_int32 *dst;
  79. uint32_t A, B, C;
  80. const wtype INIT;
  81. };
  82. } // namespace argmxx
  83. } // namespace megdnn
  84. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)