/** * \file dnn/src/common/indexing_multi_axis_vec.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megdnn/oprs.h" #include "src/common/utils.h" using namespace megdnn; namespace { size_t get_index_size_for_workspace( const TensorShape &shp, const size_t *axes, size_t nr_axes) { size_t idx_axis = axes[0]; megdnn_assert(shp.ndim && nr_axes); for (size_t i = 1; i < nr_axes; ++ i) { megdnn_assert(axes[i] > axes[i - 1]); if (axes[i] != axes[i - 1] + 1) { idx_axis = 0; break; } } megdnn_assert(shp.ndim > idx_axis, "index on the %zuth axis; but shape is %s", idx_axis, shp.to_string().c_str()); return shp.shape[idx_axis]; } } // anonymous namespace IndexingMultiAxisVecBase::IndexDescLayoutOnly IndexingMultiAxisVecBase::extract_index_layout(const IndexDesc &index) { IndexDescLayoutOnly ret(index.size()); for (size_t i = 0; i < index.size(); ++ i) { ret[i].layout = index[i].vec.layout; ret[i].axis = index[i].axis; } return ret; } size_t IndexingMultiAxisVecBase::deduce_layout_fwd( const TensorLayout &data, const IndexDescLayoutOnly &index, TensorLayout &dst) { megdnn_assert(!index.empty()); megdnn_assert(data.ndim >= index.size()); dst.ndim = data.ndim - index.size() + 1; dst.shape[0] = 1; dst.dtype = data.dtype; auto brdcast = [&](const TensorLayout &ly) { if (ly.ndim != 1) return false; if (dst.shape[0] == ly.shape[0]) return true; if (dst.shape[0] == 1) { dst.shape[0] = ly.shape[0]; return true; } return ly.shape[0] == 1; }; size_t dst_axis = 1; ptrdiff_t prev_axis = -1; for (size_t axis = 0; axis < index.size(); ++ axis) { auto &&idx = index[axis]; megdnn_assert(idx.layout.dtype == dtype::Int32(), "invalid index dtype: %s", idx.layout.dtype.name()); megdnn_assert(idx.axis < data.ndim && static_cast(idx.axis) > prev_axis, "index %zu requests invalid axis %zu", axis, idx.axis); auto brd_succ = brdcast(idx.layout); megdnn_assert(brd_succ, "invalid layout at index %zu: %s", axis, idx.layout.to_string().c_str()); for (size_t i = prev_axis + 1; i < idx.axis; ++ i) { dst.shape[dst_axis ++] = data.shape[i]; } prev_axis = idx.axis; } for (size_t i = prev_axis + 1; i < data.ndim; ++ i) { dst.shape[dst_axis ++] = data.shape[i]; } megdnn_assert(dst_axis == dst.ndim); size_t idx_axis = 0; { // fix idx_axis if index contains consecutive axes bool contig_idx = true; for (size_t i = 1; i < index.size(); ++ i) { if (index[i].axis != index[i - 1].axis + 1) { contig_idx = false; break; } } if (contig_idx) { auto shp0 = dst.shape[0]; idx_axis = index[0].axis; for (size_t i = 0; i < idx_axis; ++ i) { dst.shape[i] = dst.shape[i + 1]; } dst.shape[idx_axis] = shp0; } } dst.init_contiguous_stride(); return idx_axis; } size_t IndexingMultiAxisVecBase::get_nonindex_axes( size_t src_ndim, const IndexDesc &index, size_t *out) { auto iter = index.begin(); size_t nr = 0; for (size_t i = 0; i < src_ndim; ++ i) { if (iter != index.end() && i == iter->axis) { ++ iter; } else { out[nr ++] = i; } } megdnn_assert(nr + index.size() == src_ndim && iter == index.end()); return nr; } IndexingMultiAxisVecBase::ExecInfo IndexingMultiAxisVecBase::check_exec_noworkspace( const TensorLayout &data, const TensorLayout &value, const IndexDesc &index, IndexDescLayoutOnly &index_layout) { ExecInfo ret; index_layout = extract_index_layout(index); TensorLayout value_expect; ret.idx_axis = deduce_layout_fwd(data, index_layout, value_expect); megdnn_assert_eq_shape(value_expect, value); auto value_contig = value.collapse_contiguous(); megdnn_assert(value_contig.ndim == 1, "value layout must be 1-dim contiguous; got %s", value.to_string().c_str()); ret.value_stride = value_contig.stride[0]; return ret; } std::pair IndexingMultiAxisVecBase::get_value_iter_optimized_layout( const TensorLayout &data, const TensorLayout &value, const IndexDesc &index, size_t idx_axis) { size_t data_axes[TensorLayout::MAX_NDIM], nr_axes = get_nonindex_axes(data.ndim, index, data_axes); megdnn_assert(nr_axes == value.ndim - 1 && idx_axis < value.ndim && nr_axes + index.size() == data.ndim); TensorLayout ret; if (idx_axis) { ret.ndim = idx_axis; for (size_t i = 0; i < idx_axis; ++ i) { ret.shape[i] = data.shape[data_axes[i]]; ret.stride[i] = data.stride[data_axes[i]]; } ret = ret.collapse_contiguous(); } ret.shape[ret.ndim] = value.shape[idx_axis]; ret.stride[ret.ndim] = 0; size_t ret_idx_axis = ret.ndim; ++ ret.ndim; if (idx_axis < nr_axes) { TensorLayout tail; tail.ndim = nr_axes - idx_axis; for (size_t i = idx_axis; i < nr_axes; ++ i) { tail.shape[i - idx_axis] = data.shape[data_axes[i]]; tail.stride[i - idx_axis] = data.stride[data_axes[i]]; } tail = tail.collapse_contiguous(); for (size_t i = 0; i < tail.ndim; ++ i) { ret.shape[ret.ndim] = tail.shape[i]; ret.stride[ret.ndim] = tail.stride[i]; ++ ret.ndim; } } return {ret, ret_idx_axis}; } size_t IndexingMultiAxisVec::get_workspace_in_bytes( const TensorShape &dst, const size_t *axes, size_t nr_axes) { return get_workspace_in_bytes( get_index_size_for_workspace(dst, axes, nr_axes)); } IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec( const TensorLayout &src, const IndexDesc &index, const TensorLayout &dst, size_t workspace_in_bytes) { IndexDescLayoutOnly index_layout; auto ret = check_exec_noworkspace(src, dst, index, index_layout); megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes( dst.shape[ret.idx_axis])); megdnn_assert(ret.value_stride, "dst must be non-overlapping"); return ret; } size_t IndexingModifyMultiAxisVecBase::get_workspace_in_bytes( const TensorShape &value, const size_t *axes, size_t nr_axes) { return get_workspace_in_bytes( get_index_size_for_workspace(value, axes, nr_axes)); } IndexingModifyMultiAxisVecBase::ExecInfo IndexingModifyMultiAxisVecBase::check_exec( const TensorLayout &data, const TensorLayout &value, const IndexDesc &index, size_t workspace_in_bytes) { megdnn_assert(data.is_non_overlapping_strong(), "data layout should not overlap: %s", data.to_string().c_str()); IndexDescLayoutOnly index_layout; auto ret = check_exec_noworkspace(data, value, index, index_layout); megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes( value.shape[ret.idx_axis])); return ret; } // vim: syntax=cpp.doxygen