You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

split.cpp 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. /**
  2. * \file dnn/src/naive/split/split.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/naive/split/opr_impl.h"
  12. #include "src/common/utils.h"
  13. #include "src/naive/handle.h"
  14. #include <numeric>
  15. namespace megdnn {
  16. namespace naive {
  17. template <typename T>
  18. void SplitForwardImpl::exec_internal(_megdnn_tensor_in src,
  19. const TensorNDArray &dsts,
  20. _megdnn_workspace workspace)
  21. {
  22. size_t A, B, C;
  23. size_t *Bv = reinterpret_cast<size_t *>(workspace.raw_ptr);
  24. auto dsts_layout = apply_vector<TensorLayout>(m_get_layout, dsts);
  25. check_exec(src.layout, dsts_layout, workspace.size);
  26. auto dsts_shape = apply_vector<TensorShape>(m_get_shape, dsts_layout);
  27. get_ABC(dsts_shape, A, Bv, C);
  28. B = std::accumulate(Bv, Bv + dsts.size(), 0u);
  29. auto sptr = src.ptr<T>();
  30. rep(a, A) {
  31. // dst b index
  32. size_t dbi = 0u;
  33. // dst b offset
  34. size_t dbo = 0u;
  35. rep(sb, B) {
  36. auto dptr = dsts[dbi].ptr<T>();
  37. rep(c, C) {
  38. auto sidx = a*B*C + sb*C + c;
  39. auto didx = a*Bv[dbi]*C + dbo*C + c;
  40. dptr[didx] = sptr[sidx];
  41. }
  42. ++dbo;
  43. if (dbo >= Bv[dbi]) {
  44. dbo = 0u;
  45. ++dbi;
  46. }
  47. }
  48. }
  49. }
  50. void SplitForwardImpl::exec(_megdnn_tensor_in src,
  51. const TensorNDArray &dsts,
  52. _megdnn_workspace workspace)
  53. {
  54. #define cb(DType) \
  55. if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
  56. using ctype = typename DTypeTrait<DType>::ctype; \
  57. MEGDNN_DISPATCH_CPU_KERN_OPR( \
  58. exec_internal<ctype>(src, dsts, workspace)); \
  59. }
  60. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  61. MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
  62. #undef cb
  63. }
  64. } // namespace naive
  65. } // namespace megdnn
  66. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台