/** * \file dnn/src/common/svd.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megdnn/oprs/linalg.h" #include "src/common/utils.h" using namespace megdnn; void SVD::deduce_layout(const TensorLayout& src, TensorLayout& u, TensorLayout& s, TensorLayout& vt) { Param p = param(); size_t m, n; canonize_params(src, nullptr, &m, &n); SmallVector shape_prefix; for (size_t i = 0; i < src.ndim - 2; i++) { shape_prefix.push_back(src[i]); } SmallVector shape_s(shape_prefix), shape_u, shape_vt; shape_s.push_back(std::min(m, n)); if (p.compute_uv) { shape_u = shape_prefix; shape_vt = shape_prefix; size_t ucols = m; size_t vrows = n; if (!p.full_matrices) { ucols = vrows = std::min(m, n); } // let P = min(M, N) // M x M or M x P shape_u.push_back(m); shape_u.push_back(ucols); // N x N or P x N shape_vt.push_back(vrows); shape_vt.push_back(n); } else { shape_u = {0}; shape_vt = {0}; } s = {shape_s, src.dtype}; u = {shape_u, src.dtype}; vt = {shape_vt, src.dtype}; } size_t SVD::get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, const TensorLayout& vt) { MEGDNN_MARK_USED_VAR(u); MEGDNN_MARK_USED_VAR(s); MEGDNN_MARK_USED_VAR(vt); size_t block_cnt, m, n; canonize_params(src, &block_cnt, &m, &n); return get_workspace_in_bytes(block_cnt, m, n, src.dtype.size()); } void SVD::canonize_params(const TensorLayout& layout, size_t* block_cnt, size_t* m, size_t* n) { megdnn_assert(layout.is_contiguous() && layout.ndim >= 2, "invalid SVD layout: %s", layout.to_string().c_str()); megdnn_assert(layout.dtype == dtype::Float32(), "SVD only supports f32"); if (block_cnt) { *block_cnt = 1; for (size_t i = 0; i < layout.ndim - 2; ++i) { *block_cnt *= layout[i]; } } if (n) { *n = layout[layout.ndim - 1]; } if (m) { *m = layout[layout.ndim - 2]; } } void SVD::check_exec(const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, const TensorLayout& vt, size_t workspace_in_bytes) { size_t m, n; canonize_params(src, nullptr, &m, &n); // get_workspace_in_bytes runs the canonize_params, thus runs the check auto required_workspace_in_bytes = get_workspace_in_bytes(src, u, s, vt); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } // vim: syntax=cpp.doxygen