|
|
|
@@ -35,6 +35,9 @@ TensorFormat TensorFormat::deserialize(const std::string& bin, |
|
|
|
case Type::IMAGE2D_PACK4: |
|
|
|
return Image2DPack4TensorFormat::deserialize( |
|
|
|
handle, type + 1, bin.size() - sizeof(Type)); |
|
|
|
case Type::FOURBITS_ALIGNED_TO_BYTE: |
|
|
|
return FourBitsAlignedToBytesTensorFormat::deserialize( |
|
|
|
handle, type + 1, bin.size() - sizeof(Type)); |
|
|
|
default: |
|
|
|
megdnn_throw("invalid tensor format type in deserialize"); |
|
|
|
} |
|
|
|
@@ -67,7 +70,15 @@ bool TensorFormat::is_default() const { |
|
|
|
} |
|
|
|
|
|
|
|
/* ===================== DefaultFormat ===================== */ |
|
|
|
void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const { |
|
|
|
megdnn_assert( |
|
|
|
!layout.dtype.valid() || !layout.dtype.is_low_bit(), |
|
|
|
"DefaultTensorFormat does not support low-bits tensor(dtype:%s)", |
|
|
|
layout.dtype.name()); |
|
|
|
} |
|
|
|
|
|
|
|
size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const { |
|
|
|
assert_valid(layout); |
|
|
|
if (!layout.ndim) |
|
|
|
return 0; |
|
|
|
megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM); |
|
|
|
@@ -81,11 +92,13 @@ size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const { |
|
|
|
} |
|
|
|
|
|
|
|
bool DefaultTensorFormat::is_contiguous_spec(const TensorLayout& layout) const { |
|
|
|
assert_valid(layout); |
|
|
|
return layout.is_physical_contiguous(); |
|
|
|
} |
|
|
|
|
|
|
|
TensorLayout DefaultTensorFormat::collapse_contiguous_spec( |
|
|
|
const TensorLayout& layout) const { |
|
|
|
assert_valid(layout); |
|
|
|
megdnn_assert(layout.ndim); |
|
|
|
TensorLayout res{layout}; |
|
|
|
|
|
|
|
@@ -126,6 +139,7 @@ TensorLayout DefaultTensorFormat::collapse_contiguous_spec( |
|
|
|
|
|
|
|
TensorLayout::Span DefaultTensorFormat::span_spec( |
|
|
|
const TensorLayout& layout) const { |
|
|
|
assert_valid(layout); |
|
|
|
if (layout.ndim == 0) |
|
|
|
return {0, 0, 0, 0}; |
|
|
|
|
|
|
|
@@ -146,9 +160,6 @@ TensorLayout::Span DefaultTensorFormat::span_spec( |
|
|
|
++high_elem; |
|
|
|
ptrdiff_t low_byte; |
|
|
|
if (low_elem < 0) { |
|
|
|
megdnn_assert(!layout.dtype.is_low_bit(), |
|
|
|
"tensors with low-bit dytes shouldn't have negative " |
|
|
|
"strides"); |
|
|
|
low_byte = low_elem * layout.dtype.size(); |
|
|
|
} else { |
|
|
|
low_byte = 0; |
|
|
|
@@ -422,12 +433,151 @@ TensorLayout Image2DPackedTensorFormatBase<PIXEL_SIZE>::collapse_contiguous_spec |
|
|
|
return res; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
namespace megdnn { |
|
|
|
namespace detail { |
|
|
|
template class Image2DPackedTensorFormatBase<4>; |
|
|
|
} // namespace detail |
|
|
|
} // namespace megdnn |
|
|
|
|
|
|
|
/* =============== FourBitsAlignedToBytesTensorFormatBase ============== */ |
|
|
|
template <size_t SIZE_NBITS> |
|
|
|
LowbitsTensorFormatBase<SIZE_NBITS>::LowbitsTensorFormatBase( |
|
|
|
Type type, size_t align_size_in_bits) |
|
|
|
: ImplBase(type), m_align_size_in_bits(align_size_in_bits) { |
|
|
|
megdnn_assert(!(m_align_size_in_bits % SIZE_NBITS), |
|
|
|
"align size(%zu) must be a multiple of element size(%zu)", |
|
|
|
m_align_size_in_bits, SIZE_NBITS); |
|
|
|
m_align_size_in_elements = m_align_size_in_bits / SIZE_NBITS; |
|
|
|
} |
|
|
|
|
|
|
|
template <size_t SIZE_NBITS> |
|
|
|
std::string LowbitsTensorFormatBase<SIZE_NBITS>::to_string() const { |
|
|
|
return ssprintf("LOWBITS{%zu,%zu}", SIZE_NBITS, m_align_size_in_bits); |
|
|
|
} |
|
|
|
|
|
|
|
template <size_t SIZE_NBITS> |
|
|
|
void LowbitsTensorFormatBase<SIZE_NBITS>::assert_valid( |
|
|
|
const TensorLayout& layout) const { |
|
|
|
megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() && |
|
|
|
layout.dtype.low_bit() == SIZE_NBITS); |
|
|
|
bool has_dim_unity_stride = false; |
|
|
|
for (int i = layout.ndim - 1; i >= 0; --i) { |
|
|
|
if (!has_dim_unity_stride && layout.stride[i] == 1) |
|
|
|
has_dim_unity_stride = true; |
|
|
|
megdnn_assert( |
|
|
|
layout.stride[i] >= 0 && |
|
|
|
(layout.stride[i] % m_align_size_in_elements == 0 || |
|
|
|
layout.stride[i] == 1), |
|
|
|
"bad stride: %zu", layout.stride[i]); |
|
|
|
} |
|
|
|
megdnn_assert(has_dim_unity_stride, "innermost dim not contiguous"); |
|
|
|
} |
|
|
|
|
|
|
|
template <size_t SIZE_NBITS> |
|
|
|
void LowbitsTensorFormatBase<SIZE_NBITS>::serialize_append( |
|
|
|
std::string& result) const { |
|
|
|
SerializePack pack; |
|
|
|
pack.align_size_in_bits = m_align_size_in_bits; |
|
|
|
megdnn_assert(pack.align_size_in_bits == |
|
|
|
m_align_size_in_bits); // detect overflow; |
|
|
|
result.append(reinterpret_cast<char*>(&pack), sizeof(pack)); |
|
|
|
} |
|
|
|
|
|
|
|
template <size_t SIZE_NBITS> |
|
|
|
TensorLayout::Span LowbitsTensorFormatBase<SIZE_NBITS>::span_spec( |
|
|
|
const TensorLayout& layout) const { |
|
|
|
assert_valid(layout); |
|
|
|
if (layout.ndim == 0) |
|
|
|
return {0, 0, 0, 0}; |
|
|
|
|
|
|
|
size_t high_elem = 0; |
|
|
|
for (size_t i = 0; i < layout.ndim; ++i) { |
|
|
|
auto shape_val = layout.shape[i]; |
|
|
|
if (!shape_val) { |
|
|
|
return {0, 0, 0, 0}; |
|
|
|
} |
|
|
|
auto stride_val = layout.stride[i]; |
|
|
|
megdnn_assert(stride_val >= 0, |
|
|
|
"lowbit tensors shouldn't have negative strides"); |
|
|
|
high_elem += (shape_val - 1) * stride_val; |
|
|
|
} |
|
|
|
++high_elem; |
|
|
|
size_t high_byte = layout.dtype.size(high_elem); |
|
|
|
return TensorLayout::Span(0, 0, high_elem, high_byte); |
|
|
|
} |
|
|
|
|
|
|
|
template <size_t SIZE_NBITS> |
|
|
|
size_t LowbitsTensorFormatBase<SIZE_NBITS>::init_contiguous_stride( |
|
|
|
TensorLayout& layout) const { |
|
|
|
if (!layout.ndim) |
|
|
|
return 0; |
|
|
|
megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM); |
|
|
|
size_t accum = 1; |
|
|
|
SafeMultiplies<size_t> mul; |
|
|
|
for (size_t i = layout.ndim; i; --i) { |
|
|
|
layout.stride[i - 1] = accum; |
|
|
|
auto multiplier = layout.shape[i - 1]; |
|
|
|
if (i == layout.ndim) |
|
|
|
multiplier = round_up(multiplier, m_align_size_in_elements); |
|
|
|
accum = mul(accum, multiplier); |
|
|
|
} |
|
|
|
return accum; |
|
|
|
} |
|
|
|
|
|
|
|
template <size_t SIZE_NBITS> |
|
|
|
bool LowbitsTensorFormatBase<SIZE_NBITS>::is_contiguous_spec( |
|
|
|
const TensorLayout& layout) const { |
|
|
|
assert_valid(layout); |
|
|
|
ptrdiff_t expected = 1; |
|
|
|
for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) { |
|
|
|
if (layout.shape[i] != 1 && layout.stride[i] != expected) |
|
|
|
return false; |
|
|
|
auto multiplier = layout.shape[i]; |
|
|
|
if (i == layout.ndim - 1) |
|
|
|
multiplier = round_up(multiplier, m_align_size_in_elements); |
|
|
|
expected *= multiplier; |
|
|
|
} |
|
|
|
return expected != 0; |
|
|
|
} |
|
|
|
|
|
|
|
template <size_t SIZE_NBITS> |
|
|
|
TensorLayout LowbitsTensorFormatBase<SIZE_NBITS>::collapse_contiguous_spec( |
|
|
|
const TensorLayout& layout) const { |
|
|
|
assert_valid(layout); |
|
|
|
TensorLayout res{layout}; |
|
|
|
for (int i = static_cast<int>(res.ndim) - 1; i >= 0; --i) { |
|
|
|
if (!res.shape[i]) { |
|
|
|
// empty tensor |
|
|
|
res.ndim = 1; |
|
|
|
res.shape[0] = 0; |
|
|
|
res.stride[0] = 1; |
|
|
|
return res; |
|
|
|
} |
|
|
|
if (res.shape[i] == 1) { |
|
|
|
res.remove_axis_inplace(i); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
megdnn_assert(res.ndim && res.shape[res.ndim - 1]); |
|
|
|
for (int i = static_cast<int>(res.ndim) - 2; i >= 0; --i) { |
|
|
|
megdnn_assert(res.shape[i]); |
|
|
|
if (res.stride[i] == |
|
|
|
res.stride[i + 1] * static_cast<ptrdiff_t>(res.shape[i + 1])) { |
|
|
|
res.shape[i] *= res.shape[i + 1]; |
|
|
|
res.stride[i] = res.stride[i + 1]; |
|
|
|
res.remove_axis_inplace(i + 1); |
|
|
|
} |
|
|
|
} |
|
|
|
return res; |
|
|
|
} |
|
|
|
|
|
|
|
namespace megdnn { |
|
|
|
namespace detail { |
|
|
|
template class LowbitsTensorFormatBase<4>; |
|
|
|
} // namespace detail |
|
|
|
} // namespace megdnn |
|
|
|
|
|
|
|
/* ===================== Image2DPack4TensorFormat ===================== */ |
|
|
|
TensorFormat Image2DPack4TensorFormat::make_raw( |
|
|
|
size_t align_axis, size_t align_size_in_elements, |
|
|
|
@@ -466,4 +616,29 @@ TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const { |
|
|
|
return make_raw(axis, align_size_in_elements(), vendor()); |
|
|
|
} |
|
|
|
|
|
|
|
/* ===================== FourBitsAlignedToBytesTensorFormat |
|
|
|
* ===================== */ |
|
|
|
TensorFormat FourBitsAlignedToBytesTensorFormat::make( |
|
|
|
size_t align_size_in_bits) { |
|
|
|
static std::mutex mtx; |
|
|
|
static std::unordered_map< |
|
|
|
uint32_t, std::unique_ptr<FourBitsAlignedToBytesTensorFormat>> |
|
|
|
cache; |
|
|
|
megdnn_assert(!(align_size_in_bits % 4)); |
|
|
|
MEGDNN_LOCK_GUARD(mtx); |
|
|
|
auto&& ptr = cache[static_cast<uint32_t>(align_size_in_bits)]; |
|
|
|
if (!ptr) { |
|
|
|
ptr.reset(new FourBitsAlignedToBytesTensorFormat{align_size_in_bits}); |
|
|
|
} |
|
|
|
return impl_to_tensor_format(ptr.get()); |
|
|
|
} |
|
|
|
|
|
|
|
TensorFormat FourBitsAlignedToBytesTensorFormat::deserialize(const Handle*, |
|
|
|
const void* buf, |
|
|
|
size_t size) { |
|
|
|
megdnn_assert(size == sizeof(SerializePack)); |
|
|
|
auto pack = *static_cast<const SerializePack*>(buf); |
|
|
|
return make(pack.align_size_in_bits); |
|
|
|
} |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen |