You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_remap.cpp 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. /**
  2. * \file dnn/src/common/tensor_remap.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs.h"
  12. #include "src/common/utils.h"
  13. namespace megdnn {
  14. void IndexingRemapBase::check_layout_fwd(const TensorLayout &src,
  15. const TensorLayout &map,
  16. const TensorLayout &dst)
  17. {
  18. megdnn_assert_non_overlapping_strong(src);
  19. megdnn_assert_contiguous(map);
  20. megdnn_assert_non_overlapping_strong(dst);
  21. auto errmsg = megdnn_layout_msg(src) + ", "
  22. + megdnn_layout_msg(map) + ", "
  23. + megdnn_layout_msg(dst);
  24. auto errmsg_c = errmsg.c_str();
  25. MEGDNN_MARK_USED_VAR(errmsg_c);
  26. megdnn_assert(map.ndim == dst.ndim + 1, "%s", errmsg_c);
  27. for (size_t i = 0_z; i < dst.ndim; ++i) {
  28. megdnn_assert(map.shape[i] == dst.shape[i], "%s", errmsg_c);
  29. }
  30. megdnn_assert(map.shape[dst.ndim] == src.ndim, "%s", errmsg_c);
  31. megdnn_assert(dst.dtype == src.dtype);
  32. megdnn_assert(src.dtype == dtype::Float32() || src.dtype == dtype::Int32(),
  33. "indexing remap only support float32/int32, got %s",
  34. src.dtype.name());
  35. megdnn_assert(map.dtype == dtype::Int32());
  36. }
  37. void IndexingRemapForward::deduce_layout(const TensorLayout &src,
  38. const TensorLayout &map,
  39. TensorLayout &dst)
  40. {
  41. dst = map;
  42. dst.dtype = src.dtype;
  43. --dst.ndim;
  44. dst.init_contiguous_stride();
  45. }
  46. void IndexingRemapForward::check_exec(const TensorLayout &src,
  47. const TensorLayout &map,
  48. const TensorLayout &dst,
  49. size_t workspace_in_bytes)
  50. {
  51. check_layout_fwd(src, map, dst);
  52. auto required_workspace_in_bytes = get_workspace_in_bytes(src, map, dst);
  53. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  54. }
  55. void IndexingRemapBackward::check_exec(const TensorLayout &diff,
  56. const TensorLayout &map,
  57. const TensorLayout &grad,
  58. size_t workspace_in_bytes)
  59. {
  60. check_layout_fwd(grad, map, diff);
  61. auto required_workspace_in_bytes = get_workspace_in_bytes(diff, map, grad);
  62. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  63. }
  64. } // namespace megdnn
  65. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台