You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_iter.cpp 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. /**
  2. * \file dnn/src/common/tensor_iter.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/tensor_iter.h"
  12. #include "src/common/utils.h"
  13. using namespace megdnn;
  14. ////////////////////////// TypeRef ////////////////////
  15. TypeRef<dt_quint4>::TypeRef(dt_quint4* _ptr, size_t _offset) {
  16. ptr = reinterpret_cast<uint8_t*>(_ptr);
  17. offset = _offset;
  18. uint8_t cur = ptr[offset >> 1];
  19. val = convert<uint8_t, dt_quint4>(cur, dt_quint4(cur), offset & 0x1)
  20. .as_uint8();
  21. }
  22. void TypeRef<dt_quint4>::operator=(const uint8_t _) {
  23. uint8_t cur = ptr[offset >> 1];
  24. ptr[offset >> 1] =
  25. convert<dt_quint4, uint8_t>(dt_quint4(_), cur, offset & 0x1);
  26. }
  27. TypeRef<dt_qint4>::TypeRef(dt_qint4* _ptr, size_t _offset) {
  28. ptr = reinterpret_cast<int8_t*>(_ptr);
  29. offset = _offset;
  30. int8_t cur = ptr[offset >> 1];
  31. val = convert<int8_t, dt_qint4>(cur, dt_qint4(cur), offset & 0x1).as_int8();
  32. }
  33. void TypeRef<dt_qint4>::operator=(const int8_t _) {
  34. int8_t cur = ptr[offset >> 1];
  35. ptr[offset >> 1] =
  36. convert<dt_qint4, int8_t>(dt_qint4(_), cur, offset & 0x1);
  37. }
  38. ////////////////////// TensorIter /////////////////////
  39. template<typename ctype, bool valonly>
  40. typename TensorIter<ctype, valonly>::Iter
  41. TensorIter<ctype, valonly>::Iter::make(
  42. ctype *ptr, const TensorLayout &layout, size_t offset) {
  43. megdnn_assert(layout.ndim);
  44. Iter rst;
  45. rst.m_ptr = ptr;
  46. if (valonly)
  47. rst.m_layout = layout.collapse_contiguous();
  48. else
  49. rst.m_layout = layout;
  50. rst.m_logical_offset = offset;
  51. rst.m_tot_nr_elems = rst.m_layout.total_nr_elems();
  52. rst.m_offset = 0;
  53. megdnn_assert(offset <= rst.m_tot_nr_elems);
  54. for (int i = rst.m_layout.ndim - 1; i >= 0; i --) {
  55. auto shp = rst.m_layout.shape[i];
  56. auto stride = rst.m_layout.stride[i];
  57. if (!shp) {
  58. // empty iter for empty layout
  59. return {};
  60. }
  61. rst.m_axis_reset_stride[i] = stride * (shp - 1);
  62. rst.m_axis_offset[i] = offset % shp;
  63. rst.m_offset += rst.m_axis_offset[i] * stride;
  64. offset /= shp;
  65. }
  66. return rst;
  67. }
  68. template<typename ctype, bool valonly>
  69. void TensorIter<ctype, valonly>::Iter::on_access_idx_valonly_true() const {
  70. megdnn_throw("can not access idx of TensorIter if valonly is true");
  71. }
  72. namespace megdnn {
  73. #define cb(_dt) \
  74. template class TensorIter<DTypeTrait<dtype::_dt>::ctype, false>; \
  75. template class TensorIter<DTypeTrait<dtype::_dt>::ctype, true>;
  76. MEGDNN_FOREACH_DTYPE_NAME(cb)
  77. MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb)
  78. #undef cb
  79. }
  80. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台