GitOrigin-RevId: db50b33c11
tags/v1.7.0
| @@ -820,23 +820,26 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( | |||||
| _passes need_param_fuse = true; \ | _passes need_param_fuse = true; \ | ||||
| } | } | ||||
| using Target = GraphTuningOptions::Target; | |||||
| cb(layout_transform, { | cb(layout_transform, { | ||||
| add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
| add_pass<FuseConvBiasZPass>(); | |||||
| if (options.target == Target::CUDA) | |||||
| add_pass<FuseConvBiasZPass>(); | |||||
| add_pass(LayoutTransformPass::make(options.target)); | add_pass(LayoutTransformPass::make(options.target)); | ||||
| add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
| add_pass(FuseNCHW4Int8Preprocess::make()); | |||||
| add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||||
| if (options.target == Target::CUDA) { | |||||
| add_pass(FuseNCHW4Int8Preprocess::make()); | |||||
| add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||||
| #if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
| add_pass<FoldingConvBiasDimshufflePass>(); | |||||
| add_pass<FoldingConvBiasTypecvtPass>(); | |||||
| add_pass<FoldingConvBiasDimshufflePass>(); | |||||
| add_pass<FoldingConvBiasTypecvtPass>(); | |||||
| #endif | #endif | ||||
| } | |||||
| }); | }); | ||||
| #undef cb | #undef cb | ||||
| if (need_param_fuse) { | if (need_param_fuse) { | ||||
| add_pass<ParamFusePass>(); | add_pass<ParamFusePass>(); | ||||
| add_pass<ParamMergePass>(); | |||||
| } | } | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -15,6 +15,7 @@ | |||||
| #include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
| #include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
| #include "megbrain/opr/nn_int.h" | #include "megbrain/opr/nn_int.h" | ||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| @@ -82,6 +83,44 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||||
| {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); | {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); | ||||
| return ctx; | return ctx; | ||||
| } | } | ||||
| std::unique_ptr<LayoutTransformContext> make_arm_ctx( | |||||
| OprFormat base_opr_format, TensorFormats base_tensor_format) { | |||||
| OprList opr_list = { | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| opr::ElemwiseMultiType::typeinfo(), | |||||
| opr::Elemwise::typeinfo(), | |||||
| opr::TypeCvt::typeinfo(), | |||||
| opr::PoolingForward::typeinfo(), | |||||
| opr::Resize::typeinfo(), | |||||
| opr::PowC::typeinfo(), | |||||
| opr::Concat::typeinfo(), | |||||
| }; | |||||
| SmallVector<TensorFormats> available_tensor_formats = { | |||||
| TensorFormats::NCHW, TensorFormats::NCHWc4, | |||||
| DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; | |||||
| Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM}; | |||||
| auto ctx = std::make_unique<LayoutTransformContext>( | |||||
| std::move(opr_list), std::move(available_tensor_formats), | |||||
| attribute); | |||||
| ctx->add_opr_config( | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| {OprFormat::NCHW, OprFormat::NCHW44, | |||||
| DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) | |||||
| .add_opr_config( | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| {OprFormat::NCHW, OprFormat::NCHW44, | |||||
| DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) | |||||
| .add_opr_config(opr::PoolingForward::typeinfo(), | |||||
| {OprFormat::NCHW, OprFormat::NCHW44, | |||||
| DNN_INC_FLOAT16(OprFormat::NCHW88)}) | |||||
| .add_opr_config(opr::ResizeForward::typeinfo(), | |||||
| {OprFormat::NCHW, OprFormat::NCHW44, | |||||
| DNN_INC_FLOAT16(OprFormat::NCHW88)}); | |||||
| return ctx; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| /* ================= LayoutTransformContext ==================*/ | /* ================= LayoutTransformContext ==================*/ | ||||
| @@ -110,6 +149,8 @@ std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make( | |||||
| switch (target) { | switch (target) { | ||||
| case Target::CUDA: | case Target::CUDA: | ||||
| return make_cuda_ctx(base_opr_format, base_tensor_format); | return make_cuda_ctx(base_opr_format, base_tensor_format); | ||||
| case Target::ARM: | |||||
| return make_arm_ctx(base_opr_format, base_tensor_format); | |||||
| default: | default: | ||||
| mgb_assert(false, "unsupported target %s\n", target_to_string(target)); | mgb_assert(false, "unsupported target %s\n", target_to_string(target)); | ||||
| } | } | ||||
| @@ -60,6 +60,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| auto&& opr_configs = m_ctx->opr_configs(); | auto&& opr_configs = m_ctx->opr_configs(); | ||||
| auto&& base_fmt = m_ctx->attribute().base_tensor_formats; | auto&& base_fmt = m_ctx->attribute().base_tensor_formats; | ||||
| auto&& base_opr_fmt = m_ctx->attribute().base_opr_format; | |||||
| auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | ||||
| ThinHashMap<VarNode*, TensorFormats> var2fmts; | ThinHashMap<VarNode*, TensorFormats> var2fmts; | ||||
| static ThinHashSet<Typeinfo*> format_aware_oprs = { | static ThinHashSet<Typeinfo*> format_aware_oprs = { | ||||
| @@ -68,15 +69,18 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| #undef cb | #undef cb | ||||
| }; | }; | ||||
| auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
| auto on_opr = [&opr_configs, &base_fmt, &reformat_attribute, &rewriter, &solution, | |||||
| &var2fmts, &endpoint_vars](OperatorNodeBase* opr) { | |||||
| auto on_opr = [&opr_configs, &base_fmt, &base_opr_fmt, &reformat_attribute, | |||||
| &rewriter, &solution, &var2fmts, | |||||
| &endpoint_vars](OperatorNodeBase* opr) { | |||||
| auto it = solution.find(opr); | auto it = solution.find(opr); | ||||
| if (it != solution.end()) { | if (it != solution.end()) { | ||||
| auto opr_fmt = it->second; | auto opr_fmt = it->second; | ||||
| auto find = opr_configs.find(opr->dyn_typeinfo()); | auto find = opr_configs.find(opr->dyn_typeinfo()); | ||||
| Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | ||||
| Maybe<OprTensorFormatsConfiguration> basecfg = None; | |||||
| if (find != opr_configs.end()) { | if (find != opr_configs.end()) { | ||||
| fmtcfg = (*find->second.at(opr_fmt))(opr); | fmtcfg = (*find->second.at(opr_fmt))(opr); | ||||
| basecfg = (*find->second.at(base_opr_fmt))(opr); | |||||
| } | } | ||||
| VarNodeArray new_inp; | VarNodeArray new_inp; | ||||
| size_t nr_inps = opr->input().size(); | size_t nr_inps = opr->input().size(); | ||||
| @@ -103,6 +107,10 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| bool is_parameter = | bool is_parameter = | ||||
| fmtcfg.valid() && | fmtcfg.valid() && | ||||
| fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; | fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; | ||||
| if (is_parameter) { | |||||
| mgb_assert(basecfg.valid()); | |||||
| from = basecfg.val().input_tensor_formats[i]; | |||||
| } | |||||
| // need relayout | // need relayout | ||||
| if (from != to && !new_var->shape().is_scalar()) { | if (from != to && !new_var->shape().is_scalar()) { | ||||
| ReformatManager::ReformatImpl reformat; | ReformatManager::ReformatImpl reformat; | ||||
| @@ -78,6 +78,48 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <> | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch( | |||||
| const OperatorNodeBase* opr) { | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NCHW44; | |||||
| bool available = true; | |||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | |||||
| config.input_tensor_types = {TensorType::FEATURE}; | |||||
| config.output_dtypes = {opr->output(0)->dtype().enumv()}; | |||||
| config.input_tensor_formats = {TensorFormats::NCHWc4}; | |||||
| config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||||
| if (!available) | |||||
| return None; | |||||
| return config; | |||||
| } | |||||
| }; | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| template <> | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch( | |||||
| const OperatorNodeBase* opr) { | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NCHW88; | |||||
| bool available = true; | |||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16; | |||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | |||||
| config.input_tensor_types = {TensorType::FEATURE}; | |||||
| config.output_dtypes = {opr->output(0)->dtype().enumv()}; | |||||
| config.input_tensor_formats = {TensorFormats::NCHWc8}; | |||||
| config.output_tensor_formats = {TensorFormats::NCHWc8}; | |||||
| if (!available) | |||||
| return None; | |||||
| return config; | |||||
| } | |||||
| }; | |||||
| #endif | |||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> { | struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> { | ||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| @@ -200,7 +242,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW> { | |||||
| // setup tensor formats | // setup tensor formats | ||||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | ||||
| config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
| TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW, | |||||
| TensorFormats::NCHW, TensorFormats::KCRS, TensorFormats::NCHW, | |||||
| TensorFormats::NCHW}; | TensorFormats::NCHW}; | ||||
| } else { | } else { | ||||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | ||||
| @@ -396,6 +438,145 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename Opr> | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch( | |||||
| const OperatorNodeBase* opr) { | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NCHW44; | |||||
| bool available = true; | |||||
| // setup dtypes | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
| TensorType tensor_type = | |||||
| i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
| config.input_tensor_types.emplace_back(tensor_type); | |||||
| } | |||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
| // setup tensor formats | |||||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHWc4, TensorFormats::KCRSc4k4, | |||||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||||
| } else { | |||||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||||
| if (is_channel_wise_conv<Opr>(opr)) { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHWc4, TensorFormats::C11RSc4, | |||||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||||
| } else { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHWc4, TensorFormats::GKCRSc4k4, | |||||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||||
| } | |||||
| } | |||||
| config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||||
| if (!available) | |||||
| return None; | |||||
| return config; | |||||
| } | |||||
| }; | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| template <typename Opr> | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch( | |||||
| const OperatorNodeBase* opr) { | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NCHW88; | |||||
| bool available = true; | |||||
| // setup dtypes | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
| TensorType tensor_type = | |||||
| i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
| config.input_tensor_types.emplace_back(tensor_type); | |||||
| } | |||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
| // setup tensor formats | |||||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHWc8, TensorFormats::KCRSc8k8, | |||||
| TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||||
| } else { | |||||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||||
| if (is_channel_wise_conv<Opr>(opr)) { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHWc8, TensorFormats::C11RSc8, | |||||
| TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||||
| } else { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHWc8, TensorFormats::GKCRSc8k8, | |||||
| TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||||
| } | |||||
| } | |||||
| config.output_tensor_formats = {TensorFormats::NCHWc8}; | |||||
| if (!available) | |||||
| return None; | |||||
| return config; | |||||
| } | |||||
| }; | |||||
| #endif | |||||
| template <typename Opr> | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch( | |||||
| const OperatorNodeBase* opr) { | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NCHW44_DOT; | |||||
| bool available = true; | |||||
| // setup dtypes | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
| if (i == 2) { | |||||
| available &= opr->input(i)->dtype().enumv() == | |||||
| DTypeEnum::QuantizedS32; | |||||
| } else { | |||||
| available &= opr->input(i)->dtype().enumv() == | |||||
| DTypeEnum::QuantizedS8 || | |||||
| opr->input(i)->dtype().enumv() == | |||||
| DTypeEnum::Quantized8Asymm; | |||||
| } | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
| TensorType tensor_type = | |||||
| i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
| config.input_tensor_types.emplace_back(tensor_type); | |||||
| } | |||||
| available &= | |||||
| opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
| opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
| // setup tensor formats | |||||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHWc4, TensorFormats::KCRSk4c4, | |||||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||||
| } else { | |||||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||||
| if (is_channel_wise_conv<Opr>(opr)) { | |||||
| available = false; | |||||
| } else { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHWc4, TensorFormats::GKCRSk4c4, | |||||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||||
| } | |||||
| } | |||||
| config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||||
| if (!available) | |||||
| return None; | |||||
| return config; | |||||
| } | |||||
| }; | |||||
| template <> | template <> | ||||
| struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW> { | struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW> { | ||||
| using Opr = opr::ConvolutionBackwardData; | using Opr = opr::ConvolutionBackwardData; | ||||
| @@ -530,9 +711,19 @@ StaticData::StaticData() { | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); | |||||
| #endif | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44); | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88); | |||||
| #endif | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | ||||
| @@ -549,6 +740,16 @@ StaticData::StaticData() { | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, CHWN4); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, CHWN4); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); | |||||
| #endif | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); | |||||
| #endif | |||||
| #undef OPR_TENSOR_FORMATS_CONFIG_REG | #undef OPR_TENSOR_FORMATS_CONFIG_REG | ||||
| #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | ||||
| @@ -35,9 +35,9 @@ OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) { | |||||
| case TensorFormats::NCHW: | case TensorFormats::NCHW: | ||||
| return OprFormat::NCHW; | return OprFormat::NCHW; | ||||
| case TensorFormats::NCHWc4: | case TensorFormats::NCHWc4: | ||||
| return OprFormat::NCHW4; | |||||
| return OprFormat::NCHW44; | |||||
| case TensorFormats::NCHWc8: | case TensorFormats::NCHWc8: | ||||
| return OprFormat::NCHW8; | |||||
| return OprFormat::NCHW88; | |||||
| case TensorFormats::NCHWc32: | case TensorFormats::NCHWc32: | ||||
| return OprFormat::NCHW32; | return OprFormat::NCHW32; | ||||
| case TensorFormats::NCHWc64: | case TensorFormats::NCHWc64: | ||||
| @@ -424,11 +424,11 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons | |||||
| skip &= problem.graph_partition().input().count(i) > 0 || | skip &= problem.graph_partition().input().count(i) > 0 || | ||||
| skip_oprs.count(i->owner_opr()) > 0; | skip_oprs.count(i->owner_opr()) > 0; | ||||
| } | } | ||||
| skip &= skip_opr_types.count(opr->dyn_typeinfo()); | |||||
| auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||||
| skip &= find == format_aware_input_tensors.end(); | |||||
| if (skip) | if (skip) | ||||
| skip_oprs.insert(opr); | skip_oprs.insert(opr); | ||||
| oprs.insert(opr); | oprs.insert(opr); | ||||
| auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||||
| if (find == format_aware_input_tensors.end()) { | if (find == format_aware_input_tensors.end()) { | ||||
| for (auto&& i : opr->input()) { | for (auto&& i : opr->input()) { | ||||
| if (!cvprop.is_const(i)) { | if (!cvprop.is_const(i)) { | ||||
| @@ -470,9 +470,9 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||||
| input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | ||||
| in_channels = orig_var->shape()[i] * input_shape[i].stride(); | in_channels = orig_var->shape()[i] * input_shape[i].stride(); | ||||
| input_channel_idx = i; | input_channel_idx = i; | ||||
| // mgb_assert(input_shape[i].stride() == 1, | |||||
| // "unsupport weight format(got:%s)", | |||||
| // input_shape.to_string().c_str()); | |||||
| mgb_assert( | |||||
| input_shape[i].stride() == 1, "unsupport weight format(got:%s)", | |||||
| input_shape.to_string().c_str()); | |||||
| } else if ( | } else if ( | ||||
| (input_shape[i].name() == Dimension::Name::K || | (input_shape[i].name() == Dimension::Name::K || | ||||
| input_shape[i].name() == Dimension::Name::N) && | input_shape[i].name() == Dimension::Name::N) && | ||||
| @@ -485,13 +485,23 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||||
| input_shape.to_string().c_str()); | input_shape.to_string().c_str()); | ||||
| } | } | ||||
| } | } | ||||
| /* \notes: FIXME this is a hack. Since the layout of weight in channelwise | |||||
| * convolution does not have output channel dimension, so we mannually modify the | |||||
| * out_channel_name, out_channel_idx to bypass the following assertion statements. */ | |||||
| bool is_channelwise = key.input_format == TensorFormats::C11RS; | |||||
| if (is_channelwise) { | |||||
| out_channel_name = Dimension::Name::K; | |||||
| out_channels = in_channels; | |||||
| output_channel_idx = input_channel_idx; | |||||
| } | |||||
| mgb_assert( | mgb_assert( | ||||
| out_channel_name == Dimension::Name::K || | out_channel_name == Dimension::Name::K || | ||||
| out_channel_name == Dimension::Name::N, | out_channel_name == Dimension::Name::N, | ||||
| "invalid out channel(shp:%s)", input_shape.to_string().c_str()); | "invalid out channel(shp:%s)", input_shape.to_string().c_str()); | ||||
| mgb_assert( | mgb_assert( | ||||
| input_channel_idx < input_shape.ndim && | |||||
| output_channel_idx < input_shape.ndim, | |||||
| (input_channel_idx < input_shape.ndim && | |||||
| output_channel_idx < input_shape.ndim) || | |||||
| (is_channelwise && output_channel_idx == input_channel_idx), | |||||
| "invalid channel idx(in_channel:%zu, out_channel:%zu, shp:%s)", | "invalid channel idx(in_channel:%zu, out_channel:%zu, shp:%s)", | ||||
| input_channel_idx, output_channel_idx, input_shape.to_string().c_str()); | input_channel_idx, output_channel_idx, input_shape.to_string().c_str()); | ||||
| size_t in_channel_alignment = 0, out_channel_alignment = 0; | size_t in_channel_alignment = 0, out_channel_alignment = 0; | ||||
| @@ -506,6 +516,13 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||||
| out_channel_alignment = output_shape[i].stride(); | out_channel_alignment = output_shape[i].stride(); | ||||
| } | } | ||||
| } | } | ||||
| /* \notes: FIXME this is a hack. Since the layout of weight in channelwise | |||||
| * convolution does not have output channel dimension, so we mannually modify the | |||||
| * out_channel_alignment to bypass the following assertion statements. */ | |||||
| if (is_channelwise) { | |||||
| mgb_assert(out_channel_alignment == 0); | |||||
| out_channel_alignment = 1; | |||||
| } | |||||
| mgb_assert( | mgb_assert( | ||||
| in_channel_alignment > 0 && out_channel_alignment > 0, | in_channel_alignment > 0 && out_channel_alignment > 0, | ||||
| "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | ||||
| @@ -263,20 +263,9 @@ std::vector<GraphPartition> SubGraphExtractor::extract( | |||||
| std::vector<GraphPartition> partitions; | std::vector<GraphPartition> partitions; | ||||
| partitions.reserve(topo.size()); | partitions.reserve(topo.size()); | ||||
| ThinHashMap<OperatorNodeBase*, GraphPartition*> roots; | ThinHashMap<OperatorNodeBase*, GraphPartition*> roots; | ||||
| /// backward pass | |||||
| for (const auto& opr : reverse_adaptor(topo)) { | for (const auto& opr : reverse_adaptor(topo)) { | ||||
| if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { | |||||
| for (const auto& i : opr->input()) { | |||||
| if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) { | |||||
| auto root = union_find(i->owner_opr()); | |||||
| GraphPartition* partition; | |||||
| auto find = roots.find(root); | |||||
| if (find != roots.end()) { | |||||
| partition = find->second; | |||||
| partition->output().insert(i); | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| if (m_opr_list.count(opr->dyn_typeinfo()) > 0) { | |||||
| auto root = union_find(opr); | auto root = union_find(opr); | ||||
| auto find = roots.find(root); | auto find = roots.find(root); | ||||
| GraphPartition* partition = nullptr; | GraphPartition* partition = nullptr; | ||||
| @@ -304,6 +293,23 @@ std::vector<GraphPartition> SubGraphExtractor::extract( | |||||
| partition->input().insert(i); | partition->input().insert(i); | ||||
| } | } | ||||
| } | } | ||||
| /// forward pass | |||||
| for (auto&& opr : topo) { | |||||
| if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { | |||||
| for (const auto& i : opr->input()) { | |||||
| if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) { | |||||
| auto root = union_find(i->owner_opr()); | |||||
| GraphPartition* partition; | |||||
| auto find = roots.find(root); | |||||
| if (find != roots.end()) { | |||||
| partition = find->second; | |||||
| partition->output().insert(i); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| for (auto&& partition : partitions) { | for (auto&& partition : partitions) { | ||||
| auto& all_oprs = partition.all_oprs(); | auto& all_oprs = partition.all_oprs(); | ||||
| std::reverse(all_oprs.begin(), all_oprs.end()); | std::reverse(all_oprs.begin(), all_oprs.end()); | ||||
| @@ -29,6 +29,9 @@ static inline const char* opr_format_to_string( | |||||
| cb(NCHW32); | cb(NCHW32); | ||||
| cb(NCHW64); | cb(NCHW64); | ||||
| cb(CHWN4); | cb(CHWN4); | ||||
| cb(NCHW44); | |||||
| cb(NCHW88); | |||||
| cb(NCHW44_DOT); | |||||
| default: | default: | ||||
| mgb_assert( | mgb_assert( | ||||
| false, "Invalid opr format(got:%u)", | false, "Invalid opr format(got:%u)", | ||||
| @@ -53,6 +56,10 @@ static inline TensorFormats opr_format_to_tensor_formats( | |||||
| return TensorFormats::NCHWc64; | return TensorFormats::NCHWc64; | ||||
| case OprFormat::CHWN4: | case OprFormat::CHWN4: | ||||
| return TensorFormats::CHWNc4; | return TensorFormats::CHWNc4; | ||||
| case OprFormat::NCHW88: | |||||
| return TensorFormats::NCHWc8; | |||||
| case OprFormat::NCHW44: | |||||
| return TensorFormats::NCHWc4; | |||||
| default: | default: | ||||
| mgb_throw( | mgb_throw( | ||||
| AssertionError, "format(%s) is not supported", | AssertionError, "format(%s) is not supported", | ||||