GitOrigin-RevId: 6d5b55d7fc
tags/v1.7.0
| @@ -830,9 +830,9 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet | |||||
| src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); | src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); | ||||
| dst[4] = 32; | dst[4] = 32; | ||||
| } else if (param().format == Param::Format::NCHW88) { | } else if (param().format == Param::Format::NCHW88) { | ||||
| megdnn_assert( | |||||
| src.ndim == 5 || (src.ndim == 4 && src[1] <= 8), | |||||
| "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim); | |||||
| megdnn_assert(src.ndim == 5 || src.ndim == 4, | |||||
| "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", | |||||
| src.ndim); | |||||
| dst.ndim = 5; | dst.ndim = 5; | ||||
| dst[0] = src[0]; | dst[0] = src[0]; | ||||
| auto oc = cflt.ocpg * cflt.group; | auto oc = cflt.ocpg * cflt.group; | ||||
| @@ -850,12 +850,11 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet | |||||
| "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group); | "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group); | ||||
| } | } | ||||
| } else if ( | |||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW44_DOT) { | |||||
| megdnn_assert( | |||||
| src.ndim == 5 || (src.ndim == 4 && src[1] <= 4), | |||||
| "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim); | |||||
| } else if (param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW44_DOT) { | |||||
| megdnn_assert(src.ndim == 5 || src.ndim == 4, | |||||
| "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", | |||||
| src.ndim); | |||||
| dst.ndim = 5; | dst.ndim = 5; | ||||
| dst[0] = src[0]; | dst[0] = src[0]; | ||||
| auto oc = cflt.ocpg * cflt.group; | auto oc = cflt.ocpg * cflt.group; | ||||
| @@ -47,7 +47,7 @@ private: | |||||
| struct Value { | struct Value { | ||||
| OperatorNodeBase* opr; | OperatorNodeBase* opr; | ||||
| const State* prev; | const State* prev; | ||||
| OprFormat opr_fmt; | |||||
| OprFormatConfigID cfg_id; | |||||
| float time; | float time; | ||||
| ///! index in the topo order of the correspoding operator | ///! index in the topo order of the correspoding operator | ||||
| size_t opr_idx; | size_t opr_idx; | ||||
| @@ -87,14 +87,15 @@ private: | |||||
| }; | }; | ||||
| /*! | /*! | ||||
| * \brief get the tensor formats configuration for the operator with | * \brief get the tensor formats configuration for the operator with | ||||
| * particular op format \param[out] var2fmts hashmap that maps varnode to | |||||
| * actual tensor formats of the op format configuration \param[in] opr given | |||||
| * operator \param[in] opr_fmt given op format, an enum type argument which | |||||
| * indicates the op format configuration. \param[in] ctx context | |||||
| * particular op format | |||||
| * \param[out] var2fmts hashmap that maps varnode to actual tensor formats of the op | |||||
| * format configuration \param[in] opr given operator \param[in] opr_fmt given op | |||||
| * format, an enum type argument which indicates the op format configuration. | |||||
| * \param[in] ctx context | |||||
| */ | */ | ||||
| TensorFormats get_io_formats( | TensorFormats get_io_formats( | ||||
| ThinHashMap<VarNode*, TensorFormats>& var2fmts, const OperatorNodeBase* opr, | ThinHashMap<VarNode*, TensorFormats>& var2fmts, const OperatorNodeBase* opr, | ||||
| OprFormat opr_fmt, const Context& ctx); | |||||
| OprFormatConfigID config_id, const Context& ctx); | |||||
| /*! | /*! | ||||
| * \brief compute the distace of two states of the given varnode | * \brief compute the distace of two states of the given varnode | ||||
| * \param[in] from the source state | * \param[in] from the source state | ||||
| @@ -140,28 +141,35 @@ private: | |||||
| TensorFormats DynamicProgrammingSolver::Impl::get_io_formats( | TensorFormats DynamicProgrammingSolver::Impl::get_io_formats( | ||||
| ThinHashMap<VarNode*, TensorFormats>& var2fmts, const OperatorNodeBase* opr, | ThinHashMap<VarNode*, TensorFormats>& var2fmts, const OperatorNodeBase* opr, | ||||
| OprFormat opr_fmt, const Context& ctx) { | |||||
| OprFormatConfigID config_id, const Context& ctx) { | |||||
| auto&& rst = ctx.rst; | auto&& rst = ctx.rst; | ||||
| auto&& opr_configs = ctx.opr_configs; | auto&& opr_configs = ctx.opr_configs; | ||||
| auto iter = opr_configs.find(opr->dyn_typeinfo()); | auto iter = opr_configs.find(opr->dyn_typeinfo()); | ||||
| Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | ||||
| Maybe<OprFormat> opr_fmt = None; | |||||
| if (iter != opr_configs.end()) { | if (iter != opr_configs.end()) { | ||||
| fmtcfg = (*iter->second.at(opr_fmt))(opr); | |||||
| fmtcfg = (*iter->second.at(config_id))(opr); | |||||
| } else { | |||||
| opr_fmt = OprTensorFormatsConfiguration::safe_cast_to_opr_format(config_id); | |||||
| } | } | ||||
| TensorFormats out_fmt; | TensorFormats out_fmt; | ||||
| if (fmtcfg.valid()) | if (fmtcfg.valid()) | ||||
| out_fmt = fmtcfg.val().output_tensor_formats[0]; | out_fmt = fmtcfg.val().output_tensor_formats[0]; | ||||
| else | |||||
| out_fmt = opr_format_to_tensor_formats(opr_fmt); | |||||
| else { | |||||
| mgb_assert(opr_fmt.valid()); | |||||
| out_fmt = opr_format_to_tensor_formats(opr_fmt.val()); | |||||
| } | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| auto&& var = opr->input(i); | auto&& var = opr->input(i); | ||||
| auto iter = rst.var_record.find(var); | auto iter = rst.var_record.find(var); | ||||
| if (iter != rst.var_record.end()) { | if (iter != rst.var_record.end()) { | ||||
| if (fmtcfg.valid()) | if (fmtcfg.valid()) | ||||
| var2fmts[var] = fmtcfg.val().input_tensor_formats[i]; | var2fmts[var] = fmtcfg.val().input_tensor_formats[i]; | ||||
| else | |||||
| var2fmts[var] = opr_format_to_tensor_formats(opr_fmt); | |||||
| else { | |||||
| mgb_assert(opr_fmt.valid()); | |||||
| var2fmts[var] = opr_format_to_tensor_formats(opr_fmt.val()); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return out_fmt; | return out_fmt; | ||||
| @@ -342,13 +350,13 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( | |||||
| cuts.emplace_back(Cut{}); | cuts.emplace_back(Cut{}); | ||||
| auto& states = cuts.back().states; | auto& states = cuts.back().states; | ||||
| for (const auto& record : records) { | for (const auto& record : records) { | ||||
| auto opr_fmt = record.first; | |||||
| auto cfg_id = record.first; | |||||
| float opr_time = record.second; | float opr_time = record.second; | ||||
| ThinHashMap<VarNode*, TensorFormats> ivar2fmts; | ThinHashMap<VarNode*, TensorFormats> ivar2fmts; | ||||
| auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx); | |||||
| auto out_fmt = get_io_formats(ivar2fmts, opr, cfg_id, ctx); | |||||
| const auto& edge = edges[cur]; | const auto& edge = edges[cur]; | ||||
| State state(edge.size(), 0); | State state(edge.size(), 0); | ||||
| Value value{opr, nullptr, opr_fmt, 0.f, cur}; | |||||
| Value value{opr, nullptr, cfg_id, 0.f, cur}; | |||||
| float ovar_time = 0.f; | float ovar_time = 0.f; | ||||
| for (size_t i = 0; i < edge.size(); ++i) { | for (size_t i = 0; i < edge.size(); ++i) { | ||||
| auto&& var = edge[i]; | auto&& var = edge[i]; | ||||
| @@ -396,16 +404,16 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( | |||||
| const auto& records = it->second.costs; | const auto& records = it->second.costs; | ||||
| StateTable states; | StateTable states; | ||||
| for (const auto& record : records) { | for (const auto& record : records) { | ||||
| auto opr_fmt = record.first; | |||||
| auto cfg_id = record.first; | |||||
| float opr_time = record.second; | float opr_time = record.second; | ||||
| ThinHashMap<VarNode*, TensorFormats> ivar2fmts; | ThinHashMap<VarNode*, TensorFormats> ivar2fmts; | ||||
| auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx); | |||||
| auto out_fmt = get_io_formats(ivar2fmts, opr, cfg_id, ctx); | |||||
| for (const auto& kv : cuts.back().states) { | for (const auto& kv : cuts.back().states) { | ||||
| auto&& prev_state = kv.first; | auto&& prev_state = kv.first; | ||||
| float prev_time = kv.second.time; | float prev_time = kv.second.time; | ||||
| const auto& edge = edges[cur]; | const auto& edge = edges[cur]; | ||||
| State state(edge.size(), 0); | State state(edge.size(), 0); | ||||
| Value value{opr, &prev_state, opr_fmt, 0.f, cur}; | |||||
| Value value{opr, &prev_state, cfg_id, 0.f, cur}; | |||||
| float ovar_time = 0.f; | float ovar_time = 0.f; | ||||
| for (size_t i = 0; i < edge.size(); ++i) { | for (size_t i = 0; i < edge.size(); ++i) { | ||||
| auto&& var = edge[i]; | auto&& var = edge[i]; | ||||
| @@ -482,7 +490,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( | |||||
| /// backward pass to generate the solution | /// backward pass to generate the solution | ||||
| float min_time = std::numeric_limits<float>::max(); | float min_time = std::numeric_limits<float>::max(); | ||||
| OperatorNodeBase* cur_opr = nullptr; | OperatorNodeBase* cur_opr = nullptr; | ||||
| OprFormat min_fmt = OprFormat::NCHW; | |||||
| OprFormatConfigID min_cfg = OprFormatConfigID::NCHW; | |||||
| const State* pstate = nullptr; | const State* pstate = nullptr; | ||||
| for (auto&& kv : cuts.back().states) { | for (auto&& kv : cuts.back().states) { | ||||
| auto&& v = kv.second; | auto&& v = kv.second; | ||||
| @@ -490,7 +498,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( | |||||
| cur_opr = v.opr; | cur_opr = v.opr; | ||||
| pstate = v.prev; | pstate = v.prev; | ||||
| min_time = v.time; | min_time = v.time; | ||||
| min_fmt = v.opr_fmt; | |||||
| min_cfg = v.cfg_id; | |||||
| ///! just to check the tensor formats of the output varnode | ///! just to check the tensor formats of the output varnode | ||||
| auto&& k = kv.first; | auto&& k = kv.first; | ||||
| size_t opr_idx = v.opr_idx; | size_t opr_idx = v.opr_idx; | ||||
| @@ -505,10 +513,10 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( | |||||
| } | } | ||||
| mgb_assert(cur_opr != nullptr); | mgb_assert(cur_opr != nullptr); | ||||
| mgb_log_debug( | mgb_log_debug( | ||||
| "opr:%s;format:%s;time:%f", cur_opr->cname(), opr_format_to_string(min_fmt), | |||||
| "opr:%s;config:%s;time:%f", cur_opr->cname(), config_id_to_string(min_cfg), | |||||
| min_time); | min_time); | ||||
| solution.insert({cur_opr, min_fmt}); | |||||
| solution.insert({cur_opr, min_cfg}); | |||||
| cur = cuts.size() - 2; | cur = cuts.size() - 2; | ||||
| while (pstate) { | while (pstate) { | ||||
| auto val = cuts[cur].states[*pstate]; | auto val = cuts[cur].states[*pstate]; | ||||
| @@ -522,9 +530,9 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( | |||||
| } | } | ||||
| } | } | ||||
| mgb_log_debug( | mgb_log_debug( | ||||
| "opr:%s;format:%s;time:%f", val.opr->cname(), | |||||
| opr_format_to_string(val.opr_fmt), val.time); | |||||
| solution.insert({val.opr, val.opr_fmt}); | |||||
| "opr:%s;cofig:%s;time:%f", val.opr->cname(), | |||||
| config_id_to_string(val.cfg_id), val.time); | |||||
| solution.insert({val.opr, val.cfg_id}); | |||||
| pstate = val.prev; | pstate = val.prev; | ||||
| cur--; | cur--; | ||||
| } | } | ||||
| @@ -22,6 +22,7 @@ using namespace gopt; | |||||
| namespace { | namespace { | ||||
| using OprFormat = LayoutTransformContext::OprFormat; | using OprFormat = LayoutTransformContext::OprFormat; | ||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | using OprList = LayoutTransformContext::OprList; | ||||
| using Attribute = LayoutTransformContext::Attribute; | using Attribute = LayoutTransformContext::Attribute; | ||||
| using Target = LayoutTransformContext::Target; | using Target = LayoutTransformContext::Target; | ||||
| @@ -43,7 +44,7 @@ const char* target_to_string(Target target) { | |||||
| } | } | ||||
| std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | ||||
| OprFormat base_opr_format, TensorFormats base_tensor_format) { | |||||
| OprFormatConfigID base_config_id, TensorFormats base_tensor_format) { | |||||
| OprList opr_list = { | OprList opr_list = { | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| opr::ConvolutionForward::typeinfo(), | opr::ConvolutionForward::typeinfo(), | ||||
| @@ -58,34 +59,38 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||||
| SmallVector<TensorFormats> available_tensor_formats = { | SmallVector<TensorFormats> available_tensor_formats = { | ||||
| TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | ||||
| TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | ||||
| Attribute attribute = { | Attribute attribute = { | ||||
| base_opr_format, base_tensor_format, Target::CUDA, | |||||
| base_config_id, base_tensor_format, Target::CUDA, | |||||
| LayoutTransformContext::ReformatAttribute::AUTO_PADDING_NHWC}; | LayoutTransformContext::ReformatAttribute::AUTO_PADDING_NHWC}; | ||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | std::move(opr_list), std::move(available_tensor_formats), attribute); | ||||
| ctx->add_opr_config( | ctx->add_opr_config( | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW32, | |||||
| OprFormat::NCHW64, OprFormat::CHWN4}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NHWC, | |||||
| OprFormatConfigID::NCHW4_NCHW32, OprFormatConfigID::NCHW32_NCHW4, | |||||
| OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::ConvolutionForward::typeinfo(), | opr::ConvolutionForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NCHW4}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::ConvolutionBackwardData::typeinfo(), | opr::ConvolutionBackwardData::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NCHW4, OprFormat::NHWC}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4, | |||||
| OprFormatConfigID::NHWC}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::PoolingForward::typeinfo(), | opr::PoolingForward::typeinfo(), | ||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | |||||
| OprFormat::NCHW64, OprFormat::CHWN4}) | |||||
| {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::NHWC, OprFormatConfigID::NCHW64, | |||||
| OprFormatConfigID::CHWN4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::WarpPerspectiveForward::typeinfo(), | opr::WarpPerspectiveForward::typeinfo(), | ||||
| {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); | |||||
| {OprFormatConfigID::NHWC, OprFormatConfigID::NCHW4, | |||||
| OprFormatConfigID::NCHW64}); | |||||
| return ctx; | return ctx; | ||||
| } | } | ||||
| std::unique_ptr<LayoutTransformContext> make_arm_ctx( | std::unique_ptr<LayoutTransformContext> make_arm_ctx( | ||||
| OprFormat base_opr_format, TensorFormats base_tensor_format) { | |||||
| OprFormatConfigID base_config_id, TensorFormats base_tensor_format) { | |||||
| OprList opr_list = { | OprList opr_list = { | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| opr::ConvolutionForward::typeinfo(), | opr::ConvolutionForward::typeinfo(), | ||||
| @@ -101,57 +106,64 @@ std::unique_ptr<LayoutTransformContext> make_arm_ctx( | |||||
| SmallVector<TensorFormats> available_tensor_formats = { | SmallVector<TensorFormats> available_tensor_formats = { | ||||
| TensorFormats::NCHW, TensorFormats::NCHWc4, | TensorFormats::NCHW, TensorFormats::NCHWc4, | ||||
| DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; | DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; | ||||
| Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM}; | |||||
| Attribute attribute = {base_config_id, base_tensor_format, Target::ARM}; | |||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | std::move(opr_list), std::move(available_tensor_formats), attribute); | ||||
| ctx->add_opr_config( | ctx->add_opr_config( | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NCHW44, DNN_INC_FLOAT16(OprFormat::NCHW88), | |||||
| OprFormat::NCHW44_DOT}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | |||||
| OprFormatConfigID::NCHW44_HYBRID, | |||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88), | |||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88_HYBRID), | |||||
| OprFormatConfigID::NCHW44_DOT, OprFormatConfigID::NCHW44_DOT_HYBRID}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::ConvolutionForward::typeinfo(), | opr::ConvolutionForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NCHW44, | |||||
| DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | |||||
| OprFormatConfigID::NCHW44_HYBRID, | |||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88), | |||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88_HYBRID), | |||||
| OprFormatConfigID::NCHW44_DOT, | |||||
| OprFormatConfigID::NCHW44_DOT_HYBRID}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::PoolingForward::typeinfo(), | opr::PoolingForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NCHW44, | |||||
| DNN_INC_FLOAT16(OprFormat::NCHW88)}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | |||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88)}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::ResizeForward::typeinfo(), | opr::ResizeForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NCHW44, | |||||
| DNN_INC_FLOAT16(OprFormat::NCHW88)}); | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | |||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88)}); | |||||
| return ctx; | return ctx; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| /* ================= LayoutTransformContext ==================*/ | /* ================= LayoutTransformContext ==================*/ | ||||
| LayoutTransformContext& LayoutTransformContext::add_opr_config( | LayoutTransformContext& LayoutTransformContext::add_opr_config( | ||||
| Typeinfo* opr, OprFormat opr_format) { | |||||
| Typeinfo* opr, OprFormatConfigID config_id) { | |||||
| auto& dispatchers = m_opr_configs[opr]; | auto& dispatchers = m_opr_configs[opr]; | ||||
| dispatchers[opr_format] = | |||||
| dispatchers[config_id] = | |||||
| OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | ||||
| opr, opr_format); | |||||
| opr, config_id); | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| LayoutTransformContext& LayoutTransformContext::add_opr_config( | LayoutTransformContext& LayoutTransformContext::add_opr_config( | ||||
| Typeinfo* opr, SmallVector<OprFormat> opr_formats) { | |||||
| Typeinfo* opr, SmallVector<OprFormatConfigID> config_ids) { | |||||
| auto& dispatchers = m_opr_configs[opr]; | auto& dispatchers = m_opr_configs[opr]; | ||||
| for (auto opr_fmt : opr_formats) { | |||||
| dispatchers[opr_fmt] = | |||||
| OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | |||||
| opr, opr_fmt); | |||||
| for (auto cfg : config_ids) { | |||||
| dispatchers[cfg] = | |||||
| OprTensorFormatsConfiguration::find_dispatcher_by_type_format(opr, cfg); | |||||
| } | } | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make( | std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make( | ||||
| Target target, OprFormat base_opr_format, TensorFormats base_tensor_format) { | |||||
| Target target, OprFormatConfigID base_config_id, | |||||
| TensorFormats base_tensor_format) { | |||||
| switch (target) { | switch (target) { | ||||
| case Target::CUDA: | case Target::CUDA: | ||||
| return make_cuda_ctx(base_opr_format, base_tensor_format); | |||||
| return make_cuda_ctx(base_config_id, base_tensor_format); | |||||
| case Target::ARM: | case Target::ARM: | ||||
| return make_arm_ctx(base_opr_format, base_tensor_format); | |||||
| return make_arm_ctx(base_config_id, base_tensor_format); | |||||
| default: | default: | ||||
| mgb_assert(false, "unsupported target %s\n", target_to_string(target)); | mgb_assert(false, "unsupported target %s\n", target_to_string(target)); | ||||
| } | } | ||||
| @@ -43,6 +43,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| auto partitions = extractor.extract(opt.graph().endpoint_vars()); | auto partitions = extractor.extract(opt.graph().endpoint_vars()); | ||||
| using Solution = SolverBase::Solution; | using Solution = SolverBase::Solution; | ||||
| using OprFormat = SolverBase::OprFormat; | |||||
| Solution solution; | Solution solution; | ||||
| ThinHashSet<VarNode*> endpoint_vars; | ThinHashSet<VarNode*> endpoint_vars; | ||||
| for (auto&& partition : partitions) { | for (auto&& partition : partitions) { | ||||
| @@ -60,7 +61,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| auto&& opr_configs = m_ctx->opr_configs(); | auto&& opr_configs = m_ctx->opr_configs(); | ||||
| auto&& base_fmt = m_ctx->attribute().base_tensor_formats; | auto&& base_fmt = m_ctx->attribute().base_tensor_formats; | ||||
| auto&& base_opr_fmt = m_ctx->attribute().base_opr_format; | |||||
| auto&& base_cfg_id = m_ctx->attribute().base_config_id; | |||||
| auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | ||||
| ThinHashMap<VarNode*, TensorFormats> var2fmts; | ThinHashMap<VarNode*, TensorFormats> var2fmts; | ||||
| static ThinHashSet<Typeinfo*> format_aware_oprs = { | static ThinHashSet<Typeinfo*> format_aware_oprs = { | ||||
| @@ -69,18 +70,25 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| #undef cb | #undef cb | ||||
| }; | }; | ||||
| auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
| auto on_opr = [&opr_configs, &base_fmt, &base_opr_fmt, &reformat_attribute, | |||||
| auto on_opr = [&opr_configs, &base_fmt, &base_cfg_id, &reformat_attribute, | |||||
| &rewriter, &solution, &var2fmts, | &rewriter, &solution, &var2fmts, | ||||
| &endpoint_vars](OperatorNodeBase* opr) { | &endpoint_vars](OperatorNodeBase* opr) { | ||||
| auto it = solution.find(opr); | auto it = solution.find(opr); | ||||
| if (it != solution.end()) { | if (it != solution.end()) { | ||||
| auto opr_fmt = it->second; | |||||
| auto cfg_id = it->second; | |||||
| auto find = opr_configs.find(opr->dyn_typeinfo()); | auto find = opr_configs.find(opr->dyn_typeinfo()); | ||||
| Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | ||||
| Maybe<OprTensorFormatsConfiguration> basecfg = None; | Maybe<OprTensorFormatsConfiguration> basecfg = None; | ||||
| Maybe<OprFormat> opr_fmt = None; | |||||
| if (find != opr_configs.end()) { | if (find != opr_configs.end()) { | ||||
| fmtcfg = (*find->second.at(opr_fmt))(opr); | |||||
| basecfg = (*find->second.at(base_opr_fmt))(opr); | |||||
| fmtcfg = (*find->second.at(cfg_id))(opr); | |||||
| auto _ = OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | |||||
| opr->dyn_typeinfo(), base_cfg_id); | |||||
| basecfg = (*_)(opr); | |||||
| opr_fmt = fmtcfg.val().opr_format; | |||||
| } else { | |||||
| opr_fmt = | |||||
| OprTensorFormatsConfiguration::safe_cast_to_opr_format(cfg_id); | |||||
| } | } | ||||
| VarNodeArray new_inp; | VarNodeArray new_inp; | ||||
| size_t nr_inps = opr->input().size(); | size_t nr_inps = opr->input().size(); | ||||
| @@ -89,7 +97,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| nr_inps = std::min(fmtcfg.val().input_tensor_formats.size(), nr_inps); | nr_inps = std::min(fmtcfg.val().input_tensor_formats.size(), nr_inps); | ||||
| out_fmt = fmtcfg.val().output_tensor_formats[0]; | out_fmt = fmtcfg.val().output_tensor_formats[0]; | ||||
| } else { | } else { | ||||
| out_fmt = opr_format_to_tensor_formats(opr_fmt); | |||||
| out_fmt = opr_format_to_tensor_formats(opr_fmt.val()); | |||||
| } | } | ||||
| new_inp.resize(nr_inps); | new_inp.resize(nr_inps); | ||||
| for (size_t i = 0; i < nr_inps; ++i) { | for (size_t i = 0; i < nr_inps; ++i) { | ||||
| @@ -103,7 +111,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| from = find->second; | from = find->second; | ||||
| } | } | ||||
| auto to = fmtcfg.valid() ? fmtcfg.val().input_tensor_formats[i] | auto to = fmtcfg.valid() ? fmtcfg.val().input_tensor_formats[i] | ||||
| : opr_format_to_tensor_formats(opr_fmt); | |||||
| : opr_format_to_tensor_formats(opr_fmt.val()); | |||||
| bool is_parameter = | bool is_parameter = | ||||
| fmtcfg.valid() && | fmtcfg.valid() && | ||||
| fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; | fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; | ||||
| @@ -119,7 +127,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| var->dtype().enumv()}; | var->dtype().enumv()}; | ||||
| if (is_parameter) { | if (is_parameter) { | ||||
| auto aligned_desc = | auto aligned_desc = | ||||
| ReformatManager::make_aligned_desc(base_fmt, out_fmt); | |||||
| ReformatManager::make_aligned_desc(from, out_fmt); | |||||
| reformat = ReformatManager::instance() | reformat = ReformatManager::instance() | ||||
| .auto_aligned_reformat_weight( | .auto_aligned_reformat_weight( | ||||
| var, key, aligned_desc); | var, key, aligned_desc); | ||||
| @@ -134,7 +142,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| } | } | ||||
| VarNode* new_out; | VarNode* new_out; | ||||
| if (format_aware_oprs.count(opr->dyn_typeinfo()) > 0) { | if (format_aware_oprs.count(opr->dyn_typeinfo()) > 0) { | ||||
| new_out = intl::modify_opr_format(opr_fmt, new_inp, opr); | |||||
| new_out = intl::modify_opr_format(opr_fmt.val(), new_inp, opr); | |||||
| } else { | } else { | ||||
| new_out = serialization::copy_opr_shallow(*opr, new_inp, opr->config()) | new_out = serialization::copy_opr_shallow(*opr, new_inp, opr->config()) | ||||
| ->output(0); | ->output(0); | ||||
| @@ -170,9 +178,8 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
| ovar, new_ovar, | ovar, new_ovar, | ||||
| mgb_cstr_log(ssprintf( | mgb_cstr_log(ssprintf( | ||||
| "replace opr(%s) to new opr " | "replace opr(%s) to new opr " | ||||
| "format(%s)", | |||||
| opr->cname(), | |||||
| opr_format_to_string(opr_fmt)) | |||||
| "format config(%s)", | |||||
| opr->cname(), config_id_to_string(cfg_id)) | |||||
| .c_str())); | .c_str())); | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -24,7 +24,7 @@ namespace intl { | |||||
| bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); | bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr); | ||||
| VarNode* modify_opr_format( | VarNode* modify_opr_format( | ||||
| opr::ConvBias::Param::Format opr_format, const VarNodeArray& i, | |||||
| opr::Convolution::Param::Format opr_format, const VarNodeArray& i, | |||||
| const cg::OperatorNodeBase* opr); | const cg::OperatorNodeBase* opr); | ||||
| } // namespace intl | } // namespace intl | ||||
| @@ -25,7 +25,8 @@ MIDOUT_DECL(megbrain_opr_tensor_formats_config) | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace cg; | using namespace cg; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| using OprFormat = opr::ConvBias::Param::Format; | |||||
| using OprFormat = OprTensorFormatsConfiguration::OprFormat; | |||||
| using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID; | |||||
| namespace { | namespace { | ||||
| template <typename Opr> | template <typename Opr> | ||||
| @@ -56,19 +57,22 @@ static bool is_channel_wise_conv(const OperatorNodeBase* opr) { | |||||
| if (format == Opr::Param::Format::NCHW) { | if (format == Opr::Param::Format::NCHW) { | ||||
| ocpg = weight_shp[1], icpg = weight_shp[2]; | ocpg = weight_shp[1], icpg = weight_shp[2]; | ||||
| return ocpg == 1 && icpg == 1; | return ocpg == 1 && icpg == 1; | ||||
| } else { | |||||
| mgb_assert(false, "invalid opr format(%s)", opr_format_to_string(format)); | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| template <OprFormat opr_format_> | |||||
| template <OprFormatConfigID config_id> | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl; | struct OprSingleInOutTensorFormatsDispatcherImpl; | ||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> { | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW; | config.opr_format = OprFormat::NCHW; | ||||
| config.config_id = OprFormatConfigID::NCHW; | |||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| config.input_tensor_types = {TensorType::FEATURE}; | config.input_tensor_types = {TensorType::FEATURE}; | ||||
| config.output_dtypes = {opr->output(0)->dtype().enumv()}; | config.output_dtypes = {opr->output(0)->dtype().enumv()}; | ||||
| @@ -79,11 +83,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> { | |||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> { | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW44; | config.opr_format = OprFormat::NCHW44; | ||||
| config.config_id = OprFormatConfigID::NCHW44; | |||||
| bool available = true; | bool available = true; | ||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32; | available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32; | ||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| @@ -99,11 +104,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> { | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> { | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW88; | config.opr_format = OprFormat::NCHW88; | ||||
| config.config_id = OprFormatConfigID::NCHW88; | |||||
| bool available = true; | bool available = true; | ||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16; | available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16; | ||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| @@ -119,11 +125,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> { | |||||
| #endif | #endif | ||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> { | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW4> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW4; | config.opr_format = OprFormat::NCHW4; | ||||
| config.config_id = OprFormatConfigID::NCHW4; | |||||
| bool available = true; | bool available = true; | ||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | ||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| @@ -139,11 +146,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> { | |||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::CHWN4> { | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::CHWN4> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::CHWN4; | config.opr_format = OprFormat::CHWN4; | ||||
| config.config_id = OprFormatConfigID::CHWN4; | |||||
| bool available = true; | bool available = true; | ||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | ||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| @@ -159,11 +167,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::CHWN4> { | |||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW32> { | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW32> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW32; | config.opr_format = OprFormat::NCHW32; | ||||
| config.config_id = OprFormatConfigID::NCHW32; | |||||
| bool available = true; | bool available = true; | ||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | ||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| @@ -179,11 +188,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW32> { | |||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NHWC> { | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NHWC> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NHWC; | config.opr_format = OprFormat::NHWC; | ||||
| config.config_id = OprFormatConfigID::NHWC; | |||||
| bool available = true; | bool available = true; | ||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || | available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || | ||||
| opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; | opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; | ||||
| @@ -200,11 +210,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NHWC> { | |||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW64> { | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW64> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW64; | config.opr_format = OprFormat::NCHW64; | ||||
| config.config_id = OprFormatConfigID::NCHW64; | |||||
| bool available = true; | bool available = true; | ||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || | available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || | ||||
| opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; | opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; | ||||
| @@ -220,16 +231,17 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW64> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename Opr, OprFormat opr_format_> | |||||
| template <typename Opr, OprFormatConfigID config_id> | |||||
| struct ConvTensorFormatsDispatcherImpl; | struct ConvTensorFormatsDispatcherImpl; | ||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW> { | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW; | config.opr_format = OprFormat::NCHW; | ||||
| config.config_id = OprFormatConfigID::NCHW; | |||||
| // setup dtypes | // setup dtypes | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | ||||
| @@ -260,37 +272,35 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW> { | |||||
| }; | }; | ||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NHWC> { | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NHWC> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NHWC; | config.opr_format = OprFormat::NHWC; | ||||
| config.config_id = OprFormatConfigID::NHWC; | |||||
| auto check_dtype = [](const DType& dt) { | |||||
| bool i4_config = dt.enumv() == DTypeEnum::Quantized4Asymm || | |||||
| dt.enumv() == DTypeEnum::QuantizedS4; | |||||
| bool i8_config = dt.enumv() == DTypeEnum::QuantizedS8; | |||||
| return i4_config || i8_config; | |||||
| }; | |||||
| bool available = true; | bool available = true; | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| if (i == 2) | if (i == 2) | ||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; | available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; | ||||
| else { | else { | ||||
| bool i4_config = | |||||
| opr->input(i)->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||||
| opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS4; | |||||
| bool i8_config = | |||||
| opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| available &= (i4_config || i8_config); | |||||
| available &= check_dtype(opr->input(i)->dtype()); | |||||
| } | } | ||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | ||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | ||||
| config.input_tensor_types.emplace_back(tensor_type); | config.input_tensor_types.emplace_back(tensor_type); | ||||
| } | } | ||||
| bool i4_config = | |||||
| opr->output(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||||
| opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS4; | |||||
| bool i8_config = opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| available &= (i4_config || i8_config); | |||||
| available &= check_dtype(opr->output(0)->dtype()); | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | ||||
| config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
| TensorFormats::NHWC, TensorFormats::NHWC, TensorFormats::NHWC, | |||||
| TensorFormats::NHWC, TensorFormats::KRSC, TensorFormats::NHWC, | |||||
| TensorFormats::NHWC}; | TensorFormats::NHWC}; | ||||
| config.output_tensor_formats = {TensorFormats::NHWC}; | config.output_tensor_formats = {TensorFormats::NHWC}; | ||||
| if (available) | if (available) | ||||
| @@ -300,12 +310,13 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NHWC> { | |||||
| }; | }; | ||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW4> { | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW4> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW4; | config.opr_format = OprFormat::NCHW4; | ||||
| config.config_id = OprFormatConfigID::NCHW4; | |||||
| bool available = true; | bool available = true; | ||||
| // setup dtypes | // setup dtypes | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| @@ -322,7 +333,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW4> { | |||||
| // setup tensor formats | // setup tensor formats | ||||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | ||||
| config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4, | |||||
| TensorFormats::NCHWc4, TensorFormats::KCRSc4, TensorFormats::NCHWc4, | |||||
| TensorFormats::NCHWc4}; | TensorFormats::NCHWc4}; | ||||
| } else { | } else { | ||||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | ||||
| @@ -344,12 +355,75 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW4> { | |||||
| }; | }; | ||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW32> { | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW4_NCHW32> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NCHW4_NCHW32; | |||||
| config.config_id = OprFormatConfigID::NCHW4_NCHW32; | |||||
| bool available = true; | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
| if (i == 2) | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; | |||||
| else | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
| config.input_tensor_types.emplace_back(tensor_type); | |||||
| } | |||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHWc4, TensorFormats::KCRSc4, TensorFormats::NCHWc32, | |||||
| TensorFormats::NCHWc32}; | |||||
| config.output_tensor_formats = {TensorFormats::NCHWc32}; | |||||
| if (available) | |||||
| return config; | |||||
| return None; | |||||
| } | |||||
| }; | |||||
| template <typename Opr> | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW4_NCHW> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NCHW4_NCHW; | |||||
| config.config_id = OprFormatConfigID::NCHW4_NCHW; | |||||
| bool available = true; | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
| if (i >= 2) | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; | |||||
| else | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
| config.input_tensor_types.emplace_back(tensor_type); | |||||
| } | |||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHWc4, TensorFormats::KCRSc4, TensorFormats::NCHW, | |||||
| TensorFormats::NCHW}; | |||||
| config.output_tensor_formats = {TensorFormats::NCHW}; | |||||
| if (available) | |||||
| return config; | |||||
| return None; | |||||
| } | |||||
| }; | |||||
| template <typename Opr> | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW32> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW32; | config.opr_format = OprFormat::NCHW32; | ||||
| config.config_id = OprFormatConfigID::NCHW32; | |||||
| bool available = true; | bool available = true; | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| if (i == 2) | if (i == 2) | ||||
| @@ -364,7 +438,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW32> { | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | ||||
| config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
| TensorFormats::NCHWc32, TensorFormats::NCHWc32, TensorFormats::NCHWc32, | |||||
| TensorFormats::NCHWc32, TensorFormats::KCRSc32, TensorFormats::NCHWc32, | |||||
| TensorFormats::NCHWc32}; | TensorFormats::NCHWc32}; | ||||
| config.output_tensor_formats = {TensorFormats::NCHWc32}; | config.output_tensor_formats = {TensorFormats::NCHWc32}; | ||||
| if (available) | if (available) | ||||
| @@ -374,12 +448,44 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW32> { | |||||
| }; | }; | ||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW64> { | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW32_NCHW4> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NCHW32_NCHW4; | |||||
| config.config_id = OprFormatConfigID::NCHW32_NCHW4; | |||||
| bool available = true; | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
| if (i == 2) | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; | |||||
| else | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
| config.input_tensor_types.emplace_back(tensor_type); | |||||
| } | |||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHWc32, TensorFormats::KCRSc32, TensorFormats::NCHWc4, | |||||
| TensorFormats::NCHWc4}; | |||||
| config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||||
| if (available) | |||||
| return config; | |||||
| return None; | |||||
| } | |||||
| }; | |||||
| template <typename Opr> | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW64> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW64; | config.opr_format = OprFormat::NCHW64; | ||||
| config.config_id = OprFormatConfigID::NCHW64; | |||||
| bool available = true; | bool available = true; | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| if (i == 2) | if (i == 2) | ||||
| @@ -397,7 +503,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW64> { | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | ||||
| config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
| TensorFormats::NCHWc64, TensorFormats::NCHWc64, TensorFormats::NCHWc64, | |||||
| TensorFormats::NCHWc64, TensorFormats::KCRSc64, TensorFormats::NCHWc64, | |||||
| TensorFormats::NCHWc64}; | TensorFormats::NCHWc64}; | ||||
| config.output_tensor_formats = {TensorFormats::NCHWc64}; | config.output_tensor_formats = {TensorFormats::NCHWc64}; | ||||
| if (available) | if (available) | ||||
| @@ -407,12 +513,13 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW64> { | |||||
| }; | }; | ||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> { | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::CHWN4> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::CHWN4; | config.opr_format = OprFormat::CHWN4; | ||||
| config.config_id = OprFormatConfigID::CHWN4; | |||||
| bool available = true; | bool available = true; | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| if (i == 2) | if (i == 2) | ||||
| @@ -427,7 +534,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> { | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | ||||
| config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
| TensorFormats::CHWNc4, TensorFormats::CHWNc4, TensorFormats::CHWNc4, | |||||
| TensorFormats::CHWNc4, TensorFormats::CRSKc4, TensorFormats::CHWNc4, | |||||
| TensorFormats::CHWNc4}; | TensorFormats::CHWNc4}; | ||||
| config.output_tensor_formats = {TensorFormats::CHWNc4}; | config.output_tensor_formats = {TensorFormats::CHWNc4}; | ||||
| if (available) | if (available) | ||||
| @@ -437,12 +544,13 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> { | |||||
| }; | }; | ||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW44; | config.opr_format = OprFormat::NCHW44; | ||||
| config.config_id = OprFormatConfigID::NCHW44; | |||||
| bool available = true; | bool available = true; | ||||
| // setup dtypes | // setup dtypes | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| @@ -477,14 +585,44 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename Opr> | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_HYBRID> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NCHW44; | |||||
| config.config_id = OprFormatConfigID::NCHW44_HYBRID; | |||||
| bool available = true; | |||||
| // setup dtypes | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
| config.input_tensor_types.emplace_back(tensor_type); | |||||
| } | |||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHW, TensorFormats::KRSCk4, TensorFormats::NCHWc4, | |||||
| TensorFormats::NCHWc4}; | |||||
| config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||||
| if (!available) | |||||
| return None; | |||||
| return config; | |||||
| } | |||||
| }; | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW88; | config.opr_format = OprFormat::NCHW88; | ||||
| config.config_id = OprFormatConfigID::NCHW88; | |||||
| bool available = true; | bool available = true; | ||||
| // setup dtypes | // setup dtypes | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| @@ -518,15 +656,46 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { | |||||
| return config; | return config; | ||||
| } | } | ||||
| }; | }; | ||||
| template <typename Opr> | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88_HYBRID> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NCHW88; | |||||
| config.config_id = OprFormatConfigID::NCHW88_HYBRID; | |||||
| bool available = true; | |||||
| // setup dtypes | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
| config.input_tensor_types.emplace_back(tensor_type); | |||||
| } | |||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | |||||
| // setup tensor formats | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHW, TensorFormats::KRSCk8, TensorFormats::NCHWc8, | |||||
| TensorFormats::NCHWc8}; | |||||
| config.output_tensor_formats = {TensorFormats::NCHWc8}; | |||||
| if (!available) | |||||
| return None; | |||||
| return config; | |||||
| } | |||||
| }; | |||||
| #endif | #endif | ||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> { | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW44_DOT; | config.opr_format = OprFormat::NCHW44_DOT; | ||||
| config.config_id = OprFormatConfigID::NCHW44_DOT; | |||||
| bool available = true; | bool available = true; | ||||
| // setup dtypes | // setup dtypes | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| @@ -566,14 +735,53 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename Opr> | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT_HYBRID> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NCHW44_DOT; | |||||
| config.config_id = OprFormatConfigID::NCHW44_DOT_HYBRID; | |||||
| bool available = true; | |||||
| // setup dtypes | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
| if (i == 2) { | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32; | |||||
| } else { | |||||
| available &= | |||||
| opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
| opr->input(i)->dtype().enumv() == DTypeEnum::Quantized8Asymm; | |||||
| } | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
| config.input_tensor_types.emplace_back(tensor_type); | |||||
| } | |||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
| opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | |||||
| // setup tensor formats | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NCHW, TensorFormats::KRSCk4, TensorFormats::NCHWc4, | |||||
| TensorFormats::NCHWc4}; | |||||
| config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||||
| if (!available) | |||||
| return None; | |||||
| return config; | |||||
| } | |||||
| }; | |||||
| template <> | template <> | ||||
| struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW> { | |||||
| struct ConvTensorFormatsDispatcherImpl< | |||||
| opr::ConvolutionBackwardData, OprFormatConfigID::NCHW> { | |||||
| using Opr = opr::ConvolutionBackwardData; | using Opr = opr::ConvolutionBackwardData; | ||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW; | config.opr_format = OprFormat::NCHW; | ||||
| config.config_id = OprFormatConfigID::NCHW; | |||||
| // setup dtypes | // setup dtypes | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | ||||
| @@ -584,7 +792,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat:: | |||||
| // setup tensor formats | // setup tensor formats | ||||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | ||||
| config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
| TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW, | |||||
| TensorFormats::KCRS, TensorFormats::NCHW, TensorFormats::NCHW, | |||||
| TensorFormats::NCHW}; | TensorFormats::NCHW}; | ||||
| } else { | } else { | ||||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | ||||
| @@ -604,13 +812,15 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat:: | |||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW4> { | |||||
| struct ConvTensorFormatsDispatcherImpl< | |||||
| opr::ConvolutionBackwardData, OprFormatConfigID::NCHW4> { | |||||
| using Opr = opr::ConvolutionBackwardData; | using Opr = opr::ConvolutionBackwardData; | ||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW4; | config.opr_format = OprFormat::NCHW4; | ||||
| config.config_id = OprFormatConfigID::NCHW4; | |||||
| bool available = true; | bool available = true; | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; | available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; | ||||
| @@ -622,7 +832,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat:: | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
| available &= conv.param().sparse == opr::ConvBias::Param::Sparse::DENSE; | available &= conv.param().sparse == opr::ConvBias::Param::Sparse::DENSE; | ||||
| config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
| TensorFormats::NCHWc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4, | |||||
| TensorFormats::KCRSc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4, | |||||
| TensorFormats::NCHWc4}; | TensorFormats::NCHWc4}; | ||||
| config.output_tensor_formats = {TensorFormats::NCHWc4}; | config.output_tensor_formats = {TensorFormats::NCHWc4}; | ||||
| if (available) | if (available) | ||||
| @@ -632,13 +842,15 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat:: | |||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NHWC> { | |||||
| struct ConvTensorFormatsDispatcherImpl< | |||||
| opr::ConvolutionBackwardData, OprFormatConfigID::NHWC> { | |||||
| using Opr = opr::ConvolutionBackwardData; | using Opr = opr::ConvolutionBackwardData; | ||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| const auto& conv = opr->cast_final_safe<Opr>(); | const auto& conv = opr->cast_final_safe<Opr>(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NHWC; | config.opr_format = OprFormat::NHWC; | ||||
| config.config_id = OprFormatConfigID::NHWC; | |||||
| bool available = true; | bool available = true; | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; | available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8; | ||||
| @@ -650,7 +862,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat:: | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
| available &= conv.param().sparse == opr::ConvBias::Param::Sparse::DENSE; | available &= conv.param().sparse == opr::ConvBias::Param::Sparse::DENSE; | ||||
| config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
| TensorFormats::NHWC, TensorFormats::NHWC, TensorFormats::NHWC, | |||||
| TensorFormats::KRSC, TensorFormats::NHWC, TensorFormats::NHWC, | |||||
| TensorFormats::NHWC}; | TensorFormats::NHWC}; | ||||
| config.output_tensor_formats = {TensorFormats::NHWC}; | config.output_tensor_formats = {TensorFormats::NHWC}; | ||||
| if (available) | if (available) | ||||
| @@ -661,7 +873,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat:: | |||||
| struct StaticData { | struct StaticData { | ||||
| struct KeyHash { | struct KeyHash { | ||||
| size_t operator()(const std::pair<Typeinfo*, OprFormat>& val) const { | |||||
| size_t operator()(const std::pair<Typeinfo*, OprFormatConfigID>& val) const { | |||||
| size_t h1 = mgb::hash<Typeinfo*>(val.first); | size_t h1 = mgb::hash<Typeinfo*>(val.first); | ||||
| size_t h2 = std::hash<uint32_t>()(static_cast<uint32_t>(val.second)); | size_t h2 = std::hash<uint32_t>()(static_cast<uint32_t>(val.second)); | ||||
| return mgb::hash_pair_combine(h1, h2); | return mgb::hash_pair_combine(h1, h2); | ||||
| @@ -670,28 +882,29 @@ struct StaticData { | |||||
| using OprTensorFormatsDispatcher = | using OprTensorFormatsDispatcher = | ||||
| OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; | OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; | ||||
| std::unordered_map< | std::unordered_map< | ||||
| std::pair<Typeinfo*, OprFormat>, OprTensorFormatsDispatcher, KeyHash> | |||||
| std::pair<Typeinfo*, OprFormatConfigID>, OprTensorFormatsDispatcher, | |||||
| KeyHash> | |||||
| typefmt2dispatcher; | typefmt2dispatcher; | ||||
| StaticData(); | StaticData(); | ||||
| }; | }; | ||||
| StaticData::StaticData() { | StaticData::StaticData() { | ||||
| #define OPR_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \ | |||||
| typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormat::_fmt}] = \ | |||||
| [](const OperatorNodeBase* opr) { \ | |||||
| MIDOUT_B(opr::_Opr, midout_iv(OprFormat::_fmt)) \ | |||||
| return ConvTensorFormatsDispatcherImpl< \ | |||||
| opr::_Opr, OprFormat::_fmt>::dispatch(opr); \ | |||||
| MIDOUT_E \ | |||||
| #define OPR_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \ | |||||
| typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormatConfigID::_fmt}] = \ | |||||
| [](const OperatorNodeBase* opr) { \ | |||||
| MIDOUT_B(opr::_Opr, midout_iv(OprFormatConfigID::_fmt)) \ | |||||
| return ConvTensorFormatsDispatcherImpl< \ | |||||
| opr::_Opr, OprFormatConfigID::_fmt>::dispatch(opr); \ | |||||
| MIDOUT_E \ | |||||
| } | } | ||||
| #define OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \ | |||||
| typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormat::_fmt}] = \ | |||||
| [](const OperatorNodeBase* opr) { \ | |||||
| MIDOUT_B(opr::_Opr, midout_iv(OprFormat::_fmt)) \ | |||||
| return OprSingleInOutTensorFormatsDispatcherImpl< \ | |||||
| OprFormat::_fmt>::dispatch(opr); \ | |||||
| MIDOUT_E \ | |||||
| #define OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \ | |||||
| typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormatConfigID::_fmt}] = \ | |||||
| [](const OperatorNodeBase* opr) { \ | |||||
| MIDOUT_B(opr::_Opr, midout_iv(OprFormatConfigID::_fmt)) \ | |||||
| return OprSingleInOutTensorFormatsDispatcherImpl< \ | |||||
| OprFormatConfigID::_fmt>::dispatch(opr); \ | |||||
| MIDOUT_E \ | |||||
| } | } | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW); | ||||
| @@ -703,16 +916,22 @@ StaticData::StaticData() { | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88_HYBRID); | |||||
| #endif | #endif | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID); | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID); | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44); | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88_HYBRID); | |||||
| #endif | #endif | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID); | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID); | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | ||||
| @@ -752,14 +971,14 @@ StaticData& static_data() { | |||||
| OprTensorFormatsConfiguration::OprTensorFormatsDispatcher* | OprTensorFormatsConfiguration::OprTensorFormatsDispatcher* | ||||
| OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | ||||
| Typeinfo* type, OprFormat opr_format) { | |||||
| Typeinfo* type, OprFormatConfigID config_id) { | |||||
| auto&& typefmt2dispatcher = static_data().typefmt2dispatcher; | auto&& typefmt2dispatcher = static_data().typefmt2dispatcher; | ||||
| auto iter = typefmt2dispatcher.find(std::make_pair(type, opr_format)); | |||||
| auto iter = typefmt2dispatcher.find(std::make_pair(type, config_id)); | |||||
| mgb_assert( | mgb_assert( | ||||
| iter != typefmt2dispatcher.end(), | iter != typefmt2dispatcher.end(), | ||||
| "cannot find OprTensorFormatsDispatcher for opr type(%s) and " | "cannot find OprTensorFormatsDispatcher for opr type(%s) and " | ||||
| "opr format(%s)", | |||||
| type->name, opr_format_to_string(opr_format)); | |||||
| "opr format configuration id(%s)", | |||||
| type->name, config_id_to_string(config_id)); | |||||
| return &iter->second; | return &iter->second; | ||||
| } | } | ||||
| @@ -64,7 +64,7 @@ void ProfilerCache::Key::build_blob_from_opr() { | |||||
| // serialize opr_format | // serialize opr_format | ||||
| m_blob_storage.append( | m_blob_storage.append( | ||||
| std::to_string(static_cast<uint32_t>(m_key_impl.opr_key.opr_format))); | |||||
| std::to_string(static_cast<uint32_t>(m_key_impl.opr_key.config_id))); | |||||
| // serialize extra_attribute | // serialize extra_attribute | ||||
| m_blob_storage.append( | m_blob_storage.append( | ||||
| @@ -29,30 +29,6 @@ using namespace gopt; | |||||
| using ReformatKey = ReformatManager::ReformatKey; | using ReformatKey = ReformatManager::ReformatKey; | ||||
| namespace { | namespace { | ||||
| using OprFormat = Problem::OprFormat; | |||||
| OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) { | |||||
| switch (tensor_format) { | |||||
| case TensorFormats::NCHW: | |||||
| return OprFormat::NCHW; | |||||
| case TensorFormats::NCHWc4: | |||||
| return OprFormat::NCHW44; | |||||
| case TensorFormats::NCHWc8: | |||||
| return OprFormat::NCHW88; | |||||
| case TensorFormats::NCHWc32: | |||||
| return OprFormat::NCHW32; | |||||
| case TensorFormats::NCHWc64: | |||||
| return OprFormat::NCHW64; | |||||
| case TensorFormats::NHWC: | |||||
| return OprFormat::NHWC; | |||||
| case TensorFormats::CHWNc4: | |||||
| return OprFormat::CHWN4; | |||||
| default: | |||||
| mgb_throw( | |||||
| MegBrainError, "tensor format(%u) is not supported", | |||||
| static_cast<uint32_t>(tensor_format)); | |||||
| } | |||||
| } | |||||
| class GraphPartitionProfiler final : public PluginBase { | class GraphPartitionProfiler final : public PluginBase { | ||||
| using CompNodeEventPtr = std::unique_ptr<CompNode::Event>; | using CompNodeEventPtr = std::unique_ptr<CompNode::Event>; | ||||
| @@ -214,8 +190,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | |||||
| record.opr = opr; | record.opr = opr; | ||||
| auto& costs = record.costs; | auto& costs = record.costs; | ||||
| for (auto&& f : available_tensor_formats) { | for (auto&& f : available_tensor_formats) { | ||||
| auto opr_format = tensor_formats_to_opr_format(f); | |||||
| costs[opr_format] = profile_operator(opr, base_format, f, extra_attribute); | |||||
| auto config_id = tensor_formats_to_config_id(f); | |||||
| costs[config_id] = profile_operator(opr, base_format, f, extra_attribute); | |||||
| } | } | ||||
| return record; | return record; | ||||
| } | } | ||||
| @@ -261,7 +237,7 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | |||||
| record.opr = opr; | record.opr = opr; | ||||
| auto& costs = record.costs; | auto& costs = record.costs; | ||||
| for (auto&& i : available_configs) { | for (auto&& i : available_configs) { | ||||
| costs[i.opr_format] = profile_operator(opr, base_config, i, extra_attribute); | |||||
| costs[i.config_id] = profile_operator(opr, base_config, i, extra_attribute); | |||||
| } | } | ||||
| return record; | return record; | ||||
| } | } | ||||
| @@ -316,7 +292,6 @@ float ProfilerImpl::profile_operator( | |||||
| new_inps[i] = imm.node(); | new_inps[i] = imm.node(); | ||||
| } | } | ||||
| VarNode* y = mgb::gopt::intl::modify_opr_format(config.opr_format, new_inps, opr); | VarNode* y = mgb::gopt::intl::modify_opr_format(config.opr_format, new_inps, opr); | ||||
| #if 0 | |||||
| static const ThinHashSet<Typeinfo*> multi_algo_oprs = { | static const ThinHashSet<Typeinfo*> multi_algo_oprs = { | ||||
| opr::Convolution::typeinfo(), | opr::Convolution::typeinfo(), | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| @@ -326,7 +301,6 @@ float ProfilerImpl::profile_operator( | |||||
| if (multi_algo_oprs.count(opr->dyn_typeinfo()) && | if (multi_algo_oprs.count(opr->dyn_typeinfo()) && | ||||
| !mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr())) | !mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr())) | ||||
| return PROFILE_TIME_OUT; | return PROFILE_TIME_OUT; | ||||
| #endif | |||||
| if (!m_opr_filter(opr, y->owner_opr())) | if (!m_opr_filter(opr, y->owner_opr())) | ||||
| return PROFILE_TIME_OUT; | return PROFILE_TIME_OUT; | ||||
| auto mark = MarkInputContiguous::make(SymbolVar(y)); | auto mark = MarkInputContiguous::make(SymbolVar(y)); | ||||
| @@ -494,6 +468,30 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons | |||||
| return profiling_result; | return profiling_result; | ||||
| } | } | ||||
| ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id( | |||||
| TensorFormats tensor_format) const { | |||||
| switch (tensor_format) { | |||||
| case TensorFormats::NCHW: | |||||
| return OprFormatConfigID::NCHW; | |||||
| case TensorFormats::NCHWc4: | |||||
| return OprFormatConfigID::NCHW4; | |||||
| case TensorFormats::NCHWc8: | |||||
| return OprFormatConfigID::NCHW8; | |||||
| case TensorFormats::NCHWc32: | |||||
| return OprFormatConfigID::NCHW32; | |||||
| case TensorFormats::NCHWc64: | |||||
| return OprFormatConfigID::NCHW64; | |||||
| case TensorFormats::NHWC: | |||||
| return OprFormatConfigID::NHWC; | |||||
| case TensorFormats::CHWNc4: | |||||
| return OprFormatConfigID::CHWN4; | |||||
| default: | |||||
| mgb_throw( | |||||
| MegBrainError, "tensor format(%u) is not supported", | |||||
| static_cast<uint32_t>(tensor_format)); | |||||
| } | |||||
| } | |||||
| /* ================== ProfilerBase =================*/ | /* ================== ProfilerBase =================*/ | ||||
| std::string ProfilerBase::OperatorNodeRecord::to_string() const { | std::string ProfilerBase::OperatorNodeRecord::to_string() const { | ||||
| auto str = ssprintf( | auto str = ssprintf( | ||||
| @@ -508,7 +506,7 @@ std::string ProfilerBase::OperatorNodeRecord::to_string() const { | |||||
| opr->output(0)->shape().to_string().c_str()); | opr->output(0)->shape().to_string().c_str()); | ||||
| for (auto&& cpair : costs) { | for (auto&& cpair : costs) { | ||||
| str += ssprintf( | str += ssprintf( | ||||
| "\tformat: %s; cost:%f", opr_format_to_string(cpair.first), | |||||
| "\tconfig: %s; cost:%f", config_id_to_string(cpair.first), | |||||
| cpair.second); | cpair.second); | ||||
| } | } | ||||
| return str; | return str; | ||||
| @@ -557,7 +555,7 @@ float CachedProfiler::profile_operator( | |||||
| const OperatorNodeBase* opr, TensorFormats base_format, | const OperatorNodeBase* opr, TensorFormats base_format, | ||||
| TensorFormats tensor_format, ReformatAttribute extra_attribute) const { | TensorFormats tensor_format, ReformatAttribute extra_attribute) const { | ||||
| ProfilerCache::Key key{ | ProfilerCache::Key key{ | ||||
| opr, tensor_formats_to_opr_format(tensor_format), extra_attribute}; | |||||
| opr, tensor_formats_to_config_id(tensor_format), extra_attribute}; | |||||
| auto ret = ProfilerCache::inst().get(key); | auto ret = ProfilerCache::inst().get(key); | ||||
| if (ret.valid()) | if (ret.valid()) | ||||
| return ret.val(); | return ret.val(); | ||||
| @@ -571,7 +569,7 @@ float CachedProfiler::profile_operator( | |||||
| const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config, | const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config, | ||||
| const OprTensorFormatsConfiguration& config, | const OprTensorFormatsConfiguration& config, | ||||
| ReformatAttribute extra_attribute) const { | ReformatAttribute extra_attribute) const { | ||||
| ProfilerCache::Key key{opr, config.opr_format, extra_attribute}; | |||||
| ProfilerCache::Key key{opr, config.config_id, extra_attribute}; | |||||
| auto ret = ProfilerCache::inst().get(key); | auto ret = ProfilerCache::inst().get(key); | ||||
| if (ret.valid()) | if (ret.valid()) | ||||
| return ret.val(); | return ret.val(); | ||||
| @@ -48,7 +48,8 @@ ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profile | |||||
| }; | }; | ||||
| m_problem_filter = [](const Problem& problem) { | m_problem_filter = [](const Problem& problem) { | ||||
| auto&& base_opr_format = problem.attribute().base_opr_format; | |||||
| auto&& base_opr_format = OprTensorFormatsConfiguration::safe_cast_to_opr_format( | |||||
| problem.attribute().base_config_id); | |||||
| bool has_format_aware_opr = false; | bool has_format_aware_opr = false; | ||||
| for (auto&& opr : problem.graph_partition().all_oprs()) { | for (auto&& opr : problem.graph_partition().all_oprs()) { | ||||
| auto iter = format_aware_opr_validators.find(opr->dyn_typeinfo()); | auto iter = format_aware_opr_validators.find(opr->dyn_typeinfo()); | ||||
| @@ -40,6 +40,37 @@ static inline const char* opr_format_to_string( | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| static inline const char* config_id_to_string( | |||||
| OprTensorFormatsConfiguration::OprFormatConfigID config_id) { | |||||
| using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID; | |||||
| #define cb(_fmt) \ | |||||
| case OprFormatConfigID::_fmt: \ | |||||
| return #_fmt | |||||
| switch (config_id) { | |||||
| cb(NCHW); | |||||
| cb(NHWC); | |||||
| cb(NCHW4); | |||||
| cb(NCHW8); | |||||
| cb(NCHW4_NCHW32); | |||||
| cb(NCHW4_NCHW); | |||||
| cb(NCHW32); | |||||
| cb(NCHW32_NCHW4); | |||||
| cb(NCHW64); | |||||
| cb(CHWN4); | |||||
| cb(NCHW44); | |||||
| cb(NCHW44_HYBRID); | |||||
| cb(NCHW88); | |||||
| cb(NCHW88_HYBRID); | |||||
| cb(NCHW44_DOT); | |||||
| cb(NCHW44_DOT_HYBRID); | |||||
| default: | |||||
| mgb_assert( | |||||
| false, "Invalid config id(got:%u)", | |||||
| static_cast<uint32_t>(config_id)); | |||||
| } | |||||
| #undef cb | |||||
| } | |||||
| static inline TensorFormats opr_format_to_tensor_formats( | static inline TensorFormats opr_format_to_tensor_formats( | ||||
| OprTensorFormatsConfiguration::OprFormat opr_format) { | OprTensorFormatsConfiguration::OprFormat opr_format) { | ||||
| using OprFormat = OprTensorFormatsConfiguration::OprFormat; | using OprFormat = OprTensorFormatsConfiguration::OprFormat; | ||||
| @@ -60,6 +91,8 @@ static inline TensorFormats opr_format_to_tensor_formats( | |||||
| return TensorFormats::NCHWc8; | return TensorFormats::NCHWc8; | ||||
| case OprFormat::NCHW44: | case OprFormat::NCHW44: | ||||
| return TensorFormats::NCHWc4; | return TensorFormats::NCHWc4; | ||||
| case OprFormat::NCHW8: | |||||
| return TensorFormats::NCHWc8; | |||||
| default: | default: | ||||
| mgb_throw( | mgb_throw( | ||||
| AssertionError, "format(%s) is not supported", | AssertionError, "format(%s) is not supported", | ||||
| @@ -124,9 +157,17 @@ static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape( | |||||
| return {{"G"}, {"K"}, {"C"}, {"R"}, {"S"}}; | return {{"G"}, {"K"}, {"C"}, {"R"}, {"S"}}; | ||||
| case TensorFormats::C11RS: | case TensorFormats::C11RS: | ||||
| return {{"C"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}}; | return {{"C"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}}; | ||||
| case TensorFormats::KRSC: | |||||
| return {{"K"}, {"R"}, {"S"}, {"C"}}; | |||||
| case TensorFormats::KCRSc32: | |||||
| return {{"K"}, {"C//32"}, {"R"}, {"S"}, {"C%32"}}; | |||||
| case TensorFormats::KCRSc64: | |||||
| return {{"K"}, {"C//64"}, {"R"}, {"S"}, {"C%64"}}; | |||||
| case TensorFormats::CRSKc4: | |||||
| return {{"C//4"}, {"R"}, {"S"}, {"K"}, {"C%4"}}; | |||||
| default: | default: | ||||
| mgb_throw( | mgb_throw( | ||||
| AssertionError, "invalid tensor formats(%u)", | |||||
| MegBrainError, "invalid tensor formats(%u)", | |||||
| static_cast<uint32_t>(format)); | static_cast<uint32_t>(format)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -26,19 +26,48 @@ namespace gopt { | |||||
| * configuration of the opr format | * configuration of the opr format | ||||
| */ | */ | ||||
| struct OprTensorFormatsConfiguration { | struct OprTensorFormatsConfiguration { | ||||
| using OprFormat = opr::ConvBias::Param::Format; | |||||
| using OprFormat = opr::Convolution::Param::Format; | |||||
| static constexpr uint32_t FORMAT_NR_MEMBER = | |||||
| opr::Convolution::Param::FORMAT_NR_MEMBER; | |||||
| enum class OprFormatConfigID : uint32_t { | |||||
| #define cb(fmt_) fmt_ = static_cast<uint32_t>(OprFormat::fmt_) | |||||
| cb(NCHW), | |||||
| cb(NHWC), | |||||
| cb(NHWCD4), | |||||
| cb(NCHW4), | |||||
| cb(NCHW8), | |||||
| cb(NCHW32), | |||||
| cb(NCHW88), | |||||
| cb(NCHW44), | |||||
| cb(NCHW44_DOT), | |||||
| cb(NCHW4_NCHW32), | |||||
| cb(NCHW32_NCHW4), | |||||
| cb(NCHW4_NCHW), | |||||
| cb(NCHW4_NHWC), | |||||
| cb(CHWN4), | |||||
| cb(NCHW64), | |||||
| NCHW44_HYBRID = FORMAT_NR_MEMBER, | |||||
| NCHW88_HYBRID = FORMAT_NR_MEMBER + 1, | |||||
| NCHW44_DOT_HYBRID = FORMAT_NR_MEMBER + 2, | |||||
| }; | |||||
| #undef cb | |||||
| using OprTensorFormatsDispatcher = | using OprTensorFormatsDispatcher = | ||||
| thin_function<Maybe<OprTensorFormatsConfiguration>( | thin_function<Maybe<OprTensorFormatsConfiguration>( | ||||
| const cg::OperatorNodeBase*)>; | const cg::OperatorNodeBase*)>; | ||||
| Typeinfo* typeinfo; | Typeinfo* typeinfo; | ||||
| OprFormat opr_format; | OprFormat opr_format; | ||||
| OprFormatConfigID config_id; | |||||
| SmallVector<DTypeEnum> input_dtypes; | SmallVector<DTypeEnum> input_dtypes; | ||||
| SmallVector<DTypeEnum> output_dtypes; | SmallVector<DTypeEnum> output_dtypes; | ||||
| SmallVector<TensorFormats> input_tensor_formats; | SmallVector<TensorFormats> input_tensor_formats; | ||||
| SmallVector<TensorType> input_tensor_types; | SmallVector<TensorType> input_tensor_types; | ||||
| SmallVector<TensorFormats> output_tensor_formats; | SmallVector<TensorFormats> output_tensor_formats; | ||||
| static OprTensorFormatsDispatcher* find_dispatcher_by_type_format( | static OprTensorFormatsDispatcher* find_dispatcher_by_type_format( | ||||
| Typeinfo* type, OprFormat opr_format); | |||||
| Typeinfo* type, OprFormatConfigID config_id); | |||||
| static OprFormat safe_cast_to_opr_format(OprFormatConfigID config_id) { | |||||
| mgb_assert(static_cast<uint32_t>(config_id) < FORMAT_NR_MEMBER); | |||||
| return static_cast<OprFormat>(static_cast<uint32_t>(config_id)); | |||||
| } | |||||
| }; | }; | ||||
| /*! | /*! | ||||
| @@ -48,14 +77,15 @@ class LayoutTransformContext { | |||||
| public: | public: | ||||
| using OprList = SubGraphExtractor::OprList; | using OprList = SubGraphExtractor::OprList; | ||||
| using OprFormat = OprTensorFormatsConfiguration::OprFormat; | using OprFormat = OprTensorFormatsConfiguration::OprFormat; | ||||
| using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID; | |||||
| using OprTensorFormatsDispatcher = | using OprTensorFormatsDispatcher = | ||||
| OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; | OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; | ||||
| using OprConfigTrait = | |||||
| ThinHashMap<Typeinfo*, ThinHashMap<OprFormat, OprTensorFormatsDispatcher*>>; | |||||
| using OprConfigTrait = ThinHashMap< | |||||
| Typeinfo*, ThinHashMap<OprFormatConfigID, OprTensorFormatsDispatcher*>>; | |||||
| using Target = GraphTuningOptions::Target; | using Target = GraphTuningOptions::Target; | ||||
| using ReformatAttribute = ReformatManager::ReformatKey::Attribute; | using ReformatAttribute = ReformatManager::ReformatKey::Attribute; | ||||
| struct Attribute { | struct Attribute { | ||||
| OprFormat base_opr_format; /// the base opr format indicates that the | |||||
| OprFormatConfigID base_config_id; /// the base opr format indicates that the | |||||
| /// network to be optimized is constructed | /// network to be optimized is constructed | ||||
| /// in the base opr format, i.e. all the | /// in the base opr format, i.e. all the | ||||
| /// format aware operators (conv, conv_bias, | /// format aware operators (conv, conv_bias, | ||||
| @@ -97,21 +127,22 @@ public: | |||||
| /*! | /*! | ||||
| * \brief add an op format configuration for a particular operator type | * \brief add an op format configuration for a particular operator type | ||||
| * \param opr runtime typeinfo of operator | * \param opr runtime typeinfo of operator | ||||
| * \param opr_format op format configuration which to be enabled in the | |||||
| * layout transform problem | |||||
| * \param config_id op format configuration id which is going to be enabled | |||||
| * in the layout transform problem | |||||
| */ | */ | ||||
| LayoutTransformContext& add_opr_config(Typeinfo* opr, OprFormat opr_format); | |||||
| LayoutTransformContext& add_opr_config(Typeinfo* opr, OprFormatConfigID config_id); | |||||
| /*! | /*! | ||||
| * \brief add a vector of op format configurations for a particular operator | * \brief add a vector of op format configurations for a particular operator | ||||
| * type | * type | ||||
| * \param opr runtime typeinfo of operator | * \param opr runtime typeinfo of operator | ||||
| * \param opr_format op format configuration which to be enabled in the | |||||
| * layout transform problem | |||||
| * \param config_ids ids of op format configurations which are enabled in | |||||
| * the layout transform problem | |||||
| */ | */ | ||||
| LayoutTransformContext& add_opr_config( | LayoutTransformContext& add_opr_config( | ||||
| Typeinfo* opr, SmallVector<OprFormat> opr_formats); | |||||
| Typeinfo* opr, SmallVector<OprFormatConfigID> config_ids); | |||||
| static std::unique_ptr<LayoutTransformContext> make( | static std::unique_ptr<LayoutTransformContext> make( | ||||
| Target target = Target::UNSPEC, OprFormat base_opr_format = OprFormat::NCHW, | |||||
| Target target = Target::UNSPEC, | |||||
| OprFormatConfigID base_config_id = OprFormatConfigID::NCHW, | |||||
| TensorFormats base_tensor_format = TensorFormats::NCHW); | TensorFormats base_tensor_format = TensorFormats::NCHW); | ||||
| private: | private: | ||||
| @@ -130,6 +161,7 @@ private: | |||||
| class Problem { | class Problem { | ||||
| public: | public: | ||||
| using OprFormat = OprTensorFormatsConfiguration::OprFormat; | using OprFormat = OprTensorFormatsConfiguration::OprFormat; | ||||
| using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID; | |||||
| using OprTensorFormatsDispatcher = | using OprTensorFormatsDispatcher = | ||||
| OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; | OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; | ||||
| using OprConfigTrait = LayoutTransformContext::OprConfigTrait; | using OprConfigTrait = LayoutTransformContext::OprConfigTrait; | ||||
| @@ -152,13 +184,15 @@ public: | |||||
| */ | */ | ||||
| OprTensorFormatsConfiguration base_config(const cg::OperatorNodeBase* opr) const { | OprTensorFormatsConfiguration base_config(const cg::OperatorNodeBase* opr) const { | ||||
| auto _ = OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | auto _ = OprTensorFormatsConfiguration::find_dispatcher_by_type_format( | ||||
| opr->dyn_typeinfo(), m_ctx.attribute().base_opr_format); | |||||
| opr->dyn_typeinfo(), m_ctx.attribute().base_config_id); | |||||
| auto rst = (*_)(opr); | auto rst = (*_)(opr); | ||||
| if (rst.valid()) | if (rst.valid()) | ||||
| return rst.val(); | return rst.val(); | ||||
| OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = m_ctx.attribute().base_opr_format; | |||||
| config.config_id = m_ctx.attribute().base_config_id; | |||||
| config.opr_format = OprTensorFormatsConfiguration::safe_cast_to_opr_format( | |||||
| config.config_id); | |||||
| for (const auto& i : opr->input()) { | for (const auto& i : opr->input()) { | ||||
| config.input_dtypes.emplace_back(i->dtype().enumv()); | config.input_dtypes.emplace_back(i->dtype().enumv()); | ||||
| config.input_tensor_formats.emplace_back(base_format()); | config.input_tensor_formats.emplace_back(base_format()); | ||||
| @@ -33,9 +33,10 @@ class CachedProfiler; | |||||
| class ProfilerBase { | class ProfilerBase { | ||||
| public: | public: | ||||
| using OprFormat = Problem::OprFormat; | using OprFormat = Problem::OprFormat; | ||||
| using OprFormatConfigID = Problem::OprFormatConfigID; | |||||
| struct OperatorNodeRecord { | struct OperatorNodeRecord { | ||||
| const cg::OperatorNodeBase* opr; ///< pointer to operator node | const cg::OperatorNodeBase* opr; ///< pointer to operator node | ||||
| ThinHashMap<OprFormat, float> | |||||
| ThinHashMap<OprFormatConfigID, float> | |||||
| costs; ///< costs of operator node, i.e. the elapsed device | costs; ///< costs of operator node, i.e. the elapsed device | ||||
| ///< time of the operator node on different opr format | ///< time of the operator node on different opr format | ||||
| ///< (layout configuration). | ///< (layout configuration). | ||||
| @@ -199,6 +200,8 @@ protected: | |||||
| virtual float profile_var_node( | virtual float profile_var_node( | ||||
| const VarNode* var, TensorFormats base_format, | const VarNode* var, TensorFormats base_format, | ||||
| const ReformatKey& key) const; | const ReformatKey& key) const; | ||||
| OprFormatConfigID tensor_formats_to_config_id(TensorFormats tensor_format) const; | |||||
| OprFootprint m_opr_footprint; | OprFootprint m_opr_footprint; | ||||
| float m_opr_threshold; /// a threshold, when the computation of the newly | float m_opr_threshold; /// a threshold, when the computation of the newly | ||||
| /// created operator that is built in some opr | /// created operator that is built in some opr | ||||
| @@ -224,14 +227,14 @@ class ProfilerCache : public NonCopyableObj { | |||||
| public: | public: | ||||
| using ReformatKey = ReformatManager::ReformatKey; | using ReformatKey = ReformatManager::ReformatKey; | ||||
| using ReformatAttribute = ReformatKey::Attribute; | using ReformatAttribute = ReformatKey::Attribute; | ||||
| using OprFormat = ProfilerBase::OprFormat; | |||||
| using OprFormatConfigID = ProfilerBase::OprFormatConfigID; | |||||
| class Key final : public NonCopyableObj { | class Key final : public NonCopyableObj { | ||||
| std::string m_blob_storage; | std::string m_blob_storage; | ||||
| std::string m_category; | std::string m_category; | ||||
| struct OprKey { | struct OprKey { | ||||
| const OperatorNodeBase* opr; | const OperatorNodeBase* opr; | ||||
| OprFormat opr_format; | |||||
| OprFormatConfigID config_id; | |||||
| ReformatAttribute extra_attribute; | ReformatAttribute extra_attribute; | ||||
| }; | }; | ||||
| @@ -254,9 +257,9 @@ public: | |||||
| void build_category(CompNode cn); | void build_category(CompNode cn); | ||||
| public: | public: | ||||
| Key(const OperatorNodeBase* opr, OprFormat opr_format, | |||||
| Key(const OperatorNodeBase* opr, OprFormatConfigID config_id, | |||||
| ReformatAttribute extra_attribute = ReformatAttribute::DEFAULT) { | ReformatAttribute extra_attribute = ReformatAttribute::DEFAULT) { | ||||
| m_key_impl.opr_key = {opr, opr_format, extra_attribute}; | |||||
| m_key_impl.opr_key = {opr, config_id, extra_attribute}; | |||||
| build_blob_from_opr(); | build_blob_from_opr(); | ||||
| mgb_assert( | mgb_assert( | ||||
| opr->node_prop().contain( | opr->node_prop().contain( | ||||
| @@ -28,7 +28,8 @@ class ProfilerBase; | |||||
| class SolverBase { | class SolverBase { | ||||
| public: | public: | ||||
| using OprFormat = Problem::OprFormat; | using OprFormat = Problem::OprFormat; | ||||
| using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormat>; | |||||
| using OprFormatConfigID = Problem::OprFormatConfigID; | |||||
| using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormatConfigID>; | |||||
| SolverBase() = default; | SolverBase() = default; | ||||
| virtual ~SolverBase() = default; | virtual ~SolverBase() = default; | ||||
| /*! | /*! | ||||
| @@ -1,4 +1,5 @@ | |||||
| #!/usr/bin/env python3 | #!/usr/bin/env python3 | ||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| # | # | ||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
| @@ -95,7 +96,7 @@ static const std::vector<uint8_t> {} = {{ | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
| description='embed cache into cache header file', | |||||
| description='embed cubin into cpp source file', | |||||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter) | formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||
| parser.add_argument('-o', '--output', help='output source file', | parser.add_argument('-o', '--output', help='output source file', | ||||
| required=True) | required=True) | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include "megbrain/plugin/profiler.h" | #include "megbrain/plugin/profiler.h" | ||||
| #include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
| #define MGB_WITH_CACHED_TEST 1 | |||||
| #define MGB_WITH_CACHED_TEST 0 | |||||
| #if MGB_WITH_CACHED_TEST | #if MGB_WITH_CACHED_TEST | ||||
| #include "./cache_data.h" | #include "./cache_data.h" | ||||
| @@ -60,30 +60,6 @@ size_t find_opr_num(SymbolVar endpoint) { | |||||
| return opr_num; | return opr_num; | ||||
| } | } | ||||
| using OprFormat = Problem::OprFormat; | |||||
| OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) { | |||||
| switch (tensor_format) { | |||||
| case TensorFormats::NCHW: | |||||
| return OprFormat::NCHW; | |||||
| case TensorFormats::NCHWc4: | |||||
| return OprFormat::NCHW4; | |||||
| case TensorFormats::NCHWc8: | |||||
| return OprFormat::NCHW8; | |||||
| case TensorFormats::NCHWc32: | |||||
| return OprFormat::NCHW32; | |||||
| case TensorFormats::NCHWc64: | |||||
| return OprFormat::NCHW64; | |||||
| case TensorFormats::NHWC: | |||||
| return OprFormat::NHWC; | |||||
| case TensorFormats::CHWNc4: | |||||
| return OprFormat::CHWN4; | |||||
| default: | |||||
| mgb_throw( | |||||
| MegBrainError, "tensor format(%u) is not supported", | |||||
| static_cast<uint32_t>(tensor_format)); | |||||
| } | |||||
| } | |||||
| class ProfilerMock : public ProfilerImpl { | class ProfilerMock : public ProfilerImpl { | ||||
| public: | public: | ||||
| ProfilerMock(const uint8_t* bin, size_t size) { | ProfilerMock(const uint8_t* bin, size_t size) { | ||||
| @@ -105,7 +81,7 @@ private: | |||||
| ReformatAttribute extra_attribute = | ReformatAttribute extra_attribute = | ||||
| ReformatAttribute::DEFAULT) const override { | ReformatAttribute::DEFAULT) const override { | ||||
| ProfilerCache::Key key{ | ProfilerCache::Key key{ | ||||
| opr, tensor_formats_to_opr_format(tensor_format), extra_attribute}; | |||||
| opr, tensor_formats_to_config_id(tensor_format), extra_attribute}; | |||||
| auto ret = ProfilerCache::inst().get(key); | auto ret = ProfilerCache::inst().get(key); | ||||
| if (ret.valid()) | if (ret.valid()) | ||||
| return ret.val(); | return ret.val(); | ||||
| @@ -117,9 +93,7 @@ private: | |||||
| const OprTensorFormatsConfiguration& config, | const OprTensorFormatsConfiguration& config, | ||||
| ReformatAttribute extra_attribute = | ReformatAttribute extra_attribute = | ||||
| ReformatAttribute::DEFAULT) const override { | ReformatAttribute::DEFAULT) const override { | ||||
| ProfilerCache::Key key{opr, config.opr_format, extra_attribute}; | |||||
| std::string tmp; | |||||
| tmp.reserve(key.blob().size); | |||||
| ProfilerCache::Key key{opr, config.config_id, extra_attribute}; | |||||
| auto ret = ProfilerCache::inst().get(key); | auto ret = ProfilerCache::inst().get(key); | ||||
| if (ret.valid()) | if (ret.valid()) | ||||
| return ret.val(); | return ret.val(); | ||||
| @@ -161,7 +135,7 @@ TEST(TestLayoutTransform, Resnet18_QS8) { | |||||
| auto func1 = network.graph->compile({make_callback_copy(output, t1)}); | auto func1 = network.graph->compile({make_callback_copy(output, t1)}); | ||||
| func1->execute(); | func1->execute(); | ||||
| using OprFormat = LayoutTransformContext::OprFormat; | |||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | using OprList = LayoutTransformContext::OprList; | ||||
| using Target = LayoutTransformContext::Target; | using Target = LayoutTransformContext::Target; | ||||
| using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | ||||
| @@ -175,17 +149,18 @@ TEST(TestLayoutTransform, Resnet18_QS8) { | |||||
| TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | ||||
| TensorFormats::NCHWc32, TensorFormats::CHWNc4}; | TensorFormats::NCHWc32, TensorFormats::CHWNc4}; | ||||
| Attribute attribute = { | Attribute attribute = { | ||||
| OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| ReformatAttribute::AUTO_PADDING_NHWC}; | ReformatAttribute::AUTO_PADDING_NHWC}; | ||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | std::move(opr_list), std::move(available_tensor_formats), attribute); | ||||
| ctx->add_opr_config( | ctx->add_opr_config( | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC}) | |||||
| {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::PoolingForward::typeinfo(), | opr::PoolingForward::typeinfo(), | ||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | |||||
| OprFormat::CHWN4}); | |||||
| {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::NHWC, OprFormatConfigID::CHWN4}); | |||||
| #if MGB_WITH_CACHED_TEST | #if MGB_WITH_CACHED_TEST | ||||
| auto profiler = std::make_unique<ProfilerMock>( | auto profiler = std::make_unique<ProfilerMock>( | ||||
| static_cast<const uint8_t*>(TestLayoutTransform_Resnet18_QS8.data()), | static_cast<const uint8_t*>(TestLayoutTransform_Resnet18_QS8.data()), | ||||
| @@ -253,7 +228,7 @@ TEST(TestLayoutTransform, Resnet18_QS4) { | |||||
| auto func1 = network.graph->compile({make_callback_copy(output, t1)}); | auto func1 = network.graph->compile({make_callback_copy(output, t1)}); | ||||
| func1->execute(); | func1->execute(); | ||||
| using OprFormat = LayoutTransformContext::OprFormat; | |||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | using OprList = LayoutTransformContext::OprList; | ||||
| using Attribute = LayoutTransformContext::Attribute; | using Attribute = LayoutTransformContext::Attribute; | ||||
| using Target = LayoutTransformContext::Target; | using Target = LayoutTransformContext::Target; | ||||
| @@ -267,18 +242,20 @@ TEST(TestLayoutTransform, Resnet18_QS4) { | |||||
| TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | ||||
| TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | ||||
| Attribute attribute = { | Attribute attribute = { | ||||
| OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| ReformatAttribute::AUTO_PADDING_NHWC}; | ReformatAttribute::AUTO_PADDING_NHWC}; | ||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | std::move(opr_list), std::move(available_tensor_formats), attribute); | ||||
| ctx->add_opr_config( | ctx->add_opr_config( | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC, | |||||
| OprFormat::NCHW64}) | |||||
| {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC, | |||||
| OprFormatConfigID::NCHW64}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::PoolingForward::typeinfo(), | opr::PoolingForward::typeinfo(), | ||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64, | |||||
| OprFormat::NHWC, OprFormat::CHWN4}); | |||||
| {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::NCHW64, OprFormatConfigID::NHWC, | |||||
| OprFormatConfigID::CHWN4}); | |||||
| #if MGB_WITH_CACHED_TEST | #if MGB_WITH_CACHED_TEST | ||||
| auto profiler = std::make_unique<ProfilerMock>( | auto profiler = std::make_unique<ProfilerMock>( | ||||
| static_cast<const uint8_t*>(TestLayoutTransform_Resnet18_QS4.data()), | static_cast<const uint8_t*>(TestLayoutTransform_Resnet18_QS4.data()), | ||||
| @@ -375,7 +352,7 @@ TEST(TestLayoutTransform, Detection_QS8) { | |||||
| S strategy = S::PROFILE; | S strategy = S::PROFILE; | ||||
| gopt::modify_opr_algo_strategy_inplace({outputs}, strategy); | gopt::modify_opr_algo_strategy_inplace({outputs}, strategy); | ||||
| using OprFormat = LayoutTransformContext::OprFormat; | |||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | using OprList = LayoutTransformContext::OprList; | ||||
| using Attribute = LayoutTransformContext::Attribute; | using Attribute = LayoutTransformContext::Attribute; | ||||
| using Target = LayoutTransformContext::Target; | using Target = LayoutTransformContext::Target; | ||||
| @@ -389,18 +366,18 @@ TEST(TestLayoutTransform, Detection_QS8) { | |||||
| TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | ||||
| TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | ||||
| Attribute attribute = { | Attribute attribute = { | ||||
| OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| ReformatAttribute::AUTO_PADDING_NHWC}; | ReformatAttribute::AUTO_PADDING_NHWC}; | ||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | std::move(opr_list), std::move(available_tensor_formats), attribute); | ||||
| ctx->add_opr_config( | ctx->add_opr_config( | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC, | |||||
| OprFormat::NCHW64}) | |||||
| {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC, | |||||
| OprFormatConfigID::NCHW64}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::PoolingForward::typeinfo(), | |||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64, | |||||
| OprFormat::NHWC, OprFormat::CHWN4}); | |||||
| opr::ConvolutionBackwardData::typeinfo(), | |||||
| {OprFormatConfigID::NCHW4, OprFormatConfigID::NHWC}); | |||||
| #if MGB_WITH_CACHED_TEST | #if MGB_WITH_CACHED_TEST | ||||
| auto profiler = std::make_unique<ProfilerMock>( | auto profiler = std::make_unique<ProfilerMock>( | ||||
| static_cast<const uint8_t*>(TestLayoutTransform_Detection_QS8.data()), | static_cast<const uint8_t*>(TestLayoutTransform_Detection_QS8.data()), | ||||
| @@ -452,7 +429,7 @@ TEST(TestLayoutTransform, Detection_QS4) { | |||||
| S strategy = S::PROFILE; | S strategy = S::PROFILE; | ||||
| gopt::modify_opr_algo_strategy_inplace({outputs}, strategy); | gopt::modify_opr_algo_strategy_inplace({outputs}, strategy); | ||||
| using OprFormat = LayoutTransformContext::OprFormat; | |||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | using OprList = LayoutTransformContext::OprList; | ||||
| using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | ||||
| using Attribute = LayoutTransformContext::Attribute; | using Attribute = LayoutTransformContext::Attribute; | ||||
| @@ -466,18 +443,18 @@ TEST(TestLayoutTransform, Detection_QS4) { | |||||
| TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | ||||
| TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | ||||
| Attribute attribute = { | Attribute attribute = { | ||||
| OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| ReformatAttribute::AUTO_PADDING_NHWC}; | ReformatAttribute::AUTO_PADDING_NHWC}; | ||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | std::move(opr_list), std::move(available_tensor_formats), attribute); | ||||
| ctx->add_opr_config( | ctx->add_opr_config( | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC, | |||||
| OprFormat::NCHW64}) | |||||
| {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC, | |||||
| OprFormatConfigID::NCHW64}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::PoolingForward::typeinfo(), | |||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64, | |||||
| OprFormat::NHWC, OprFormat::CHWN4}); | |||||
| opr::ConvolutionBackwardData::typeinfo(), | |||||
| {OprFormatConfigID::NCHW4, OprFormatConfigID::NHWC}); | |||||
| #if MGB_WITH_CACHED_TEST | #if MGB_WITH_CACHED_TEST | ||||
| auto profiler = std::make_unique<ProfilerMock>( | auto profiler = std::make_unique<ProfilerMock>( | ||||
| static_cast<const uint8_t*>(TestLayoutTransform_Detection_QS4.data()), | static_cast<const uint8_t*>(TestLayoutTransform_Detection_QS4.data()), | ||||
| @@ -538,7 +515,7 @@ TEST(TestLayoutTransform, Wide) { | |||||
| S strategy = S::PROFILE; | S strategy = S::PROFILE; | ||||
| gopt::modify_opr_algo_strategy_inplace({y}, strategy); | gopt::modify_opr_algo_strategy_inplace({y}, strategy); | ||||
| using OprFormat = LayoutTransformContext::OprFormat; | |||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | using OprList = LayoutTransformContext::OprList; | ||||
| using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | ||||
| using Attribute = LayoutTransformContext::Attribute; | using Attribute = LayoutTransformContext::Attribute; | ||||
| @@ -550,12 +527,13 @@ TEST(TestLayoutTransform, Wide) { | |||||
| SmallVector<TensorFormats> available_tensor_formats = { | SmallVector<TensorFormats> available_tensor_formats = { | ||||
| TensorFormats::NCHW, TensorFormats::NHWC}; | TensorFormats::NCHW, TensorFormats::NHWC}; | ||||
| Attribute attribute = { | Attribute attribute = { | ||||
| OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| ReformatAttribute::DEFAULT}; | ReformatAttribute::DEFAULT}; | ||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | std::move(opr_list), std::move(available_tensor_formats), attribute); | ||||
| ctx->add_opr_config( | ctx->add_opr_config( | ||||
| opr::ConvBiasForward::typeinfo(), {OprFormat::NCHW, OprFormat::NHWC}); | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NHWC}); | |||||
| #if MGB_WITH_CACHED_TEST | #if MGB_WITH_CACHED_TEST | ||||
| auto profiler = std::make_unique<ProfilerMock>( | auto profiler = std::make_unique<ProfilerMock>( | ||||
| static_cast<const uint8_t*>(TestLayoutTransform_Wide.data()), | static_cast<const uint8_t*>(TestLayoutTransform_Wide.data()), | ||||
| @@ -580,6 +558,8 @@ TEST(TestLayoutTransform, Wide) { | |||||
| auto func = network.graph->compile({{sym_o, {}}}); | auto func = network.graph->compile({{sym_o, {}}}); | ||||
| func->execute(); | func->execute(); | ||||
| gprof.to_json_full(func.get())->writeto_fpath(output_file("wide.json")); | gprof.to_json_full(func.get())->writeto_fpath(output_file("wide.json")); | ||||
| /// check global layout transform pass, no dimshuffle | |||||
| /// disable the following check, to make ci stable. | |||||
| auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(sym_o); | auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(sym_o); | ||||
| ASSERT_EQ(nr_dimshuffle, 0u); | ASSERT_EQ(nr_dimshuffle, 0u); | ||||
| auto nr_param_merge = find_opr_num<opr::MultipleDeviceTensorHolder>(sym_o); | auto nr_param_merge = find_opr_num<opr::MultipleDeviceTensorHolder>(sym_o); | ||||
| @@ -631,7 +611,7 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
| S strategy = S::PROFILE; | S strategy = S::PROFILE; | ||||
| gopt::modify_opr_algo_strategy_inplace({y}, strategy); | gopt::modify_opr_algo_strategy_inplace({y}, strategy); | ||||
| using OprFormat = LayoutTransformContext::OprFormat; | |||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | using OprList = LayoutTransformContext::OprList; | ||||
| using Attribute = LayoutTransformContext::Attribute; | using Attribute = LayoutTransformContext::Attribute; | ||||
| using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | using ReformatAttribute = LayoutTransformContext::ReformatAttribute; | ||||
| @@ -650,27 +630,30 @@ TEST(TestLayoutTransform, DetectionHead) { | |||||
| TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | ||||
| TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | ||||
| Attribute attribute = { | Attribute attribute = { | ||||
| OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC, | |||||
| ReformatAttribute::AUTO_PADDING_NHWC}; | ReformatAttribute::AUTO_PADDING_NHWC}; | ||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | std::move(opr_list), std::move(available_tensor_formats), attribute); | ||||
| ctx->add_opr_config( | ctx->add_opr_config( | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW32, | |||||
| OprFormat::NCHW64, OprFormat::CHWN4}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NHWC, | |||||
| OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::ConvolutionForward::typeinfo(), | opr::ConvolutionForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NCHW4}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::ConvolutionBackwardData::typeinfo(), | opr::ConvolutionBackwardData::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::PoolingForward::typeinfo(), | opr::PoolingForward::typeinfo(), | ||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | |||||
| OprFormat::NCHW64, OprFormat::CHWN4}) | |||||
| {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::NHWC, OprFormatConfigID::NCHW64, | |||||
| OprFormatConfigID::CHWN4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::WarpPerspectiveForward::typeinfo(), | opr::WarpPerspectiveForward::typeinfo(), | ||||
| {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); | |||||
| {OprFormatConfigID::NHWC, OprFormatConfigID::NCHW4, | |||||
| OprFormatConfigID::NCHW64}); | |||||
| #if MGB_WITH_CACHED_TEST | #if MGB_WITH_CACHED_TEST | ||||
| auto profiler = std::make_unique<ProfilerMock>( | auto profiler = std::make_unique<ProfilerMock>( | ||||
| static_cast<const uint8_t*>(TestLayoutTransform_DetectionHead.data()), | static_cast<const uint8_t*>(TestLayoutTransform_DetectionHead.data()), | ||||
| @@ -765,4 +748,184 @@ TEST(TestLayoutTransform, CanonicalizeLayoutTransform) { | |||||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | MGB_ASSERT_TENSOR_EQ(t1, t2); | ||||
| } | } | ||||
| TEST(TestLayoutTransform, Resnet18_F32) { | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| Network network(cn); | |||||
| auto output = make_resnet18(network, 1); | |||||
| HostTensorND t1; | |||||
| auto func1 = network.graph->compile({make_callback_copy(output, t1)}); | |||||
| func1->execute(); | |||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | |||||
| using Target = LayoutTransformContext::Target; | |||||
| using Attribute = LayoutTransformContext::Attribute; | |||||
| OprList opr_list = { | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| opr::ElemwiseMultiType::typeinfo(), | |||||
| opr::Elemwise::typeinfo(), | |||||
| opr::TypeCvt::typeinfo(), | |||||
| opr::Concat::typeinfo(), | |||||
| opr::PoolingForward::typeinfo(), | |||||
| opr::WarpPerspectiveForward::typeinfo(), | |||||
| opr::Resize::typeinfo(), | |||||
| }; | |||||
| SmallVector<TensorFormats> available_tensor_formats = { | |||||
| TensorFormats::NCHW, | |||||
| TensorFormats::NCHWc4, | |||||
| TensorFormats::NCHWc8, | |||||
| }; | |||||
| Attribute attribute = { | |||||
| OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC}; | |||||
| auto ctx = std::make_unique<LayoutTransformContext>( | |||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | |||||
| ctx->add_opr_config( | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| { | |||||
| OprFormatConfigID::NCHW44, | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW44_HYBRID, | |||||
| }) | |||||
| .add_opr_config( | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| { | |||||
| OprFormatConfigID::NCHW44, | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW44_HYBRID, | |||||
| }) | |||||
| .add_opr_config( | |||||
| opr::PoolingForward::typeinfo(), { | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW44, | |||||
| }); | |||||
| #if MGB_WITH_CACHED_TEST | |||||
| auto profiler = std::make_unique<ProfilerMock>( | |||||
| static_cast<const uint8_t*>(TestLayoutTransform_Resnet18_F32.data()), | |||||
| TestLayoutTransform_Resnet18_F32.size()); | |||||
| #else | |||||
| auto profiler = ProfilerBase::make_cached_profiler( | |||||
| "TestLayoutTransform.Resnet18_F32.cache"); | |||||
| #endif | |||||
| std::unique_ptr<SolverBase> solver{ | |||||
| new DynamicProgrammingSolver(std::move(profiler))}; | |||||
| auto new_output = | |||||
| gopt::GraphOptimizer{} | |||||
| .add_pass<FuseConvBiasNonlinPass>() | |||||
| .add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver)) | |||||
| .add_pass<ShuffleShuffleRemovePass>() | |||||
| .add_pass<ParamFusePass>() | |||||
| .add_pass<ParamMergePass>() | |||||
| .apply({{output}}) | |||||
| .endpoint_vars(); | |||||
| auto new_out_var = new_output[0]; | |||||
| /// check global layout transform pass | |||||
| auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | |||||
| ASSERT_EQ(nr_dimshuffle, 1u); | |||||
| /// check first conv format | |||||
| const auto& first_conv = find_opr<opr::ConvBiasForward>(new_out_var); | |||||
| const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>(); | |||||
| ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW44); | |||||
| GraphProfiler gprof{network.graph.get()}; | |||||
| HostTensorND t2; | |||||
| auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)}); | |||||
| func2->execute(); | |||||
| gprof.to_json_full(func2.get())->writeto_fpath(output_file("resnet18_f32.json")); | |||||
| /// check correct | |||||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
| } | |||||
| TEST(TestLayoutTransform, MobileNetV2) { | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| Network network(cn); | |||||
| auto output = make_mobilenet_v2(network, 1); | |||||
| HostTensorND t1; | |||||
| auto func1 = network.graph->compile({make_callback_copy(output, t1)}); | |||||
| func1->execute(); | |||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | |||||
| using Target = LayoutTransformContext::Target; | |||||
| using Attribute = LayoutTransformContext::Attribute; | |||||
| OprList opr_list = { | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| opr::ElemwiseMultiType::typeinfo(), | |||||
| opr::Elemwise::typeinfo(), | |||||
| opr::TypeCvt::typeinfo(), | |||||
| opr::Concat::typeinfo(), | |||||
| opr::PoolingForward::typeinfo(), | |||||
| opr::WarpPerspectiveForward::typeinfo(), | |||||
| opr::Resize::typeinfo(), | |||||
| }; | |||||
| SmallVector<TensorFormats> available_tensor_formats = { | |||||
| TensorFormats::NCHW, | |||||
| TensorFormats::NCHWc4, | |||||
| TensorFormats::NCHWc8, | |||||
| }; | |||||
| Attribute attribute = { | |||||
| OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC}; | |||||
| auto ctx = std::make_unique<LayoutTransformContext>( | |||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | |||||
| ctx->add_opr_config( | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| { | |||||
| OprFormatConfigID::NCHW44, | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW44_HYBRID, | |||||
| }) | |||||
| .add_opr_config( | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| { | |||||
| OprFormatConfigID::NCHW44, | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW44_HYBRID, | |||||
| }) | |||||
| .add_opr_config( | |||||
| opr::PoolingForward::typeinfo(), { | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW44, | |||||
| }); | |||||
| #if MGB_WITH_CACHED_TEST | |||||
| auto profiler = std::make_unique<ProfilerMock>( | |||||
| static_cast<const uint8_t*>(TestLayoutTransform_MobileNetV2_F32.data()), | |||||
| TestLayoutTransform_MobileNetV2_F32.size()); | |||||
| #else | |||||
| auto profiler = ProfilerBase::make_cached_profiler( | |||||
| "TestLayoutTransform.MobileNetV2_F32.cache"); | |||||
| #endif | |||||
| std::unique_ptr<SolverBase> solver{ | |||||
| new DynamicProgrammingSolver(std::move(profiler))}; | |||||
| auto new_output = | |||||
| gopt::GraphOptimizer{} | |||||
| .add_pass<FuseConvBiasNonlinPass>() | |||||
| .add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver)) | |||||
| .add_pass<ShuffleShuffleRemovePass>() | |||||
| .add_pass<ParamFusePass>() | |||||
| .add_pass<ParamMergePass>() | |||||
| .apply({{output}}) | |||||
| .endpoint_vars(); | |||||
| auto new_out_var = new_output[0]; | |||||
| /// check global layout transform pass | |||||
| auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | |||||
| ASSERT_EQ(nr_dimshuffle, 1u); | |||||
| /// check first conv format | |||||
| const auto& first_conv = find_opr<opr::ConvBiasForward>(new_out_var); | |||||
| const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>(); | |||||
| ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW44); | |||||
| GraphProfiler gprof{network.graph.get()}; | |||||
| HostTensorND t2; | |||||
| auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)}); | |||||
| func2->execute(); | |||||
| gprof.to_json_full(func2.get())->writeto_fpath(output_file("mobilenet_v2_f32.json")); | |||||
| /// check correct | |||||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -45,6 +45,36 @@ SymbolVar Network::add_conv( | |||||
| return conv; | return conv; | ||||
| } | } | ||||
| SymbolVar Network::add_group_conv( | |||||
| SymbolVar f, size_t output_channels, size_t groups, KernSize kern_size, | |||||
| DType out_dtype, bool has_relu, Stride stride, Padding padding) { | |||||
| static int weight_idx = 0; | |||||
| static int bias_idx = 0; | |||||
| size_t input_channels = f.node()->shape()[1]; | |||||
| auto weight = add_cvar( | |||||
| ssprintf("w%d", weight_idx).c_str(), | |||||
| {groups, output_channels / groups, input_channels / groups, kern_size[0], | |||||
| kern_size[1]}); | |||||
| auto bias = add_cvar(ssprintf("b%d", bias_idx).c_str(), {1, output_channels, 1, 1}); | |||||
| mgb_assert(out_dtype.category() == DTypeCategory::FLOAT); | |||||
| opr::ConvBias::Param param; | |||||
| param.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||||
| param.stride_h = stride[0], param.stride_w = stride[1]; | |||||
| param.pad_h = padding[0], param.pad_w = padding[1]; | |||||
| if (has_relu) { | |||||
| param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
| } else { | |||||
| param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; | |||||
| } | |||||
| auto conv = opr::ConvBias::make( | |||||
| f, weight, bias, param, {}, OperatorNodeConfig{out_dtype}); | |||||
| weight_idx++; | |||||
| bias_idx++; | |||||
| return conv; | |||||
| } | |||||
| SymbolVar Network::add_deconv( | SymbolVar Network::add_deconv( | ||||
| SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype) { | SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype) { | ||||
| static int weight_idx = 0; | static int weight_idx = 0; | ||||
| @@ -208,6 +238,7 @@ SymbolVarArray fusion_pyramids_feature( | |||||
| false, {1, 1}, {0, 0}); | false, {1, 1}, {0, 0}); | ||||
| if (!touch) { | if (!touch) { | ||||
| x = f; | x = f; | ||||
| touch = true; | |||||
| } else { | } else { | ||||
| x = network.add_deconv(x, 2, 16, dtype::QuantizedS8{1.f}); | x = network.add_deconv(x, 2, 16, dtype::QuantizedS8{1.f}); | ||||
| x = network.add_elemwise( | x = network.add_elemwise( | ||||
| @@ -236,4 +267,63 @@ SymbolVarArray mgb::make_det(Network& network, size_t batch, DType out_dtype) { | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| SymbolVar mgb::bottleneck( | |||||
| Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t, | |||||
| size_t stride) { | |||||
| size_t in_channels = f.node()->shape()[1]; | |||||
| SymbolVar x = f; | |||||
| if (t != 1) { | |||||
| x = network.add_conv( | |||||
| f, input_channels * t, {1, 1}, dtype::Float32(), true, {1, 1}, {0, 0}); | |||||
| } | |||||
| x = network.add_group_conv( | |||||
| x, input_channels * t, input_channels * t, {3, 3}, dtype::Float32(), true, | |||||
| {stride, stride}, {1, 1}); | |||||
| x = network.add_conv(x, channels, {1, 1}, dtype::Float32(), false, {1, 1}, {0, 0}); | |||||
| if (stride == 1 && in_channels == channels) | |||||
| x = f + x; | |||||
| return x; | |||||
| } | |||||
| SymbolVar mgb::bottleneck_group( | |||||
| Network& network, SymbolVar f, size_t input_channels, size_t channels, | |||||
| size_t stages, size_t s, size_t t) { | |||||
| SymbolVar x = f; | |||||
| for (size_t i = 0; i < stages; ++i) { | |||||
| size_t stride = i == 0 ? s : 1; | |||||
| x = bottleneck(network, x, input_channels, channels, t, stride); | |||||
| input_channels = channels; | |||||
| } | |||||
| return x; | |||||
| } | |||||
| namespace { | |||||
| size_t make_divisible(size_t v, size_t divisor) { | |||||
| size_t min_value = divisor; | |||||
| size_t new_v = std::max(min_value, (v + divisor / 2) / divisor * divisor); | |||||
| if (new_v < 0.9 * v) | |||||
| new_v += divisor; | |||||
| return new_v; | |||||
| } | |||||
| } // namespace | |||||
| SymbolVar mgb::make_mobilenet_v2(Network& network, size_t batch) { | |||||
| auto data = network.add_var("data", {batch, 3, 224, 224}); | |||||
| constexpr size_t round_nearest = 8; | |||||
| auto x = network.add_conv( | |||||
| data, make_divisible(32, round_nearest), {3, 3}, dtype::Float32(), true, | |||||
| {2, 2}, {1, 1}); | |||||
| x = bottleneck(network, x, 32, make_divisible(16, round_nearest), 1, 1); | |||||
| x = bottleneck_group(network, x, 16, make_divisible(24, round_nearest), 2, 2, 6); | |||||
| x = bottleneck_group(network, x, 24, make_divisible(32, round_nearest), 3, 2, 6); | |||||
| x = bottleneck_group(network, x, 32, make_divisible(64, round_nearest), 4, 2, 6); | |||||
| x = bottleneck_group(network, x, 64, make_divisible(96, round_nearest), 3, 1, 6); | |||||
| x = bottleneck_group(network, x, 96, make_divisible(160, round_nearest), 3, 2, 6); | |||||
| x = bottleneck_group(network, x, 160, make_divisible(320, round_nearest), 1, 1, 6); | |||||
| x = network.add_conv( | |||||
| x, make_divisible(1280, round_nearest), {1, 1}, dtype::Float32(), true, | |||||
| {1, 1}, {0, 0}); | |||||
| return x; | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -28,7 +28,7 @@ | |||||
| namespace mgb { | namespace mgb { | ||||
| class Network { | class Network { | ||||
| private: | private: | ||||
| HostTensorGenerator<> gen; | |||||
| HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{-0.01, 0.01}; | |||||
| CompNode cn; | CompNode cn; | ||||
| public: | public: | ||||
| @@ -49,6 +49,10 @@ public: | |||||
| SymbolVar f, size_t output_channels, KernSize kern_size, | SymbolVar f, size_t output_channels, KernSize kern_size, | ||||
| DType out_dtype = dtype::Float32(), bool has_relu = true, | DType out_dtype = dtype::Float32(), bool has_relu = true, | ||||
| Stride stride = {1, 1}, Padding padding = {0, 0}); | Stride stride = {1, 1}, Padding padding = {0, 0}); | ||||
| SymbolVar add_group_conv( | |||||
| SymbolVar f, size_t output_channels, size_t groups, KernSize kern_size, | |||||
| DType out_dtype = dtype::Float32(), bool has_relu = true, | |||||
| Stride stride = {1, 1}, Padding padding = {0, 0}); | |||||
| SymbolVar add_deconv( | SymbolVar add_deconv( | ||||
| SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype); | SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype); | ||||
| SymbolVar add_elemwise( | SymbolVar add_elemwise( | ||||
| @@ -73,6 +77,16 @@ SymbolVar make_resnet18( | |||||
| SymbolVarArray make_det( | SymbolVarArray make_det( | ||||
| Network& network, size_t batch = 16, DType out_dtype = dtype::Float32()); | Network& network, size_t batch = 16, DType out_dtype = dtype::Float32()); | ||||
| SymbolVar bottleneck( | |||||
| Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t, | |||||
| size_t stride); | |||||
| SymbolVar bottleneck_group( | |||||
| Network& network, SymbolVar f, size_t input_channels, size_t channels, | |||||
| size_t stages, size_t s, size_t t); | |||||
| SymbolVar make_mobilenet_v2(Network& network, size_t batch = 1); | |||||
| } // namespace mgb | } // namespace mgb | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -26,7 +26,7 @@ using namespace serialization; | |||||
| #if MGB_CUDA | #if MGB_CUDA | ||||
| namespace { | namespace { | ||||
| std::unique_ptr<LayoutTransformContext> make_ctx() { | std::unique_ptr<LayoutTransformContext> make_ctx() { | ||||
| using OprFormat = LayoutTransformContext::OprFormat; | |||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | using OprList = LayoutTransformContext::OprList; | ||||
| using Attribute = LayoutTransformContext::Attribute; | using Attribute = LayoutTransformContext::Attribute; | ||||
| using Target = LayoutTransformContext::Target; | using Target = LayoutTransformContext::Target; | ||||
| @@ -44,26 +44,29 @@ std::unique_ptr<LayoutTransformContext> make_ctx() { | |||||
| SmallVector<TensorFormats> available_tensor_formats = { | SmallVector<TensorFormats> available_tensor_formats = { | ||||
| TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4, | ||||
| TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4}; | ||||
| Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::CUDA}; | |||||
| Attribute attribute = {OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::CUDA}; | |||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | std::move(opr_list), std::move(available_tensor_formats), attribute); | ||||
| ctx->add_opr_config( | ctx->add_opr_config( | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW32, | |||||
| OprFormat::NCHW64, OprFormat::CHWN4}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NHWC, | |||||
| OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::ConvolutionForward::typeinfo(), | opr::ConvolutionForward::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NCHW4}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::ConvolutionBackwardData::typeinfo(), | opr::ConvolutionBackwardData::typeinfo(), | ||||
| {OprFormat::NCHW, OprFormat::NCHW4}) | |||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::PoolingForward::typeinfo(), | opr::PoolingForward::typeinfo(), | ||||
| {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, | |||||
| OprFormat::NCHW64, OprFormat::CHWN4}) | |||||
| {OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32, | |||||
| OprFormatConfigID::NHWC, OprFormatConfigID::NCHW64, | |||||
| OprFormatConfigID::CHWN4}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::WarpPerspectiveForward::typeinfo(), | opr::WarpPerspectiveForward::typeinfo(), | ||||
| {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); | |||||
| {OprFormatConfigID::NHWC, OprFormatConfigID::NCHW4, | |||||
| OprFormatConfigID::NCHW64}); | |||||
| return ctx; | return ctx; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||