GitOrigin-RevId: 60942aca5b
tags/v1.0.0-rc1
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/internal/opr_header_prologue.h" | |||
| @@ -314,8 +315,10 @@ public: | |||
| /** | |||
| * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic) | |||
| * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw, | |||
| * 4*ic) \param[in] bias (1, oc, 1, 1) \param[in] z same as dst \param[out] | |||
| * dst (n, oc, oh, ow) or (n, oh, ow, oc) | |||
| * 4 * ic) | |||
| * \param[in] bias (1, oc, 1, 1) | |||
| * \param[in] z same as dst | |||
| * \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc) | |||
| * | |||
| * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah, | |||
| * alphaw, oc, ic) | |||
| @@ -407,6 +410,26 @@ public: | |||
| */ | |||
| static WinogradParam parse_winograd_name(const std::string& algo_name); | |||
| /** | |||
| * @brief find if there is nchw_nchwxx conv kernel optimized for argment, | |||
| * nchw44 used for arm, nchw88 used for x86 | |||
| * | |||
| * @param src_dtype conv feature map data type | |||
| * @param filter_dtype conv filter or weight data type | |||
| * @param dst_dtype output data type | |||
| * @param fm filter meta param | |||
| * @param bias_mode bias mode, no_bias or broadcast or bias | |||
| * @param nonline_mode identity or relu or h_swish or sigmoid | |||
| * @return true, found a kernel | |||
| * @return false, can`t found any kernel | |||
| */ | |||
| static bool is_nchw_nchwxx_optimized( | |||
| const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
| const DTypeEnum dst_dtype, | |||
| const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
| const ConvBiasForward::BiasMode bias_mode, | |||
| const param::ConvBias::NonlineMode nonline_mode); | |||
| protected: | |||
| CanonizedFilterMeta check_exec( | |||
| const TensorLayout& src, const TensorLayout& filter, | |||
| @@ -16,10 +16,10 @@ | |||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/fp32/strategy.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/common/nchw_nchwxx_valid.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "midout.h" | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| using conv_fun = std::function<void( | |||
| @@ -191,22 +191,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||
| bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable( | |||
| const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
| auto&& fm = param.filter_meta; | |||
| auto fh = fm.spatial[0]; | |||
| int oc = fm.ocpg; | |||
| bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float32 && | |||
| param.filter_type.enumv() == DTypeEnum::Float32 && | |||
| (param.dst_type.enumv() == DTypeEnum::Float32))) && | |||
| (fm.format == param::Convolution::Format::NCHW44); | |||
| bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1; | |||
| bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && | |||
| (fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
| fm.stride[0] == fm.stride[1] && | |||
| (fm.stride[0] == 1 || fm.stride[0] == 2); | |||
| bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; | |||
| bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; | |||
| return avaible; | |||
| return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_FP32>( | |||
| param.src_type.enumv(), param.filter_type.enumv(), | |||
| param.dst_type.enumv(), param.filter_meta, param.bias_mode, | |||
| param.nonlineMode); | |||
| } | |||
| size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( | |||
| @@ -15,6 +15,7 @@ | |||
| #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/int8/strategy.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/common/nchw_nchwxx_valid.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "midout.h" | |||
| @@ -214,26 +215,12 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||
| ow, op); | |||
| } | |||
| bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable( | |||
| const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy algo_selection_strategy) const { | |||
| MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||
| auto&& fm = param.filter_meta; | |||
| auto FH = fm.spatial[0]; | |||
| auto OC = fm.ocpg; | |||
| bool avaible = //! src and filter are qint8, dst is qint8 | |||
| fm.icpg < 4 && // must be nchw input | |||
| ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
| param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
| (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && | |||
| (fm.format == param::Convolution::Format::NCHW44) && | |||
| (OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 && | |||
| fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
| fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && | |||
| (fm.stride[0] == 1 || fm.stride[0] == 2) && FH == fm.spatial[1] && | |||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.group == 1 && | |||
| param.bias_mode != BiasMode::BIAS; | |||
| return avaible; | |||
| bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable(const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy) const { | |||
| return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8>( | |||
| param.src_type.enumv(), param.filter_type.enumv(), | |||
| param.dst_type.enumv(), param.filter_meta, param.bias_mode, | |||
| param.nonlineMode); | |||
| } | |||
| bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred( | |||
| @@ -16,6 +16,7 @@ | |||
| #include "src/arm_common/conv_bias/int8/algos.h" | |||
| #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/common/nchw_nchwxx_valid.h" | |||
| #include "midout.h" | |||
| @@ -174,23 +175,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||
| bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( | |||
| const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
| auto&& fm = param.filter_meta; | |||
| auto fh = fm.spatial[0]; | |||
| int oc = fm.ocpg; | |||
| int ic = fm.icpg; | |||
| bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
| param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
| (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && | |||
| (fm.format == param::Convolution::Format::NCHW44_DOT); | |||
| bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4); | |||
| bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && | |||
| (fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
| fm.stride[0] == fm.stride[1] && | |||
| (fm.stride[0] == 1 || fm.stride[0] == 2); | |||
| bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; | |||
| bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; | |||
| return avaible; | |||
| return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>( | |||
| param.src_type.enumv(), param.filter_type.enumv(), | |||
| param.dst_type.enumv(), param.filter_meta, param.bias_mode, | |||
| param.nonlineMode); | |||
| } | |||
| size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace( | |||
| @@ -16,6 +16,7 @@ | |||
| #include "src/arm_common/conv_bias/int8x8x16/algos.h" | |||
| #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/common/nchw_nchwxx_valid.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "midout.h" | |||
| @@ -220,23 +221,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||
| bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable( | |||
| const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
| auto&& fm = param.filter_meta; | |||
| auto fh = fm.spatial[0]; | |||
| int oc = fm.ocpg; | |||
| bool ok_type = ((param.src_type.enumv() == DTypeEnum::Int8 && | |||
| param.filter_type.enumv() == DTypeEnum::Int8 && | |||
| (param.dst_type.enumv() == DTypeEnum::Int16))) && | |||
| (fm.format == param::Convolution::Format::NCHW44); | |||
| bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1; | |||
| bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && | |||
| (fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
| fm.stride[0] == fm.stride[1] && | |||
| (fm.stride[0] == 2 || fm.stride[0] == 1); | |||
| bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS && | |||
| param.nonlineMode == param::ConvBias::NonlineMode::IDENTITY; | |||
| bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; | |||
| return avaible; | |||
| return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_INT8_INT16>( | |||
| param.src_type.enumv(), param.filter_type.enumv(), | |||
| param.dst_type.enumv(), param.filter_meta, param.bias_mode, | |||
| param.nonlineMode); | |||
| } | |||
| size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace( | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * \file dnn/src/common/nchw_nchwxx_valid.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megdnn/oprs/nn.h" | |||
| #include "src/common/nchw_nchwxx_valid.h" | |||
| using namespace megdnn; | |||
| namespace { | |||
| using NchwNchwxxFuncInterface = std::function<bool( | |||
| const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
| const DTypeEnum dst_dtype, | |||
| const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
| const ConvBiasForward::BiasMode bias_mode, | |||
| const param::ConvBias::NonlineMode nonline_mode)>; | |||
| static SmallVector<NchwNchwxxFuncInterface> g_func_vec{ | |||
| nchw_nchwxx_valid<NchwNchwxxType::NCHW44_FP32>, | |||
| nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8>, | |||
| nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_INT8_INT16>, | |||
| nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>, | |||
| nchw_nchwxx_valid<NchwNchwxxType::NCHW88>, | |||
| }; | |||
| } // namespace | |||
| bool ConvBiasForward::is_nchw_nchwxx_optimized( | |||
| const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
| const DTypeEnum dst_dtype, | |||
| const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
| const ConvBiasForward::BiasMode bias_mode, | |||
| const param::ConvBias::NonlineMode nonline_mode) { | |||
| for (auto& func : g_func_vec) { | |||
| if (func(src_dtype, filter_dtype, dst_dtype, fm, bias_mode, | |||
| nonline_mode)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| @@ -0,0 +1,161 @@ | |||
| /** | |||
| * \file dnn/src/common/nchw_nchwxx_valid.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/fallback/conv_bias/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace { | |||
| enum NchwNchwxxType { | |||
| NCHW44_FP32, | |||
| NCHW44_INT8, | |||
| NCHW44_INT8_INT8_INT16, | |||
| NCHW44_INT8_DOT, | |||
| NCHW88, | |||
| }; | |||
| template <NchwNchwxxType T> | |||
| static inline bool nchw_nchwxx_valid( | |||
| const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
| const DTypeEnum dst_dtype, | |||
| const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
| const BiasMode bias_mode, | |||
| const param::ConvBias::NonlineMode nonline_mode); | |||
| template <> | |||
| inline bool nchw_nchwxx_valid<NCHW44_FP32>( | |||
| const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
| const DTypeEnum dst_dtype, | |||
| const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
| const BiasMode bias_mode, | |||
| const param::ConvBias::NonlineMode nonline_mode) { | |||
| bool ok_type = ((src_dtype == DTypeEnum::Float32 && | |||
| filter_dtype == DTypeEnum::Float32 && | |||
| (dst_dtype == DTypeEnum::Float32))) && | |||
| (fm.format == param::Convolution::Format::NCHW44); | |||
| bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY || | |||
| nonline_mode == param::ConvBias::NonlineMode::RELU || | |||
| nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | |||
| bool ok_src_dst = | |||
| fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | |||
| bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
| (fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||
| fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
| fm.stride[0] == fm.stride[1] && | |||
| (fm.stride[0] == 1 || fm.stride[1] == 2); | |||
| bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; | |||
| bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && | |||
| ok_slide && ok_conv; | |||
| return avaible; | |||
| } | |||
| template <> | |||
| inline bool nchw_nchwxx_valid<NCHW44_INT8>( | |||
| const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
| const DTypeEnum dst_dtype, | |||
| const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
| const BiasMode bias_mode, | |||
| const param::ConvBias::NonlineMode nonline_mode) { | |||
| bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 && | |||
| filter_dtype == DTypeEnum::QuantizedS8 && | |||
| (dst_dtype == DTypeEnum::QuantizedS8))) && | |||
| (fm.format == param::Convolution::Format::NCHW44); | |||
| bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY || | |||
| nonline_mode == param::ConvBias::NonlineMode::RELU || | |||
| nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | |||
| bool ok_src_dst = | |||
| fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | |||
| bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
| (fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||
| fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
| fm.stride[0] == fm.stride[1] && | |||
| (fm.stride[0] == 1 || fm.stride[1] == 2); | |||
| bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; | |||
| bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && | |||
| ok_slide && ok_conv; | |||
| return avaible; | |||
| } | |||
| template <> | |||
| inline bool nchw_nchwxx_valid<NCHW44_INT8_INT8_INT16>( | |||
| const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
| const DTypeEnum dst_dtype, | |||
| const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
| const BiasMode bias_mode, | |||
| const param::ConvBias::NonlineMode nonline_mode) { | |||
| bool ok_type = | |||
| ((src_dtype == DTypeEnum::Int8 && filter_dtype == DTypeEnum::Int8 && | |||
| (dst_dtype == DTypeEnum::Int16))) && | |||
| (fm.format == param::Convolution::Format::NCHW44); | |||
| bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY; | |||
| bool ok_src_dst = | |||
| fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | |||
| bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
| (fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||
| fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
| fm.stride[0] == fm.stride[1] && | |||
| (fm.stride[0] == 2 || fm.stride[0] == 1); | |||
| bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; | |||
| bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && | |||
| ok_slide && ok_conv; | |||
| return avaible; | |||
| } | |||
| template <> | |||
| inline bool nchw_nchwxx_valid<NCHW44_INT8_DOT>( | |||
| const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
| const DTypeEnum dst_dtype, | |||
| const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
| const BiasMode bias_mode, | |||
| const param::ConvBias::NonlineMode nonline_mode) { | |||
| bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 && | |||
| filter_dtype == DTypeEnum::QuantizedS8 && | |||
| (dst_dtype == DTypeEnum::QuantizedS8))) && | |||
| (fm.format == param::Convolution::Format::NCHW44_DOT); | |||
| bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY || | |||
| nonline_mode == param::ConvBias::NonlineMode::RELU || | |||
| nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | |||
| bool ok_src_dst = | |||
| fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | |||
| bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
| (fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||
| fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
| fm.stride[0] == fm.stride[1] && | |||
| (fm.stride[0] == 1 || fm.stride[1] == 2); | |||
| bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; | |||
| bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && | |||
| ok_slide && ok_conv; | |||
| return avaible; | |||
| } | |||
| template <> | |||
| inline bool nchw_nchwxx_valid<NCHW88>( | |||
| const DTypeEnum src_dtype, const DTypeEnum filter_dtype, | |||
| const DTypeEnum dst_dtype, | |||
| const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm, | |||
| const BiasMode bias_mode, | |||
| const param::ConvBias::NonlineMode nonline_mode) { | |||
| bool ok_type = ((src_dtype == DTypeEnum::Float32 && | |||
| filter_dtype == DTypeEnum::Float32 && | |||
| (dst_dtype == DTypeEnum::Float32))) && | |||
| (fm.format == param::Convolution::Format::NCHW88); | |||
| bool ok_src_dst = | |||
| fm.icpg < 8 && (fm.ocpg % 8 == 0 && fm.ocpg >= 8) && fm.group == 1; | |||
| bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; | |||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1; | |||
| bool avaible = ok_type && ok_src_dst && ok_slide && ok_conv; | |||
| return avaible; | |||
| } | |||
| } // namespace | |||
| } // namespace megdnn | |||
| @@ -11,6 +11,7 @@ | |||
| */ | |||
| #pragma once | |||
| #include "src/common/nchw_nchwxx_valid.h" | |||
| #include "src/x86/conv_bias/opr_impl.h" | |||
| using namespace megdnn; | |||
| @@ -29,6 +30,7 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase { | |||
| const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index, | |||
| const CpuNDRange& workspace_ids); | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { | |||
| @@ -61,6 +63,7 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase { | |||
| const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index, | |||
| const CpuNDRange& workspace_ids); | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { | |||
| @@ -163,13 +166,19 @@ public: | |||
| AlgoSelectionStrategy) const override { | |||
| auto&& fm = param.filter_meta; | |||
| bool ok = (fm.format == param::ConvBias::Format::NCHW88) && | |||
| fm.spatial_ndim == 2 && | |||
| param.src_type.enumv() == DTypeEnum::Float32 && | |||
| param.filter_type.enumv() == DTypeEnum::Float32 && | |||
| param.dst_type.enumv() == DTypeEnum::Float32 && | |||
| fm.dilation[0] == 1 && fm.dilation[1] == 1; | |||
| return ok; | |||
| bool nchw_nchw88_ok = nchw_nchwxx_valid<NchwNchwxxType::NCHW88>( | |||
| param.src_type.enumv(), param.filter_type.enumv(), | |||
| param.dst_type.enumv(), param.filter_meta, param.bias_mode, | |||
| param.nonlineMode); | |||
| bool normal_conv_ok = (fm.format == param::ConvBias::Format::NCHW88) && | |||
| fm.spatial_ndim == 2 && | |||
| param.src_type.enumv() == DTypeEnum::Float32 && | |||
| param.filter_type.enumv() == DTypeEnum::Float32 && | |||
| param.dst_type.enumv() == DTypeEnum::Float32 && | |||
| fm.dilation[0] == 1 && fm.dilation[1] == 1; | |||
| return nchw_nchw88_ok || normal_conv_ok; | |||
| }; | |||
| size_t get_workspace(const NCBKernSizeParam&) const override { return 0; } | |||
| @@ -1816,155 +1816,67 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, | |||
| } | |||
| template <typename OprType> | |||
| static inline bool nchw_nchwxx_valid(const OprType& opr, | |||
| const VarNodeArray& new_inp, | |||
| const size_t pack_size, bool is_dense, | |||
| bool is_dot = false); | |||
| template <> | |||
| inline bool nchw_nchwxx_valid<opr::ConvolutionForward>( | |||
| const opr::ConvolutionForward& opr, const VarNodeArray& new_inp, | |||
| const size_t pack_size, bool is_dense, bool is_dot) { | |||
| auto& filter_shape = new_inp[1]->shape(); | |||
| auto filter_dtype = new_inp[1]->dtype(); | |||
| bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
| filter_dtype.enumv() == DTypeEnum::Int8; | |||
| const size_t oc = filter_shape[0]; | |||
| const size_t ic = filter_shape[1]; | |||
| bool is_like_nchw_nchwxx = | |||
| is_dense && oc % pack_size == 0 && ic < pack_size; | |||
| if (!is_like_nchw_nchwxx) { | |||
| static inline bool nchw_nchwxx_valid( | |||
| const OprType& opr, const VarNodeArray& new_inp, const size_t pack_size, | |||
| megdnn::param::ConvBias::NonlineMode nonline_mode = | |||
| megdnn::param::ConvBias::NonlineMode::IDENTITY, | |||
| bool is_dot = false) { | |||
| auto& src_node = new_inp[0]; | |||
| auto& filter_node = new_inp[1]; | |||
| auto dst_node = opr.output(0); | |||
| if (filter_node->shape().ndim != 4) { | |||
| return false; | |||
| } | |||
| SmallVector<TensorLayout> layouts; | |||
| //! src | |||
| layouts.push_back( | |||
| {new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); | |||
| //! weight | |||
| layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], | |||
| filter_shape[3], filter_shape[1], pack_size}, | |||
| new_inp[1]->dtype(), | |||
| new_inp[1]->format()}); | |||
| auto out0 = opr.output(0); | |||
| auto& out_shape = out0->shape(); | |||
| //! FIXME: return false if oc is invalid | |||
| layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], | |||
| out_shape[3], pack_size}, | |||
| out0->dtype(), | |||
| out0->format()}); | |||
| auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) | |||
| ->create_operator<megdnn::ConvolutionForward>(); | |||
| megdnn_conv.get()->param() = opr.param(); | |||
| //! set by dtype | |||
| switch (pack_size) { | |||
| case 4: | |||
| if (is_dot && is_int8) { | |||
| megdnn_conv.get()->param().format = | |||
| megdnn::param::Convolution::Format::NCHW44_DOT; | |||
| } else { | |||
| megdnn_conv.get()->param().format = | |||
| megdnn::param::Convolution::Format::NCHW44; | |||
| } | |||
| break; | |||
| case 8: | |||
| megdnn_conv.get()->param().format = | |||
| megdnn::param::Convolution::Format::NCHW88; | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| bool find_valid_algo = false; | |||
| auto algos = megdnn_conv.get()->get_all_algorithms(layouts[0], layouts[1], | |||
| layouts[2]); | |||
| for (auto i : algos) { | |||
| if (i->type() != nullptr) { | |||
| find_valid_algo = true; | |||
| megdnn::ConvolutionBase<megdnn::param::Convolution>::CanonizedFilterMeta fm; | |||
| fm.format = megdnn::param::Convolution::Format::NCHW; | |||
| fm.should_flip = | |||
| opr.param().mode == megdnn::ConvBiasForward::Mode::CONVOLUTION; | |||
| fm.group = 1; | |||
| fm.spatial_ndim = 2; | |||
| fm.ocpg = filter_node->shape()[0]; | |||
| fm.icpg = filter_node->shape()[1]; | |||
| fm.spatial[0] = filter_node->shape()[2]; | |||
| fm.spatial[1] = filter_node->shape()[3]; | |||
| fm.stride[0] = opr.param().stride_h; | |||
| fm.stride[1] = opr.param().stride_w; | |||
| fm.padding[0] = opr.param().pad_h; | |||
| fm.padding[1] = opr.param().pad_w; | |||
| fm.dilation[0] = opr.param().dilate_h; | |||
| fm.dilation[1] = opr.param().dilate_w; | |||
| megdnn::ConvBiasForward::BiasMode bias_mode = | |||
| megdnn::ConvBiasForward::BiasMode::NO_BIAS; | |||
| if (std::is_same<OprType, opr::ConvBiasForward>::value) { | |||
| auto& bias_shape = new_inp[2]->shape(); | |||
| if (bias_shape.ndim == 0) { | |||
| bias_mode = megdnn::ConvBiasForward::BiasMode::NO_BIAS; | |||
| } else if (bias_shape.eq_shape(dst_node->shape())) { | |||
| bias_mode = megdnn::ConvBiasForward::BiasMode::BIAS; | |||
| } else { | |||
| //! just check the ndim, the detail shape check is in check_exec | |||
| mgb_assert(bias_shape.ndim == dst_node->shape().ndim); | |||
| bias_mode = | |||
| megdnn::ConvBiasForward::BiasMode::BROADCAST_CHANNEL_BIAS; | |||
| } | |||
| } | |||
| return find_valid_algo; | |||
| } | |||
| template <> | |||
| inline bool nchw_nchwxx_valid<opr::ConvBiasForward>( | |||
| const opr::ConvBiasForward& opr, const VarNodeArray& new_inp, | |||
| const size_t pack_size, bool is_dense, bool is_dot) { | |||
| auto& filter_shape = new_inp[1]->shape(); | |||
| auto filter_dtype = new_inp[1]->dtype(); | |||
| bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
| filter_dtype.enumv() == DTypeEnum::Int8; | |||
| const size_t oc = filter_shape[0]; | |||
| const size_t ic = filter_shape[1]; | |||
| bool is_like_nchw_nchwxx = | |||
| is_dense && oc % pack_size == 0 && ic < pack_size; | |||
| if (!is_like_nchw_nchwxx) { | |||
| return false; | |||
| } | |||
| SmallVector<TensorLayout> layouts; | |||
| //! src | |||
| layouts.push_back( | |||
| {new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); | |||
| //! weight | |||
| layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], | |||
| filter_shape[3], filter_shape[1], pack_size}, | |||
| new_inp[1]->dtype(), | |||
| new_inp[1]->format()}); | |||
| auto& bias_shape = new_inp[2]->shape(); | |||
| layouts.push_back({{bias_shape[0], bias_shape[1] / pack_size, bias_shape[2], | |||
| bias_shape[3], pack_size}, | |||
| new_inp[2]->dtype(), | |||
| new_inp[2]->format()}); | |||
| auto out0 = opr.output(0); | |||
| auto& out_shape = out0->shape(); | |||
| //! FIXME: return false if oc is invalid | |||
| layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], | |||
| out_shape[3], pack_size}, | |||
| out0->dtype(), | |||
| out0->format()}); | |||
| // megdnn::ConvolutionForward | |||
| auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) | |||
| ->create_operator<megdnn::ConvBiasForward>(); | |||
| megdnn_conv.get()->param() = opr.param(); | |||
| //! FIXME: set by dtype | |||
| switch (pack_size) { | |||
| case 4: | |||
| if (is_dot && is_int8) { | |||
| megdnn_conv.get()->param().format = | |||
| megdnn::param::Convolution::Format::NCHW44_DOT; | |||
| } else { | |||
| megdnn_conv.get()->param().format = | |||
| megdnn::param::Convolution::Format::NCHW44; | |||
| } | |||
| break; | |||
| case 8: | |||
| megdnn_conv.get()->param().format = | |||
| megdnn::param::Convolution::Format::NCHW88; | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| bool find_valid_algo = false; | |||
| auto algos = megdnn_conv.get()->get_all_algorithms( | |||
| layouts[0], layouts[1], layouts[2], {}, layouts[3]); | |||
| for (auto i : algos) { | |||
| if (i->type() != nullptr) { | |||
| find_valid_algo = true; | |||
| if (pack_size == 4) { | |||
| if (is_dot && filter_node->dtype().enumv() == DTypeEnum::QuantizedS8) { | |||
| fm.format = megdnn::param::Convolution::Format::NCHW44_DOT; | |||
| } else { | |||
| fm.format = megdnn::param::Convolution::Format::NCHW44; | |||
| } | |||
| } else if (pack_size == 8) { | |||
| fm.format = megdnn::param::Convolution::Format::NCHW88; | |||
| } else { | |||
| mgb_assert(0, "only support nchw44 nchw88"); | |||
| } | |||
| return find_valid_algo; | |||
| return megdnn::ConvBiasForward::is_nchw_nchwxx_optimized( | |||
| src_node->dtype().enumv(), filter_node->dtype().enumv(), | |||
| dst_node->dtype().enumv(), fm, bias_mode, nonline_mode); | |||
| } | |||
| void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
| using RelayoutMode = RelayoutPlaceholder::LayoutType; | |||
| using TestFilterResult = std::pair<TransType, RelayoutMode>; | |||
| @@ -1984,19 +1896,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
| megdnn::param::Pooling::Format pooling_format = | |||
| megdnn::param::Pooling::Format::NCHW88; | |||
| std::string convter_pass_name = "conv_format_nchw88"; | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMv7 | |||
| if (pack_c_size == 8) { | |||
| mgb_log_error( | |||
| "runtime backend is ARM, but nchw88 only support X86, you may " | |||
| "have performance loss\n"); | |||
| } | |||
| #elif MEGDNN_X86 | |||
| if (pack_c_size == 4) { | |||
| mgb_log_error( | |||
| "runtime backend is X86, but nchw44 only support arm, you may " | |||
| "have performance loss\n"); | |||
| } | |||
| #endif | |||
| if (pack_c_size == 4) { | |||
| weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; | |||
| @@ -2053,10 +1952,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
| mgb_assert(conv_opr.param().format == | |||
| megdnn::param::Convolution::Format::NCHW, | |||
| "ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
| bool is_dense = conv_opr.param().sparse == | |||
| megdnn::param::Convolution::Sparse::DENSE; | |||
| bool valid_nchw_nchw44 = | |||
| nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); | |||
| nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size); | |||
| auto is_trans = test_trans_nchwxx( | |||
| conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, | |||
| conv_opr.param().stride_w, valid_nchw_nchw44); | |||
| @@ -2133,10 +2030,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
| mgb_assert(conv_bias_opr.param().format == | |||
| megdnn::param::ConvBias::Format::NCHW, | |||
| "ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
| bool is_dense = conv_bias_opr.param().sparse == | |||
| megdnn::param::Convolution::Sparse::DENSE; | |||
| bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, | |||
| pack_c_size, is_dense); | |||
| bool valid_nchw_nchw44 = | |||
| nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, | |||
| conv_bias_opr.param().nonlineMode); | |||
| auto is_trans = test_trans_nchwxx( | |||
| conv_bias_opr.param().sparse, new_inp[1], | |||
| conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, | |||
| @@ -2371,13 +2267,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
| MIDOUT_B("EnableNchw44DotPass::make") | |||
| auto ret = std::make_unique<EnableNchw44DotPass>(); | |||
| ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | |||
| //! First is whether the conv can trans to nchwxx, second is the filter | |||
| //! trans mode | |||
| #if MEGDNN_X86 | |||
| mgb_log_error( | |||
| "backend is X86, but nchw44_dot only support arm, you may have " | |||
| "performance loss\n"); | |||
| #endif | |||
| //! First is whether the conv can trans to nchwxx, second is the filter | |||
| //! trans mode | |||
| using RelayoutMode = RelayoutPlaceholder::LayoutType; | |||
| struct TestTransResult { | |||
| @@ -2453,14 +2344,12 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
| megdnn::param::Convolution::Format::NCHW, | |||
| "ConvertFormat Pass only support converting NCHW to " | |||
| "NCHW44_DOT"); | |||
| bool is_dense = conv_opr.param().sparse == | |||
| megdnn::param::Convolution::Sparse::DENSE; | |||
| bool valid_nchw_nchw44 = | |||
| nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); | |||
| bool valid_nchw_nchw44 = nchw_nchwxx_valid( | |||
| conv_opr, new_inp, pack_c_size, | |||
| megdnn::param::ConvBias::NonlineMode::IDENTITY, true); | |||
| auto is_trans = test_trans_nchw44_dot( | |||
| conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, | |||
| conv_opr.param().stride_w, valid_nchw_nchw44); | |||
| //! can not trans to nchwxx | |||
| if (is_trans.trans_type == TransType::TRANS_NONE) { | |||
| mgb_assert(new_inp[1]->shape().ndim == 4 || | |||
| @@ -2533,10 +2422,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
| mgb_assert(conv_bias_opr.param().format == | |||
| megdnn::param::ConvBias::Format::NCHW, | |||
| "ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
| bool is_dense = conv_bias_opr.param().sparse == | |||
| megdnn::param::Convolution::Sparse::DENSE; | |||
| bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, | |||
| pack_c_size, is_dense); | |||
| bool valid_nchw_nchw44 = | |||
| nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, | |||
| conv_bias_opr.param().nonlineMode, true); | |||
| auto is_trans = test_trans_nchw44_dot( | |||
| conv_bias_opr.param().sparse, new_inp[1], | |||
| conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, | |||
| @@ -2913,7 +2913,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { | |||
| opr::Convolution::Param param_conv; | |||
| param_conv.pad_h = param_conv.pad_w = 1; | |||
| auto w1 = mkcvar("w1", {8, 3, 3, 3}), | |||
| conv1 = opr::Convolution::make(x, w1, param_conv); | |||
| conv1 = opr::Convolution::make(x, w1, param_conv, {}, | |||
| OperatorNodeConfig("conv1")); | |||
| //! channel wise | |||
| opr::ConvBias::Param param_conv_bias; | |||
| param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||
| @@ -2954,7 +2955,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { | |||
| options.enable_nchw88(); | |||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
| } | |||
| ASSERT_EQ(opr::ConvBias::Param::Format::NCHW88, | |||
| find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
| ASSERT_EQ(opr::ConvBias::Param::Format::NCHW88, | |||
| find_opr<opr::ConvBias>(y_opt).param().format); | |||
| @@ -3084,13 +3086,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||
| options.enable_nchw44(); | |||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
| find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
| #else | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
| find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
| #endif | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
| find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format); | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
| @@ -3325,17 +3322,10 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||
| options.enable_nchw44_dot(); | |||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
| find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, | |||
| find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format); | |||
| #else | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
| find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
| find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format); | |||
| #endif | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
| find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format); | |||
| ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, | |||
| @@ -611,11 +611,11 @@ public: | |||
| "%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s " | |||
| "workspace=%.2fMiB reproducible=%d", | |||
| mgb_opr->dyn_typeinfo()->name, | |||
| layouts[0].TensorShape::to_string().c_str(), | |||
| layouts[0].to_string().c_str(), | |||
| layouts[0].dtype.name(), | |||
| layouts[1].TensorShape::to_string().c_str(), | |||
| layouts[1].to_string().c_str(), | |||
| layouts[1].dtype.name(), | |||
| layouts[layouts.size() - 1].TensorShape::to_string().c_str(), | |||
| layouts[layouts.size() - 1].to_string().c_str(), | |||
| layouts[layouts.size() - 1].dtype.name(), | |||
| algo->name(), | |||
| workspace / (1024 * 1024.0), algo->is_reproducible()); | |||