From 3bf73ff16f27d11b1a5ff07762b6cff5d93b8a84 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 1 Dec 2020 17:26:17 +0800 Subject: [PATCH] feat(dnn): add cuda preprocess fusion GitOrigin-RevId: d789c99e59ce713a075061aacf6acdba78af43d3 --- dnn/include/megdnn/dtype.h | 4 + dnn/scripts/opr_param_defs.py | 1 + dnn/src/common/relayout.cpp | 7 +- dnn/src/common/relayout_format.cpp | 35 +- dnn/src/cuda/relayout_format/opr_impl.cpp | 32 +- .../cuda/relayout_format/relayout_format.cpp | 59 +++ .../cuda/relayout_format/relayout_format.cu | 371 +++++++++++++++ .../cuda/relayout_format/relayout_format.cuh | 37 ++ .../cuda/relayout_format/relayout_format.h | 34 ++ dnn/src/naive/relayout_format/opr_impl.cpp | 130 ++++- dnn/test/cuda/relayout_format.cpp | 165 ++++++- sdk/load-and-run/src/mgblar.cpp | 10 +- src/core/include/megbrain/graph/cg.h | 3 + src/gopt/impl/framework.cpp | 3 + src/gopt/impl/fuse_nchw4_int8_preprocess.cpp | 446 ++++++++++++++++++ src/gopt/include/megbrain/gopt/inference.h | 20 + src/gopt/test/inference.cpp | 106 ++++- src/opr/impl/basic_arith.cpp | 2 +- src/opr/impl/tensor_manip.cpp | 17 + test/src/helper.cpp | 18 + test/src/include/megbrain/test/helper.h | 31 ++ 21 files changed, 1499 insertions(+), 32 deletions(-) create mode 100644 dnn/src/cuda/relayout_format/relayout_format.cpp create mode 100644 dnn/src/cuda/relayout_format/relayout_format.cu create mode 100644 dnn/src/cuda/relayout_format/relayout_format.cuh create mode 100644 dnn/src/cuda/relayout_format/relayout_format.h create mode 100644 src/gopt/impl/fuse_nchw4_int8_preprocess.cpp diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index 9bd117a4..c3d022a1 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -201,6 +201,8 @@ class dt_quint8 { #endif bool operator<(const dt_quint8& b) const { return _ < b._; } bool operator>(const dt_quint8& b) const { return _ > b._; } + bool operator==(const dt_quint8& b) const { return _ == b._; } + bool operator!=(const dt_quint8& b) const { return _ != b._; } } MEGDNN_PACKED; class dt_qint32 { @@ -255,6 +257,8 @@ class dt_qint8 { #endif bool operator<(const dt_qint8& b) const { return _ < b._; } bool operator>(const dt_qint8& b) const { return _ > b._; } + bool operator==(const dt_qint8& b) const { return _ == b._; } + bool operator!=(const dt_qint8& b) const { return _ != b._; } } MEGDNN_PACKED; class dt_qint16 { diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 5a667ca2..c8fbeb3a 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -877,6 +877,7 @@ when the ``I`` suffix is present. 'NCHW88_NCHW', 'NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT', + 'NCHW_NCHW4', ) ) diff --git a/dnn/src/common/relayout.cpp b/dnn/src/common/relayout.cpp index d29857a9..2e2d8df7 100644 --- a/dnn/src/common/relayout.cpp +++ b/dnn/src/common/relayout.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "megdnn/oprs.h" @@ -94,7 +95,9 @@ void RelayoutForward::check_layout_and_canonize(TensorLayout& src, src = src.collapse_contiguous(); dst = dst.collapse_contiguous(); megdnn_assert(src.dtype == dst.dtype && - src.total_nr_elems() == dst.total_nr_elems()); + src.total_nr_elems() == dst.total_nr_elems(), + "check %s == %s and %zu == %zu", src.dtype.name(), + dst.dtype.name(), src.total_nr_elems(), dst.total_nr_elems()); } bool relayout::is_transpose(const TensorLayout& src, const TensorLayout& dst, diff --git a/dnn/src/common/relayout_format.cpp b/dnn/src/common/relayout_format.cpp index af82449d..e30ba58f 100644 --- a/dnn/src/common/relayout_format.cpp +++ b/dnn/src/common/relayout_format.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "megdnn/oprs.h" @@ -207,6 +208,15 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, dst[3] = src[2]; dst[4] = src[4]; break; + case Param::Mode::NCHW_NCHW4: + megdnn_assert(src.ndim == 4); + dst.ndim = 5; + dst[0] = src[0]; + dst[1] = div_ceil(src[1], 4); + dst[2] = src[2]; + dst[3] = src[3]; + dst[4] = 4; + break; default: megdnn_assert(0, "Invalid RelayoutFormat Mode"); break; @@ -214,7 +224,9 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorFormat dst_fmt; deduce_format(src.format, dst_fmt); dst.format = dst_fmt; - dst.dtype = src.dtype; + if (!dst.dtype.valid()) { + dst.dtype = src.dtype; + } dst.init_contiguous_stride(); } @@ -245,6 +257,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { CHECK_SRC(DefaultTensorFormat::make()); dst = src; break; + case Param::Mode::NCHW_NCHW4: + CHECK_SRC(DefaultTensorFormat::make()); + dst = src; + break; case Param::Mode::NCHW_NHWCD4I: CHECK_SRC(DefaultTensorFormat::make()); dst = Image2DPack4TensorFormat::make_raw(2, align); @@ -322,6 +338,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { void RelayoutFormat::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { TensorLayout dst_expected; + dst_expected.dtype = dst.dtype; deduce_layout_fwd(src, dst_expected); megdnn_assert_eq_layout(dst_expected, dst); } @@ -354,6 +371,19 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, exec_dst = dst; } break; + case Param::Mode::NCHW_NCHW4: + // nchw to nchw4 + { + TensorLayout work_space_layout( + {src[0], round_up(src[1], 4_z), src[2], src[3]}, + src.dtype, src.format); + exec_src = work_space_layout + .reshape({src[0], div_ceil(src[1], 4_z), 4, + src[2], src[3]}) + .dimshuffle({0, 1, 3, 4, 2}); + exec_dst = dst; + } + break; case Param::Mode::NCHW88_NCHW: // nchw8c to nchw exec_src = src; @@ -422,7 +452,6 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, } break; - case Param::Mode::NCHW_NHWCD4: case Param::Mode::NCHW_NHWCD4I: // src is {N, C, H, W} diff --git a/dnn/src/cuda/relayout_format/opr_impl.cpp b/dnn/src/cuda/relayout_format/opr_impl.cpp index abfc6a21..448f2fc5 100644 --- a/dnn/src/cuda/relayout_format/opr_impl.cpp +++ b/dnn/src/cuda/relayout_format/opr_impl.cpp @@ -6,11 +6,13 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "src/cuda/relayout_format/opr_impl.h" #include "src/cuda/handle.h" +#include "src/cuda/relayout_format/opr_impl.h" +#include "src/cuda/relayout_format/relayout_format.h" #include "src/cuda/utils.h" using namespace megdnn; @@ -21,6 +23,7 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, auto src_dtype = src.layout.dtype; megdnn_assert( param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || + param().mode == param::RelayoutFormat::Mode::NCHW_NCHW4 || param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 || param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || param().mode == @@ -72,12 +75,25 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, return handle()->create_operator()->exec( {src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); } - TensorLayout exec_src, exec_dst; - deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); - TensorND exec_src_nd{src.raw_ptr, exec_src}; - TensorND exec_dst_nd{dst.raw_ptr, exec_dst}; - handle()->create_operator()->exec(exec_src_nd, - exec_dst_nd); + + if (param().mode == Param::Mode::NCHW_NCHW4) { + bool is_usable = relayout_format::RelayoutFormatFast::usable( + src.layout, dst.layout); + megdnn_assert(is_usable, + "RelayoutFormatNCHW_NCHW4 kernel not usable for %s(%s) " + "to %s(%s)", + src.layout.to_string().c_str(), src.layout.dtype.name(), + dst.layout.to_string().c_str(), dst.layout.dtype.name()); + relayout_format::RelayoutFormatFast::exec(src, dst, + cuda_stream(this->handle())); + } else { + TensorLayout exec_src, exec_dst; + deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); + TensorND exec_src_nd{src.raw_ptr, exec_src}; + TensorND exec_dst_nd{dst.raw_ptr, exec_dst}; + handle()->create_operator()->exec(exec_src_nd, + exec_dst_nd); + } } size_t RelayoutFormatImpl::get_workspace_in_bytes( diff --git a/dnn/src/cuda/relayout_format/relayout_format.cpp b/dnn/src/cuda/relayout_format/relayout_format.cpp new file mode 100644 index 00000000..ac9fb4c7 --- /dev/null +++ b/dnn/src/cuda/relayout_format/relayout_format.cpp @@ -0,0 +1,59 @@ +/** + * \file dnn/src/cuda/relayout_format/relayout_format.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/relayout_format/relayout_format.cuh" +#include "src/cuda/relayout_format/relayout_format.h" +using namespace megdnn; +using namespace cuda; + +namespace { + +inline void get_scale_zeropoint(const DType& tensor_dtype, float& scale, + uint8_t& zero_point) { + if (tensor_dtype.enumv() == DTypeEnum::Quantized8Asymm) { + zero_point = tensor_dtype.param().zero_point; + scale = tensor_dtype.param().scale; + } else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS8) { + scale = tensor_dtype.param().scale; + } +} + +} // namespace + +bool relayout_format::RelayoutFormatFast::usable( + const TensorLayout& src_layout, const TensorLayout& dst_layout) { + return relayout_format_cuda_usable(src_layout, dst_layout); +} + +void relayout_format::RelayoutFormatFast::exec(const TensorND& src, + const TensorND& dst, + cudaStream_t stream) { + size_t ih = src.layout[2]; + size_t iw = src.layout[3]; + size_t hw = ih * iw; + float src_scale = 1.f; + float dst_scale = 1.f; + uint8_t src_zero_point = 0; + uint8_t dst_zero_point = 0; + get_scale_zeropoint(src.layout.dtype, src_scale, src_zero_point); + get_scale_zeropoint(dst.layout.dtype, dst_scale, dst_zero_point); + if (src.layout.dtype.enumv() == DTypeEnum::Uint8) { + src_zero_point = 128; + } + if (hw % 4 == 0) { + relayout_format_cuda_exec<4>(src, dst, stream, src_scale, dst_scale, + src_zero_point, dst_zero_point); + } else { + relayout_format_cuda_exec<1>(src, dst, stream, src_scale, dst_scale, + src_zero_point, dst_zero_point); + } +} diff --git a/dnn/src/cuda/relayout_format/relayout_format.cu b/dnn/src/cuda/relayout_format/relayout_format.cu new file mode 100644 index 00000000..aa068dca --- /dev/null +++ b/dnn/src/cuda/relayout_format/relayout_format.cu @@ -0,0 +1,371 @@ +/** + * \file dnn/src/cuda/relayout_format/relayout_format.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/query_blocksize.cuh" +#include "src/cuda/relayout_format/relayout_format.cuh" +using namespace megdnn; +using namespace cuda; + +namespace { + +template +struct CudaPostProcess; + +template <> +struct CudaPostProcess { + CudaPostProcess(float, uint8_t, float, uint8_t){}; + inline __device__ int8_t operator()(uint8_t val) { return val - 128; } +}; + +template <> +struct CudaPostProcess { + CudaDTypeParamImpl m_dst_type_cvt; + CudaPostProcess(float, uint8_t, float dst_scale, uint8_t) { + m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); + }; + inline __device__ int8_t operator()(uint8_t val) { + return m_dst_type_cvt.quantize((float)val - 128.f).as_int8(); + } +}; + +template <> +struct CudaPostProcess { + CudaDTypeParamImpl m_dst_type_cvt; + CudaDTypeParamImpl m_src_type_cvt; + CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, + uint8_t) { + m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); + m_src_type_cvt = + CudaDTypeParamImpl(src_scale, src_zero_point); + }; + inline __device__ int8_t operator()(uint8_t val) { + float med_var = m_src_type_cvt.dequantize(dt_quint8(val)); + return m_dst_type_cvt.quantize(med_var).as_int8(); + } +}; + +template <> +struct CudaPostProcess { + uint8_t m_src_zero_point = 0; + CudaPostProcess(float, uint8_t src_zero_point, float, uint8_t) { + m_src_zero_point = src_zero_point; + }; + inline __device__ int8_t operator()(uint8_t val) { + return val - m_src_zero_point; + } +}; + +template <> +struct CudaPostProcess { + CudaDTypeParamImpl m_dst_type_cvt; + CudaDTypeParamImpl m_src_type_cvt; + CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { + m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); + m_src_type_cvt = CudaDTypeParamImpl(src_scale); + }; + inline __device__ int8_t operator()(int8_t val) { + float med_var = m_src_type_cvt.dequantize(dt_qint8(val)); + return m_dst_type_cvt.quantize(med_var).as_int8(); + } +}; + +template <> +struct CudaPostProcess { + CudaPostProcess(float, uint8_t, float, uint8_t){}; + inline __device__ int8_t operator()(int8_t val) { return val; } +}; + +template +struct DTypeRWHelper; +template <> +struct DTypeRWHelper { + using InnerDtype = char; + using DstDtype = char4; +}; + +template <> +struct DTypeRWHelper { + using InnerDtype = char4; + using DstDtype = char4; +}; + +template +struct Translayout { + using InnerDtype = typename DTypeRWHelper::InnerDtype; + using DstDtype = typename DTypeRWHelper::DstDtype; + static inline __device__ void trans(DstDtype (&dst_width)[pack_w], + InnerDtype (&read_channel)[pack_c], + const char zero_point); +}; + +template +struct Translayout<1, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { + using InnerDtype = typename DTypeRWHelper::InnerDtype; + using DstDtype = typename DTypeRWHelper::DstDtype; + static inline __device__ void trans( + DstDtype (&dst_width)[1], InnerDtype (&read_channel)[4], + CudaPostProcess& post_process, + const char zero_point) { + dst_width[0].x = post_process(read_channel[0]); + dst_width[0].y = post_process(read_channel[1]); + dst_width[0].z = post_process(read_channel[2]); + dst_width[0].w = post_process(read_channel[3]); + } +}; + +template +struct Translayout<4, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { + using InnerDtype = typename DTypeRWHelper::InnerDtype; + using DstDtype = typename DTypeRWHelper::DstDtype; + static inline __device__ void trans( + DstDtype (&dst_width)[4], InnerDtype (&read_channel)[4], + CudaPostProcess& post_process, + const char zero_point) { + dst_width[0].x = post_process(read_channel[0].x); + dst_width[0].y = post_process(read_channel[1].x); + dst_width[0].z = post_process(read_channel[2].x); + dst_width[0].w = post_process(read_channel[3].x); + + dst_width[1].x = post_process(read_channel[0].y); + dst_width[1].y = post_process(read_channel[1].y); + dst_width[1].z = post_process(read_channel[2].y); + dst_width[1].w = post_process(read_channel[3].y); + + dst_width[2].x = post_process(read_channel[0].z); + dst_width[2].y = post_process(read_channel[1].z); + dst_width[2].z = post_process(read_channel[2].z); + dst_width[2].w = post_process(read_channel[3].z); + + dst_width[3].x = post_process(read_channel[0].w); + dst_width[3].y = post_process(read_channel[1].w); + dst_width[3].z = post_process(read_channel[2].w); + dst_width[3].w = post_process(read_channel[3].w); + } +}; + +template +inline __device__ DstType make_zero_pad(const char zero_point) { + return zero_point; +} + +template <> +inline __device__ char4 make_zero_pad(const char zero_point) { + return {zero_point, zero_point, zero_point, zero_point}; +} + +template +inline __device__ void write_helper(DstDtype* ptr, DstDtype val) { + *ptr = val; +} + +template <> +inline __device__ void write_helper(char4* ptr, char4 val) { + int32_t* rel_ptr = (int32_t*)ptr; + *rel_ptr = *(int32_t*)(&val); +} + +template +struct RelayoutKern { + using InnerDtype = typename DTypeRWHelper::InnerDtype; + using DstDtype = typename DTypeRWHelper::DstDtype; + static inline __device__ void write(DstType* dst_ptr, + char4 (&dst_width)[pack_w]) { + DstDtype* dst_inner_ptr = (DstDtype*)dst_ptr; +#pragma unroll + for (int iw_idx = 0; iw_idx < pack_w; ++iw_idx) { + write_helper(dst_inner_ptr + iw_idx, dst_width[iw_idx]); + } + } + + static inline __device__ void read(const SrcType* src_ptr, + InnerDtype (&read_channel)[pack_c], + const int ic_stride) { +#pragma unroll + for (int ic_idx = 0; ic_idx < pack_c; ++ic_idx) { + read_channel[ic_idx] = *(InnerDtype*)(src_ptr + ic_idx * ic_stride); + } + } + + static inline __device__ void read_with_pad( + const SrcType* src_ptr, InnerDtype (&read_channel)[pack_c], + const int ic_stride, const int remain_ic, + const InnerDtype zero_point) { +#pragma unroll + for (int ic_idx = 0; ic_idx < pack_c; ++ic_idx) { + read_channel[ic_idx] = + ic_idx < remain_ic + ? *(InnerDtype*)(src_ptr + ic_idx * ic_stride) + : zero_point; + } + } + + static inline __device__ void core_relayout_kern( + const SrcType* src, DstType* dst, const int src_offset_base, + const int dst_offset_base, const int ic_offset, const int ic_stride, + const int remain_ic, + CudaPostProcess& post_process, + const char zero_point) { + InnerDtype read_channel[pack_c]; + if (with_pad) { + const InnerDtype zero_pad = make_zero_pad(zero_point); + read_with_pad(src + ic_offset + src_offset_base, read_channel, + ic_stride, remain_ic, zero_pad); + } else { + read(src + ic_offset + src_offset_base, read_channel, ic_stride); + } + DstDtype dst_width[pack_w]; + Translayout::trans(dst_width, read_channel, post_process, + zero_point); + write(dst + ic_offset + dst_offset_base, dst_width); + } +}; + +template +__global__ void kern_nchw_nchw4( + const SrcType* src, DstType* dst, int ic, int ihw, int n_stride_src, + int ic_stride, int n_stride_dst, + CudaPostProcess post_process, + const char zero_point) { + constexpr int pack_c = 4; + const int n_idx = blockIdx.y; + const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int ihw_offset = ihw_block_idx * pack_w; + + if (ihw_offset < ihw) { + const int ic_block = ic / pack_c; + const int remain_ic = ic % pack_c; + const int src_offset_base = n_idx * n_stride_src + ihw_offset; + const int dst_offset_base = n_idx * n_stride_dst + ihw_offset * pack_c; + + for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { + const int ic_offset = ic_blk_idx * pack_c * ic_stride; + RelayoutKern::core_relayout_kern(src, dst, + src_offset_base, + dst_offset_base, + ic_offset, ic_stride, + remain_ic, + post_process, + zero_point); + } + + if (remain_ic > 0) { + const int ic_offset = ic_block * pack_c * ic_stride; + RelayoutKern::core_relayout_kern(src, dst, + src_offset_base, + dst_offset_base, + ic_offset, ic_stride, + remain_ic, + post_process, + zero_point); + } + } +} + +} // namespace + +template +void relayout_format::relayout_format_cuda_exec( + const TensorND& src, const TensorND& dst, const cudaStream_t& stream, + const float src_scale, const float dst_scale, + const uint8_t src_zero_point, const uint8_t dst_zero_point) { + constexpr int pack_oc = 4; + const int n = src.layout[0]; + const int c = src.layout[1]; + const int h = src.layout[2]; + const int w = src.layout[3]; + const int hw = h * w; + const int oc_block = DIVUP(c, pack_oc); + const int n_stride_src = c * hw; + const int ic_stride = hw; + const int n_stride_dst = oc_block * pack_oc * h * w; + + auto& src_layout = src.layout; + auto& dst_layout = dst.layout; + bool same_scale = src_scale == dst_scale; +#define RUN_KERNEL(same_scale, SRC_TYPE, DST_TYPE, SRC_C_TYPE, DST_C_TYPE) \ + if (same_scale) { \ + int nr_threads = query_blocksize_for_kernel( \ + kern_nchw_nchw4); \ + const dim3 block_dim(DIVUP(hw, nr_threads* pack_w), n); \ + const dim3 thread_dim(nr_threads); \ + kern_nchw_nchw4<<>>( \ + (SRC_C_TYPE*)src.raw_ptr, (DST_C_TYPE*)dst.raw_ptr, c, hw, \ + n_stride_src, ic_stride, n_stride_dst, \ + CudaPostProcess( \ + src_scale, src_zero_point, dst_scale, dst_zero_point), \ + src_zero_point); \ + } else { \ + int nr_threads = query_blocksize_for_kernel( \ + kern_nchw_nchw4); \ + const dim3 block_dim(DIVUP(hw, nr_threads* pack_w), n); \ + const dim3 thread_dim(nr_threads); \ + kern_nchw_nchw4<<>>( \ + (SRC_C_TYPE*)src.raw_ptr, (DST_C_TYPE*)dst.raw_ptr, c, hw, \ + n_stride_src, ic_stride, n_stride_dst, \ + CudaPostProcess( \ + src_scale, src_zero_point, dst_scale, dst_zero_point), \ + src_zero_point); \ + } + + if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Uint8 && + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) { + RUN_KERNEL(same_scale, dtype::Uint8, dtype::QuantizedS8, char, char); + } else if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Quantized8Asymm && + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) { + RUN_KERNEL(same_scale, dtype::Quantized8Asymm, dtype::QuantizedS8, char, + char); + } else if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8 && + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) { + RUN_KERNEL(same_scale, dtype::QuantizedS8, dtype::QuantizedS8, char, + char); + } else { + megdnn_assert(0, "not support dtype %s %s", src_layout.dtype.name(), + dst_layout.dtype.name()); + } +} + +bool relayout_format::relayout_format_cuda_usable( + const TensorLayout& src_layout, const TensorLayout& dst_layout) { + bool is_all_continue = + src_layout.is_contiguous() && dst_layout.is_contiguous(); + bool is_all_int8 = + (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Uint8 && + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) || + (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Quantized8Asymm && + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) || + (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8 && + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8); + return is_all_continue && is_all_int8; +} + +template void relayout_format::relayout_format_cuda_exec<1>( + const TensorND& src, const TensorND& dst, const cudaStream_t& stream, + const float src_scale, const float dst_scale, + const uint8_t src_zero_point, const uint8_t dst_zero_point); + +template void relayout_format::relayout_format_cuda_exec<4>( + const TensorND& src, const TensorND& dst, const cudaStream_t& stream, + const float src_scale, const float dst_scale, + const uint8_t src_zero_point, const uint8_t dst_zero_point); diff --git a/dnn/src/cuda/relayout_format/relayout_format.cuh b/dnn/src/cuda/relayout_format/relayout_format.cuh new file mode 100644 index 00000000..7a42cf7d --- /dev/null +++ b/dnn/src/cuda/relayout_format/relayout_format.cuh @@ -0,0 +1,37 @@ +/** + * \file dnn/src/cuda/relayout_format/relayout_format.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megdnn/basic_types.h" +#include "src/cuda/utils.cuh" + +namespace megdnn { +namespace cuda { +namespace relayout_format { + +template +void relayout_format_cuda_exec(const TensorND& src, const TensorND& dst, + const cudaStream_t& stream, + const float src_scale = 1.f, + const float dst_scale = 1.f, + const uint8_t src_zero_point = 0, + const uint8_t dst_zero_point = 0); + +bool relayout_format_cuda_usable(const TensorLayout& src_layout, + const TensorLayout& dst_layout); + +} // namespace relayout_format +} // namespace cuda +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/relayout_format/relayout_format.h b/dnn/src/cuda/relayout_format/relayout_format.h new file mode 100644 index 00000000..01e74c3d --- /dev/null +++ b/dnn/src/cuda/relayout_format/relayout_format.h @@ -0,0 +1,34 @@ +/** + * \file dnn/src/cuda/relayout_format/relayout_format.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megdnn/basic_types.h" +#include "src/cuda/utils.cuh" + +namespace megdnn { +namespace cuda { +namespace relayout_format { + +struct RelayoutFormatFast { + static bool usable(const TensorLayout& src_layout, + const TensorLayout& dst_layout); + static void exec(const TensorND& src, const TensorND& dst, + cudaStream_t stream); +}; + +} // namespace relayout_format + +} // namespace cuda +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/naive/relayout_format/opr_impl.cpp b/dnn/src/naive/relayout_format/opr_impl.cpp index cc4e71ac..44725340 100644 --- a/dnn/src/naive/relayout_format/opr_impl.cpp +++ b/dnn/src/naive/relayout_format/opr_impl.cpp @@ -6,11 +6,12 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "src/naive/relayout_format/opr_impl.h" #include "src/naive/handle.h" +#include "src/naive/relayout_format/opr_impl.h" #include "megdnn/tensor_iter.h" @@ -44,7 +45,7 @@ void padding_src_to_workspace(dtype* dptr, const dtype* sptr, size_t N, template void padding_to_workspace(dtype* dptr, const dtype* sptr, const TensorLayout& src_layout, const size_t pad_axis, - const size_t align_size) { + const size_t align_size, const int pad_val = 0) { megdnn_assert(pad_axis < src_layout.ndim); const size_t axis_dim = src_layout[pad_axis]; const size_t axis_dim_padded = round_up(axis_dim, align_size); @@ -64,14 +65,16 @@ void padding_to_workspace(dtype* dptr, const dtype* sptr, sptr[src_inner_offset + inner_idx_offset]; } else { dptr[dst_outer_offset + inner_idx_offset] = - static_cast(0); + static_cast(pad_val); } } } } } + void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, - const size_t pad_axis, const size_t align_size) { + const size_t pad_axis, const size_t align_size, + DType exec_dst_dtype) { switch (src.layout.dtype.enumv()) { #define cb(name, ctype) \ case (DTypeEnum::name): { \ @@ -84,8 +87,27 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, cb(Float32, dt_float32); cb(QuantizedS8, dt_qint8); + + case (DTypeEnum::Quantized8Asymm): { + dt_quint8* sptr = src.compatible_ptr(); + dt_quint8* dptr = dst.compatible_ptr(); + padding_to_workspace( + dptr, sptr, src.layout, pad_axis, align_size, + src.layout.dtype.param() + .zero_point); + break; + } + case (DTypeEnum::Uint8): { + uint8_t* sptr = src.compatible_ptr(); + uint8_t* dptr = dst.compatible_ptr(); + uint8_t zero_point = + exec_dst_dtype.enumv() == DTypeEnum::QuantizedS8 ? 128 : 0; + padding_to_workspace(dptr, sptr, src.layout, pad_axis, + align_size, zero_point); + break; + } default: - megdnn_assert(0); + megdnn_assert(0, "not support dtype %s", src.layout.dtype.name()); #undef cb } } @@ -108,6 +130,57 @@ void padding_filter_to_workspace(dtype* dptr, const dtype* sptr, size_t OC, } } } + +void do_copy_diff_qu8_q8(const TensorND& dst, const TensorND& src) { + auto isrc = + tensor_iter_valonly::ctype>(src) + .begin(); + auto idst = tensor_iter_valonly::ctype>(dst) + .begin(); + auto src_dt_parm = src.layout.dtype.param(); + auto dst_dt_parm = dst.layout.dtype.param(); + for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) { + *idst = dst_dt_parm.quantize(src_dt_parm.dequantize(*isrc)); + ++idst; + ++isrc; + } +} + +void do_copy_diff_q8_q8(const TensorND& dst, const TensorND& src) { + auto isrc = tensor_iter_valonly::ctype>(src) + .begin(); + auto idst = tensor_iter_valonly::ctype>(dst) + .begin(); + auto src_dt_parm = src.layout.dtype.param(); + auto dst_dt_parm = dst.layout.dtype.param(); + for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) { + *idst = dst_dt_parm.quantize(src_dt_parm.dequantize(*isrc)); + ++idst; + ++isrc; + } +} + +void do_copy_diff_u8_q8(const TensorND& dst, const TensorND& src) { + auto isrc = + tensor_iter_valonly::ctype>(src).begin(); + auto idst = tensor_iter_valonly::ctype>(dst) + .begin(); + auto dst_dt_parm = dst.layout.dtype.param(); + for (size_t i = 0, it = dst.layout.total_nr_elems(); i < it; ++i) { + *idst = dst_dt_parm.quantize((float)(*isrc) - 128.f); + ++idst; + ++isrc; + } +} + +void check_layout_and_canonize(TensorLayout& src, TensorLayout& dst) { + megdnn_assert(dst.is_non_overlapping_strong()); + src = src.collapse_contiguous(); + dst = dst.collapse_contiguous(); + megdnn_assert(dst.dtype.valid() && + src.total_nr_elems() == dst.total_nr_elems()); +} + } // anonymous namespace size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, @@ -189,6 +262,13 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, size_t w = src[3]; return n * c * h * w * src.dtype.size(); } + case Param::Mode::NCHW_NCHW4: { + size_t n = src[0]; + size_t c = round_up(src[1], 4_z); + size_t h = src[2]; + size_t w = src[3]; + return n * c * h * w * src.dtype.size(); + } case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: { megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); if (src[1] % 4 == 0) @@ -208,6 +288,8 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { megdnn_assert(src.layout.dtype.category() == DTypeCategory::FLOAT || + (src.layout.dtype.enumv() == DTypeEnum::Uint8 && + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) || src.layout.dtype.category() == DTypeCategory::QUANTIZED); check_exec(src.layout, dst.layout, workspace.size); HandleImpl* m_handle = static_cast(handle()); @@ -284,7 +366,7 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, size_t val = src.layout[_idx]; \ if (val % _pack_size != 0) { \ padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ - _pack_size); \ + _pack_size, exec_dst.dtype); \ exec_src_nd.raw_ptr = workspace.raw_ptr; \ } \ } \ @@ -301,11 +383,43 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, cb(2, 8, NCHW_NCHW88_CONV_GROUP_WEIGHT); } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { cb(1, 4, NCHW_NCHW4_IC_SMALL); + } else if (param().mode == Param::Mode::NCHW_NCHW4) { + cb(1, 4, NCHW_NCHW4); } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); } - m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); + + if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { + TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; + check_layout_and_canonize(src0.layout, src0.layout); + auto func = [](const TensorND& dst, const TensorND& src) { + do_copy_diff_qu8_q8(dst, src); + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); + return; + } else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 && + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { + TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; + check_layout_and_canonize(src0.layout, src0.layout); + auto func = [](const TensorND& dst, const TensorND& src) { + do_copy_diff_u8_q8(dst, src); + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); + return; + } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 && + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { + TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; + check_layout_and_canonize(src0.layout, src0.layout); + auto func = [](const TensorND& dst, const TensorND& src) { + do_copy_diff_q8_q8(dst, src); + }; + MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); + return; + } else { + m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); + } #undef cb } diff --git a/dnn/test/cuda/relayout_format.cpp b/dnn/test/cuda/relayout_format.cpp index c4a99fc0..12fd5d89 100644 --- a/dnn/test/cuda/relayout_format.cpp +++ b/dnn/test/cuda/relayout_format.cpp @@ -6,10 +6,12 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "megdnn/dtype.h" #include "megdnn/oprs.h" +#include "test/common/benchmarker.h" #include "test/common/checker.h" #include "test/common/rng.h" #include "test/cuda/fixture.h" @@ -24,6 +26,7 @@ TEST_F(CUDA, RELAYOUT_FORMAT) { param.mode = param::RelayoutFormat::Mode::NCHW4_CHWN4; checker.set_dtype(0, dtype::QuantizedS8{0.1f}) + .set_dtype(1, dtype::QuantizedS8{0.1f}) .set_rng(0, &rng) .set_param(param) .execs({{22, 23, 24, 25, 4}, {}}); @@ -31,6 +34,164 @@ TEST_F(CUDA, RELAYOUT_FORMAT) { checker.execs({{22, 23, 24, 25, 4}, {}}); } +TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4) { + Checker checker(handle_cuda()); + UniformIntRNG rng{0, 50}; + param::RelayoutFormat param; + param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4; + + for (size_t n : {1, 3}) { + for (size_t c : {1, 2, 3, 4, 8, 9, 11, 16}) { + for (size_t h : {3, 7, 12, 16, 22, 59, 83}) { + for (size_t w : {3, 22, 63, 128, 256}) { + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{n, c, h, w}, {}}); + + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{2.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{n, c, h, w}, {}}); + } + } + } + } + + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{8, 3, 224, 224}, {}}); + + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{8, 3, 600, 600}, {}}); + + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{1, 6, 768, 1280}, {}}); +} + +TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4_DEFAULT) { + Checker checker(handle_cuda()); + UniformIntRNG rng{0, 50}; + param::RelayoutFormat param; + param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4; + for (size_t n : {1, 3}) { + for (size_t c : {1, 2, 3, 4, 8, 9, 11, 16}) { + for (size_t h : {3, 7, 12, 16, 59, 83}) { + for (size_t w : {3, 63, 128, 256}) { + checker.set_dtype(0, dtype::Quantized8Asymm{1.f, 128}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{n, c, h, w}, {}}); + } + } + } + } +} + +TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4_U8) { + Checker checker(handle_cuda()); + UniformIntRNG rng{0, 255}; + param::RelayoutFormat param; + param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4; + for (size_t n : {1, 3}) { + for (size_t c : {1, 2, 3, 4, 8, 9, 11, 16}) { + for (size_t h : {3, 7, 12, 16, 59, 83}) { + for (size_t w : {3, 13, 3 * 4, 63 * 4, 128 * 4, 256 * 4}) { + checker.set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{n, c, h, w}, {}}); + + checker.set_dtype(0, dtype::Quantized8Asymm{1.f, 128}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{n, c, h, w}, {}}); + + checker.set_dtype(0, dtype::Uint8()) + .set_dtype(1, dtype::QuantizedS8{2.5f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{n, c, h, w}, {}}); + } + } + } + } +} + +TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW4_IC_SMALL) { + Checker checker(handle_cuda()); + UniformIntRNG rng{0, 50}; + param::RelayoutFormat param; + param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL; + + checker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}) + .set_rng(0, &rng) + .set_param(param) + .execs({{8, 3, 768, 1280}, {}}); +} + +#if MEGDNN_WITH_BENCHMARK +TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT) { + using Param = RelayoutFormat::Param; + + auto run = [&](const TensorShapeArray& shapes, Param param, + Param default_param) { + Benchmarker benchmarker(handle_cuda()); + benchmarker.set_param(param); + benchmarker.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}); + + Benchmarker benchmarker_default(handle_cuda()); + benchmarker_default.set_param(default_param); + benchmarker_default.set_dtype(0, dtype::QuantizedS8{1.f}) + .set_dtype(1, dtype::QuantizedS8{1.f}); + for (auto&& shape : shapes) { + double memaccess = (double(shape.total_nr_elems()) + + double(shape[0]) * ((shape[1] + 3) / 4 * 4) * + shape[2] * shape[3]) * + 1e-6; + auto time_ms = benchmarker.execs({shape, {}}); + if (shape[1] <= 4) { + auto time_default_ms = benchmarker_default.execs({shape, {}}); + printf("execute %s, time %.4f ms, %.4f GB/s, default %.4f " + "GB/s\n", + shape.to_string().c_str(), time_ms, memaccess / time_ms, + memaccess / time_default_ms); + } else { + printf("execute %s, time %.4f ms, %.4f GB/s\n", + shape.to_string().c_str(), time_ms, memaccess / time_ms); + } + } + }; + + TensorShapeArray shapes = { + {8, 1, 768, 1280}, {8, 3, 768, 1280}, {8, 3, 224, 224}, + {8, 4, 768, 1280}, {64, 3, 768, 1280}, + }; + { + Param param; + param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4; + Param default_param; + default_param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL; + run(shapes, param, default_param); + } +} +#endif + TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) { Checker checker(handle_cuda()); UniformIntRNG rng{-50, 50}; @@ -39,7 +200,7 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) { for (DType dtype : std::vector({dtype::QuantizedS8{0.1f}, dtype::Float32{}})) { - checker.set_dtype(0, dtype).set_rng(0, &rng); + checker.set_dtype(0, dtype).set_dtype(1, dtype).set_rng(0, &rng); checker.set_param(param).execs({{2, 4, 35, 36}, {}}); checker.set_param(param).execs({{2, 3, 35, 36}, {}}); diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index 7aac31c4..9c4f1998 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -219,7 +219,10 @@ R"__usage__( Execute operators with weight preprocess, which can optimize the operator execution time with algo of winograd, im2col ,etc., but it may consume more memory. )__usage__" - +R"__usage__( + --enable-fuse-preprocess + Fusion astype\pad_channel\dimshuffle and etc opr from h2d op +)__usage__" ; struct DataParser { @@ -1141,6 +1144,11 @@ Args Args::from_argv(int argc, char **argv) { graph_opt.graph_opt.enable_nchw44_dot(); continue; } + if (!strcmp(argv[i], "--enable-fuse-preprocess")) { + mgb_log_warn("enable-fuse-preprocess optimization"); + graph_opt.graph_opt.enable_fuse_preprocess(); + continue; + } if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 127d4084..86b04884 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -101,6 +101,8 @@ struct GraphCommonOptimizeOptions { //! memory, default disable now, when weight preprocess is enabled, the //! input shape should no change bool weight_preprocess = false; + //! fuse preprocess patten, like astype + pad_channel + dimshuffle + bool fuse_preprocess = false; enum LayoutTransform : uint32_t { DEFAULT, NCHW4, ///< compute using NCHW4 tensor format @@ -130,6 +132,7 @@ struct GraphCommonOptimizeOptions { SET(f16_io_comp); SET(fuse_conv_bias_nonlinearity); SET(fuse_conv_bias_with_z); + SET(fuse_preprocess); SET(weight_winograd_transform); SET(weight_preprocess); #undef SET diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 87abae82..ad57e59f 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -724,6 +724,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( options.disable_##_option(); \ } \ } + + cb(fuse_preprocess, {add_pass(FuseNCHW4Int8Preprocess::make());}); cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); @@ -761,6 +763,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( add_pass(EnableTensorCorePass::make_tensorcore_converter()); add_pass(); add_pass(); + add_pass(FuseNCHW4Int8Preprocess::make()); }); cb(chwn4, { add_pass(); diff --git a/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp b/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp new file mode 100644 index 00000000..1923b6a1 --- /dev/null +++ b/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp @@ -0,0 +1,446 @@ +/** + * \file src/gopt/impl/fuse_nchw4_int8_preprocess.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/gopt/inference.h" +#include "megbrain/gopt/misc.h" +#include "megbrain/graph/grad_impl.h" +#include "megbrain/opr/cond.h" +#include "megbrain/opr/io.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/utility.h" +#include "megbrain/serialization/opr_shallow_copy.h" +#include "megbrain/serialization/serializer.h" + +using namespace mgb; +using namespace gopt; +namespace { +#define RETURN_IF_FALSE(ok) \ + { \ + if (!ok) \ + return ok; \ + } + +struct SubGraphMatcher { + struct Node { + using CallBack = std::function; + Node(Typeinfo* in_op_type) : op_type(in_op_type){}; + Node(Typeinfo* in_op_type, CallBack func) + : op_type(in_op_type), cbk(func){}; + Node(Typeinfo* in_op_type, std::vector in_pre_node) + : op_type(in_op_type), pre_node(in_pre_node){}; + Node(Typeinfo* in_op_type, std::vector in_pre_node, CallBack func) + : op_type(in_op_type), pre_node(in_pre_node), cbk(func){}; + + Typeinfo* op_type{nullptr}; + std::vector pre_node; + //! cbk used to check param and gather args for creating fusion op + CallBack cbk; + }; + + bool match(Node& root, OperatorNodeBase* opr) { + if (opr == nullptr) { + return false; + } + //! match nullptr node always + if (root.op_type == nullptr || root.op_type == opr->dyn_typeinfo()) { + bool match_ok = true; + if (root.cbk) + match_ok &= root.cbk(opr); + RETURN_IF_FALSE(match_ok); + auto& inp = opr->input(); + for (size_t node_idx = 0; node_idx < root.pre_node.size(); + ++node_idx) { + bool valid_node_idx = node_idx < inp.size(); + RETURN_IF_FALSE(valid_node_idx); + match_ok &= match(root.pre_node[node_idx], + inp[node_idx]->owner_opr()); + RETURN_IF_FALSE(match_ok); + } + return match_ok; + } else { + return false; + } + } +}; +#undef RETURN_IF_FALSE + +struct SubGraphChecker { + using DepType = cg::OperatorNodeProp::DepType; + using ReaderType = + ThinHashMap>>; + SubGraphChecker() {} + + bool check(ThinHashSet used_input, + OperatorNodeBase* start_opr, OperatorNodeBase* stop_opr, + ReaderType& readers, bool ignore_immutable = true) { + bool is_all_inp_used = check_all_inp_used(used_input, start_opr, + stop_opr, ignore_immutable); + bool is_all_dep_inside = + check_all_dep_inside_node(start_opr, stop_opr, readers); + return is_all_inp_used && is_all_dep_inside; + } + + bool check_all_inp_used(ThinHashSet& used_input, + OperatorNodeBase* start_opr, + OperatorNodeBase* stop_opr, + bool ignore_immutable = true) { + ThinHashSet leaf_set; + get_leaf_node(start_opr, stop_opr, leaf_set); + for (auto in_opr : leaf_set) { + bool skip = in_opr->same_type() && + ignore_immutable; + if (used_input.find(in_opr) == used_input.end() && !skip) { + return false; + } + } + return true; + } + + bool check_all_dep_inside_node(OperatorNodeBase* start_opr, + OperatorNodeBase* stop_opr, + ReaderType& readers) { + ThinHashSet mid_set; + get_mid_node(start_opr, start_opr, stop_opr, mid_set); + for (auto inner_opr : mid_set) { + if (readers.find(inner_opr) != readers.end()) { + for (auto& out_node : readers[inner_opr]) { + if (mid_set.find(out_node.first) == mid_set.end() && + out_node.first != start_opr && + out_node.second == + cg::OperatorNodeProp::DepType::DEV_VALUE) { + return false; + } + } + } + } + return true; + } + + void get_mid_node(OperatorNodeBase* opr, OperatorNodeBase* start_opr, + OperatorNodeBase* stop_opr, + ThinHashSet& mid_set) { + if (opr == nullptr) { + return; + } + if (opr != start_opr) { + mid_set.insert(opr); + } + if (opr == stop_opr) { + return; + } + for (auto& tensor : opr->input()) { + auto pre_opr = tensor->owner_opr(); + get_mid_node(pre_opr, start_opr, stop_opr, mid_set); + } + } + + void get_leaf_node(OperatorNodeBase* opr, OperatorNodeBase* stop_opr, + ThinHashSet& leaf_set) { + if (opr == nullptr) { + return; + } + if (opr == stop_opr || opr->input().size() == 0) { + leaf_set.insert(opr); + } + if (opr == stop_opr) { + return; + } + for (auto& tensor : opr->input()) { + auto pre_opr = tensor->owner_opr(); + get_leaf_node(pre_opr, stop_opr, leaf_set); + } + } +}; + +static inline bool is_shape_nchw(const TensorShape& shape) { + return shape.ndim == 4; +} + +static inline bool is_shape_before_nchw4(const TensorShape& shape) { + return shape.ndim == 5 && shape[2] == 4; +} + +static inline bool is_nchw_nchw4_shuffle_vec( + const opr::Dimshuffle::Param param) { + return param.ndim == 5 && param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 3 && param.pattern[3] == 4 && + param.pattern[4] == 2; +} + +template +static inline bool is_immutable_equal(OperatorNodeBase* opr, T val, + DTypeEnum dtype_enum) { + auto const_opr = opr->try_cast_final(); + if (!const_opr) { + return false; + } + auto& host_value = const_opr->host_value(); + bool ok_value = host_value.layout().total_nr_elems() == 1 && + host_value.dtype().enumv() == dtype_enum && + host_value.ptr()[0] == val; + return ok_value; +} + +template +static inline bool is_immutable_all_equal(OperatorNodeBase* opr, + typename DTypeTrait::ctype val) { + auto const_opr = opr->try_cast_final(); + if (!const_opr) { + return false; + } + auto& host_value = const_opr->host_value(); + bool ok_value = host_value.dtype().enumv() == DTypeTrait::enumv; + if (!ok_value) { + return false; + } + size_t nr_elem = host_value.layout().total_nr_elems(); + for (size_t i = 0; i < nr_elem; ++i) { + if (host_value.ptr::ctype>()[i] != val) { + ok_value = false; + break; + } + } + return ok_value; +} + +} // namespace + +const char* FuseNCHW4Int8Preprocess::name() const { + return "fuse_pre_process_pass"; +} + +std::unique_ptr FuseNCHW4Int8Preprocess::make() { + using SGM = SubGraphMatcher; + auto gen_pad_dimshuffle_graph = [&](SGM::Node& in_node, + SGM::Node::CallBack& pad_cbk, + SGM::Node::CallBack& shape_cbk) { + SGM::Node::CallBack check_pad = [&](OperatorNodeBase* opr) { + SGM sub_matcher; + SGM::Node immu_node{opr::ImmutableTensor::typeinfo(), pad_cbk}; + if (opr->same_type()) { + return sub_matcher.match(immu_node, opr); + } else if (opr->same_type()) { + return sub_matcher.match(immu_node, + opr->input()[0]->owner_opr()); + } else { + return false; + } + }; + SGM::Node broadcast_or_immutable{nullptr, check_pad}; + SGM::Node broadcast_concat{ + opr::Concat::typeinfo(), + {in_node, broadcast_or_immutable}, + [](OperatorNodeBase* opr) { + auto concat_pad = opr->try_cast_final(); + return concat_pad->axis() == 1; + }}; + + SGM::Node nchwx_reshape{opr::Reshape::typeinfo(), + {broadcast_concat, SGM::Node(nullptr)}, + [](OperatorNodeBase* opr) { + auto inp0 = opr->input()[0]; + return is_shape_nchw(inp0->shape()); + }}; + SGM::Node shuffle_root{ + opr::Dimshuffle::typeinfo(), + {nchwx_reshape}, + [](OperatorNodeBase* opr) { + auto& shuffle_opr = opr->cast_final(); + auto& input_vec = shuffle_opr.input(); + return is_shape_before_nchw4(input_vec[0]->shape()) && + is_nchw_nchw4_shuffle_vec(shuffle_opr.param()); + }}; + return shuffle_root; + }; + auto replace_shuffle_opr = [&](OperatorNodeBase* opr, + const VarNodeArray& new_inp, + SubGraph::Rewriter& rewriter, + ReaderType& reader) { + SGM matcher; + OperatorNodeBase* src_node = nullptr; + SGM::Node input_data_cp{ + nullptr, [&](OperatorNodeBase* opr) { + auto src_dtype = opr->output()[0]->dtype(); + if (src_dtype.enumv() == DTypeEnum::Quantized8Asymm) { + src_node = opr; + return true; + } else { + return false; + } + }}; + SGM::Node type_cvt{opr::TypeCvt::typeinfo(), {input_data_cp}}; + SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) { + bool is_fp32_pad = is_immutable_all_equal(opr, 0); + bool is_i32_pad = is_immutable_all_equal(opr, 0); + bool is_q8_pad = is_immutable_all_equal( + opr, dt_qint8(0)); + return is_fp32_pad || is_i32_pad || is_q8_pad; + }; + SGM::Node::CallBack const_reshape_cbk = [](OperatorNodeBase* opr) { + return true; + }; + auto&& shuffle_root = gen_pad_dimshuffle_graph(type_cvt, const_pad_cbk, + const_reshape_cbk); + bool match = matcher.match(shuffle_root, opr); + bool check_ok = false; + if (match) { + check_ok = + SubGraphChecker().check({src_node}, opr, src_node, reader); + } + if (match && check_ok) { + opr::RelayoutFormat::Param param; + param.mode = opr::RelayoutFormat::Param::Mode::NCHW_NCHW4; + OperatorNodeConfig config(opr->output()[0]->dtype()); + auto out_node = opr::RelayoutFormat::make( + rewriter.get_var(src_node->output()[0]), param.mode, + config); + return out_node.node()->owner_opr(); + } else { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + }; + + auto replace_astype_opr = [&](OperatorNodeBase* opr, + const VarNodeArray& new_inp, + SubGraph::Rewriter& rewriter, + ReaderType& reader) { + SGM matcher; + OperatorNodeBase* src_node = nullptr; + OperatorNodeBase* neg_128_immu_node = nullptr; + OperatorNodeBase* pad0_immu_node = nullptr; + OperatorNodeBase* const_reshape_last_dim_node = nullptr; + SGM::Node input_data_cp{nullptr, [&](OperatorNodeBase* opr) { + auto src_dtype = opr->output()[0]->dtype(); + if (src_dtype.enumv() == DTypeEnum::Uint8) { + src_node = opr; + return true; + } else { + return false; + } + }}; + SGM::Node cvt_fp32{opr::TypeCvt::typeinfo(), + {input_data_cp}, + [](OperatorNodeBase* opr) { + auto cvt_op = + opr->try_cast_final(); + bool is_fp32 = cvt_op->param().enumv() == + DTypeEnum::Float32; + return is_fp32; + }}; + SGM::Node sub_128{ + opr::Elemwise::typeinfo(), + {cvt_fp32}, + [&](OperatorNodeBase* opr) { + auto elem_op = opr->try_cast_final(); + bool is_add_op = elem_op->param().mode == + opr::Elemwise::Param::Mode::ADD; + auto neg_128_op = elem_op->input()[1]->owner_opr(); + bool is_neg_128 = is_immutable_equal(neg_128_op, -128.f, + DTypeEnum::Float32); + neg_128_immu_node = is_neg_128 ? neg_128_op : nullptr; + return is_add_op && is_neg_128; + }}; + SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) { + pad0_immu_node = opr; + bool is_fp32_pad = is_immutable_all_equal(opr, 0); + bool is_i32_pad = is_immutable_all_equal(opr, 0); + return is_fp32_pad || is_i32_pad; + }; + SGM::Node::CallBack const_reshape_cbk = [&](OperatorNodeBase* opr) { + const_reshape_last_dim_node = opr; + return true; + }; + auto&& shuffle_root = gen_pad_dimshuffle_graph(sub_128, const_pad_cbk, + const_reshape_cbk); + + SGM::Node astype_root{opr::TypeCvt::typeinfo(), {shuffle_root}}; + bool match = matcher.match(astype_root, opr); + bool check_ok = false; + if (match) { + check_ok = SubGraphChecker().check( + {src_node, neg_128_immu_node, pad0_immu_node, + const_reshape_last_dim_node}, + opr, src_node, reader); + } + if (match && check_ok) { + opr::RelayoutFormat::Param param; + param.mode = opr::RelayoutFormat::Param::Mode::NCHW_NCHW4; + OperatorNodeConfig config(opr->output()[0]->dtype()); + auto out_node = opr::RelayoutFormat::make( + rewriter.get_var(src_node->output()[0]), param.mode, + config); + return out_node.node()->owner_opr(); + } else { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + }; + auto ret = std::make_unique(); + auto&& replace_func = ret->m_opr_replace_func; + + MGB_MARK_USED_VAR(replace_astype_opr); + MGB_MARK_USED_VAR(replace_shuffle_opr); + replace_func[opr::Dimshuffle::typeinfo()] = replace_shuffle_opr; + replace_func[opr::TypeCvt::typeinfo()] = replace_astype_opr; + return ret; +} + +void FuseNCHW4Int8Preprocess::apply(OptState& state) const { + state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_DTYPE | + VarReplaceCheckFlag::CHECK_SHAPE); + auto rewriter = state.graph().make_rewriter(); + VarNodeArray new_inp_cache; + + ReaderType readers; + state.graph().iter([&readers](OperatorNodeBase* opr) { + for (auto&& i : opr->node_prop().dep_map()) { + readers[i.first->owner_opr()].emplace_back(opr, i.second); + } + }); + + auto on_opr = [this, &rewriter, &new_inp_cache, + &readers](OperatorNodeBase* opr) { + auto it = m_opr_replace_func.find(opr->dyn_typeinfo()); + + if (it != m_opr_replace_func.end()) { + auto&& new_inp = new_inp_cache; + new_inp.clear(); + new_inp.reserve(opr->input().size()); + for (auto i : opr->input()) { + new_inp.push_back(rewriter.get_var(i)); + } + auto new_opr = (it->second)(opr, new_inp, rewriter, readers); + if (new_opr->try_cast_final()) { + auto &&origin_out = opr->output(), + &&cur_out = new_opr->output(); + rewriter.replace_var(origin_out[0], cur_out[0], nullptr); + } else { + auto &&origin_out = opr->output(), + &&cur_out = new_opr->output(); + mgb_assert(origin_out.size() == cur_out.size(), + "bad opr replace: src=%s{%s} dst=%s{%s}, %zu != %zu", + opr->cname(), opr->dyn_typeinfo()->name, + new_opr->cname(), new_opr->dyn_typeinfo()->name, + origin_out.size(), cur_out.size()); + for (size_t i = 0; i < origin_out.size(); i++) { + rewriter.replace_var(origin_out[i], cur_out[i], nullptr); + } + } + } else { + rewriter.auto_replace_outputs(opr); + } + }; + state.graph().iter(on_opr); + rewriter.apply_inplace(); +} \ No newline at end of file diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index 34892f22..b309bcb8 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -152,6 +152,26 @@ namespace gopt { void apply(OptState& opt) const override; }; + /*! + * \brief fuse preprocess, like pad channel, quint8 to qint8 + */ + class FuseNCHW4Int8Preprocess : public Pass { + public: + const char* name() const override; + void apply(OptState& opt) const override; + static std::unique_ptr make(); + using DepType = cg::OperatorNodeProp::DepType; + using ReaderType = + ThinHashMap>>; + + private: + ThinHashMap> + m_opr_replace_func; + }; + /*! * \brief fuse deconv and typecvt to a deconv opr */ diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 229cd2c1..bf35086d 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -719,15 +719,15 @@ TEST(TestGoptInference, Float16IOFloat32ComputeDeConv) { }; graph->options().graph_opt_level = 0; - auto s0 = mkvar("s0", {5, 5, 3, 3}), - s1 = mkvar("s1", {1, 5, INP_H, INP_W}); + auto s0 = mkvar("s0", {5, 5, 3, 3}), s1 = mkvar("s1", {1, 5, INP_H, INP_W}); auto y = opr::ConvolutionBackwardData::make(s0, s1, {}, {}); SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; options.enable_f16_io_f32_comp(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); - ASSERT_EQ(find_opr(y_opt).param().compute_mode, - opr::ConvBias::Param::ConvBias::ComputeMode::FLOAT32); + ASSERT_EQ( + find_opr(y_opt).param().compute_mode, + opr::ConvBias::Param::ConvBias::ComputeMode::FLOAT32); ASSERT_EQ(y_opt.dtype(), dtype::Float32()); HostTensorND host_y, host_y_opt; @@ -1603,7 +1603,6 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass_FullBias) { } } - TEST(TestGoptInference, ParamMerge) { auto cns = load_multiple_xpus(2); HostTensorGenerator<> gen; @@ -3364,14 +3363,14 @@ TEST(TestGoptInference, ConvertFormatNCHW44MultiInput) { auto b = mkvar("b", {1, 1, 16, 16}), elem0 = opr::Elemwise::make({conv1 + b + b}, - opr::Elemwise::Param::Mode::RELU); + opr::Elemwise::Param::Mode::RELU); auto w2 = mkcvar("w2", {8, 8, 3, 3}), conv2 = opr::Convolution::make(elem0, w2, param_conv); auto b1 = mkvar("b1", {1}), y = opr::Elemwise::make({conv2 + b1 + b}, - opr::Elemwise::Param::Mode::RELU); + opr::Elemwise::Param::Mode::RELU); SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; @@ -3631,4 +3630,97 @@ TEST(TestGoptInference, ConvertFormatCD4GroupOneConv) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } +#if MGB_CUDA + +TEST(TestGoptInference, PreProcessCase0) { + REQUIRE_GPU(1); + HostTensorGenerator + gen(dt_quint8(0), dt_quint8(50), 1, 128, 1234); + auto cn = CompNode::load("gpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + + size_t n = 1; + size_t c = 3; + size_t h = 16; + size_t w = 16; + auto host_x1 = gen({n, c, h, w}, cn); + + auto x = opr::Host2DeviceCopy::make(*graph, host_x1); + auto x_q8 = opr::TypeCvt::make(x, dtype::QuantizedS8(1.f), cn); + auto zero = DTypeScalar(dtype::QuantizedS8(1.f)); + auto zero_tensor = opr::ImmutableTensor::make(*graph, zero, cn); + auto pad_channel_tensor = + opr::Broadcast::make(zero_tensor, {n, 1, h, w}, cn); + auto paded_x = opr::Concat::make({x_q8, pad_channel_tensor}, 1, cn) + .reshape({n, 1, 4, h, w}); + + auto result = opr::Dimshuffle::make(paded_x, {0, 1, 3, 4, 2}, 5, cn); + + auto y = result; + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_fuse_preprocess(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath( + output_file("TestGoptInference.PreProcessCase0.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); + + ASSERT_TRUE(y_opt.node()->owner_opr()->same_type()); +} + +TEST(TestGoptInference, PreProcessCase1) { + REQUIRE_GPU(1); + HostTensorGenerator gen(0, 255); + auto cn = CompNode::load("gpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + + size_t n = 1; + size_t c = 3; + size_t h = 16; + size_t w = 16; + auto host_x1 = gen({n, c, h, w}, cn); + + auto x = opr::Host2DeviceCopy::make(*graph, host_x1); + auto x_u8 = opr::TypeCvt::make(x, dtype::Float32(), cn); + auto x_s8 = x_u8 - 128; + auto zero = DTypeScalar(dtype::Float32()); + auto zero_tensor = opr::ImmutableTensor::make(*graph, zero, cn); + auto pad_channel_tensor = + opr::Broadcast::make(zero_tensor, {n, 1, h, w}, cn); + auto paded_x = opr::Concat::make({x_s8, pad_channel_tensor}, 1, cn) + .reshape({n, 1, 4, h, w}); + + auto nchw4_out = opr::Dimshuffle::make(paded_x, {0, 1, 3, 4, 2}, 5, cn); + auto result = opr::TypeCvt::make(nchw4_out, dtype::QuantizedS8(1.f)); + + auto y = result; + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_fuse_preprocess(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath( + output_file("TestGoptInference.PreProcessCase1.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); + + ASSERT_TRUE(y_opt.node()->owner_opr()->same_type()); +} +#endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index c7916909..e176f58d 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -198,7 +198,7 @@ Elemwise::Elemwise( param.mode == Param::Mode::MAX || param.mode == Param::Mode::MIN, "Only ADD, SUB, NEGATE, RELU, MAX and MIN is guaranteed " - "to be supported on Elemwise for quantized DType"); + "to be supported on Elemwise for quantized DType, no support %d", (int)param.mode); } } diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index a350bf02..af3f5e65 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -1578,6 +1578,23 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) { // f}}} /* f{{{ ======================= RelayoutFormat ======================= */ +namespace mgb { +namespace opr { +namespace intl { +template <> +struct MegDNNOprInitPostCtor { + static void apply(cg::OperatorNodeBase& opr) { + if (opr.config().output_dtype().valid()) { + opr.output(0)->dtype(opr.config().output_dtype()); + } else { + opr.output(0)->dtype(opr.input(0)->dtype()); + } + } +}; +} // namespace intl +} // namespace opr +} // namespace mgb + MGB_DYN_TYPE_OBJ_FINAL_IMPL(RelayoutFormat); MEGDNN_OPR_INIT1(RelayoutFormat, "relayout_format") diff --git a/test/src/helper.cpp b/test/src/helper.cpp index e3ef4b02..fe111bca 100644 --- a/test/src/helper.cpp +++ b/test/src/helper.cpp @@ -190,6 +190,24 @@ namespace mgb { } return ret; } + + std::shared_ptr + HostTensorGenerator:: + operator()(const TensorShape& shape, CompNode cn) { + if (!cn.valid()) + cn = CompNode::load("xpu0"); + auto dtype = dtype::Quantized8Asymm(m_scale, m_zero_point); + auto param = dtype.param(); + std::shared_ptr ret = + std::make_shared(cn, shape, dtype); + auto ptr = ret->ptr(); + double scale = (param.dequantize(m_hi) - param.dequantize(m_lo)) / + (m_rng.max() + 1.0); + for (size_t i = 0, it = shape.total_nr_elems(); i < it; ++i) { + ptr[i] = param.quantize(m_rng() * scale + param.dequantize(m_lo)); + } + return ret; + } } ::testing::AssertionResult mgb::__assert_float_equal( diff --git a/test/src/include/megbrain/test/helper.h b/test/src/include/megbrain/test/helper.h index 47de1c75..b9b34745 100644 --- a/test/src/include/megbrain/test/helper.h +++ b/test/src/include/megbrain/test/helper.h @@ -264,6 +264,10 @@ struct UniformRNGDefaultRange { static const dt_qint8 LO, HI; }; +template<> +struct UniformRNGDefaultRange { + static const dt_quint8 LO, HI; +}; //! gaussian template class HostTensorGenerator final: @@ -404,6 +408,33 @@ class HostTensorGenerator final ctype m_lo, m_hi; }; +template <> +class HostTensorGenerator + final : public HostTensorGeneratorBase { +public: + using ctype = typename DTypeTrait::ctype; + + HostTensorGenerator( + ctype lo = UniformRNGDefaultRange::LO, + ctype hi = UniformRNGDefaultRange::HI, + float scale = 1.f, uint8_t zero_point = 0, + uint64_t seed = next_rand_seed()) + : HostTensorGeneratorBase{seed}, + m_scale{scale}, + m_zero_point(zero_point), + m_lo{lo}, + m_hi{hi} {} + + std::shared_ptr operator()(const TensorShape& shape, + CompNode cn = {}) override; + using HostTensorGeneratorBase::operator(); + +private: + float m_scale; + uint8_t m_zero_point; + ctype m_lo, m_hi; +}; + /*! * \brief get output file name in test output dir * \param check_writable whether to ensure the file is writable