/** * \file dnn/src/common/mesh_indexing.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megdnn/oprs/general.h" #include "src/common/utils.h" namespace megdnn { /* ============================== MeshIndexing ============================= */ void MeshBase::check_exec(const TensorLayout& origin, const TensorLayout& indexed, const IndexDesc& desc) { megdnn_assert(origin.dtype == indexed.dtype); megdnn_assert(origin.ndim == indexed.ndim); for (auto&& index : desc) { megdnn_assert(index.vec.layout.dtype == dtype::Int32()); } } void NormalMeshBase::check_exec(const TensorLayout& src, const TensorLayout& dst, const IndexDesc& desc) { MeshBase::check_exec(src, dst, desc); for (auto&& index : desc) { size_t ndim = index.vec.layout.ndim; megdnn_assert(ndim == 1, "index must be 1-dim vector, while dim %zu", ndim); megdnn_assert(dst.shape[index.axis] == index.vec.layout[0]); } } void BatchedMeshBase::check_exec(const TensorLayout& src, const TensorLayout& dst, const IndexDesc& desc) { MeshBase::check_exec(src, dst, desc); megdnn_assert(src[0] == dst[0], "batch mismatch, src %zu, dst %zu", src[0], dst[0]); for (auto&& index : desc) { size_t ndim = index.vec.layout.ndim; megdnn_assert(ndim == 2, "index must be a 2-dim matrix, while ndim %zu", ndim); megdnn_assert(dst[0] == index.vec.layout[0] && dst[index.axis] == index.vec.layout[1], "require each index shape equals (%zu, %zu), but got " "(%zu, %zu)", dst[0], dst[index.axis], index.vec.layout[0], index.vec.layout[1]); megdnn_assert(index.axis != 0, "index axis should be 0-th dim when executing " "BatchedMeshIndexing"); } } void MeshIndexing::deduce_layout(const TensorLayout& inp, const IndexDescLayoutOnly& layouts, TensorLayout& out_layout) { out_layout = inp; for (auto&& index : layouts) { megdnn_assert(index.layout.ndim == 1, "mesh indexing require index being 1-dim vector"); out_layout[index.axis] = index.layout[0]; } out_layout.init_contiguous_stride(); } void BatchedMeshIndexing::deduce_layout(const TensorLayout& inp, const IndexDescLayoutOnly& layouts, TensorLayout& out_layout) { out_layout = inp; for (auto&& index : layouts) { megdnn_assert(index.layout.ndim == 2, "batch mesh indexing require index being 2-dim matrix"); out_layout[index.axis] = index.layout[1]; } out_layout.init_contiguous_stride(); } } // namespace megdnn