GitOrigin-RevId: db50b33c11
tags/v1.7.0
| @@ -820,23 +820,26 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( | |||
| _passes need_param_fuse = true; \ | |||
| } | |||
| using Target = GraphTuningOptions::Target; | |||
| cb(layout_transform, { | |||
| add_pass<FuseConvBiasNonlinPass>(); | |||
| add_pass<FuseConvBiasZPass>(); | |||
| if (options.target == Target::CUDA) | |||
| add_pass<FuseConvBiasZPass>(); | |||
| add_pass(LayoutTransformPass::make(options.target)); | |||
| add_pass<ShuffleShuffleRemovePass>(); | |||
| add_pass(FuseNCHW4Int8Preprocess::make()); | |||
| add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||
| if (options.target == Target::CUDA) { | |||
| add_pass(FuseNCHW4Int8Preprocess::make()); | |||
| add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||
| #if CUDA_VERSION >= 10020 | |||
| add_pass<FoldingConvBiasDimshufflePass>(); | |||
| add_pass<FoldingConvBiasTypecvtPass>(); | |||
| add_pass<FoldingConvBiasDimshufflePass>(); | |||
| add_pass<FoldingConvBiasTypecvtPass>(); | |||
| #endif | |||
| } | |||
| }); | |||
| #undef cb | |||
| if (need_param_fuse) { | |||
| add_pass<ParamFusePass>(); | |||
| add_pass<ParamMergePass>(); | |||
| } | |||
| return *this; | |||
| } | |||
| @@ -15,6 +15,7 @@ | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| #include "megbrain/opr/nn_int.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| using namespace mgb; | |||
| using namespace gopt; | |||
| @@ -82,6 +83,44 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||
| {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); | |||
| return ctx; | |||
| } | |||
| std::unique_ptr<LayoutTransformContext> make_arm_ctx( | |||
| OprFormat base_opr_format, TensorFormats base_tensor_format) { | |||
| OprList opr_list = { | |||
| opr::ConvBiasForward::typeinfo(), | |||
| opr::ConvolutionForward::typeinfo(), | |||
| opr::ElemwiseMultiType::typeinfo(), | |||
| opr::Elemwise::typeinfo(), | |||
| opr::TypeCvt::typeinfo(), | |||
| opr::PoolingForward::typeinfo(), | |||
| opr::Resize::typeinfo(), | |||
| opr::PowC::typeinfo(), | |||
| opr::Concat::typeinfo(), | |||
| }; | |||
| SmallVector<TensorFormats> available_tensor_formats = { | |||
| TensorFormats::NCHW, TensorFormats::NCHWc4, | |||
| DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; | |||
| Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM}; | |||
| auto ctx = std::make_unique<LayoutTransformContext>( | |||
| std::move(opr_list), std::move(available_tensor_formats), | |||
| attribute); | |||
| ctx->add_opr_config( | |||
| opr::ConvBiasForward::typeinfo(), | |||
| {OprFormat::NCHW, OprFormat::NCHW44, | |||
| DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) | |||
| .add_opr_config( | |||
| opr::ConvolutionForward::typeinfo(), | |||
| {OprFormat::NCHW, OprFormat::NCHW44, | |||
| DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) | |||
| .add_opr_config(opr::PoolingForward::typeinfo(), | |||
| {OprFormat::NCHW, OprFormat::NCHW44, | |||
| DNN_INC_FLOAT16(OprFormat::NCHW88)}) | |||
| .add_opr_config(opr::ResizeForward::typeinfo(), | |||
| {OprFormat::NCHW, OprFormat::NCHW44, | |||
| DNN_INC_FLOAT16(OprFormat::NCHW88)}); | |||
| return ctx; | |||
| } | |||
| } // namespace | |||
| /* ================= LayoutTransformContext ==================*/ | |||
| @@ -110,6 +149,8 @@ std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make( | |||
| switch (target) { | |||
| case Target::CUDA: | |||
| return make_cuda_ctx(base_opr_format, base_tensor_format); | |||
| case Target::ARM: | |||
| return make_arm_ctx(base_opr_format, base_tensor_format); | |||
| default: | |||
| mgb_assert(false, "unsupported target %s\n", target_to_string(target)); | |||
| } | |||
| @@ -60,6 +60,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||
| auto&& opr_configs = m_ctx->opr_configs(); | |||
| auto&& base_fmt = m_ctx->attribute().base_tensor_formats; | |||
| auto&& base_opr_fmt = m_ctx->attribute().base_opr_format; | |||
| auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | |||
| ThinHashMap<VarNode*, TensorFormats> var2fmts; | |||
| static ThinHashSet<Typeinfo*> format_aware_oprs = { | |||
| @@ -68,15 +69,18 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||
| #undef cb | |||
| }; | |||
| auto rewriter = opt.graph().make_rewriter(); | |||
| auto on_opr = [&opr_configs, &base_fmt, &reformat_attribute, &rewriter, &solution, | |||
| &var2fmts, &endpoint_vars](OperatorNodeBase* opr) { | |||
| auto on_opr = [&opr_configs, &base_fmt, &base_opr_fmt, &reformat_attribute, | |||
| &rewriter, &solution, &var2fmts, | |||
| &endpoint_vars](OperatorNodeBase* opr) { | |||
| auto it = solution.find(opr); | |||
| if (it != solution.end()) { | |||
| auto opr_fmt = it->second; | |||
| auto find = opr_configs.find(opr->dyn_typeinfo()); | |||
| Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | |||
| Maybe<OprTensorFormatsConfiguration> basecfg = None; | |||
| if (find != opr_configs.end()) { | |||
| fmtcfg = (*find->second.at(opr_fmt))(opr); | |||
| basecfg = (*find->second.at(base_opr_fmt))(opr); | |||
| } | |||
| VarNodeArray new_inp; | |||
| size_t nr_inps = opr->input().size(); | |||
| @@ -103,6 +107,10 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||
| bool is_parameter = | |||
| fmtcfg.valid() && | |||
| fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; | |||
| if (is_parameter) { | |||
| mgb_assert(basecfg.valid()); | |||
| from = basecfg.val().input_tensor_formats[i]; | |||
| } | |||
| // need relayout | |||
| if (from != to && !new_var->shape().is_scalar()) { | |||
| ReformatManager::ReformatImpl reformat; | |||
| @@ -78,6 +78,48 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> { | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> { | |||
| static Maybe<OprTensorFormatsConfiguration> dispatch( | |||
| const OperatorNodeBase* opr) { | |||
| OprTensorFormatsConfiguration config; | |||
| config.typeinfo = opr->dyn_typeinfo(); | |||
| config.opr_format = OprFormat::NCHW44; | |||
| bool available = true; | |||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32; | |||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | |||
| config.input_tensor_types = {TensorType::FEATURE}; | |||
| config.output_dtypes = {opr->output(0)->dtype().enumv()}; | |||
| config.input_tensor_formats = {TensorFormats::NCHWc4}; | |||
| config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||
| if (!available) | |||
| return None; | |||
| return config; | |||
| } | |||
| }; | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| template <> | |||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> { | |||
| static Maybe<OprTensorFormatsConfiguration> dispatch( | |||
| const OperatorNodeBase* opr) { | |||
| OprTensorFormatsConfiguration config; | |||
| config.typeinfo = opr->dyn_typeinfo(); | |||
| config.opr_format = OprFormat::NCHW88; | |||
| bool available = true; | |||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16; | |||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | |||
| config.input_tensor_types = {TensorType::FEATURE}; | |||
| config.output_dtypes = {opr->output(0)->dtype().enumv()}; | |||
| config.input_tensor_formats = {TensorFormats::NCHWc8}; | |||
| config.output_tensor_formats = {TensorFormats::NCHWc8}; | |||
| if (!available) | |||
| return None; | |||
| return config; | |||
| } | |||
| }; | |||
| #endif | |||
| template <> | |||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> { | |||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||
| @@ -200,7 +242,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW> { | |||
| // setup tensor formats | |||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW, | |||
| TensorFormats::NCHW, TensorFormats::KCRS, TensorFormats::NCHW, | |||
| TensorFormats::NCHW}; | |||
| } else { | |||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
| @@ -396,6 +438,145 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> { | |||
| } | |||
| }; | |||
| template <typename Opr> | |||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { | |||
| static Maybe<OprTensorFormatsConfiguration> dispatch( | |||
| const OperatorNodeBase* opr) { | |||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||
| OprTensorFormatsConfiguration config; | |||
| config.typeinfo = opr->dyn_typeinfo(); | |||
| config.opr_format = OprFormat::NCHW44; | |||
| bool available = true; | |||
| // setup dtypes | |||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; | |||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||
| TensorType tensor_type = | |||
| i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||
| config.input_tensor_types.emplace_back(tensor_type); | |||
| } | |||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||
| // setup tensor formats | |||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NCHWc4, TensorFormats::KCRSc4k4, | |||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||
| } else { | |||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
| if (is_channel_wise_conv<Opr>(opr)) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NCHWc4, TensorFormats::C11RSc4, | |||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||
| } else { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NCHWc4, TensorFormats::GKCRSc4k4, | |||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||
| } | |||
| } | |||
| config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||
| if (!available) | |||
| return None; | |||
| return config; | |||
| } | |||
| }; | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| template <typename Opr> | |||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { | |||
| static Maybe<OprTensorFormatsConfiguration> dispatch( | |||
| const OperatorNodeBase* opr) { | |||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||
| OprTensorFormatsConfiguration config; | |||
| config.typeinfo = opr->dyn_typeinfo(); | |||
| config.opr_format = OprFormat::NCHW88; | |||
| bool available = true; | |||
| // setup dtypes | |||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; | |||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||
| TensorType tensor_type = | |||
| i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||
| config.input_tensor_types.emplace_back(tensor_type); | |||
| } | |||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; | |||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||
| // setup tensor formats | |||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NCHWc8, TensorFormats::KCRSc8k8, | |||
| TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||
| } else { | |||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
| if (is_channel_wise_conv<Opr>(opr)) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NCHWc8, TensorFormats::C11RSc8, | |||
| TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||
| } else { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NCHWc8, TensorFormats::GKCRSc8k8, | |||
| TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||
| } | |||
| } | |||
| config.output_tensor_formats = {TensorFormats::NCHWc8}; | |||
| if (!available) | |||
| return None; | |||
| return config; | |||
| } | |||
| }; | |||
| #endif | |||
| template <typename Opr> | |||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> { | |||
| static Maybe<OprTensorFormatsConfiguration> dispatch( | |||
| const OperatorNodeBase* opr) { | |||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||
| OprTensorFormatsConfiguration config; | |||
| config.typeinfo = opr->dyn_typeinfo(); | |||
| config.opr_format = OprFormat::NCHW44_DOT; | |||
| bool available = true; | |||
| // setup dtypes | |||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||
| if (i == 2) { | |||
| available &= opr->input(i)->dtype().enumv() == | |||
| DTypeEnum::QuantizedS32; | |||
| } else { | |||
| available &= opr->input(i)->dtype().enumv() == | |||
| DTypeEnum::QuantizedS8 || | |||
| opr->input(i)->dtype().enumv() == | |||
| DTypeEnum::Quantized8Asymm; | |||
| } | |||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||
| TensorType tensor_type = | |||
| i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||
| config.input_tensor_types.emplace_back(tensor_type); | |||
| } | |||
| available &= | |||
| opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
| opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; | |||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||
| // setup tensor formats | |||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NCHWc4, TensorFormats::KCRSk4c4, | |||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||
| } else { | |||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
| if (is_channel_wise_conv<Opr>(opr)) { | |||
| available = false; | |||
| } else { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NCHWc4, TensorFormats::GKCRSk4c4, | |||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||
| } | |||
| } | |||
| config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||
| if (!available) | |||
| return None; | |||
| return config; | |||
| } | |||
| }; | |||
| template <> | |||
| struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW> { | |||
| using Opr = opr::ConvolutionBackwardData; | |||
| @@ -530,9 +711,19 @@ StaticData::StaticData() { | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); | |||
| #endif | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44); | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88); | |||
| #endif | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | |||
| @@ -549,6 +740,16 @@ StaticData::StaticData() { | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, CHWN4); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); | |||
| #endif | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); | |||
| #endif | |||
| #undef OPR_TENSOR_FORMATS_CONFIG_REG | |||
| #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | |||
| @@ -35,9 +35,9 @@ OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) { | |||
| case TensorFormats::NCHW: | |||
| return OprFormat::NCHW; | |||
| case TensorFormats::NCHWc4: | |||
| return OprFormat::NCHW4; | |||
| return OprFormat::NCHW44; | |||
| case TensorFormats::NCHWc8: | |||
| return OprFormat::NCHW8; | |||
| return OprFormat::NCHW88; | |||
| case TensorFormats::NCHWc32: | |||
| return OprFormat::NCHW32; | |||
| case TensorFormats::NCHWc64: | |||
| @@ -424,11 +424,11 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons | |||
| skip &= problem.graph_partition().input().count(i) > 0 || | |||
| skip_oprs.count(i->owner_opr()) > 0; | |||
| } | |||
| skip &= skip_opr_types.count(opr->dyn_typeinfo()); | |||
| auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||
| skip &= find == format_aware_input_tensors.end(); | |||
| if (skip) | |||
| skip_oprs.insert(opr); | |||
| oprs.insert(opr); | |||
| auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||
| if (find == format_aware_input_tensors.end()) { | |||
| for (auto&& i : opr->input()) { | |||
| if (!cvprop.is_const(i)) { | |||
| @@ -470,9 +470,9 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||
| input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | |||
| in_channels = orig_var->shape()[i] * input_shape[i].stride(); | |||
| input_channel_idx = i; | |||
| // mgb_assert(input_shape[i].stride() == 1, | |||
| // "unsupport weight format(got:%s)", | |||
| // input_shape.to_string().c_str()); | |||
| mgb_assert( | |||
| input_shape[i].stride() == 1, "unsupport weight format(got:%s)", | |||
| input_shape.to_string().c_str()); | |||
| } else if ( | |||
| (input_shape[i].name() == Dimension::Name::K || | |||
| input_shape[i].name() == Dimension::Name::N) && | |||
| @@ -485,13 +485,23 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||
| input_shape.to_string().c_str()); | |||
| } | |||
| } | |||
| /* \notes: FIXME this is a hack. Since the layout of weight in channelwise | |||
| * convolution does not have output channel dimension, so we mannually modify the | |||
| * out_channel_name, out_channel_idx to bypass the following assertion statements. */ | |||
| bool is_channelwise = key.input_format == TensorFormats::C11RS; | |||
| if (is_channelwise) { | |||
| out_channel_name = Dimension::Name::K; | |||
| out_channels = in_channels; | |||
| output_channel_idx = input_channel_idx; | |||
| } | |||
| mgb_assert( | |||
| out_channel_name == Dimension::Name::K || | |||
| out_channel_name == Dimension::Name::N, | |||
| "invalid out channel(shp:%s)", input_shape.to_string().c_str()); | |||
| mgb_assert( | |||
| input_channel_idx < input_shape.ndim && | |||
| output_channel_idx < input_shape.ndim, | |||
| (input_channel_idx < input_shape.ndim && | |||
| output_channel_idx < input_shape.ndim) || | |||
| (is_channelwise && output_channel_idx == input_channel_idx), | |||
| "invalid channel idx(in_channel:%zu, out_channel:%zu, shp:%s)", | |||
| input_channel_idx, output_channel_idx, input_shape.to_string().c_str()); | |||
| size_t in_channel_alignment = 0, out_channel_alignment = 0; | |||
| @@ -506,6 +516,13 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||
| out_channel_alignment = output_shape[i].stride(); | |||
| } | |||
| } | |||
| /* \notes: FIXME this is a hack. Since the layout of weight in channelwise | |||
| * convolution does not have output channel dimension, so we mannually modify the | |||
| * out_channel_alignment to bypass the following assertion statements. */ | |||
| if (is_channelwise) { | |||
| mgb_assert(out_channel_alignment == 0); | |||
| out_channel_alignment = 1; | |||
| } | |||
| mgb_assert( | |||
| in_channel_alignment > 0 && out_channel_alignment > 0, | |||
| "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | |||
| @@ -263,20 +263,9 @@ std::vector<GraphPartition> SubGraphExtractor::extract( | |||
| std::vector<GraphPartition> partitions; | |||
| partitions.reserve(topo.size()); | |||
| ThinHashMap<OperatorNodeBase*, GraphPartition*> roots; | |||
| /// backward pass | |||
| for (const auto& opr : reverse_adaptor(topo)) { | |||
| if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { | |||
| for (const auto& i : opr->input()) { | |||
| if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) { | |||
| auto root = union_find(i->owner_opr()); | |||
| GraphPartition* partition; | |||
| auto find = roots.find(root); | |||
| if (find != roots.end()) { | |||
| partition = find->second; | |||
| partition->output().insert(i); | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| if (m_opr_list.count(opr->dyn_typeinfo()) > 0) { | |||
| auto root = union_find(opr); | |||
| auto find = roots.find(root); | |||
| GraphPartition* partition = nullptr; | |||
| @@ -304,6 +293,23 @@ std::vector<GraphPartition> SubGraphExtractor::extract( | |||
| partition->input().insert(i); | |||
| } | |||
| } | |||
| /// forward pass | |||
| for (auto&& opr : topo) { | |||
| if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { | |||
| for (const auto& i : opr->input()) { | |||
| if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) { | |||
| auto root = union_find(i->owner_opr()); | |||
| GraphPartition* partition; | |||
| auto find = roots.find(root); | |||
| if (find != roots.end()) { | |||
| partition = find->second; | |||
| partition->output().insert(i); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| for (auto&& partition : partitions) { | |||
| auto& all_oprs = partition.all_oprs(); | |||
| std::reverse(all_oprs.begin(), all_oprs.end()); | |||
| @@ -29,6 +29,9 @@ static inline const char* opr_format_to_string( | |||
| cb(NCHW32); | |||
| cb(NCHW64); | |||
| cb(CHWN4); | |||
| cb(NCHW44); | |||
| cb(NCHW88); | |||
| cb(NCHW44_DOT); | |||
| default: | |||
| mgb_assert( | |||
| false, "Invalid opr format(got:%u)", | |||
| @@ -53,6 +56,10 @@ static inline TensorFormats opr_format_to_tensor_formats( | |||
| return TensorFormats::NCHWc64; | |||
| case OprFormat::CHWN4: | |||
| return TensorFormats::CHWNc4; | |||
| case OprFormat::NCHW88: | |||
| return TensorFormats::NCHWc8; | |||
| case OprFormat::NCHW44: | |||
| return TensorFormats::NCHWc4; | |||
| default: | |||
| mgb_throw( | |||
| AssertionError, "format(%s) is not supported", | |||