GitOrigin-RevId: 3b95da02c8
tags/v1.7.0
| @@ -76,7 +76,7 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| cg::OperatorNodeBase* apply_on_var_node( | |||||
| VarNodeArray apply_on_var_node( | |||||
| const OpDef& def, const VarNodeArray& inputs) { | const OpDef& def, const VarNodeArray& inputs) { | ||||
| auto&& attr = def.cast_final_safe<OprAttr>(); | auto&& attr = def.cast_final_safe<OprAttr>(); | ||||
| auto config = attr.config; | auto config = attr.config; | ||||
| @@ -85,7 +85,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||||
| auto registry = serialization::OprRegistry::find_by_name(attr.type); | auto registry = serialization::OprRegistry::find_by_name(attr.type); | ||||
| mgb_assert(registry, "operator %s not found", attr.type.c_str()); | mgb_assert(registry, "operator %s not found", attr.type.c_str()); | ||||
| OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | ||||
| return registry->loader(ctx, inputs, config); | |||||
| return registry->loader(ctx, inputs, config).usable_output(); | |||||
| } | } | ||||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | ||||
| @@ -200,7 +200,7 @@ TEST(TestImperative, BatchNormGrad) { | |||||
| LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | ||||
| LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | ||||
| { | { | ||||
| auto op = OprAttr::make("BatchNorm"); | |||||
| auto op = OprAttr::make("BatchNormV1"); | |||||
| auto&& attr = op->cast_final_safe<OprAttr>(); | auto&& attr = op->cast_final_safe<OprAttr>(); | ||||
| Param param; | Param param; | ||||
| param.fwd_mode = Param::FwdMode::TRAINING; | param.fwd_mode = Param::FwdMode::TRAINING; | ||||
| @@ -210,7 +210,7 @@ TEST(TestImperative, BatchNormGrad) { | |||||
| {false, false, false, false, false, true}); | {false, false, false, false, false, true}); | ||||
| } | } | ||||
| { | { | ||||
| auto op = OprAttr::make("BatchNorm"); | |||||
| auto op = OprAttr::make("BatchNormV1"); | |||||
| auto&& attr = op->cast_final_safe<OprAttr>(); | auto&& attr = op->cast_final_safe<OprAttr>(); | ||||
| Param param; | Param param; | ||||
| param.fwd_mode = Param::FwdMode::TRAINING; | param.fwd_mode = Param::FwdMode::TRAINING; | ||||
| @@ -59,7 +59,7 @@ TEST(TestImperative, Reduce) { | |||||
| } | } | ||||
| TEST(TestImperative, BatchNorm) { | TEST(TestImperative, BatchNorm) { | ||||
| auto op = OprAttr::make("BatchNorm"); | |||||
| auto op = OprAttr::make("BatchNormV1"); | |||||
| auto&& attr = op->cast_final_safe<OprAttr>(); | auto&& attr = op->cast_final_safe<OprAttr>(); | ||||
| using Param = opr::BatchNorm::Param; | using Param = opr::BatchNorm::Param; | ||||
| Param param; | Param param; | ||||
| @@ -16,14 +16,13 @@ | |||||
| #include "megbrain/opr/dnn/correlation.h" | #include "megbrain/opr/dnn/correlation.h" | ||||
| #include "megbrain/opr/dnn/fake_quant.h" | #include "megbrain/opr/dnn/fake_quant.h" | ||||
| #include "megbrain/opr/dnn/images2neibs.h" | #include "megbrain/opr/dnn/images2neibs.h" | ||||
| #include "megbrain/opr/dnn/sliding_window_transpose.h" | |||||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||||
| #include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
| #include "megbrain/opr/dnn/lrn.h" | #include "megbrain/opr/dnn/lrn.h" | ||||
| #include "megbrain/opr/dnn/lsq.h" | #include "megbrain/opr/dnn/lsq.h" | ||||
| #include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
| #include "megbrain/opr/dnn/roi_align.h" | #include "megbrain/opr/dnn/roi_align.h" | ||||
| #include "megbrain/opr/dnn/roi_pooling.h" | #include "megbrain/opr/dnn/roi_pooling.h" | ||||
| #include "megbrain/opr/dnn/sliding_window_transpose.h" | |||||
| #include "megbrain/opr/dnn/tqt.h" | #include "megbrain/opr/dnn/tqt.h" | ||||
| #include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
| #include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
| @@ -390,6 +389,7 @@ struct OprMaker<opr::BatchNorm, 0> { | |||||
| } | } | ||||
| }; | }; | ||||
| // OprMaker in MGB_SEREG_OPR only support unique output opr | |||||
| template <> | template <> | ||||
| struct OprMaker<opr::BatchNormBackward, 6> { | struct OprMaker<opr::BatchNormBackward, 6> { | ||||
| using Param = opr::BatchNormBackward::Param; | using Param = opr::BatchNormBackward::Param; | ||||
| @@ -398,8 +398,8 @@ struct OprMaker<opr::BatchNormBackward, 6> { | |||||
| ComputingGraph& graph, | ComputingGraph& graph, | ||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| MGB_MARK_USED_VAR(graph); | MGB_MARK_USED_VAR(graph); | ||||
| return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param, | |||||
| config)[0] | |||||
| return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], | |||||
| param, config)[0] | |||||
| .node() | .node() | ||||
| ->owner_opr(); | ->owner_opr(); | ||||
| } | } | ||||
| @@ -575,8 +575,10 @@ MGB_SEREG_OPR(Convolution3DBackwardFilter, 0); | |||||
| using ConvBiasForwardV4 = ConvBiasForward; | using ConvBiasForwardV4 = ConvBiasForward; | ||||
| MGB_SEREG_OPR(ConvBiasForwardV4, 0); | MGB_SEREG_OPR(ConvBiasForwardV4, 0); | ||||
| MGB_SEREG_OPR(BatchNorm, 0); | |||||
| MGB_SEREG_OPR(BatchNormBackward, 6); | |||||
| using BatchNormV1 = BatchNorm; | |||||
| using BatchNormBackwardV1 = BatchNormBackward; | |||||
| MGB_SEREG_OPR(BatchNormV1, 0); | |||||
| MGB_SEREG_OPR(BatchNormBackwardV1, 6); | |||||
| using LocalShareForwardV1 = LocalShareForward; | using LocalShareForwardV1 = LocalShareForward; | ||||
| using LocalShareBackwardDataV1 = LocalShareBackwardData; | using LocalShareBackwardDataV1 = LocalShareBackwardData; | ||||
| @@ -39,7 +39,7 @@ namespace { | |||||
| return inst; | return inst; | ||||
| } | } | ||||
| cg::OperatorNodeBase* dynamic_loader( | |||||
| OprWithOutputAccessor dynamic_loader( | |||||
| OprLoadContext &ctx, const cg::VarNodeArray &inputs, | OprLoadContext &ctx, const cg::VarNodeArray &inputs, | ||||
| const OperatorNodeConfig &config) { | const OperatorNodeConfig &config) { | ||||
| auto name = ctx.load_buf_with_len(); | auto name = ctx.load_buf_with_len(); | ||||
| @@ -171,4 +171,20 @@ std::vector<std::pair<size_t, std::string>> OprRegistry::dump_registries() { | |||||
| } | } | ||||
| #endif | #endif | ||||
| namespace { | |||||
| const VarNodeArray& default_accessor(const VarNodeArray& outputs) { | |||||
| return outputs; | |||||
| } | |||||
| } | |||||
| OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr) : m_opr(opr){ | |||||
| m_accessor = &default_accessor; | |||||
| }; | |||||
| OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor) | |||||
| : OprWithOutputAccessor(opr) { | |||||
| if (accessor) { | |||||
| m_accessor = accessor; | |||||
| } | |||||
| }; | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -207,7 +207,7 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl( | |||||
| registry->dumper(dumper, opr); | registry->dumper(dumper, opr); | ||||
| OprLoadContextMemory loader{opr.owner_graph(), dumper}; | OprLoadContextMemory loader{opr.owner_graph(), dumper}; | ||||
| return registry->loader(loader, inputs, config); | |||||
| return registry->loader(loader, inputs, config).opr(); | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -782,7 +782,8 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||||
| } | } | ||||
| // call loader | // call loader | ||||
| auto opr = registry->loader(*this, inputs, config); | |||||
| auto accessor = registry->loader(*this, inputs, config); | |||||
| auto opr = accessor.opr(); | |||||
| // check opr type; note that: | // check opr type; note that: | ||||
| // 1. registry->type may be empty for dynamic opr loaders or legacy oprs | // 1. registry->type may be empty for dynamic opr loaders or legacy oprs | ||||
| @@ -794,7 +795,7 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||||
| opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name); | opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name); | ||||
| // record output vars; read output names | // record output vars; read output names | ||||
| size_t i = 0; | size_t i = 0; | ||||
| for (auto ovar : opr->output()) { | |||||
| for (auto ovar : accessor.output()) { | |||||
| if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | ||||
| m_id2varnode.push_back(ovar); | m_id2varnode.push_back(ovar); | ||||
| if (fbopr->output_name()) { | if (fbopr->output_name()) { | ||||
| @@ -19,16 +19,36 @@ namespace serialization { | |||||
| class OprDumpContext; | class OprDumpContext; | ||||
| class OprLoadContext; | class OprLoadContext; | ||||
| class OprShallowCopyContext; | class OprShallowCopyContext; | ||||
| class OprWithOutputAccessor { | |||||
| cg::OperatorNodeBase* m_opr; | |||||
| using Accessor = thin_function<const VarNodeArray(const VarNodeArray&)>; | |||||
| Accessor m_accessor; | |||||
| public: | |||||
| OprWithOutputAccessor(cg::OperatorNodeBase* opr); | |||||
| OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor); | |||||
| VarNode* output(size_t idx) const { return output().at(idx); } | |||||
| VarNodeArray output() const { return m_accessor(m_opr->output()); } | |||||
| VarNodeArray usable_output() const { return m_accessor(m_opr->usable_output()); } | |||||
| cg::OperatorNodeBase* opr() { return m_opr; } | |||||
| }; | |||||
| //! dump opr internal params to OprDumpContext | //! dump opr internal params to OprDumpContext | ||||
| using OprDumper = thin_function<void( | using OprDumper = thin_function<void( | ||||
| OprDumpContext &ctx, const cg::OperatorNodeBase &opr)>; | OprDumpContext &ctx, const cg::OperatorNodeBase &opr)>; | ||||
| //! load and restore operator from OprLoadContext | //! load and restore operator from OprLoadContext | ||||
| //! is also used by GraphLoadConfig. | |||||
| using OprLoader = thin_function<cg::OperatorNodeBase*( | using OprLoader = thin_function<cg::OperatorNodeBase*( | ||||
| OprLoadContext &ctx, const cg::VarNodeArray &inputs, | OprLoadContext &ctx, const cg::VarNodeArray &inputs, | ||||
| const OperatorNodeConfig &config)>; | const OperatorNodeConfig &config)>; | ||||
| //! loader that can change opr output map for compatibility | |||||
| using OprLoaderWrapper = thin_function<OprWithOutputAccessor( | |||||
| OprLoadContext &ctx, const cg::VarNodeArray &inputs, | |||||
| const OperatorNodeConfig &config)>; | |||||
| //! shallow copy function for a single operator | //! shallow copy function for a single operator | ||||
| using OprShallowCopy = thin_function<cg::OperatorNodeBase*( | using OprShallowCopy = thin_function<cg::OperatorNodeBase*( | ||||
| const OprShallowCopyContext &ctx, | const OprShallowCopyContext &ctx, | ||||
| @@ -41,7 +61,7 @@ namespace serialization { | |||||
| uint64_t persist_type_id; | uint64_t persist_type_id; | ||||
| std::string name; | std::string name; | ||||
| OprDumper dumper; | OprDumper dumper; | ||||
| OprLoader loader; | |||||
| OprLoaderWrapper loader; | |||||
| OprShallowCopy shallow_copy; //!< set to empty to use default impl | OprShallowCopy shallow_copy; //!< set to empty to use default impl | ||||
| uint64_t unversioned_type_id; | uint64_t unversioned_type_id; | ||||
| @@ -167,16 +167,22 @@ namespace { \ | |||||
| /*! | /*! | ||||
| * \brief register opr serialization methods | * \brief register opr serialization methods | ||||
| */ | */ | ||||
| #define MGB_SEREG_OPR(_cls, _arity) \ | |||||
| namespace { \ | |||||
| struct _OprReg##_cls { \ | |||||
| static void entry() { \ | |||||
| using Impl = ::mgb::serialization::OprLoadDumpImpl< \ | |||||
| _cls, _arity>; \ | |||||
| MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, Impl::load); \ | |||||
| } \ | |||||
| }; \ | |||||
| } \ | |||||
| #define MGB_SEREG_OPR(_cls, _arity) \ | |||||
| namespace { \ | |||||
| namespace ser = ::mgb::serialization; \ | |||||
| struct _OprReg##_cls { \ | |||||
| using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ | |||||
| static ser::OprWithOutputAccessor wrap_loader( \ | |||||
| ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||||
| const mgb::cg::OperatorNodeConfig& config) { \ | |||||
| return ser::OprWithOutputAccessor( \ | |||||
| Impl::load(ctx, inputs, config)); \ | |||||
| } \ | |||||
| static void entry() { \ | |||||
| MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \ | |||||
| } \ | |||||
| }; \ | |||||
| } \ | |||||
| MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | ||||
| //! use to check type is complete or not, midout need a complete type | //! use to check type is complete or not, midout need a complete type | ||||
| @@ -187,33 +193,35 @@ template <class T> | |||||
| struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | ||||
| //! call OprRegistry::add with only loader, used for backward compatibility | //! call OprRegistry::add with only loader, used for backward compatibility | ||||
| #define MGB_SEREG_OPR_COMPAT(_name, _load) \ | |||||
| namespace { \ | |||||
| static_assert(IsComplete<_name>(), \ | |||||
| "need a complete type for MGB_SEREG_OPR_COMPAT"); \ | |||||
| struct _OprReg##_name { \ | |||||
| static cg::OperatorNodeBase* compat_loader( \ | |||||
| serialization::OprLoadContext& ctx, \ | |||||
| const cg::VarNodeArray& inputs, \ | |||||
| const OperatorNodeConfig& config) { \ | |||||
| return _load( \ | |||||
| static_cast<serialization::OprLoadContextRawPOD&>(ctx), \ | |||||
| inputs, config); \ | |||||
| } \ | |||||
| static void entry() { \ | |||||
| ::mgb::serialization::OprRegistry::add( \ | |||||
| {nullptr, \ | |||||
| MGB_HASH_STR(#_name), \ | |||||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_name), \ | |||||
| nullptr, \ | |||||
| compat_loader, \ | |||||
| {}, \ | |||||
| {}}); \ | |||||
| } \ | |||||
| }; \ | |||||
| } \ | |||||
| #define MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, _accessor) \ | |||||
| namespace { \ | |||||
| static_assert(IsComplete<_name>(), \ | |||||
| "need a complete type for MGB_SEREG_OPR_COMPAT"); \ | |||||
| namespace ser = ::mgb::serialization; \ | |||||
| struct _OprReg##_name { \ | |||||
| static ser::OprWithOutputAccessor compat_loader( \ | |||||
| ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||||
| const mgb::cg::OperatorNodeConfig& config) { \ | |||||
| auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \ | |||||
| return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), \ | |||||
| _accessor); \ | |||||
| } \ | |||||
| static void entry() { \ | |||||
| ser::OprRegistry::add({nullptr, \ | |||||
| MGB_HASH_STR(#_name), \ | |||||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_name), \ | |||||
| nullptr, \ | |||||
| compat_loader, \ | |||||
| {}, \ | |||||
| {}}); \ | |||||
| } \ | |||||
| }; \ | |||||
| } \ | |||||
| MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name) | MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name) | ||||
| #define MGB_SEREG_OPR_COMPAT(_name, _load) \ | |||||
| MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, nullptr) | |||||
| /*! | /*! | ||||
| * \brief use \p _copy to implement shallow copy for given operator | * \brief use \p _copy to implement shallow copy for given operator | ||||
| */ | */ | ||||