GitOrigin-RevId: 39cad612cc
tags/v1.7.0
| @@ -32,13 +32,8 @@ namespace megdnn { | |||||
| */ | */ | ||||
| template <class Opr, typename... Args> | template <class Opr, typename... Args> | ||||
| bool has_available_algo(Opr* opr, Args&&... args) { | bool has_available_algo(Opr* opr, Args&&... args) { | ||||
| const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...); | |||||
| for (auto i : Opr::algo_pack().all_algos) { | |||||
| if (i->is_available(size_args)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| auto&& all_algos = opr->get_all_algorithms_info(std::forward<Args>(args)...); | |||||
| return !all_algos.empty(); | |||||
| } | } | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -157,7 +157,6 @@ struct ConvMaker<opr::BatchConvBiasForward> | |||||
| MakeConvCaller4<megdnn::BatchConvBiasForward>, | MakeConvCaller4<megdnn::BatchConvBiasForward>, | ||||
| megdnn::param::BatchConvBias> {}; | megdnn::param::BatchConvBias> {}; | ||||
| #if 0 | |||||
| #include "../../opr/impl/internal/invoke.h" | #include "../../opr/impl/internal/invoke.h" | ||||
| template <typename Opr> | template <typename Opr> | ||||
| struct MultiAlgoOprTrait; | struct MultiAlgoOprTrait; | ||||
| @@ -202,7 +201,6 @@ INST(ConvolutionBackwardData) | |||||
| INST(PoolingForward) | INST(PoolingForward) | ||||
| #undef APPLY | #undef APPLY | ||||
| #undef INST | #undef INST | ||||
| #endif | |||||
| } // namespace | } // namespace | ||||
| namespace mgb { | namespace mgb { | ||||
| @@ -291,9 +289,7 @@ VarNode* modify_opr_format( | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| #if 0 | |||||
| bool has_available_algo(const VarNodeArray& i, | |||||
| const cg::OperatorNodeBase* opr) { | |||||
| bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr) { | |||||
| #define cb(_Opr) \ | #define cb(_Opr) \ | ||||
| if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ | if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \ | ||||
| MGB_MARK_USED_VAR(MultiAlgoOprTrait<_Opr>::has_algo); \ | MGB_MARK_USED_VAR(MultiAlgoOprTrait<_Opr>::has_algo); \ | ||||
| @@ -301,13 +297,12 @@ bool has_available_algo(const VarNodeArray& i, | |||||
| _.emplace_back(opr->output(0)); \ | _.emplace_back(opr->output(0)); \ | ||||
| return MultiAlgoOprTrait<_Opr>::has_available_algo(_, opr); \ | return MultiAlgoOprTrait<_Opr>::has_available_algo(_, opr); \ | ||||
| } else | } else | ||||
| cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) | |||||
| cb(PoolingForward) { | |||||
| mgb_throw(InternalError, "invalid multi-algo operator(got:%s)", | |||||
| opr->dyn_typeinfo()->name); | |||||
| cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) { | |||||
| mgb_throw( | |||||
| InternalError, "invalid multi-algo operator(got:%s)", | |||||
| opr->dyn_typeinfo()->name); | |||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| } // namespace intl | } // namespace intl | ||||
| } // namespace gopt | } // namespace gopt | ||||
| @@ -21,9 +21,7 @@ namespace intl { | |||||
| #define FOREACH_FORMAT_AWARE_OPR(cb) \ | #define FOREACH_FORMAT_AWARE_OPR(cb) \ | ||||
| cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) \ | cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) \ | ||||
| cb(WarpPerspective) cb(Resize) | cb(WarpPerspective) cb(Resize) | ||||
| #if 0 | |||||
| bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); | bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); | ||||
| #endif | |||||
| VarNode* modify_opr_format( | VarNode* modify_opr_format( | ||||
| opr::ConvBias::Param::Format opr_format, const VarNodeArray& i, | opr::ConvBias::Param::Format opr_format, const VarNodeArray& i, | ||||
| @@ -43,7 +43,8 @@ static inline size_t extra_alignment( | |||||
| size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | ||||
| size_t extra_alignment = | size_t extra_alignment = | ||||
| alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1; | alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1; | ||||
| if (target_formats == TensorFormats::NHWC) | |||||
| if (target_formats == TensorFormats::NHWC || | |||||
| target_formats == TensorFormats::KRSC) | |||||
| channel_alignment = extra_alignment * channel_alignment / | channel_alignment = extra_alignment * channel_alignment / | ||||
| gcd(channel_alignment, extra_alignment); | gcd(channel_alignment, extra_alignment); | ||||
| return channel_alignment; | return channel_alignment; | ||||
| @@ -60,10 +61,12 @@ static inline std::tuple<size_t, size_t> extra_alignment( | |||||
| size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | ||||
| size_t extra_alignment = | size_t extra_alignment = | ||||
| alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1; | alignment_in_bits >= dtype_bits ? alignment_in_bits / dtype_bits : 1; | ||||
| if (key.input_format == TensorFormats::NHWC) | |||||
| if (key.input_format == TensorFormats::NHWC || | |||||
| key.input_format == TensorFormats::KRSC) | |||||
| input_channel_alignment = input_channel_alignment * extra_alignment / | input_channel_alignment = input_channel_alignment * extra_alignment / | ||||
| gcd(input_channel_alignment, extra_alignment); | gcd(input_channel_alignment, extra_alignment); | ||||
| if (key.output_format == TensorFormats::NHWC) | |||||
| if (key.output_format == TensorFormats::NHWC || | |||||
| key.output_format == TensorFormats::KRSC) | |||||
| output_channel_alignment = output_channel_alignment * extra_alignment / | output_channel_alignment = output_channel_alignment * extra_alignment / | ||||
| gcd(output_channel_alignment, extra_alignment); | gcd(output_channel_alignment, extra_alignment); | ||||
| return std::make_tuple(input_channel_alignment, output_channel_alignment); | return std::make_tuple(input_channel_alignment, output_channel_alignment); | ||||
| @@ -62,6 +62,16 @@ enum class TensorFormats : uint32_t { | |||||
| KCRS = 24, ///< [K, C, R, S] | KCRS = 24, ///< [K, C, R, S] | ||||
| GKCRS = 25, ///< [G, K, C, R, S] | GKCRS = 25, ///< [G, K, C, R, S] | ||||
| C11RS = 26, ///< [C, 1, 1, R, S] | C11RS = 26, ///< [C, 1, 1, R, S] | ||||
| // NHWC | |||||
| KRSC = 27, /// < [K, R, S, C] | |||||
| // NCHW32 | |||||
| KCRSc32 = 28, ///<[K, C/32, R, S, C%32] | |||||
| // NCHW64 | |||||
| KCRSc64 = 29, ///<[K, C/64, R, S, C%64] | |||||
| // CHWN4 | |||||
| CRSKc4 = 30, ///< [C/4, R, S, K, C%4] | |||||
| }; | }; | ||||
| class ReformatManager : public NonCopyableObj { | class ReformatManager : public NonCopyableObj { | ||||