GitOrigin-RevId: e11f3e5408
tags/v1.5.0
| @@ -1001,7 +1001,9 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||
| 'NCHW_NCHW4_WEIGHT', | |||
| 'NCHW_NCHW64', | |||
| 'NCHW64_NCHW', | |||
| ) | |||
| 'NCHW_NHWC', | |||
| 'NHWC_NCHW', | |||
| ) | |||
| ) | |||
| (pdef('RelayoutFormat', 'Change the tensor layout format', version=1). | |||
| @@ -268,6 +268,22 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, | |||
| dst[2] = src[2]; | |||
| dst[3] = src[3]; | |||
| break; | |||
| case Param::Mode::NCHW_NHWC: | |||
| megdnn_assert(src.ndim == 4); | |||
| dst.ndim = 4; | |||
| dst[0] = src[0]; | |||
| dst[1] = src[2]; | |||
| dst[2] = src[3]; | |||
| dst[3] = src[1]; | |||
| break; | |||
| case Param::Mode::NHWC_NCHW: | |||
| megdnn_assert(src.ndim == 4); | |||
| dst.ndim = 4; | |||
| dst[0] = src[0]; | |||
| dst[1] = src[3]; | |||
| dst[2] = src[1]; | |||
| dst[3] = src[2]; | |||
| break; | |||
| default: | |||
| megdnn_assert(0, "Invalid RelayoutFormat Mode"); | |||
| break; | |||
| @@ -375,6 +391,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||
| case Param::Mode::NCHW64_NCHW: | |||
| dst = src; | |||
| break; | |||
| case Param::Mode::NCHW_NHWC: | |||
| case Param::Mode::NHWC_NCHW: | |||
| dst = src; | |||
| break; | |||
| default: | |||
| megdnn_throw("Invalid relayout format mode"); | |||
| break; | |||
| @@ -666,6 +686,14 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, | |||
| exec_src = src.dimshuffle({0, 1, 4, 2, 3}); | |||
| exec_dst = dst; | |||
| break; | |||
| case Param::Mode::NCHW_NHWC: | |||
| exec_src = src.dimshuffle({0, 2, 3, 1}); | |||
| exec_dst = dst; | |||
| break; | |||
| case Param::Mode::NHWC_NCHW: | |||
| exec_src = src.dimshuffle({0, 3, 1, 2}); | |||
| exec_dst = dst; | |||
| break; | |||
| default: | |||
| megdnn_assert(0, "Invalid RelayoutFormat Mode"); | |||
| } | |||
| @@ -505,7 +505,7 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, | |||
| void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4( | |||
| const int8_t* d_src, int8_t* d_dst, const Param& param, | |||
| cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { | |||
| cudaStream_t stream, uint32_t mode, bool /* uint_case */, int zero_point) { | |||
| using Mode = megdnn::param_enumv::Pooling::Mode; | |||
| void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, | |||
| int zero_point); | |||
| @@ -545,7 +545,7 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4( | |||
| void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32( | |||
| const int8_t* d_src, int8_t* d_dst, const Param& param, | |||
| cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { | |||
| cudaStream_t stream, uint32_t mode, bool /* uint_case */, int zero_point) { | |||
| using Mode = megdnn::param_enumv::Pooling::Mode; | |||
| void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, | |||
| int zero_point); | |||
| @@ -33,7 +33,9 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| Param::Mode:: | |||
| NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT || | |||
| param().mode == Param::Mode::NCHW_NCHW64 || | |||
| param().mode == Param::Mode::NCHW64_NCHW, | |||
| param().mode == Param::Mode::NCHW64_NCHW || | |||
| param().mode == Param::Mode::NCHW_NHWC || | |||
| param().mode == Param::Mode::NHWC_NCHW, | |||
| "relayout format of cuda only support NCHW4->CHWN4 or " | |||
| "CHWN4->NCHW4 or NCHW->NCHW4"); | |||
| if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || | |||
| @@ -82,7 +84,9 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| {src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); | |||
| } | |||
| bool is_trans_4bits = (param().mode == Param::Mode::NCHW_NCHW64 || | |||
| param().mode == Param::Mode::NCHW64_NCHW) && | |||
| param().mode == Param::Mode::NCHW64_NCHW || | |||
| param().mode == Param::Mode::NCHW_NHWC || | |||
| param().mode == Param::Mode::NHWC_NCHW) && | |||
| (src_dtype.enumv() == DTypeEnum::QuantizedS4 || | |||
| src_dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||
| bool is_nchw_nchw4 = param().mode == Param::Mode::NCHW_NCHW4 || | |||
| @@ -66,6 +66,22 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src, | |||
| return relayout_format_cuda_nchwx_nchw(src, dst, stream, src_scale, | |||
| dst_scale, src_zero_point, | |||
| dst_zero_point); | |||
| } else if (mode == RelayoutFormat::Param::Mode::NCHW_NHWC) { | |||
| #define CHECK(dt) \ | |||
| megdnn_assert(dt.enumv() == DTypeEnum::Quantized4Asymm || \ | |||
| dt.enumv() == DTypeEnum::QuantizedS4) | |||
| CHECK(src.layout.dtype); | |||
| CHECK(dst.layout.dtype); | |||
| return relayout_format_cuda_nchw_nhwc(src, dst, stream, src_scale, | |||
| dst_scale, src_zero_point, | |||
| dst_zero_point); | |||
| } else if (mode == RelayoutFormat::Param::Mode::NHWC_NCHW) { | |||
| CHECK(src.layout.dtype); | |||
| CHECK(dst.layout.dtype); | |||
| return relayout_format_cuda_nhwc_nchw(src, dst, stream, src_scale, | |||
| dst_scale, src_zero_point, | |||
| dst_zero_point); | |||
| #undef CHECK | |||
| } else if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT) { | |||
| return relayout_format_cuda_nchw_nchw4_weight(src, dst, stream); | |||
| } else if (mode == RelayoutFormat::Param::Mode::NCHW4_NCHW) { | |||
| @@ -20,8 +20,17 @@ namespace relayout_format { | |||
| namespace internal { | |||
| using namespace memory; | |||
| struct LayoutType { | |||
| static constexpr uint32_t NCHWx = 0; | |||
| static constexpr uint32_t NHWC = 1; | |||
| }; | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_> | |||
| int size_nbits_, uint32_t layout_type_ = LayoutType::NCHWx> | |||
| class TensorIteratorOverChannel; | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_, uint32_t layout_type_> | |||
| class TensorIteratorOverChannel { | |||
| public: | |||
| using Type = Type_; | |||
| @@ -116,6 +125,98 @@ private: | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_> | |||
| class TensorIteratorOverChannel<Type_, pack_size_, chan_blk_, width_, | |||
| size_nbits_, LayoutType::NHWC> { | |||
| public: | |||
| using Type = Type_; | |||
| static constexpr int pack_size = pack_size_; | |||
| static constexpr int chan_blk = chan_blk_; | |||
| static constexpr int width = width_; | |||
| static constexpr int size_nbits = size_nbits_; | |||
| static constexpr int elements_in_type = | |||
| chan_blk * width * size_nbits / (8 * sizeof(Type)); | |||
| static constexpr int pack_size_in_type = | |||
| pack_size * size_nbits / (8 * sizeof(Type)); | |||
| static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); | |||
| using AccessType = array_wrapper<Type, pack_size_in_type>; | |||
| using Fragment = array_wrapper<Type, elements_in_type>; | |||
| MEGDNN_HOST TensorIteratorOverChannel() | |||
| : pointer{nullptr}, hw_stride_in_elements{0}, channel{0} {} | |||
| MEGDNN_HOST TensorIteratorOverChannel(Type* pointer_, | |||
| int hw_stride_in_elements_, | |||
| int channel_, int, int) | |||
| : pointer{pointer_}, | |||
| hw_stride_in_elements{hw_stride_in_elements_}, | |||
| channel{channel_} {} | |||
| MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | |||
| pointer += c_idx * size_nbits / (8 * sizeof(Type)) + | |||
| hw_idx * hw_stride_in_elements; | |||
| channel -= c_idx; | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void add_pointer_offset( | |||
| size_t offset_in_type) { | |||
| pointer += offset_in_type; | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||
| AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | |||
| Type* pointer_ = pointer; | |||
| #pragma unroll | |||
| for (int i = 0; i < width; ++i) { | |||
| #pragma unroll | |||
| for (int j = 0; j < chan_blk; j += pack_size) { | |||
| int frag_idx = i * (chan_blk / pack_size) + (j / pack_size); | |||
| bool guard = j < channel; | |||
| global_load<AccessType, pack_size_in_byte>( | |||
| frag_ptr[frag_idx], | |||
| reinterpret_cast<void*>( | |||
| pointer_ + j * size_nbits / (8 * sizeof(Type))), | |||
| guard, zero_point); | |||
| } | |||
| pointer_ += hw_stride_in_elements; | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { | |||
| const AccessType* frag_ptr = reinterpret_cast<const AccessType*>(&frag); | |||
| Type* pointer_ = pointer; | |||
| #pragma unroll | |||
| for (int i = 0; i < width; ++i) { | |||
| #pragma unroll | |||
| for (int j = 0; j < chan_blk; j += pack_size) { | |||
| int frag_idx = i * (chan_blk / pack_size) + (j / pack_size); | |||
| bool guard = j < channel; | |||
| global_store<AccessType, pack_size_in_byte>( | |||
| frag_ptr[frag_idx], | |||
| reinterpret_cast<void*>( | |||
| pointer_ + j * size_nbits / (8 * sizeof(Type))), | |||
| guard); | |||
| } | |||
| pointer_ += hw_stride_in_elements; | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void advance() { | |||
| pointer += chan_blk * size_nbits / (8 * sizeof(Type)); | |||
| channel -= chan_blk; | |||
| } | |||
| private: | |||
| Type* pointer; | |||
| int hw_stride_in_elements; | |||
| int channel; | |||
| }; | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_, uint32_t layout_type_ = LayoutType::NCHWx> | |||
| class MaskedTensorIteratorOverChannel; | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_, uint32_t layout_type_> | |||
| class MaskedTensorIteratorOverChannel { | |||
| public: | |||
| using Type = Type_; | |||
| @@ -243,24 +344,143 @@ private: | |||
| size_t stride[lane_size_in_type / pack_size_in_type]; | |||
| }; | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_> | |||
| class MaskedTensorIteratorOverChannel<Type_, pack_size_, chan_blk_, width_, | |||
| size_nbits_, LayoutType::NHWC> { | |||
| public: | |||
| using Type = Type_; | |||
| static constexpr int pack_size = pack_size_; | |||
| static constexpr int chan_blk = chan_blk_; | |||
| static constexpr int width = width_; | |||
| static constexpr int size_nbits = size_nbits_; | |||
| static constexpr int elements_in_type = | |||
| chan_blk * width * size_nbits / (8 * sizeof(Type)); | |||
| static constexpr int lane_size_in_type = | |||
| (width * pack_size * size_nbits) / (8 * sizeof(Type)); | |||
| static constexpr int pack_size_in_type = | |||
| pack_size * size_nbits / (8 * sizeof(Type)); | |||
| static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); | |||
| static constexpr int accesses = elements_in_type / pack_size_in_type; | |||
| static constexpr int mask_size = (accesses + 32 - 1) / 32; | |||
| using AccessType = array_wrapper<Type, pack_size_in_type>; | |||
| using Fragment = array_wrapper<Type, elements_in_type>; | |||
| MEGDNN_HOST MaskedTensorIteratorOverChannel() | |||
| : pointer{nullptr}, hw_stride_in_elements{0}, channel{0} {} | |||
| MEGDNN_HOST MaskedTensorIteratorOverChannel(Type* pointer_, | |||
| int hw_stride_in_elements_, | |||
| int channel_, int bound_, | |||
| int div_) | |||
| : pointer{pointer_}, | |||
| hw_stride_in_elements{hw_stride_in_elements_}, | |||
| channel{channel_}, | |||
| bound{bound_}, | |||
| div{uint32_t(div_)} {} | |||
| MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | |||
| pointer += c_idx * size_nbits / (8 * sizeof(Type)); | |||
| channel -= c_idx; | |||
| #pragma unroll | |||
| for (int i = 0; i < mask_size; ++i) { | |||
| mask[i] = 0; | |||
| } | |||
| #pragma unroll | |||
| for (int i = 0; i < width; ++i) { | |||
| int offset = hw_idx + i; | |||
| int h = (int)((uint32_t)(offset) / div); | |||
| int w = (int)((uint32_t)(offset) % div); | |||
| stride[i] = (h * bound + w) * hw_stride_in_elements; | |||
| #pragma unroll | |||
| for (int j = 0; j < chan_blk; j += pack_size) { | |||
| bool guard = (j < channel) && (w < bound); | |||
| int index = i * (chan_blk / pack_size) + (j / pack_size); | |||
| int mask_index = (index >> 5); | |||
| int mask_shift = (index & 0x1f); | |||
| mask[mask_index] |= (guard << mask_shift); | |||
| } | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void add_pointer_offset( | |||
| size_t offset_in_type) { | |||
| pointer += offset_in_type; | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||
| AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | |||
| #pragma unroll | |||
| for (int i = 0; i < width; ++i) { | |||
| Type* pointer_ = pointer + stride[i]; | |||
| #pragma unroll | |||
| for (int j = 0; j < chan_blk; j+= pack_size) { | |||
| int frag_idx = i * (chan_blk / pack_size) + (j / pack_size); | |||
| int mask_index = (frag_idx >> 5); | |||
| int mask_shift = (frag_idx & 0x1f); | |||
| bool guard = (mask[mask_index] & (1 << mask_shift)); | |||
| global_load<AccessType, pack_size_in_byte>( | |||
| frag_ptr[frag_idx], | |||
| reinterpret_cast<void*>( | |||
| pointer_ + j * size_nbits / (8 * sizeof(Type))), | |||
| guard, zero_point); | |||
| } | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { | |||
| const AccessType* frag_ptr = reinterpret_cast<const AccessType*>(&frag); | |||
| #pragma unroll | |||
| for (int i = 0; i < width; ++i) { | |||
| Type* pointer_ = pointer + stride[i]; | |||
| #pragma unroll | |||
| for (int j = 0; j < chan_blk; j+= pack_size) { | |||
| int frag_idx = i * (chan_blk / pack_size) + (j / pack_size); | |||
| int mask_index = (frag_idx >> 5); | |||
| int mask_shift = (frag_idx & 0x1f); | |||
| bool guard = (mask[mask_index] & (1 << mask_shift)); | |||
| global_store<AccessType, pack_size_in_byte>( | |||
| frag_ptr[frag_idx], | |||
| reinterpret_cast<void*>( | |||
| pointer_ + j * size_nbits / (8 * sizeof(Type))), | |||
| guard); | |||
| } | |||
| } | |||
| } | |||
| MEGDNN_DEVICE __forceinline__ void advance() { | |||
| pointer += chan_blk * size_nbits / (8 * sizeof(Type)); | |||
| channel -= chan_blk; | |||
| } | |||
| private: | |||
| Type* pointer; | |||
| int hw_stride_in_elements; | |||
| int channel; | |||
| int bound; | |||
| Uint32Fastdiv div; | |||
| uint32_t mask[mask_size]; | |||
| size_t stride[width]; | |||
| }; | |||
| template <bool padding_, typename Type_, int pack_size_, int chan_blk_, | |||
| int width_, int size_nbits_> | |||
| int width_, int size_nbits_, | |||
| uint32_t layout_type_ = LayoutType::NCHWx> | |||
| struct TensorIteratorPolicy; | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_> | |||
| int size_nbits_, uint32_t layout_type_> | |||
| struct TensorIteratorPolicy<true, Type_, pack_size_, chan_blk_, width_, | |||
| size_nbits_> { | |||
| size_nbits_, layout_type_> { | |||
| using TensorIterator = | |||
| MaskedTensorIteratorOverChannel<Type_, pack_size_, chan_blk_, | |||
| width_, size_nbits_>; | |||
| width_, size_nbits_, layout_type_>; | |||
| }; | |||
| template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
| int size_nbits_> | |||
| int size_nbits_, uint32_t layout_type_> | |||
| struct TensorIteratorPolicy<false, Type_, pack_size_, chan_blk_, width_, | |||
| size_nbits_> { | |||
| size_nbits_, layout_type_> { | |||
| using TensorIterator = | |||
| TensorIteratorOverChannel<Type_, pack_size_, chan_blk_, width_, | |||
| size_nbits_>; | |||
| size_nbits_, layout_type_>; | |||
| }; | |||
| template <typename SrcIterator_, typename DstIterator_, typename Transpose_, | |||
| @@ -0,0 +1,211 @@ | |||
| /** | |||
| * \file dnn/src/cuda/relayout_format/relayout_format_nchw_nhwc.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/query_blocksize.cuh" | |||
| #include "src/cuda/relayout_format/relayout_format_kern.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace relayout_format; | |||
| using namespace internal; | |||
| namespace { | |||
| template <int pack_w> | |||
| struct rwtype_helper; | |||
| template <> | |||
| struct rwtype_helper<2> { | |||
| using InnerDtype = char; | |||
| }; | |||
| template <> | |||
| struct rwtype_helper<8> { | |||
| using InnerDtype = unsigned; | |||
| }; | |||
| } // namespace | |||
| void relayout_format::relayout_format_cuda_nchw_nhwc( | |||
| const TensorND& src, const TensorND& dst, const cudaStream_t& stream, | |||
| const float src_scale, const float dst_scale, | |||
| const uint8_t src_zero_point, const uint8_t dst_zero_point) { | |||
| auto&& stype = src.layout.dtype; | |||
| auto&& dtype = dst.layout.dtype; | |||
| auto& src_layout = src.layout; | |||
| auto& dst_layout = dst.layout; | |||
| int n = src.layout[0]; | |||
| int ic = src.layout[1]; | |||
| int h = src.layout[2]; | |||
| int w = src.layout[3]; | |||
| int w_pad = DIVUP(w, 2) * 2; | |||
| int hw = h * w_pad; | |||
| int n_stride_src = src_layout.stride[0]; | |||
| int ic_stride = src_layout.stride[1]; | |||
| int n_stride_dst = dst_layout.stride[0]; | |||
| int hw_stride = dst_layout.stride[2]; | |||
| static constexpr int chan_blk = 8; | |||
| static constexpr int pack_oc = 8; | |||
| int problem_size = n * DIVUP(ic, chan_blk) * hw; | |||
| int oc = dst.layout[3]; | |||
| bool same_scale = src_scale == dst_scale; | |||
| bool padding = w % 2 != 0; | |||
| #define DISPATCH_RAW(_padding, _same_scale, _pack_w, _src_type, _dst_type, \ | |||
| _src_c_type, _dst_c_type, _size_nbits) \ | |||
| if (padding == _padding && same_scale == _same_scale && \ | |||
| hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \ | |||
| dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ | |||
| using InnerDtype_ = typename rwtype_helper<_pack_w>::InnerDtype; \ | |||
| using SrcIterator_ = \ | |||
| TensorIteratorOverChannel<InnerDtype_, 1, chan_blk, _pack_w, \ | |||
| _size_nbits>; \ | |||
| using DstIterator_ = typename TensorIteratorPolicy< \ | |||
| _padding, _dst_c_type, pack_oc, chan_blk, _pack_w, \ | |||
| _size_nbits, LayoutType::NHWC>::TensorIterator; \ | |||
| using CudaPostProcess_ = \ | |||
| CudaPostProcess<dtype::_src_type, dtype::_dst_type, \ | |||
| _same_scale>; \ | |||
| using Transpose_ = \ | |||
| Translayout<_pack_w, chan_blk, InnerDtype_, dtype::_src_type, \ | |||
| dtype::_dst_type, _same_scale>; \ | |||
| using RelayoutProblem_ = \ | |||
| RelayoutProblem<SrcIterator_, DstIterator_, Transpose_, \ | |||
| CudaPostProcess_>; \ | |||
| n_stride_src = n_stride_src * _size_nbits / (8 * sizeof(InnerDtype_)); \ | |||
| ic_stride = ic_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ | |||
| n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \ | |||
| hw_stride = hw_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \ | |||
| typename RelayoutProblem_::Param param{ \ | |||
| SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, ic, w, \ | |||
| w_pad}, \ | |||
| DstIterator_{(_dst_c_type*)dst.raw_ptr, hw_stride, oc, w, \ | |||
| w_pad}, \ | |||
| CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ | |||
| dst_zero_point}, \ | |||
| n_stride_src, \ | |||
| n_stride_dst, \ | |||
| n, \ | |||
| ic, \ | |||
| hw, \ | |||
| src_zero_point}; \ | |||
| auto kernel = relayout_kern<RelayoutProblem_>; \ | |||
| int nr_threads = query_blocksize_for_kernel(kernel); \ | |||
| nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ | |||
| const dim3 block_dim(DIVUP(problem_size, nr_threads* _pack_w)); \ | |||
| const dim3 thread_dim(nr_threads); \ | |||
| return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \ | |||
| } | |||
| #define DISPATCH_4BITS(_src_type, _dst_type) \ | |||
| DISPATCH_RAW(true, true, 8, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(true, false, 8, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(true, true, 2, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(true, false, 2, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(false, true, 8, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(false, false, 8, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(false, true, 2, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(false, false, 2, _src_type, _dst_type, char, char, 4); | |||
| DISPATCH_4BITS(QuantizedS4, QuantizedS4); | |||
| DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); | |||
| #undef DISPATCH_4BITS | |||
| #undef DISPATCH_RAW | |||
| megdnn_assert(false, | |||
| "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", | |||
| stype.name(), dtype.name(), h, w); | |||
| } | |||
| void relayout_format::relayout_format_cuda_nhwc_nchw( | |||
| const TensorND& src, const TensorND& dst, const cudaStream_t& stream, | |||
| const float src_scale, const float dst_scale, | |||
| const uint8_t src_zero_point, const uint8_t dst_zero_point) { | |||
| auto&& stype = src.layout.dtype; | |||
| auto&& dtype = dst.layout.dtype; | |||
| auto& src_layout = src.layout; | |||
| auto& dst_layout = dst.layout; | |||
| int n = src.layout[0]; | |||
| int h = src.layout[1]; | |||
| int w = src.layout[2]; | |||
| int ic = src.layout[3]; | |||
| int w_pad = DIVUP(w, 2) * 2; | |||
| int hw = h * w_pad; | |||
| int n_stride_src = src_layout.stride[0]; | |||
| int hw_stride = src_layout.stride[2]; | |||
| int n_stride_dst = dst_layout.stride[0]; | |||
| int oc_stride = dst_layout.stride[1]; | |||
| static constexpr int chan_blk = 8; | |||
| static constexpr int pack_oc = 8; | |||
| int problem_size = n * DIVUP(ic, chan_blk) * hw; | |||
| int oc = dst.layout[1]; | |||
| bool same_scale = src_scale == dst_scale; | |||
| bool padding = w % 2 != 0; | |||
| #define DISPATCH_RAW(_padding, _same_scale, _pack_w, _src_type, _dst_type, \ | |||
| _src_c_type, _dst_c_type, _size_nbits) \ | |||
| if (padding == _padding && same_scale == _same_scale && \ | |||
| hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \ | |||
| dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ | |||
| using SrcIterator_ = typename TensorIteratorPolicy< \ | |||
| _padding, _src_c_type, pack_oc, chan_blk, _pack_w, \ | |||
| _size_nbits, LayoutType::NHWC>::TensorIterator; \ | |||
| using InnerDtype_ = typename rwtype_helper<_pack_w>::InnerDtype; \ | |||
| using DstIterator_ = \ | |||
| TensorIteratorOverChannel<InnerDtype_, 1, chan_blk, _pack_w, \ | |||
| _size_nbits>; \ | |||
| using CudaPostProcess_ = \ | |||
| CudaPostProcess<dtype::_src_type, dtype::_dst_type, \ | |||
| _same_scale>; \ | |||
| using Transpose_ = \ | |||
| Translayout<chan_blk, _pack_w, _src_c_type, dtype::_src_type, \ | |||
| dtype::_dst_type, _same_scale>; \ | |||
| using RelayoutProblem_ = \ | |||
| RelayoutProblem<SrcIterator_, DstIterator_, Transpose_, \ | |||
| CudaPostProcess_>; \ | |||
| n_stride_src = n_stride_src * _size_nbits / (8 * sizeof(_src_c_type)); \ | |||
| hw_stride = hw_stride * _size_nbits / (8 * sizeof(_src_c_type)); \ | |||
| n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \ | |||
| oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ | |||
| typename RelayoutProblem_::Param param{ \ | |||
| SrcIterator_{(_src_c_type*)src.raw_ptr, hw_stride, ic, w, \ | |||
| w_pad}, \ | |||
| DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, oc, w, \ | |||
| w_pad}, \ | |||
| CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ | |||
| dst_zero_point}, \ | |||
| n_stride_src, \ | |||
| n_stride_dst, \ | |||
| n, \ | |||
| ic, \ | |||
| hw, \ | |||
| src_zero_point}; \ | |||
| auto kernel = relayout_kern<RelayoutProblem_>; \ | |||
| int nr_threads = query_blocksize_for_kernel(kernel); \ | |||
| nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ | |||
| const dim3 block_dim(DIVUP(problem_size, nr_threads* _pack_w)); \ | |||
| const dim3 thread_dim(nr_threads); \ | |||
| return kernel<<<block_dim, thread_dim, 0, stream>>>(param); \ | |||
| } | |||
| #define DISPATCH_4BITS(_src_type, _dst_type) \ | |||
| DISPATCH_RAW(true, true, 8, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(true, false, 8, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(true, true, 2, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(true, false, 2, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(false, true, 8, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(false, false, 8, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(false, true, 2, _src_type, _dst_type, char, char, 4); \ | |||
| DISPATCH_RAW(false, false, 2, _src_type, _dst_type, char, char, 4); | |||
| DISPATCH_4BITS(QuantizedS4, QuantizedS4); | |||
| DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); | |||
| #undef DISPATCH_4BITS | |||
| #undef DISPATCH_RAW | |||
| megdnn_assert(false, | |||
| "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", | |||
| stype.name(), dtype.name(), h, w); | |||
| } | |||
| @@ -42,8 +42,9 @@ struct enable_qtype_b4 { | |||
| static constexpr bool val_dst = | |||
| std::is_same<dt_dst, dtype::QuantizedS4>::value || | |||
| std::is_same<dt_dst, dtype::Quantized4Asymm>::value; | |||
| using type = typename std::enable_if<std::is_same<dt_src, dt_dst>::value && | |||
| val_src && val_dst>::type; | |||
| static constexpr bool value = | |||
| std::is_same<dt_src, dt_dst>::value && val_src && val_dst; | |||
| using type = typename std::enable_if<value>::type; | |||
| }; | |||
| // The input fragment is stored in RowMajor order. The translayout operator | |||
| @@ -393,26 +394,32 @@ struct Translayout<2, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
| using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
| static inline __device__ void trans( | |||
| Fragment& dst, const Fragment& src, | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
| const char zero_point) { | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||
| int intermediate[8][2]; | |||
| transform_b4x2_to_int8<signedness>(intermediate[0], | |||
| reinterpret_cast<uint8_t&>(src[0])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[1], | |||
| reinterpret_cast<uint8_t&>(src[1])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[2], | |||
| reinterpret_cast<uint8_t&>(src[2])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[3], | |||
| reinterpret_cast<uint8_t&>(src[3])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[4], | |||
| reinterpret_cast<uint8_t&>(src[4])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[5], | |||
| reinterpret_cast<uint8_t&>(src[5])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[6], | |||
| reinterpret_cast<uint8_t&>(src[6])); | |||
| transform_b4x2_to_int8<signedness>(intermediate[7], | |||
| reinterpret_cast<uint8_t&>(src[7])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[0], | |||
| reinterpret_cast<const uint8_t&>(src[0 * col_in_type])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[1], | |||
| reinterpret_cast<const uint8_t&>(src[1 * col_in_type])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[2], | |||
| reinterpret_cast<const uint8_t&>(src[2 * col_in_type])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[3], | |||
| reinterpret_cast<const uint8_t&>(src[3 * col_in_type])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[4], | |||
| reinterpret_cast<const uint8_t&>(src[4 * col_in_type])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[5], | |||
| reinterpret_cast<const uint8_t&>(src[5 * col_in_type])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[6], | |||
| reinterpret_cast<const uint8_t&>(src[6 * col_in_type])); | |||
| transform_b4x2_to_int8<signedness>( | |||
| intermediate[7], | |||
| reinterpret_cast<const uint8_t&>(src[7 * col_in_type])); | |||
| int* dst_frag = reinterpret_cast<int*>(&dst); | |||
| auto pack = [&](int idx) -> int { | |||
| return transform_int8_to_b4x8<signedness>( | |||
| @@ -445,25 +452,24 @@ struct Translayout<8, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
| using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
| static inline __device__ void trans( | |||
| Fragment& dst, const Fragment& src, | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
| const char zero_point) { | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||
| int intermediate[8][8]; | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[0], reinterpret_cast<const int&>(src[0])); | |||
| intermediate[0], reinterpret_cast<const int&>(src[0 * col_in_type])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[1], reinterpret_cast<const int&>(src[1])); | |||
| intermediate[1], reinterpret_cast<const int&>(src[1 * col_in_type])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[2], reinterpret_cast<const int&>(src[2])); | |||
| intermediate[2], reinterpret_cast<const int&>(src[2 * col_in_type])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[3], reinterpret_cast<const int&>(src[3])); | |||
| intermediate[3], reinterpret_cast<const int&>(src[3 * col_in_type])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[4], reinterpret_cast<const int&>(src[4])); | |||
| intermediate[4], reinterpret_cast<const int&>(src[4 * col_in_type])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[5], reinterpret_cast<const int&>(src[5])); | |||
| intermediate[5], reinterpret_cast<const int&>(src[5 * col_in_type])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[6], reinterpret_cast<const int&>(src[6])); | |||
| intermediate[6], reinterpret_cast<const int&>(src[6 * col_in_type])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[7], reinterpret_cast<const int&>(src[7])); | |||
| intermediate[7], reinterpret_cast<const int&>(src[7 * col_in_type])); | |||
| int* dst_frag = reinterpret_cast<int*>(&dst); | |||
| auto pack = [&](int idx) { | |||
| return transform_int8_to_b4x8<signedness>( | |||
| @@ -502,13 +508,12 @@ struct Translayout<8, 2, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
| using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
| static inline __device__ void trans( | |||
| Fragment& dst, const Fragment& src, | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
| const char zero_point) { | |||
| CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||
| int intermediate[2][8]; | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[0], reinterpret_cast<const int&>(src[0])); | |||
| intermediate[0], reinterpret_cast<const int&>(src[0 * col_in_type])); | |||
| transform_b4x8_to_int8<signedness>( | |||
| intermediate[1], reinterpret_cast<const int&>(src[1])); | |||
| intermediate[1], reinterpret_cast<const int&>(src[1 * col_in_type])); | |||
| int* dst_frag = reinterpret_cast<int*>(&dst); | |||
| dst_frag[0] = transform_int8_to_b4x8<signedness>( | |||
| post_process(intermediate[0][0]), | |||
| @@ -508,7 +508,7 @@ struct KernCoreNHWC<ctype, OutputConverter, 8> { | |||
| "assert qu4 or q4"); | |||
| constexpr bool signedness = std::is_same<ctype, dt_qint4>::value; | |||
| int8_t bval_4 = bval.as_storage() & 0xF; | |||
| const int bval_int = transform_int8_to_bit4x8<signedness>( | |||
| const int bval_int = transform_int8_to_b4x8<signedness>( | |||
| bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); | |||
| int src_ori[4]; | |||
| src_ori[0] = src0_ok ? *(int*)(src_ptr0 + offset) : bval_int; | |||
| @@ -516,10 +516,10 @@ struct KernCoreNHWC<ctype, OutputConverter, 8> { | |||
| src_ori[2] = src2_ok ? *(int*)(src_ptr2 + offset) : bval_int; | |||
| src_ori[3] = src3_ok ? *(int*)(src_ptr3 + offset) : bval_int; | |||
| int src[4][8]; | |||
| transform_bit4x8_to_int8<signedness>(src[0], src_ori[0]); | |||
| transform_bit4x8_to_int8<signedness>(src[1], src_ori[1]); | |||
| transform_bit4x8_to_int8<signedness>(src[2], src_ori[2]); | |||
| transform_bit4x8_to_int8<signedness>(src[3], src_ori[3]); | |||
| transform_b4x8_to_int8<signedness>(src[0], src_ori[0]); | |||
| transform_b4x8_to_int8<signedness>(src[1], src_ori[1]); | |||
| transform_b4x8_to_int8<signedness>(src[2], src_ori[2]); | |||
| transform_b4x8_to_int8<signedness>(src[3], src_ori[3]); | |||
| int res = pack_output_func<signedness>(output_converter, src[0], src[1], | |||
| src[2], src[3], w00, w01, w10, | |||
| w11); | |||
| @@ -542,7 +542,7 @@ struct KernCoreNHWC<ctype, OutputConverter, 16> { | |||
| "assert qu4 or q4"); | |||
| constexpr bool signedness = std::is_same<ctype, dt_qint4>::value; | |||
| int8_t bval_4 = bval.as_storage() & 0xF; | |||
| const int bval_int_temp = transform_int8_to_bit4x8<signedness>( | |||
| const int bval_int_temp = transform_int8_to_b4x8<signedness>( | |||
| bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); | |||
| const int2 bval_int{bval_int_temp, bval_int_temp}; | |||
| @@ -552,15 +552,15 @@ struct KernCoreNHWC<ctype, OutputConverter, 16> { | |||
| src_ori[2] = src2_ok ? *(int2*)(src_ptr2 + offset) : bval_int; | |||
| src_ori[3] = src3_ok ? *(int2*)(src_ptr3 + offset) : bval_int; | |||
| int src[8][8]; | |||
| transform_bit4x8_to_int8<signedness>(src[0], src_ori[0].x); | |||
| transform_bit4x8_to_int8<signedness>(src[1], src_ori[1].x); | |||
| transform_bit4x8_to_int8<signedness>(src[2], src_ori[2].x); | |||
| transform_bit4x8_to_int8<signedness>(src[3], src_ori[3].x); | |||
| transform_bit4x8_to_int8<signedness>(src[4], src_ori[0].y); | |||
| transform_bit4x8_to_int8<signedness>(src[5], src_ori[1].y); | |||
| transform_bit4x8_to_int8<signedness>(src[6], src_ori[2].y); | |||
| transform_bit4x8_to_int8<signedness>(src[7], src_ori[3].y); | |||
| transform_b4x8_to_int8<signedness>(src[0], src_ori[0].x); | |||
| transform_b4x8_to_int8<signedness>(src[1], src_ori[1].x); | |||
| transform_b4x8_to_int8<signedness>(src[2], src_ori[2].x); | |||
| transform_b4x8_to_int8<signedness>(src[3], src_ori[3].x); | |||
| transform_b4x8_to_int8<signedness>(src[4], src_ori[0].y); | |||
| transform_b4x8_to_int8<signedness>(src[5], src_ori[1].y); | |||
| transform_b4x8_to_int8<signedness>(src[6], src_ori[2].y); | |||
| transform_b4x8_to_int8<signedness>(src[7], src_ori[3].y); | |||
| int2 res; | |||
| res.x = pack_output_func<signedness>(output_converter, src[0], src[1], | |||
| @@ -325,6 +325,91 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW64_NCHW) { | |||
| } | |||
| } | |||
| TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NHWC) { | |||
| Checker<RelayoutFormat> checker(handle_cuda()); | |||
| UniformIntRNG s4{-8, 7}; | |||
| UniformIntRNG u4{0, 15}; | |||
| param::RelayoutFormat param; | |||
| param.mode = param::RelayoutFormat::Mode::NCHW_NHWC; | |||
| for (size_t n : {1, 3}) { | |||
| for (size_t c : {8, 16}) { | |||
| for (size_t h : {7, 14, 16, 28}) { | |||
| for (size_t w : {2, 3, 7, 8, 16, 31}) { | |||
| checker.set_dtype(0, dtype::QuantizedS4{2.f}) | |||
| .set_dtype(1, dtype::QuantizedS4{2.f}) | |||
| .set_rng(0, &s4) | |||
| .set_param(param) | |||
| .execs({{n, c, h, w}, {}}); | |||
| checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 8}) | |||
| .set_dtype(1, dtype::Quantized4Asymm{1.2f, 4}) | |||
| .set_rng(0, &u4) | |||
| .set_param(param) | |||
| .execs({{n, c, h, w}, {}}); | |||
| checker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) | |||
| .set_dtype(1, dtype::QuantizedS4{1.f}) | |||
| .set_rng(0, &s4) | |||
| .set_param(param) | |||
| .execs({{n, c, h, w}, {}}); | |||
| checker.set_dtype(0, dtype::Quantized4Asymm{1.19990307f, 8}) | |||
| .set_dtype(1, dtype::Quantized4Asymm{1.f, 4}) | |||
| .set_rng(0, &u4) | |||
| .set_param(param) | |||
| .set_epsilon(1e-3) | |||
| .execs({{n, c, h, w}, {}}); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| checker.execs({{1, 256, 384, 640}, {}}); | |||
| } | |||
| TEST_F(CUDA, RELAYOUT_FORMAT_NHWC_NCHW) { | |||
| Checker<RelayoutFormat> checker(handle_cuda()); | |||
| UniformIntRNG s4{-8, 7}; | |||
| UniformIntRNG u4{0, 15}; | |||
| param::RelayoutFormat param; | |||
| param.mode = param::RelayoutFormat::Mode::NHWC_NCHW; | |||
| for (size_t n : {1, 3}) { | |||
| for (size_t c : {8, 16}) { | |||
| for (size_t h : {7, 14, 16, 28}) { | |||
| for (size_t w : {2, 3, 4, 7, 14, 16, 17}) { | |||
| checker.set_dtype(0, dtype::QuantizedS4{2.f}) | |||
| .set_dtype(1, dtype::QuantizedS4{2.f}) | |||
| .set_rng(0, &s4) | |||
| .set_param(param) | |||
| .set_epsilon(1e-3) | |||
| .execs({{n, h, w, c}, {}}); | |||
| checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 4}) | |||
| .set_dtype(1, dtype::Quantized4Asymm{1.2f, 8}) | |||
| .set_rng(0, &u4) | |||
| .set_param(param) | |||
| .set_epsilon(1e-3) | |||
| .execs({{n, h, w, c}, {}}); | |||
| checker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) | |||
| .set_dtype(1, dtype::QuantizedS4{1.f}) | |||
| .set_rng(0, &s4) | |||
| .set_param(param) | |||
| .set_epsilon(1e-3) | |||
| .execs({{n, h, w, c}, {}}); | |||
| checker.set_dtype(0, dtype::Quantized4Asymm{1.20211209f, 8}) | |||
| .set_dtype(1, dtype::Quantized4Asymm{1.f, 4}) | |||
| .set_rng(0, &u4) | |||
| .set_param(param) | |||
| .set_epsilon(1e-3) | |||
| .execs({{n, h, w, c}, {}}); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| checker.execs({{1, 384, 640, 256}, {}}); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT) { | |||
| using Param = RelayoutFormat::Param; | |||
| @@ -393,6 +478,7 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { | |||
| } | |||
| }; | |||
| printf("nchw -> nchw64\n"); | |||
| { | |||
| TensorShapeArray shapes = { | |||
| {1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56}, | |||
| @@ -403,6 +489,18 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { | |||
| param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; | |||
| run(shapes, param); | |||
| } | |||
| printf("nchw -> nhwc\n"); | |||
| { | |||
| TensorShapeArray shapes = { | |||
| {1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56}, | |||
| {1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55}, | |||
| {1, 256, 384, 640}, {16, 16, 384, 640}, | |||
| }; | |||
| Param param; | |||
| param.mode = param::RelayoutFormat::Mode::NCHW_NHWC; | |||
| run(shapes, param); | |||
| } | |||
| printf("nchw64 -> nchw\n"); | |||
| { | |||
| TensorShapeArray shapes = { | |||
| {64, 1, 56, 56, 64}, | |||
| @@ -415,6 +513,19 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { | |||
| param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; | |||
| run(shapes, param); | |||
| } | |||
| printf("nhwc -> nchw\n"); | |||
| { | |||
| TensorShapeArray shapes = { | |||
| {64, 56, 56, 64}, | |||
| {1, 7, 7, 64*32}, | |||
| {16, 7, 7, 64*32}, | |||
| {64, 7, 7, 64*32}, | |||
| {1, 384, 640, 64*4}, | |||
| }; | |||
| Param param; | |||
| param.mode = param::RelayoutFormat::Mode::NHWC_NCHW; | |||
| run(shapes, param); | |||
| } | |||
| } | |||
| #endif | |||