GitOrigin-RevId: adc2301203
tags/v1.7.0
| @@ -121,7 +121,10 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( | |||||
| auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
| bool is_format_ok = param.format == param::ConvBias::Format::NCHW; | bool is_format_ok = param.format == param::ConvBias::Format::NCHW; | ||||
| bool is_version_ok = CUDNN_VERSION >= 7500; | bool is_version_ok = CUDNN_VERSION >= 7500; | ||||
| bool is_dtype_ok = args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8; | |||||
| bool is_dtype_ok = | |||||
| (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| (args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 || | |||||
| args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm)); | |||||
| bool is_bias_ok = | bool is_bias_ok = | ||||
| args.bias_layout->ndim == 0 || | args.bias_layout->ndim == 0 || | ||||
| (args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 && | (args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 && | ||||
| @@ -31,6 +31,11 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( | |||||
| } | } | ||||
| } | } | ||||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
| return false; | |||||
| } | |||||
| // FIXME: cudnn cannot handle the case when the initial value of dst tensor | // FIXME: cudnn cannot handle the case when the initial value of dst tensor | ||||
| // contains nan and beta is zero, because the result of 0.f * nan is still | // contains nan and beta is zero, because the result of 0.f * nan is still | ||||
| // nan | // nan | ||||
| @@ -24,6 +24,11 @@ bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available( | |||||
| if (!is_compute_capability_required(6, 1)) | if (!is_compute_capability_required(6, 1)) | ||||
| return false; | return false; | ||||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm || | |||||
| args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { | |||||
| return false; | |||||
| } | |||||
| auto dst_layout = *args.dst_layout; | auto dst_layout = *args.dst_layout; | ||||
| if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| dst_layout.dtype = DType(); | dst_layout.dtype = DType(); | ||||
| @@ -0,0 +1,151 @@ | |||||
| /** | |||||
| * \file src/gopt/impl/folding_conv_typecvt.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megbrain/gopt/inference.h" | |||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/opr/dnn/convolution.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| #include "megbrain/opr/utility.h" | |||||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||||
| #include "megdnn/opr_param_defs.h" | |||||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "midout.h" | |||||
| #include "megbrain/gopt/reformat_manager.h" | |||||
| #if CUDA_VERSION >= 10020 | |||||
| MIDOUT_DECL(megbrain_folding_conv_typecvt) | |||||
| #define MIDOUT_B(tag) \ | |||||
| MIDOUT_BEGIN(megbrain_folding_conv_typecvt, midout_iv(MGB_HASH_STR(tag))) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| using namespace mgb; | |||||
| using namespace gopt; | |||||
| using ReformatKey = ReformatManager::ReformatKey; | |||||
| /* ==================== FoldingConvBiasTypecvtPass ================= */ | |||||
| const char* FoldingConvBiasTypecvtPass::name() const { | |||||
| return mgb_cstr_log("folding conv bias typecvt pass"); | |||||
| } | |||||
| void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { | |||||
| MIDOUT_B("FoldingConvBiasTypecvtPass::apply"); | |||||
| using DepType = cg::OperatorNodeProp::DepType; | |||||
| ThinHashMap<OperatorNodeBase*, | |||||
| SmallVector<std::pair<OperatorNodeBase*, DepType>>> | |||||
| readers; | |||||
| static const ThinHashSet<Typeinfo*> opr_type_list = { | |||||
| opr::TypeCvt::typeinfo(), opr::ConvBias::typeinfo()}; | |||||
| opt.graph().iter([&readers](OperatorNodeBase* opr) { | |||||
| for (auto&& i : opr->node_prop().dep_map()) { | |||||
| if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { | |||||
| readers[i.first->owner_opr()].emplace_back(opr, i.second); | |||||
| } | |||||
| } | |||||
| }); | |||||
| auto rewriter = opt.graph().make_rewriter(); | |||||
| auto try_conv_typecvt = [&rewriter, &readers](OperatorNodeBase* opr) { | |||||
| ThinHashSet<OperatorNodeBase*> opr_set; | |||||
| ThinHashSet<OperatorNodeBase*> reader_set; | |||||
| // check typecvt | |||||
| auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); | |||||
| if (typecvt == nullptr) | |||||
| return false; | |||||
| auto inp_dtype_typecvt = typecvt->input(0)->dtype(), | |||||
| out_dtype_typecvt = typecvt->output(0)->dtype(); | |||||
| bool is_s82f32 = inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 && | |||||
| out_dtype_typecvt.enumv() == DTypeEnum::Float32; | |||||
| bool is_s82s4 = | |||||
| inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 && | |||||
| (out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 || | |||||
| out_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm); | |||||
| bool is_s42s8 = | |||||
| (inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 || | |||||
| inp_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm) && | |||||
| out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8; | |||||
| if (!(is_s82f32 || is_s82s4 || is_s42s8)) | |||||
| return false; | |||||
| opr_set.insert(opr); | |||||
| // check conv bias | |||||
| auto conv_bias = | |||||
| try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); | |||||
| if (conv_bias == nullptr) | |||||
| return false; | |||||
| auto inp_dtype_conv = conv_bias->input(0)->dtype(), | |||||
| out_dtype_conv = conv_bias->input(0)->dtype(); | |||||
| bool is_s8nhwc = inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && | |||||
| out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | |||||
| conv_bias->param().format == | |||||
| megdnn::param::ConvBias::Format::NHWC; | |||||
| bool is_s4nhwc = | |||||
| (inp_dtype_conv.enumv() == DTypeEnum::QuantizedS4 || | |||||
| inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) && | |||||
| out_dtype_conv.enumv() == inp_dtype_conv.enumv() && | |||||
| conv_bias->param().format == | |||||
| megdnn::param::ConvBias::Format::NHWC; | |||||
| if (!(is_s8nhwc || is_s4nhwc)) | |||||
| return false; | |||||
| if (conv_bias->input().size() != 3) | |||||
| return false; | |||||
| opr_set.insert(conv_bias); | |||||
| for (auto&& i : readers[conv_bias]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| for (auto reader : reader_set) { | |||||
| if (opr_set.count(reader) <= 0) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto src = rewriter.get_var(conv_bias->input(0)), | |||||
| filter = rewriter.get_var(conv_bias->input(1)), | |||||
| bias = rewriter.get_var(conv_bias->input(2)); | |||||
| auto new_bias = | |||||
| (out_dtype_typecvt.enumv() == DTypeEnum::Float32) | |||||
| ? opr::TypeCvt::make(bias, dtype::Float32()).node() | |||||
| : bias; | |||||
| auto new_param = conv_bias->param(); | |||||
| new_param.format = megdnn::param::ConvBias::Format::NHWC; | |||||
| auto conv_bias_typecvt = opr::ConvBias::make( | |||||
| src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
| OperatorNodeConfig{out_dtype_typecvt}); | |||||
| rewriter.replace_var(opr->output(0), conv_bias_typecvt.node(), | |||||
| mgb_cstr_log("replace conv_bias(NHWC) + typecvt " | |||||
| "to conv_bias(NHWC)")); | |||||
| return true; | |||||
| }; | |||||
| auto on_opr = [&try_conv_typecvt, &rewriter](OperatorNodeBase* opr) { | |||||
| if (!try_conv_typecvt(opr)) { | |||||
| rewriter.auto_replace_outputs(opr); | |||||
| } | |||||
| }; | |||||
| opt.graph().iter(on_opr); | |||||
| rewriter.apply_inplace(); | |||||
| MIDOUT_E | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -835,6 +835,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( | |||||
| add_pass<FuseWarpPerspectiveDimshufflePass>(); | add_pass<FuseWarpPerspectiveDimshufflePass>(); | ||||
| #if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
| add_pass<FoldingConvBiasDimshufflePass>(); | add_pass<FoldingConvBiasDimshufflePass>(); | ||||
| add_pass<FoldingConvBiasTypecvtPass>(); | |||||
| #endif | #endif | ||||
| }); | }); | ||||
| #undef cb | #undef cb | ||||
| @@ -57,7 +57,10 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||||
| TensorFormats::NCHW, TensorFormats::NHWC, | TensorFormats::NCHW, TensorFormats::NHWC, | ||||
| TensorFormats::NCHWc4, TensorFormats::NCHWc32, | TensorFormats::NCHWc4, TensorFormats::NCHWc32, | ||||
| TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | ||||
| Attribute attribute = {base_opr_format, base_tensor_format, Target::CUDA}; | |||||
| Attribute attribute = { | |||||
| base_opr_format, base_tensor_format, Target::CUDA, | |||||
| LayoutTransformContext::ReformatAttribute::AUTO_PADDING_NHWC}; | |||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), | std::move(opr_list), std::move(available_tensor_formats), | ||||
| attribute); | attribute); | ||||
| @@ -67,8 +70,9 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||||
| OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | ||||
| .add_opr_config(opr::ConvolutionForward::typeinfo(), | .add_opr_config(opr::ConvolutionForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NCHW4}) | {OprFormat::NCHW, OprFormat::NCHW4}) | ||||
| .add_opr_config(opr::ConvolutionBackwardData::typeinfo(), | |||||
| {OprFormat::NCHW, OprFormat::NCHW4}) | |||||
| .add_opr_config( | |||||
| opr::ConvolutionBackwardData::typeinfo(), | |||||
| {OprFormat::NCHW, OprFormat::NCHW4, OprFormat::NHWC}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::PoolingForward::typeinfo(), | opr::PoolingForward::typeinfo(), | ||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | ||||
| @@ -512,7 +512,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW4; | |||||
| config.opr_format = OprFormat::NHWC; | |||||
| bool available = true; | bool available = true; | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| available &= | available &= | ||||
| @@ -481,6 +481,12 @@ namespace gopt { | |||||
| const char* name() const override; | const char* name() const override; | ||||
| void apply(OptState& opt) const override; | void apply(OptState& opt) const override; | ||||
| }; | }; | ||||
| class FoldingConvBiasTypecvtPass final : public Pass { | |||||
| public: | |||||
| const char* name() const override; | |||||
| void apply(OptState& opt) const override; | |||||
| }; | |||||
| #endif | #endif | ||||
| /*! | /*! | ||||
| @@ -585,6 +585,7 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
| using OprFormat = LayoutTransformContext::OprFormat; | using OprFormat = LayoutTransformContext::OprFormat; | ||||
| using OprList = LayoutTransformContext::OprList; | using OprList = LayoutTransformContext::OprList; | ||||
| using Attribute = LayoutTransformContext::Attribute; | using Attribute = LayoutTransformContext::Attribute; | ||||
| using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | |||||
| using Target = LayoutTransformContext::Target; | using Target = LayoutTransformContext::Target; | ||||
| OprList opr_list = { | OprList opr_list = { | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| @@ -600,8 +601,8 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
| TensorFormats::NCHW, TensorFormats::NHWC, | TensorFormats::NCHW, TensorFormats::NHWC, | ||||
| TensorFormats::NCHWc4, TensorFormats::NCHWc32, | TensorFormats::NCHWc4, TensorFormats::NCHWc32, | ||||
| TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | ||||
| Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, | |||||
| Target::UNSPEC}; | |||||
| Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| ReformatAttribute::AUTO_PADDING_NHWC}; | |||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), | std::move(opr_list), std::move(available_tensor_formats), | ||||
| attribute); | attribute); | ||||
| @@ -611,8 +612,9 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
| OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) | ||||
| .add_opr_config(opr::ConvolutionForward::typeinfo(), | .add_opr_config(opr::ConvolutionForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NCHW4}) | {OprFormat::NCHW, OprFormat::NCHW4}) | ||||
| .add_opr_config(opr::ConvolutionBackwardData::typeinfo(), | |||||
| {OprFormat::NCHW, OprFormat::NCHW4}) | |||||
| .add_opr_config( | |||||
| opr::ConvolutionBackwardData::typeinfo(), | |||||
| {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::PoolingForward::typeinfo(), | opr::PoolingForward::typeinfo(), | ||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | ||||
| @@ -630,6 +632,7 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
| .add_pass<ShuffleShuffleRemovePass>() | .add_pass<ShuffleShuffleRemovePass>() | ||||
| .add_pass(FuseNCHW4Int8Preprocess::make()) | .add_pass(FuseNCHW4Int8Preprocess::make()) | ||||
| .add_pass<FoldingConvBiasDimshufflePass>() | .add_pass<FoldingConvBiasDimshufflePass>() | ||||
| .add_pass<FoldingConvBiasTypecvtPass>() | |||||
| .add_pass<ParamFusePass>() | .add_pass<ParamFusePass>() | ||||
| .add_pass<ParamMergePass>() | .add_pass<ParamMergePass>() | ||||
| .apply(SymbolVarArray{y}) | .apply(SymbolVarArray{y}) | ||||
| @@ -656,7 +659,8 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
| /// check first conv format | /// check first conv format | ||||
| const auto& first_conv = find_opr<opr::ConvBiasForward>(v); | const auto& first_conv = find_opr<opr::ConvBiasForward>(v); | ||||
| const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>(); | const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>(); | ||||
| ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW4_NHWC); | |||||
| ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NHWC); | |||||
| ASSERT_EQ(cast.output()[0]->dtype().enumv(), DTypeEnum::Quantized4Asymm); | |||||
| } | } | ||||
| #endif | #endif | ||||
| #endif | #endif | ||||