You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

svd.cpp 3.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. /**
  2. * \file dnn/src/common/svd.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/linalg.h"
  12. #include "src/common/utils.h"
  13. using namespace megdnn;
  14. void SVD::deduce_layout(const TensorLayout& src, TensorLayout& u,
  15. TensorLayout& s, TensorLayout& vt) {
  16. Param p = param();
  17. size_t m, n;
  18. canonize_params(src, nullptr, &m, &n);
  19. SmallVector<size_t> shape_prefix;
  20. for (size_t i = 0; i < src.ndim - 2; i++) {
  21. shape_prefix.push_back(src[i]);
  22. }
  23. SmallVector<size_t> shape_s(shape_prefix), shape_u, shape_vt;
  24. shape_s.push_back(std::min(m, n));
  25. if (p.compute_uv) {
  26. shape_u = shape_prefix;
  27. shape_vt = shape_prefix;
  28. size_t ucols = m;
  29. size_t vrows = n;
  30. if (!p.full_matrices) {
  31. ucols = vrows = std::min(m, n);
  32. }
  33. // let P = min(M, N)
  34. // M x M or M x P
  35. shape_u.push_back(m);
  36. shape_u.push_back(ucols);
  37. // N x N or P x N
  38. shape_vt.push_back(vrows);
  39. shape_vt.push_back(n);
  40. } else {
  41. shape_u = {0};
  42. shape_vt = {0};
  43. }
  44. s = {shape_s, src.dtype};
  45. u = {shape_u, src.dtype};
  46. vt = {shape_vt, src.dtype};
  47. }
  48. size_t SVD::get_workspace_in_bytes(const TensorLayout& src,
  49. const TensorLayout& u, const TensorLayout& s,
  50. const TensorLayout& vt) {
  51. MEGDNN_MARK_USED_VAR(u);
  52. MEGDNN_MARK_USED_VAR(s);
  53. MEGDNN_MARK_USED_VAR(vt);
  54. size_t block_cnt, m, n;
  55. canonize_params(src, &block_cnt, &m, &n);
  56. return get_workspace_in_bytes(block_cnt, m, n, src.dtype.size());
  57. }
  58. void SVD::canonize_params(const TensorLayout& layout, size_t* block_cnt,
  59. size_t* m, size_t* n) {
  60. megdnn_assert(layout.is_contiguous() && layout.ndim >= 2,
  61. "invalid SVD layout: %s", layout.to_string().c_str());
  62. megdnn_assert(layout.dtype == dtype::Float32(), "SVD only supports f32");
  63. if (block_cnt) {
  64. *block_cnt = 1;
  65. for (size_t i = 0; i < layout.ndim - 2; ++i) {
  66. *block_cnt *= layout[i];
  67. }
  68. }
  69. if (n) {
  70. *n = layout[layout.ndim - 1];
  71. }
  72. if (m) {
  73. *m = layout[layout.ndim - 2];
  74. }
  75. }
  76. void SVD::check_exec(const TensorLayout& src, const TensorLayout& u,
  77. const TensorLayout& s, const TensorLayout& vt,
  78. size_t workspace_in_bytes) {
  79. size_t m, n;
  80. canonize_params(src, nullptr, &m, &n);
  81. // get_workspace_in_bytes runs the canonize_params, thus runs the check
  82. auto required_workspace_in_bytes = get_workspace_in_bytes(src, u, s, vt);
  83. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  84. }
  85. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台