GitOrigin-RevId: adc2301203
tags/v1.7.0
| @@ -121,7 +121,10 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( | |||
| auto&& param = args.opr->param(); | |||
| bool is_format_ok = param.format == param::ConvBias::Format::NCHW; | |||
| bool is_version_ok = CUDNN_VERSION >= 7500; | |||
| bool is_dtype_ok = args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8; | |||
| bool is_dtype_ok = | |||
| (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| (args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 || | |||
| args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm)); | |||
| bool is_bias_ok = | |||
| args.bias_layout->ndim == 0 || | |||
| (args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 && | |||
| @@ -31,6 +31,11 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( | |||
| } | |||
| } | |||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||
| args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
| return false; | |||
| } | |||
| // FIXME: cudnn cannot handle the case when the initial value of dst tensor | |||
| // contains nan and beta is zero, because the result of 0.f * nan is still | |||
| // nan | |||
| @@ -24,6 +24,11 @@ bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available( | |||
| if (!is_compute_capability_required(6, 1)) | |||
| return false; | |||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm || | |||
| args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { | |||
| return false; | |||
| } | |||
| auto dst_layout = *args.dst_layout; | |||
| if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { | |||
| dst_layout.dtype = DType(); | |||
| @@ -0,0 +1,151 @@ | |||
| /** | |||
| * \file src/gopt/impl/folding_conv_typecvt.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/gopt/inference.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megbrain/utils/hash_ct.h" | |||
| #include "midout.h" | |||
| #include "megbrain/gopt/reformat_manager.h" | |||
| #if CUDA_VERSION >= 10020 | |||
| MIDOUT_DECL(megbrain_folding_conv_typecvt) | |||
| #define MIDOUT_B(tag) \ | |||
| MIDOUT_BEGIN(megbrain_folding_conv_typecvt, midout_iv(MGB_HASH_STR(tag))) { | |||
| #define MIDOUT_E \ | |||
| } \ | |||
| MIDOUT_END(); | |||
| using namespace mgb; | |||
| using namespace gopt; | |||
| using ReformatKey = ReformatManager::ReformatKey; | |||
| /* ==================== FoldingConvBiasTypecvtPass ================= */ | |||
| const char* FoldingConvBiasTypecvtPass::name() const { | |||
| return mgb_cstr_log("folding conv bias typecvt pass"); | |||
| } | |||
| void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { | |||
| MIDOUT_B("FoldingConvBiasTypecvtPass::apply"); | |||
| using DepType = cg::OperatorNodeProp::DepType; | |||
| ThinHashMap<OperatorNodeBase*, | |||
| SmallVector<std::pair<OperatorNodeBase*, DepType>>> | |||
| readers; | |||
| static const ThinHashSet<Typeinfo*> opr_type_list = { | |||
| opr::TypeCvt::typeinfo(), opr::ConvBias::typeinfo()}; | |||
| opt.graph().iter([&readers](OperatorNodeBase* opr) { | |||
| for (auto&& i : opr->node_prop().dep_map()) { | |||
| if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { | |||
| readers[i.first->owner_opr()].emplace_back(opr, i.second); | |||
| } | |||
| } | |||
| }); | |||
| auto rewriter = opt.graph().make_rewriter(); | |||
| auto try_conv_typecvt = [&rewriter, &readers](OperatorNodeBase* opr) { | |||
| ThinHashSet<OperatorNodeBase*> opr_set; | |||
| ThinHashSet<OperatorNodeBase*> reader_set; | |||
| // check typecvt | |||
| auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); | |||
| if (typecvt == nullptr) | |||
| return false; | |||
| auto inp_dtype_typecvt = typecvt->input(0)->dtype(), | |||
| out_dtype_typecvt = typecvt->output(0)->dtype(); | |||
| bool is_s82f32 = inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 && | |||
| out_dtype_typecvt.enumv() == DTypeEnum::Float32; | |||
| bool is_s82s4 = | |||
| inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 && | |||
| (out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 || | |||
| out_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm); | |||
| bool is_s42s8 = | |||
| (inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 || | |||
| inp_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm) && | |||
| out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8; | |||
| if (!(is_s82f32 || is_s82s4 || is_s42s8)) | |||
| return false; | |||
| opr_set.insert(opr); | |||
| // check conv bias | |||
| auto conv_bias = | |||
| try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); | |||
| if (conv_bias == nullptr) | |||
| return false; | |||
| auto inp_dtype_conv = conv_bias->input(0)->dtype(), | |||
| out_dtype_conv = conv_bias->input(0)->dtype(); | |||
| bool is_s8nhwc = inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && | |||
| out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | |||
| conv_bias->param().format == | |||
| megdnn::param::ConvBias::Format::NHWC; | |||
| bool is_s4nhwc = | |||
| (inp_dtype_conv.enumv() == DTypeEnum::QuantizedS4 || | |||
| inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) && | |||
| out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | |||
| conv_bias->param().format == | |||
| megdnn::param::ConvBias::Format::NHWC; | |||
| if (!(is_s8nhwc || is_s4nhwc)) | |||
| return false; | |||
| if (conv_bias->input().size() != 3) | |||
| return false; | |||
| opr_set.insert(conv_bias); | |||
| for (auto&& i : readers[conv_bias]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| for (auto reader : reader_set) { | |||
| if (opr_set.count(reader) <= 0) { | |||
| return false; | |||
| } | |||
| } | |||
| auto src = rewriter.get_var(conv_bias->input(0)), | |||
| filter = rewriter.get_var(conv_bias->input(1)), | |||
| bias = rewriter.get_var(conv_bias->input(2)); | |||
| auto new_bias = | |||
| (out_dtype_typecvt.enumv() == DTypeEnum::Float32) | |||
| ? opr::TypeCvt::make(bias, dtype::Float32()).node() | |||
| : bias; | |||
| auto new_param = conv_bias->param(); | |||
| new_param.format = megdnn::param::ConvBias::Format::NHWC; | |||
| auto conv_bias_typecvt = opr::ConvBias::make( | |||
| src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||
| OperatorNodeConfig{out_dtype_typecvt}); | |||
| rewriter.replace_var(opr->output(0), conv_bias_typecvt.node(), | |||
| mgb_cstr_log("replace conv_bias(NHWC) + typecvt " | |||
| "to conv_bias(NHWC)")); | |||
| return true; | |||
| }; | |||
| auto on_opr = [&try_conv_typecvt, &rewriter](OperatorNodeBase* opr) { | |||
| if (!try_conv_typecvt(opr)) { | |||
| rewriter.auto_replace_outputs(opr); | |||
| } | |||
| }; | |||
| opt.graph().iter(on_opr); | |||
| rewriter.apply_inplace(); | |||
| MIDOUT_E | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -835,6 +835,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( | |||
| add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||
| #if CUDA_VERSION >= 10020 | |||
| add_pass<FoldingConvBiasDimshufflePass>(); | |||
| add_pass<FoldingConvBiasTypecvtPass>(); | |||
| #endif | |||
| }); | |||
| #undef cb | |||
| @@ -57,7 +57,10 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||
| TensorFormats::NCHW, TensorFormats::NHWC, | |||
| TensorFormats::NCHWc4, TensorFormats::NCHWc32, | |||
| TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | |||
| Attribute attribute = {base_opr_format, base_tensor_format, Target::CUDA}; | |||
| Attribute attribute = { | |||
| base_opr_format, base_tensor_format, Target::CUDA, | |||
| LayoutTransformContext::ReformatAttribute::AUTO_PADDING_NHWC}; | |||
| auto ctx = std::make_unique<LayoutTransformContext>( | |||
| std::move(opr_list), std::move(available_tensor_formats), | |||
| attribute); | |||
| @@ -67,8 +70,9 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||
| OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | |||
| .add_opr_config(opr::ConvolutionForward::typeinfo(), | |||
| {OprFormat::NCHW, OprFormat::NCHW4}) | |||
| .add_opr_config(opr::ConvolutionBackwardData::typeinfo(), | |||
| {OprFormat::NCHW, OprFormat::NCHW4}) | |||
| .add_opr_config( | |||
| opr::ConvolutionBackwardData::typeinfo(), | |||
| {OprFormat::NCHW, OprFormat::NCHW4, OprFormat::NHWC}) | |||
| .add_opr_config( | |||
| opr::PoolingForward::typeinfo(), | |||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | |||
| @@ -512,7 +512,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, | |||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||
| OprTensorFormatsConfiguration config; | |||
| config.typeinfo = opr->dyn_typeinfo(); | |||
| config.opr_format = OprFormat::NCHW4; | |||
| config.opr_format = OprFormat::NHWC; | |||
| bool available = true; | |||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||
| available &= | |||
| @@ -481,6 +481,12 @@ namespace gopt { | |||
| const char* name() const override; | |||
| void apply(OptState& opt) const override; | |||
| }; | |||
| class FoldingConvBiasTypecvtPass final : public Pass { | |||
| public: | |||
| const char* name() const override; | |||
| void apply(OptState& opt) const override; | |||
| }; | |||
| #endif | |||
| /*! | |||
| @@ -585,6 +585,7 @@ TEST(TestLayoutTransform, DetectionHead) { | |||
| using OprFormat = LayoutTransformContext::OprFormat; | |||
| using OprList = LayoutTransformContext::OprList; | |||
| using Attribute = LayoutTransformContext::Attribute; | |||
| using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | |||
| using Target = LayoutTransformContext::Target; | |||
| OprList opr_list = { | |||
| opr::ConvBiasForward::typeinfo(), | |||
| @@ -600,8 +601,8 @@ TEST(TestLayoutTransform, DetectionHead) { | |||
| TensorFormats::NCHW, TensorFormats::NHWC, | |||
| TensorFormats::NCHWc4, TensorFormats::NCHWc32, | |||
| TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | |||
| Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, | |||
| Target::UNSPEC}; | |||
| Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||
| ReformatAttribute::AUTO_PADDING_NHWC}; | |||
| auto ctx = std::make_unique<LayoutTransformContext>( | |||
| std::move(opr_list), std::move(available_tensor_formats), | |||
| attribute); | |||
| @@ -611,8 +612,9 @@ TEST(TestLayoutTransform, DetectionHead) { | |||
| OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | |||
| .add_opr_config(opr::ConvolutionForward::typeinfo(), | |||
| {OprFormat::NCHW, OprFormat::NCHW4}) | |||
| .add_opr_config(opr::ConvolutionBackwardData::typeinfo(), | |||
| {OprFormat::NCHW, OprFormat::NCHW4}) | |||
| .add_opr_config( | |||
| opr::ConvolutionBackwardData::typeinfo(), | |||
| {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4}) | |||
| .add_opr_config( | |||
| opr::PoolingForward::typeinfo(), | |||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | |||
| @@ -630,6 +632,7 @@ TEST(TestLayoutTransform, DetectionHead) { | |||
| .add_pass<ShuffleShuffleRemovePass>() | |||
| .add_pass(FuseNCHW4Int8Preprocess::make()) | |||
| .add_pass<FoldingConvBiasDimshufflePass>() | |||
| .add_pass<FoldingConvBiasTypecvtPass>() | |||
| .add_pass<ParamFusePass>() | |||
| .add_pass<ParamMergePass>() | |||
| .apply(SymbolVarArray{y}) | |||
| @@ -656,7 +659,8 @@ TEST(TestLayoutTransform, DetectionHead) { | |||
| /// check first conv format | |||
| const auto& first_conv = find_opr<opr::ConvBiasForward>(v); | |||
| const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>(); | |||
| ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW4_NHWC); | |||
| ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NHWC); | |||
| ASSERT_EQ(cast.output()[0]->dtype().enumv(), DTypeEnum::Quantized4Asymm); | |||
| } | |||
| #endif | |||
| #endif | |||