#include "megdnn/oprs/general.h" #include "src/common/utils.h" #include using namespace megdnn; void TopK::deduce_layout( int k, const TensorLayout& data, TensorLayout& values, TensorLayout& indices) { megdnn_assert( k && data.ndim == 2 && data.stride[1] == 1, "invalid k=%d data=%s", k, data.to_string().c_str()); values.dtype = data.dtype; indices.dtype = dtype::Int32{}; switch (param().mode) { case Param::Mode::KTH_ONLY: values.init_contiguous_stride({data[0]}); indices.ndim = 0; break; case Param::Mode::VALUE_IDX_NOSORT: case Param::Mode::VALUE_IDX_SORTED: values.init_contiguous_stride( {data[0], std::min(std::abs(k), data.shape[1])}); indices.init_contiguous_stride(values); break; default: megdnn_throw("invalid TopK mode"); } } void TopK::exec( int k, _megdnn_tensor_in data, _megdnn_tensor_out values, _megdnn_tensor_out indices, _megdnn_workspace workspace) { TensorLayout oval, oidx; deduce_layout(k, data.layout, oval, oidx); megdnn_assert_eq_layout(oval, values.layout); int32_t* iptr = nullptr; if (param().mode == Param::Mode::KTH_ONLY) { megdnn_assert_eq_shape(indices.layout, TensorShape{}); } else { iptr = indices.ptr(); megdnn_assert_eq_layout(oidx, indices.layout); } megdnn_assert( workspace.size >= get_workspace_in_bytes(k, data.layout, values.layout, indices.layout)); if (static_cast(std::abs(k)) > data.layout[1]) { if (k > 0) { k = data.layout[1]; } else { k = -static_cast(data.layout[1]); } } do_exec(k, data, values, iptr, workspace); } // vim: syntax=cpp.doxygen