#include "megdnn/oprs/general.h" #include "src/common/utils.h" using namespace megdnn; void ParamPackConcatSplitBase::check_exec( const TensorLayout& concated, const TensorLayout& offsets, const TensorLayout& parts) { megdnn_assert( offsets.dtype == dtype::Int32{}, "bad dtype: %s", offsets.dtype.name()); megdnn_assert( concated.ndim == 1 && offsets.ndim == 1 && parts.ndim == 1 && concated.stride[0] == 1 && offsets.stride[0] == 1 && parts.stride[0] == 1, "bad layout: concated=%s offsets=%s parts=%s", concated.to_string().c_str(), offsets.to_string().c_str(), parts.to_string().c_str()); } std::vector ParamPackConcatSplitBase::gen_offsets( const TensorShapeArray& shapes, size_t alignment, size_t dtype_size) { megdnn_assert( alignment && (alignment & (alignment - 1)) == 0, "alignment must be power of 2: %zu", alignment); if (alignment < dtype_size) alignment = dtype_size; megdnn_assert( alignment % dtype_size == 0, "alignment must be multiple of dtype size: %zu vs %zu", alignment, dtype_size); alignment /= dtype_size; auto get_aligned = [alignment](size_t v) { auto mod = v & (alignment - 1); return v + ((alignment - mod) & (alignment - 1)); }; std::vector offsets(shapes.size() << 1); size_t offset = 0; for (size_t i = 0; i < shapes.size(); i++) { offset = get_aligned(offset); offsets[i << 1] = offset; offset += shapes[i].total_nr_elems(); offsets[(i << 1) + 1] = offset; } return offsets; } // vim: syntax=cpp.doxygen