GitOrigin-RevId: 8eacd5e77c
tags/v1.10.0
| @@ -153,4 +153,6 @@ struct EnsureHashConstexpr { | |||
| #define MGB_HASH_STR(v) \ | |||
| ::mgb::EnsureHashConstexpr<::mgb::XXHash64CT::hash(v, sizeof(v), 20160701)>::val | |||
| #define MGB_HASH_RUNTIME(v) XXHash64CT::hash((v).c_str(), (v).size() + 1, 20160701) | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -52,6 +52,21 @@ mgb::cg::OperatorNodeBase* custom_loader( | |||
| } \ | |||
| MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) | |||
| #define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \ | |||
| namespace { \ | |||
| struct _OprRegV2##cls { \ | |||
| static void entry() { \ | |||
| MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ | |||
| cls, ::mgb::serialization::custom_dumper, \ | |||
| ::mgb::serialization::custom_loader, nullptr, _version_min, \ | |||
| _version_max); \ | |||
| } \ | |||
| }; \ | |||
| } \ | |||
| MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(cls, _OprRegV2##cls) | |||
| using namespace mgb; | |||
| using CustomOpNode = opr::CustomOpNode; | |||
| CUSTOM_OP_SEREG_REG(CustomOpNode); | |||
| CUSTOM_OP_SEREG_REG_V2(CustomOpNode, 2, CURRENT_VERSION); | |||
| @@ -0,0 +1,228 @@ | |||
| #include "megbrain/graph/symbol_var.h" | |||
| #include "megdnn/oprs/general.h" | |||
| #if MGB_ENABLE_FBS_SERIALIZATION | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/dnn/softmax.h" | |||
| #include "megbrain/serialization/oss_opr_load_dump.h" | |||
| #include "megbrain/serialization/sereg.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megdnn/oprs/nn.h" | |||
| namespace mgb { | |||
| namespace serialization { | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::Softmax, 1> { | |||
| using Opr = opr::Softmax; | |||
| using PersisParam = opr::Softmax::Param; | |||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||
| ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param()); | |||
| } | |||
| static cg::OperatorNodeBase* replace_opr( | |||
| cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | |||
| int32_t axis = opr->cast_final_safe<Opr>().param().axis; | |||
| auto input_var = inputs[0]; | |||
| auto max_reduce_out = | |||
| opr::Reduce::make(input_var, {megdnn::Reduce::Mode::MAX, axis}); | |||
| auto elemwise_sub_out = opr::Elemwise::make( | |||
| {input_var, max_reduce_out}, {megdnn::Elemwise::Mode::SUB}); | |||
| auto elemwise_exp_out = | |||
| opr::Elemwise::make({elemwise_sub_out}, {megdnn::Elemwise::Mode::EXP}); | |||
| auto sum_reduce_out = | |||
| opr::Reduce::make(elemwise_exp_out, {megdnn::Reduce::Mode::SUM, axis}); | |||
| auto out = opr::Elemwise::make( | |||
| {elemwise_exp_out, sum_reduce_out}, {megdnn::Elemwise::Mode::TRUE_DIV}); | |||
| return out.node()->owner_opr(); | |||
| } | |||
| static cg::OperatorNodeBase* load( | |||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
| auto param = fbs_ctx.read_param<PersisParam>(0); | |||
| return Opr::make(inputs[0], param, config).node()->owner_opr(); | |||
| } | |||
| }; | |||
| template < | |||
| class Opr, class Maker0, class MegDNNConv, | |||
| class Maker1 = MakeConvCallerEmpty<MegDNNConv>, | |||
| class Maker2 = MakeConvCallerEmpty<MegDNNConv>, | |||
| typename ConvParam = megdnn::param::Convolution> | |||
| struct WithPolicyOprLoadDumpImpl { | |||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
| auto&& opr = opr_.cast_final_safe<Opr>(); | |||
| ctx.write_param<ConvParam>(opr.param()); | |||
| ctx.write_param<megdnn::param::ExecutionPolicy>( | |||
| opr.execution_policy_transient()); | |||
| } | |||
| static VarNode* make( | |||
| const cg::VarNodeArray& inputs, const ConvParam& param, | |||
| const megdnn::param::ExecutionPolicy& execution_policy, | |||
| const OperatorNodeConfig& config) { | |||
| VarNode* ret = | |||
| Maker0::template make<Opr>(inputs, param, execution_policy, config); | |||
| if (!ret) { | |||
| ret = Maker1::template make<Opr>(inputs, param, execution_policy, config); | |||
| } | |||
| if (!ret) { | |||
| ret = Maker2::template make<Opr>(inputs, param, execution_policy, config); | |||
| } | |||
| mgb_assert(ret); | |||
| return ret; | |||
| } | |||
| static cg::OperatorNodeBase* load( | |||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
| auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||
| fbs_ctx.get_current_opr_data()); | |||
| auto conv_param = fbs_ctx.read_param<ConvParam>(0); | |||
| megdnn::param::ExecutionPolicy policy; | |||
| if (fopr->additional_params() && fopr->additional_params()->size()) { | |||
| policy = fbs_ctx.read_param<megdnn::param::ExecutionPolicy>(1); | |||
| } | |||
| return make(inputs, conv_param, policy, config)->owner_opr(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::Convolution, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::Convolution, MakeConvCaller2<megdnn::Convolution>, | |||
| megdnn::Convolution> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::ConvolutionBackwardData, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>, | |||
| megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::ConvolutionBackwardFilter, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::ConvolutionBackwardFilter, MakeConvCaller3<megdnn::Convolution>, | |||
| megdnn::Convolution> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::Convolution3D, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::Convolution3D, MakeConvCaller2<megdnn::Convolution3D>, | |||
| megdnn::Convolution3D, MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
| MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
| megdnn::param::Convolution3D> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::Convolution3DBackwardData, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::Convolution3DBackwardData, | |||
| MakeConvCaller2<megdnn::Convolution3D>, megdnn::Convolution3D, | |||
| MakeConvCaller3<megdnn::Convolution3D>, | |||
| MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
| megdnn::param::Convolution3D> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::Convolution3DBackwardFilter, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::Convolution3DBackwardFilter, | |||
| MakeConvCaller3<megdnn::Convolution3D>, megdnn::Convolution3D, | |||
| MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
| MakeConvCallerEmpty<megdnn::Convolution3D>, | |||
| megdnn::param::Convolution3D> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::ConvBiasForward, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>, | |||
| megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>, | |||
| MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::BatchConvBiasForward, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::BatchConvBiasForward, | |||
| MakeConvCaller2<megdnn::BatchConvBiasForward>, | |||
| megdnn::BatchConvBiasForward, | |||
| MakeConvCaller3<megdnn::BatchConvBiasForward>, | |||
| MakeConvCaller4<megdnn::BatchConvBiasForward>, | |||
| megdnn::param::BatchConvBias> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::LocalShare, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::LocalShare, MakeLocalShareCaller2<megdnn::LocalShare>, | |||
| megdnn::LocalShare, MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
| MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
| megdnn::param::LocalShare> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::LocalShareBackwardData, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::LocalShareBackwardData, | |||
| MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare, | |||
| MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
| MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
| megdnn::param::LocalShare> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::LocalShareBackwardFilter, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::LocalShareBackwardFilter, | |||
| MakeLocalShareCaller3<megdnn::LocalShare>, megdnn::LocalShare, | |||
| MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
| MakeLocalShareCallerEmpty<megdnn::LocalShare>, | |||
| megdnn::param::LocalShare> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::DeformableConvForward, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::DeformableConvForward, | |||
| MakeConvCaller4<megdnn::DeformableConvForward>, megdnn::Convolution> { | |||
| }; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::DeformableConvBackwardData, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::DeformableConvBackwardData, | |||
| MakeConvCaller5<megdnn::DeformableConvBackwardData>, | |||
| megdnn::Convolution> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::DeformableConvBackwardFilter, 0> | |||
| : public WithPolicyOprLoadDumpImpl< | |||
| opr::DeformableConvBackwardFilter, | |||
| MakeConvCaller5<megdnn::DeformableConvBackwardFilter>, | |||
| megdnn::Convolution> {}; | |||
| } // namespace serialization | |||
| namespace opr { | |||
| #define SERGE_OPR_V2_CONVERTER(_cls, _arity, _converter) \ | |||
| MGB_SEREG_OPR_V2(_cls, _arity, _converter, VERSION_2, CURRENT_VERSION); | |||
| #define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ | |||
| MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); | |||
| SERGE_OPR_V2_CONVERTER( | |||
| Softmax, 1, | |||
| (mgb::serialization::OprLoadDumpImplV2<opr::Softmax, 1>::replace_opr)); | |||
| SERGE_OPR_V2_NO_CONVERTER(ConvBiasForward, 0) | |||
| SERGE_OPR_V2_NO_CONVERTER(BatchConvBiasForward, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(Convolution, 0) | |||
| SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardData, 0) | |||
| SERGE_OPR_V2_NO_CONVERTER(ConvolutionBackwardFilter, 0) | |||
| SERGE_OPR_V2_NO_CONVERTER(Convolution3D, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardData, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(Convolution3DBackwardFilter, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(LocalShareForward, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardData, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(LocalShareBackwardFilter, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(DeformableConvForward, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardData, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(DeformableConvBackwardFilter, 0); | |||
| #undef SERGE_OPR_V2_CONVERTER | |||
| #undef SERGE_OPR_V2_NO_CONVERTER | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| #endif | |||
| // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,197 @@ | |||
| #if MGB_ENABLE_FBS_SERIALIZATION | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/opr/dnn/softmax.h" | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/serialization/oss_opr_load_dump.h" | |||
| #include "megbrain/serialization/sereg.h" | |||
| #include "megbrain/serialization/internal/mgb_cpp_opr_generated.h" | |||
| #include "megbrain/serialization/internal/schema_v2_generated.h" | |||
| namespace mgb { | |||
| namespace serialization { | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::ImmutableTensor, 0> { | |||
| using Opr = opr::ImmutableTensor; | |||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
| using Meth = OprDumpContext::TensorWriteMethod; | |||
| auto&& opr = opr_.cast_final_safe<Opr>(); | |||
| ctx.dump_tensor( | |||
| {}, HostTensorND{}.copy_from(opr.value()).sync(), | |||
| Meth::VALUE_ANONYMOUS); | |||
| } | |||
| static cg::OperatorNodeBase* load( | |||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
| mgb_assert(inputs.empty()); | |||
| auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||
| fbs_ctx.get_current_opr_data()); | |||
| if (fopr->tensors() && fopr->tensors()->size() > 0) { | |||
| auto val = fbs_ctx.load_tensor(); | |||
| return Opr::make(fbs_ctx.graph(), *val, config).node()->owner_opr(); | |||
| } else { | |||
| mgb_throw(SerializationError, "ImmutableTensor load with no tensor data."); | |||
| } | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::Host2DeviceCopy, 0> { | |||
| using Opr = opr::Host2DeviceCopy; | |||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
| auto&& opr = opr_.cast_final_safe<Opr>(); | |||
| ctx.write_param(opr.param()); | |||
| using Meth = OprDumpContext::TensorWriteMethod; | |||
| ctx.dump_tensor( | |||
| opr.name(), *opr.host_data(), | |||
| opr.param().dump_default_value ? Meth::VALUE_INPUT : Meth::META_INPUT); | |||
| } | |||
| static cg::OperatorNodeBase* load( | |||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| mgb_assert(inputs.empty()); | |||
| auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
| auto param = fbs_ctx.read_param<Opr::Param>(0); | |||
| auto tensor = fbs_ctx.load_tensor(); | |||
| return Opr::make(fbs_ctx.graph(), tensor, param, config).node()->owner_opr(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::SharedDeviceTensorWithFormat, 0> { | |||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
| using Meth = OprDumpContext::TensorWriteMethod; | |||
| auto&& opr = opr_.cast_final_safe<opr::SharedDeviceTensorWithFormat>(); | |||
| HostTensorND val; | |||
| val.copy_from(opr.get_dev_tensor()).sync(); | |||
| ctx.dump_tensor({}, val, Meth::VALUE_ANONYMOUS); | |||
| } | |||
| static cg::OperatorNodeBase* load( | |||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| mgb_assert(inputs.empty()); | |||
| auto val = ctx.load_tensor(); | |||
| auto dev_val = | |||
| std::make_shared<DeviceTensorND>(val->comp_node(), val->layout()); | |||
| dev_val->copy_from_fixlayout(*val); | |||
| auto out_var = | |||
| opr::SharedDeviceTensorWithFormat::make(ctx.graph(), dev_val, config); | |||
| dev_val->sync(); | |||
| return out_var.node()->owner_opr(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::MultipleDeviceTensorHolder, 0> { | |||
| using Opr = opr::MultipleDeviceTensorHolder; | |||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
| using Meth = OprDumpContext::TensorWriteMethod; | |||
| auto&& opr = opr_.cast_final_safe<Opr>(); | |||
| uint32_t nr_val = opr.values().size(); | |||
| for (uint32_t i = 0; i < nr_val; ++i) { | |||
| HostTensorND val; | |||
| val.copy_from(*opr.values()[i]).sync(); | |||
| ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED); | |||
| } | |||
| } | |||
| static cg::OperatorNodeBase* load( | |||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| mgb_assert(inputs.empty()); | |||
| auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
| auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||
| fbs_ctx.get_current_opr_data()); | |||
| uint32_t nr = 0; | |||
| if (fopr && fopr->tensors()) { | |||
| nr = fopr->tensors()->size(); | |||
| } | |||
| Opr::ValueArray values(nr); | |||
| for (auto&& i : values) { | |||
| i = ctx.load_tensor_shared(); | |||
| } | |||
| return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::MultipleDeviceTensorWithFormatHolder, 0> { | |||
| using Opr = opr::MultipleDeviceTensorWithFormatHolder; | |||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
| using Meth = OprDumpContext::TensorWriteMethod; | |||
| auto&& opr = opr_.cast_final_safe<Opr>(); | |||
| uint32_t nr_val = opr.values().size(); | |||
| for (uint32_t i = 0; i < nr_val; ++i) { | |||
| HostTensorND val; | |||
| auto value = *opr.values()[i]; | |||
| val.copy_from(value).sync(); | |||
| ctx.dump_tensor(opr.output(i)->name(), val, Meth::VALUE_SHARED); | |||
| } | |||
| } | |||
| static cg::OperatorNodeBase* load( | |||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| mgb_assert(inputs.empty()); | |||
| auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
| auto fopr = reinterpret_cast<const fbs::v2::Operator*>( | |||
| fbs_ctx.get_current_opr_data()); | |||
| uint32_t nr = 0; | |||
| if (fopr && fopr->tensors()) { | |||
| nr = fopr->tensors()->size(); | |||
| } | |||
| Opr::ValueArray values(nr); | |||
| for (auto&& i : values) { | |||
| i = ctx.load_tensor_shared(); | |||
| //! set tensor format | |||
| TensorLayout layout_with_format = i->layout(); | |||
| if (i->storage().comp_node().mem_node() == | |||
| CompNode::default_cpu().mem_node()) { | |||
| mgb_assert( | |||
| i->storage().ptr(), | |||
| "storage should not be nullptr if mem_node is " | |||
| "default_cpu"); | |||
| HostTensorND src{i->storage().comp_node(), layout_with_format}; | |||
| src.copy_from_fixlayout(*i).sync(); | |||
| *i = DeviceTensorND::make_proxy(src); | |||
| } else { | |||
| //! actually only layout of this tensor will be used later, see | |||
| //! src/serialization/impl/batched_device_value_loader.cpp:49. But we | |||
| //! have no way to reset layout only, so just construct a invalid | |||
| //! storage instead | |||
| auto size = layout_with_format.span().dist_byte(); | |||
| DeviceTensorStorage storage; | |||
| storage.reset(i->comp_node(), size, nullptr); | |||
| i->reset(storage, layout_with_format); | |||
| } | |||
| } | |||
| return Opr::make(ctx.graph(), std::move(values), config)[0].node()->owner_opr(); | |||
| } | |||
| }; | |||
| } // namespace serialization | |||
| namespace opr { | |||
| #define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ | |||
| MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); | |||
| SERGE_OPR_V2_NO_CONVERTER(ImmutableTensor, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(Host2DeviceCopy, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(SharedDeviceTensorWithFormat, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorWithFormatHolder, 0); | |||
| SERGE_OPR_V2_NO_CONVERTER(MultipleDeviceTensorHolder, 0); | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| #endif | |||
| // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -135,6 +135,16 @@ void LoopSerializer::reg_all() { | |||
| MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker); | |||
| MGB_SEREG_OPR_INTL_CALL_ADD( | |||
| CounterProvider, dump_counter_provider, load_counter_provider); | |||
| MGB_SEREG_OPR_INTL_CALL_ADD_V2( | |||
| opr::Loop, dump_loop, load_loop, nullptr, 2, | |||
| CURRENT_VERSION); | |||
| MGB_SEREG_OPR_INTL_CALL_ADD_V2( | |||
| InputMaker, dump_input_maker, load_input_maker, nullptr, 2, | |||
| CURRENT_VERSION); | |||
| MGB_SEREG_OPR_INTL_CALL_ADD_V2( | |||
| CounterProvider, dump_counter_provider, load_counter_provider, nullptr, 2, | |||
| CURRENT_VERSION); | |||
| } | |||
| void LoopSerializer::dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||
| @@ -20,6 +20,11 @@ struct StaticData { | |||
| ThinHashMap<Typeinfo*, OprRegistry*> type2reg; | |||
| std::unordered_map<std::string, OprRegistry*> name2reg; | |||
| ThinHashMap<size_t, OprRegistry*> unversioned_id2reg; | |||
| //! versioned OprRegistryV2, version_id_reg_map is used for Operator | |||
| //! load/shallow copy and version_type_reg_map is used for Operator dump | |||
| ThinHashMap<uint8_t, ThinHashMap<size_t, OprRegistryV2>> version_id_reg_map; | |||
| ThinHashMap<uint8_t, ThinHashMap<Typeinfo*, OprRegistryV2*>> version_type_reg_map; | |||
| }; | |||
| StaticData& static_data() { | |||
| @@ -47,6 +52,20 @@ const OprRegistry* dynamic_registry() { | |||
| return ret; | |||
| } | |||
| const OprRegistryV2* dynamic_registry_v2() { | |||
| static const OprRegistryV2* ret = nullptr; | |||
| if (ret) | |||
| return ret; | |||
| auto id = MGB_HASH_STR("dynamic"); | |||
| OprRegistryV2::versioned_add( | |||
| {nullptr, id, {}, {}, dynamic_loader, {}}, CURRENT_VERSION, | |||
| CURRENT_VERSION); | |||
| ret = OprRegistryV2::versioned_find_by_id(id, CURRENT_VERSION); | |||
| mgb_assert(ret); | |||
| return ret; | |||
| } | |||
| class _Init { | |||
| public: | |||
| _Init() { | |||
| @@ -63,8 +82,7 @@ void OprRegistry::add(const OprRegistry& record) { | |||
| auto registry_ins = sd.id2reg.emplace(persist_id, record); | |||
| mgb_assert( | |||
| registry_ins.second || persist_id == dynamic_registry()->persist_type_id, | |||
| "duplicated operator persist_type_id: %s", | |||
| std::to_string(persist_id).c_str()); | |||
| "duplicated operator name : %s", record.name.c_str()); | |||
| OprRegistry* persis_record_ptr; | |||
| if (registry_ins.second) { | |||
| @@ -129,6 +147,73 @@ const OprRegistry* OprRegistry::find_by_unversioned_id(size_t unversioned_id) { | |||
| return iter == uid2reg.end() ? nullptr : iter->second; | |||
| } | |||
| //! find the registry equal to the giving version | |||
| const OprRegistryV2* OprRegistryV2::versioned_find_by_id( | |||
| const size_t id, uint8_t version) { | |||
| auto&& id_reg_map = static_data().version_id_reg_map; | |||
| auto iter_version = id_reg_map.find(version); | |||
| if (iter_version != id_reg_map.end()) { | |||
| auto iter = iter_version->second.find(id); | |||
| return iter == iter_version->second.end() ? nullptr : &iter->second; | |||
| } | |||
| return nullptr; | |||
| } | |||
| //! find the registry equal or below the giving version | |||
| const OprRegistryV2* OprRegistryV2::versioned_find_by_typeinfo( | |||
| Typeinfo* type, uint8_t version) { | |||
| const auto& type_reg_map = static_data().version_type_reg_map; | |||
| for (int version_id = version; version_id > 0; version_id--) { | |||
| auto iter_version = type_reg_map.find(version_id); | |||
| if (iter_version != type_reg_map.end()) { | |||
| auto iter = iter_version->second.find(type); | |||
| if (iter == iter_version->second.end()) { | |||
| continue; | |||
| } else { | |||
| return iter->second; | |||
| } | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| void OprRegistryV2::versioned_add( | |||
| const OprRegistryV2& record, uint8_t min_version, uint8_t max_version) { | |||
| mgb_assert(max_version >= min_version); | |||
| auto&& sd = static_data(); | |||
| auto id = record.type_id; | |||
| uint64_t type_id = id; | |||
| //! record.type->name is nullptr when MGB_VERBOSE_TYPEINFO_NAME==0 | |||
| if (record.type && record.type->name) { | |||
| type_id = MGB_HASH_RUNTIME(std::string(record.type->name)); | |||
| } | |||
| for (uint8_t version = min_version; version <= max_version; version++) { | |||
| auto&& registry_map = sd.version_id_reg_map[version]; | |||
| auto versioned_record = record; | |||
| versioned_record.version = version; | |||
| mgb_assert( | |||
| registry_map.find(id) == registry_map.end() || | |||
| id == dynamic_registry_v2()->type_id, | |||
| "dduplicated OprRegistryV2 of %s\n", record.name.c_str()); | |||
| auto registry_ins = registry_map.emplace(id, versioned_record); | |||
| if (!registry_ins.second) { | |||
| //! the registry is dynamic | |||
| mgb_assert(!record.converter); | |||
| registry_map[id] = versioned_record; | |||
| } | |||
| //! sometimes the register id and the hash typeinfo is not same, just as | |||
| //! dynamic Operator | |||
| if (id != type_id) { | |||
| mgb_assert( | |||
| registry_map.find(type_id) == registry_map.end(), | |||
| "dduplicated OprRegistryV2 of %s\n", record.name.c_str()); | |||
| registry_map.emplace(type_id, versioned_record); | |||
| } | |||
| auto&& registry_type_map = sd.version_type_reg_map[version]; | |||
| registry_type_map.emplace(record.type, ®istry_map[id]); | |||
| } | |||
| } | |||
| void OprRegistry::add_using_dynamic_loader( | |||
| Typeinfo* type, const std::string& name, const OprDumper& dumper) { | |||
| // dynamic oprs are implemented by mapping different opr types to the same | |||
| @@ -140,6 +225,11 @@ void OprRegistry::add_using_dynamic_loader( | |||
| {}, | |||
| {}, | |||
| dynamic_registry()->unversioned_type_id}); | |||
| mgb_assert(type, "type must be not nullptr"); | |||
| OprRegistryV2::versioned_add( | |||
| {type, dynamic_registry_v2()->type_id, type->name, dumper, | |||
| dynamic_registry_v2()->loader, nullptr}, | |||
| CURRENT_VERSION, CURRENT_VERSION); | |||
| } | |||
| #if MGB_ENABLE_DEBUG_UTIL | |||
| @@ -9,10 +9,12 @@ void call_sereg() {} | |||
| #include "../../opr/impl/blas.sereg.h" | |||
| #include "../../opr/impl/cond.sereg.h" | |||
| #include "../../opr/impl/dnn/dnn.sereg.h" | |||
| #include "../../opr/impl/dnn/dnn.sereg.v2.h" | |||
| #include "./extern_c_opr.sereg.h" | |||
| #include "../../opr/impl/imgproc.sereg.h" | |||
| #include "../../opr/impl/indexing.sereg.h" | |||
| #include "../../opr/impl/io.sereg.h" | |||
| #include "../../opr/impl/io.sereg.v2.h" | |||
| #include "../../opr/impl/loop/forward.sereg.h" | |||
| #include "../../opr/impl/loop/grad.sereg.h" | |||
| #include "../../opr/impl/misc.sereg.h" | |||
| @@ -53,7 +53,6 @@ struct OprRegistry { | |||
| uint64_t unversioned_type_id; | |||
| MGE_WIN_DECLSPEC_FUC static void add(const OprRegistry& record); | |||
| /*! | |||
| * \brief register an operator to use dynamic loader | |||
| * | |||
| @@ -89,6 +88,39 @@ struct OprRegistry { | |||
| #endif | |||
| }; | |||
| //! Convert some modified Opr to compatible Opr | |||
| using OprConvertToCompatible = thin_function<cg::OperatorNodeBase*( | |||
| cg::OperatorNodeBase*, const VarNodeArray&)>; | |||
| //! record of a single operator | |||
| struct OprRegistryV2 { | |||
| Typeinfo* type; | |||
| uint64_t type_id; | |||
| std::string name; | |||
| OprDumper dumper; | |||
| OprLoaderWrapper loader; | |||
| OprConvertToCompatible converter; | |||
| uint8_t version = 2; | |||
| MGE_WIN_DECLSPEC_FUC uint8_t get_version() const { return version; } | |||
| //! register opr load/dump to version2regmap | |||
| MGE_WIN_DECLSPEC_FUC static void versioned_add( | |||
| const OprRegistryV2& record, uint8_t min_version, uint8_t max_version); | |||
| MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_id( | |||
| const size_t id, uint8_t version); | |||
| MGE_WIN_DECLSPEC_FUC static const OprRegistryV2* versioned_find_by_typeinfo( | |||
| Typeinfo* type, uint8_t version); | |||
| #if MGB_ENABLE_DEBUG_UTIL | |||
| //! dump registered oprs | |||
| MGE_WIN_DECLSPEC_FUC static std::vector<std::pair<size_t, std::string>> | |||
| dump_registries(); | |||
| #endif | |||
| }; | |||
| } // namespace serialization | |||
| } // namespace mgb | |||
| @@ -3,6 +3,7 @@ | |||
| #include "megbrain/serialization/opr_load_dump.h" | |||
| #include "megbrain/serialization/opr_registry.h" | |||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||
| #include "megbrain/serialization/oss_opr_load_dump.h" | |||
| #include "megbrain/utils/hash_ct.h" | |||
| namespace mgb { | |||
| @@ -66,6 +67,9 @@ struct OprLoadDumpImpl { | |||
| } | |||
| }; | |||
| template <class Opr, size_t arity> | |||
| struct OprLoadDumpImplV2 : public OprLoadDumpImpl<Opr, arity> {}; | |||
| #define IMPL_OPR_MAKER(_arity, _args...) \ | |||
| template <class Opr> \ | |||
| struct OprMaker<Opr, _arity> { \ | |||
| @@ -124,6 +128,12 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||
| __caller_OprReg##_cls##_ins; \ | |||
| } | |||
| #define MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _impl) \ | |||
| namespace { \ | |||
| [[gnu::unused]] ::mgb::serialization::OprRegistryCaller<_cls, _impl> \ | |||
| __caller_V2_OprReg##_cls##_ins; \ | |||
| } | |||
| // Trim the terminating null character and a "V0" like suffix from the string | |||
| // then hash it. | |||
| // TODO: Get rid of this. | |||
| @@ -138,17 +148,35 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||
| : 0), \ | |||
| 20160701)>::val | |||
| //! call OprRegistry::add | |||
| #define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ | |||
| do { \ | |||
| ::mgb::serialization::OprRegistry::add( \ | |||
| {_cls::typeinfo(), \ | |||
| MGB_HASH_STR(#_cls), \ | |||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||
| _dump, \ | |||
| _load, \ | |||
| {}, \ | |||
| MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||
| //! call OprRegistry::add for old serialization | |||
| //! call OprRegistryV2::versioned_add for new serialization which is compatiable | |||
| //! with old serialization, convert is nullptr, this registry is just only for | |||
| //! varsion 1 | |||
| #define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ | |||
| do { \ | |||
| ::mgb::serialization::OprRegistry::add( \ | |||
| {_cls::typeinfo(), \ | |||
| MGB_HASH_STR(#_cls), \ | |||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||
| _dump, \ | |||
| _load, \ | |||
| {}, \ | |||
| MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||
| ::mgb::serialization::OprRegistryV2::versioned_add( \ | |||
| {_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ | |||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ | |||
| ::mgb::VERSION_1, ::mgb::VERSION_1); \ | |||
| } while (0) | |||
| //! call OprRegistryV2::versioned_add for new serialization, in which convert the | |||
| //! function converter the Operator to the compatiable | |||
| #define MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ | |||
| _cls, _dump, _load, _convert, _version_min, _version_max) \ | |||
| do { \ | |||
| ::mgb::serialization::OprRegistryV2::versioned_add( \ | |||
| {_cls::typeinfo(), MGB_HASH_STR(#_cls), \ | |||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, _convert}, \ | |||
| _version_min, _version_max); \ | |||
| } while (0) | |||
| /*! | |||
| @@ -171,6 +199,27 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||
| } \ | |||
| MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | |||
| //! new dump/load function should implement in OprLoadDumpImplV2, _converter is | |||
| //! optional , if not implement pass nullptr | |||
| #define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \ | |||
| namespace { \ | |||
| namespace ser = ::mgb::serialization; \ | |||
| struct _OprRegV2##_cls { \ | |||
| using Impl = ser::OprLoadDumpImplV2<_cls, _arity>; \ | |||
| static ser::OprWithOutputAccessor wrap_loader( \ | |||
| ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
| const mgb::cg::OperatorNodeConfig& config) { \ | |||
| return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ | |||
| } \ | |||
| static void entry() { \ | |||
| MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ | |||
| _cls, Impl::dump, wrap_loader, _converter, _version_min, \ | |||
| _version_max); \ | |||
| } \ | |||
| }; \ | |||
| } \ | |||
| MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(_cls, _OprRegV2##_cls) | |||
| //! use to check type is complete or not, midout need a complete type | |||
| template <class T, class = void> | |||
| struct IsComplete : std::false_type {}; | |||