GitOrigin-RevId: 3b95da02c8
tags/v1.7.0
| @@ -76,7 +76,7 @@ public: | |||
| } | |||
| }; | |||
| cg::OperatorNodeBase* apply_on_var_node( | |||
| VarNodeArray apply_on_var_node( | |||
| const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& attr = def.cast_final_safe<OprAttr>(); | |||
| auto config = attr.config; | |||
| @@ -85,7 +85,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
| auto registry = serialization::OprRegistry::find_by_name(attr.type); | |||
| mgb_assert(registry, "operator %s not found", attr.type.c_str()); | |||
| OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | |||
| return registry->loader(ctx, inputs, config); | |||
| return registry->loader(ctx, inputs, config).usable_output(); | |||
| } | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | |||
| @@ -200,7 +200,7 @@ TEST(TestImperative, BatchNormGrad) { | |||
| LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | |||
| LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | |||
| { | |||
| auto op = OprAttr::make("BatchNorm"); | |||
| auto op = OprAttr::make("BatchNormV1"); | |||
| auto&& attr = op->cast_final_safe<OprAttr>(); | |||
| Param param; | |||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||
| @@ -210,7 +210,7 @@ TEST(TestImperative, BatchNormGrad) { | |||
| {false, false, false, false, false, true}); | |||
| } | |||
| { | |||
| auto op = OprAttr::make("BatchNorm"); | |||
| auto op = OprAttr::make("BatchNormV1"); | |||
| auto&& attr = op->cast_final_safe<OprAttr>(); | |||
| Param param; | |||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||
| @@ -59,7 +59,7 @@ TEST(TestImperative, Reduce) { | |||
| } | |||
| TEST(TestImperative, BatchNorm) { | |||
| auto op = OprAttr::make("BatchNorm"); | |||
| auto op = OprAttr::make("BatchNormV1"); | |||
| auto&& attr = op->cast_final_safe<OprAttr>(); | |||
| using Param = opr::BatchNorm::Param; | |||
| Param param; | |||
| @@ -16,14 +16,13 @@ | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/opr/dnn/fake_quant.h" | |||
| #include "megbrain/opr/dnn/images2neibs.h" | |||
| #include "megbrain/opr/dnn/sliding_window_transpose.h" | |||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/dnn/lrn.h" | |||
| #include "megbrain/opr/dnn/lsq.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/dnn/roi_align.h" | |||
| #include "megbrain/opr/dnn/roi_pooling.h" | |||
| #include "megbrain/opr/dnn/sliding_window_transpose.h" | |||
| #include "megbrain/opr/dnn/tqt.h" | |||
| #include "megbrain/serialization/sereg.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| @@ -390,6 +389,7 @@ struct OprMaker<opr::BatchNorm, 0> { | |||
| } | |||
| }; | |||
| // OprMaker in MGB_SEREG_OPR only support unique output opr | |||
| template <> | |||
| struct OprMaker<opr::BatchNormBackward, 6> { | |||
| using Param = opr::BatchNormBackward::Param; | |||
| @@ -398,8 +398,8 @@ struct OprMaker<opr::BatchNormBackward, 6> { | |||
| ComputingGraph& graph, | |||
| const OperatorNodeConfig& config) { | |||
| MGB_MARK_USED_VAR(graph); | |||
| return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], param, | |||
| config)[0] | |||
| return opr::BatchNormBackward::make(i[0], i[1], i[2], i[3], i[4], i[5], | |||
| param, config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| @@ -575,8 +575,10 @@ MGB_SEREG_OPR(Convolution3DBackwardFilter, 0); | |||
| using ConvBiasForwardV4 = ConvBiasForward; | |||
| MGB_SEREG_OPR(ConvBiasForwardV4, 0); | |||
| MGB_SEREG_OPR(BatchNorm, 0); | |||
| MGB_SEREG_OPR(BatchNormBackward, 6); | |||
| using BatchNormV1 = BatchNorm; | |||
| using BatchNormBackwardV1 = BatchNormBackward; | |||
| MGB_SEREG_OPR(BatchNormV1, 0); | |||
| MGB_SEREG_OPR(BatchNormBackwardV1, 6); | |||
| using LocalShareForwardV1 = LocalShareForward; | |||
| using LocalShareBackwardDataV1 = LocalShareBackwardData; | |||
| @@ -39,7 +39,7 @@ namespace { | |||
| return inst; | |||
| } | |||
| cg::OperatorNodeBase* dynamic_loader( | |||
| OprWithOutputAccessor dynamic_loader( | |||
| OprLoadContext &ctx, const cg::VarNodeArray &inputs, | |||
| const OperatorNodeConfig &config) { | |||
| auto name = ctx.load_buf_with_len(); | |||
| @@ -171,4 +171,20 @@ std::vector<std::pair<size_t, std::string>> OprRegistry::dump_registries() { | |||
| } | |||
| #endif | |||
| namespace { | |||
| const VarNodeArray& default_accessor(const VarNodeArray& outputs) { | |||
| return outputs; | |||
| } | |||
| } | |||
| OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr) : m_opr(opr){ | |||
| m_accessor = &default_accessor; | |||
| }; | |||
| OprWithOutputAccessor::OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor) | |||
| : OprWithOutputAccessor(opr) { | |||
| if (accessor) { | |||
| m_accessor = accessor; | |||
| } | |||
| }; | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -207,7 +207,7 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl( | |||
| registry->dumper(dumper, opr); | |||
| OprLoadContextMemory loader{opr.owner_graph(), dumper}; | |||
| return registry->loader(loader, inputs, config); | |||
| return registry->loader(loader, inputs, config).opr(); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -782,7 +782,8 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||
| } | |||
| // call loader | |||
| auto opr = registry->loader(*this, inputs, config); | |||
| auto accessor = registry->loader(*this, inputs, config); | |||
| auto opr = accessor.opr(); | |||
| // check opr type; note that: | |||
| // 1. registry->type may be empty for dynamic opr loaders or legacy oprs | |||
| @@ -794,7 +795,7 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||
| opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name); | |||
| // record output vars; read output names | |||
| size_t i = 0; | |||
| for (auto ovar : opr->output()) { | |||
| for (auto ovar : accessor.output()) { | |||
| if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||
| m_id2varnode.push_back(ovar); | |||
| if (fbopr->output_name()) { | |||
| @@ -19,16 +19,36 @@ namespace serialization { | |||
| class OprDumpContext; | |||
| class OprLoadContext; | |||
| class OprShallowCopyContext; | |||
| class OprWithOutputAccessor { | |||
| cg::OperatorNodeBase* m_opr; | |||
| using Accessor = thin_function<const VarNodeArray(const VarNodeArray&)>; | |||
| Accessor m_accessor; | |||
| public: | |||
| OprWithOutputAccessor(cg::OperatorNodeBase* opr); | |||
| OprWithOutputAccessor(cg::OperatorNodeBase* opr, Accessor accessor); | |||
| VarNode* output(size_t idx) const { return output().at(idx); } | |||
| VarNodeArray output() const { return m_accessor(m_opr->output()); } | |||
| VarNodeArray usable_output() const { return m_accessor(m_opr->usable_output()); } | |||
| cg::OperatorNodeBase* opr() { return m_opr; } | |||
| }; | |||
| //! dump opr internal params to OprDumpContext | |||
| using OprDumper = thin_function<void( | |||
| OprDumpContext &ctx, const cg::OperatorNodeBase &opr)>; | |||
| //! load and restore operator from OprLoadContext | |||
| //! is also used by GraphLoadConfig. | |||
| using OprLoader = thin_function<cg::OperatorNodeBase*( | |||
| OprLoadContext &ctx, const cg::VarNodeArray &inputs, | |||
| const OperatorNodeConfig &config)>; | |||
| //! loader that can change opr output map for compatibility | |||
| using OprLoaderWrapper = thin_function<OprWithOutputAccessor( | |||
| OprLoadContext &ctx, const cg::VarNodeArray &inputs, | |||
| const OperatorNodeConfig &config)>; | |||
| //! shallow copy function for a single operator | |||
| using OprShallowCopy = thin_function<cg::OperatorNodeBase*( | |||
| const OprShallowCopyContext &ctx, | |||
| @@ -41,7 +61,7 @@ namespace serialization { | |||
| uint64_t persist_type_id; | |||
| std::string name; | |||
| OprDumper dumper; | |||
| OprLoader loader; | |||
| OprLoaderWrapper loader; | |||
| OprShallowCopy shallow_copy; //!< set to empty to use default impl | |||
| uint64_t unversioned_type_id; | |||
| @@ -167,16 +167,22 @@ namespace { \ | |||
| /*! | |||
| * \brief register opr serialization methods | |||
| */ | |||
| #define MGB_SEREG_OPR(_cls, _arity) \ | |||
| namespace { \ | |||
| struct _OprReg##_cls { \ | |||
| static void entry() { \ | |||
| using Impl = ::mgb::serialization::OprLoadDumpImpl< \ | |||
| _cls, _arity>; \ | |||
| MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, Impl::load); \ | |||
| } \ | |||
| }; \ | |||
| } \ | |||
| #define MGB_SEREG_OPR(_cls, _arity) \ | |||
| namespace { \ | |||
| namespace ser = ::mgb::serialization; \ | |||
| struct _OprReg##_cls { \ | |||
| using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ | |||
| static ser::OprWithOutputAccessor wrap_loader( \ | |||
| ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
| const mgb::cg::OperatorNodeConfig& config) { \ | |||
| return ser::OprWithOutputAccessor( \ | |||
| Impl::load(ctx, inputs, config)); \ | |||
| } \ | |||
| static void entry() { \ | |||
| MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \ | |||
| } \ | |||
| }; \ | |||
| } \ | |||
| MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | |||
| //! use to check type is complete or not, midout need a complete type | |||
| @@ -187,33 +193,35 @@ template <class T> | |||
| struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | |||
| //! call OprRegistry::add with only loader, used for backward compatibility | |||
| #define MGB_SEREG_OPR_COMPAT(_name, _load) \ | |||
| namespace { \ | |||
| static_assert(IsComplete<_name>(), \ | |||
| "need a complete type for MGB_SEREG_OPR_COMPAT"); \ | |||
| struct _OprReg##_name { \ | |||
| static cg::OperatorNodeBase* compat_loader( \ | |||
| serialization::OprLoadContext& ctx, \ | |||
| const cg::VarNodeArray& inputs, \ | |||
| const OperatorNodeConfig& config) { \ | |||
| return _load( \ | |||
| static_cast<serialization::OprLoadContextRawPOD&>(ctx), \ | |||
| inputs, config); \ | |||
| } \ | |||
| static void entry() { \ | |||
| ::mgb::serialization::OprRegistry::add( \ | |||
| {nullptr, \ | |||
| MGB_HASH_STR(#_name), \ | |||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_name), \ | |||
| nullptr, \ | |||
| compat_loader, \ | |||
| {}, \ | |||
| {}}); \ | |||
| } \ | |||
| }; \ | |||
| } \ | |||
| #define MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, _accessor) \ | |||
| namespace { \ | |||
| static_assert(IsComplete<_name>(), \ | |||
| "need a complete type for MGB_SEREG_OPR_COMPAT"); \ | |||
| namespace ser = ::mgb::serialization; \ | |||
| struct _OprReg##_name { \ | |||
| static ser::OprWithOutputAccessor compat_loader( \ | |||
| ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
| const mgb::cg::OperatorNodeConfig& config) { \ | |||
| auto&& ctx_ = static_cast<ser::OprLoadContextRawPOD&>(ctx); \ | |||
| return ser::OprWithOutputAccessor(_load(ctx_, inputs, config), \ | |||
| _accessor); \ | |||
| } \ | |||
| static void entry() { \ | |||
| ser::OprRegistry::add({nullptr, \ | |||
| MGB_HASH_STR(#_name), \ | |||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_name), \ | |||
| nullptr, \ | |||
| compat_loader, \ | |||
| {}, \ | |||
| {}}); \ | |||
| } \ | |||
| }; \ | |||
| } \ | |||
| MGB_SEREG_OPR_INTL_CALL_ENTRY(_name, _OprReg##_name) | |||
| #define MGB_SEREG_OPR_COMPAT(_name, _load) \ | |||
| MGB_SEREG_OPR_COMPAT_WITH_ACCESSOR(_name, _load, nullptr) | |||
| /*! | |||
| * \brief use \p _copy to implement shallow copy for given operator | |||
| */ | |||