GitOrigin-RevId: ee58271276
tags/v1.11.0
| @@ -40,16 +40,16 @@ mgb::cg::OperatorNodeBase* custom_loader( | |||
| } // namespace serialization | |||
| } // namespace mgb | |||
| #define CUSTOM_OP_SEREG_REG(cls) \ | |||
| namespace { \ | |||
| struct _OprReg##cls { \ | |||
| static void entry() { \ | |||
| MGB_SEREG_OPR_INTL_CALL_ADD( \ | |||
| cls, ::mgb::serialization::custom_dumper, \ | |||
| ::mgb::serialization::custom_loader); \ | |||
| } \ | |||
| }; \ | |||
| } \ | |||
| #define CUSTOM_OP_SEREG_REG(cls) \ | |||
| namespace { \ | |||
| struct _OprReg##cls { \ | |||
| static void entry() { \ | |||
| MGB_SEREG_OPR_INTL_CALL_ADD( \ | |||
| cls, ::mgb::serialization::custom_dumper, \ | |||
| ::mgb::serialization::custom_loader, true); \ | |||
| } \ | |||
| }; \ | |||
| } \ | |||
| MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) | |||
| #define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \ | |||
| @@ -131,10 +131,10 @@ cg::OperatorNodeBase* serialization::opr_shallow_copy_loop( | |||
| } | |||
| void LoopSerializer::reg_all() { | |||
| MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop); | |||
| MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker); | |||
| MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop, true); | |||
| MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker, true); | |||
| MGB_SEREG_OPR_INTL_CALL_ADD( | |||
| CounterProvider, dump_counter_provider, load_counter_provider); | |||
| CounterProvider, dump_counter_provider, load_counter_provider, true); | |||
| MGB_SEREG_OPR_INTL_CALL_ADD_V2( | |||
| opr::Loop, dump_loop, load_loop, nullptr, 2, CURRENT_VERSION); | |||
| @@ -1,3 +1,4 @@ | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/nn_int.h" | |||
| #include "megbrain/serialization/sereg.h" | |||
| @@ -7,10 +8,74 @@ template <> | |||
| struct OprMaker<opr::ElemwiseMultiType, 0> | |||
| : public OprMakerVariadic<opr::ElemwiseMultiType> {}; | |||
| template <> | |||
| struct OprLoadDumpImplV2<opr::ElemwiseMultiType, 0> { | |||
| using Opr = opr::ElemwiseMultiType; | |||
| using PersisParam = opr::ElemwiseMultiType::Param; | |||
| using PersisElemwseiParam = opr::Elemwise::Param; | |||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||
| ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param()); | |||
| } | |||
| static cg::OperatorNodeBase* replace_opr( | |||
| cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | |||
| auto mode = opr->cast_final_safe<Opr>().param().mode; | |||
| auto change_to_elemwise_mode = [&](PersisParam::Mode multitype_mode) { | |||
| if (multitype_mode == PersisParam::Mode::EQ) { | |||
| return PersisElemwseiParam::Mode::EQ; | |||
| } else if (multitype_mode == PersisParam::Mode::LT) { | |||
| return PersisElemwseiParam::Mode::LT; | |||
| } else if (multitype_mode == PersisParam::Mode::LEQ) { | |||
| return PersisElemwseiParam::Mode::LEQ; | |||
| } | |||
| mgb_assert(0, "no supported model."); | |||
| }; | |||
| if (PersisParam::Mode::EQ == mode || PersisParam::Mode::LT == mode || | |||
| PersisParam::Mode::LEQ == mode) { | |||
| auto elemwise_mode = change_to_elemwise_mode(mode); | |||
| auto elemiwse_out = opr::Elemwise::make(inputs, {elemwise_mode}); | |||
| return opr::TypeCvt::make(elemiwse_out, dtype::Bool()).node()->owner_opr(); | |||
| } else if (PersisParam::Mode::NEQ == mode) { | |||
| auto elemiwse_out = | |||
| opr::Elemwise::make(inputs, {PersisElemwseiParam::Mode::EQ}); | |||
| auto bool_out = opr::TypeCvt::make(elemiwse_out, dtype::Bool()); | |||
| return opr::Elemwise::make({bool_out}, {PersisElemwseiParam::Mode::NOT}) | |||
| .node() | |||
| ->owner_opr(); | |||
| } else if (PersisParam::Mode::ISNAN == mode) { | |||
| auto elemiwse_out = opr::Elemwise::make( | |||
| {inputs[0], inputs[0]}, {PersisElemwseiParam::Mode::EQ}); | |||
| auto bool_out = opr::TypeCvt::make(elemiwse_out, dtype::Bool()); | |||
| return opr::Elemwise::make({bool_out}, {PersisElemwseiParam::Mode::NOT}) | |||
| .node() | |||
| ->owner_opr(); | |||
| } else if (PersisParam::Mode::ISINF == mode) { | |||
| auto input_var = SymbolVar{inputs[0]}; | |||
| auto inf_var = input_var.make_scalar(INFINITY); | |||
| auto float_out = opr::TypeCvt::make(inputs[0], dtype::Float32()); | |||
| auto elemiwse_out = opr::Elemwise::make( | |||
| {float_out, inf_var}, {PersisElemwseiParam::Mode::EQ}); | |||
| return opr::TypeCvt::make(elemiwse_out, dtype::Bool()).node()->owner_opr(); | |||
| } | |||
| return opr; | |||
| } | |||
| static cg::OperatorNodeBase* load( | |||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| return OprMaker<opr::ElemwiseMultiType, 0>::make( | |||
| ctx.read_param<PersisParam>(), inputs, ctx.graph(), config); | |||
| } | |||
| }; | |||
| } // namespace serialization | |||
| namespace opr { | |||
| MGB_SEREG_OPR(ElemwiseMultiType, 0); | |||
| MGB_SEREG_OPR_CONDITION(ElemwiseMultiType, 0, false); | |||
| MGB_SEREG_OPR_V2( | |||
| ElemwiseMultiType, 0, | |||
| (mgb::serialization::OprLoadDumpImplV2<opr::ElemwiseMultiType, 0>::replace_opr), | |||
| VERSION_1, VERSION_1); | |||
| MGB_SEREG_OPR(AffineInt, 3); | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| @@ -125,16 +125,18 @@ ComputingGraph* serialization::OprShallowCopyContext::owner_graph( | |||
| cg::OperatorNodeBase* serialization::copy_opr_shallow( | |||
| const cg::OperatorNodeBase& opr, const VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) { | |||
| auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo()); | |||
| mgb_assert( | |||
| registry, "could not find OprReceiver to copy opr %s{%s}", opr.cname(), | |||
| opr.dyn_typeinfo()->name); | |||
| OprShallowCopy shallow_copy = nullptr; | |||
| if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) { | |||
| shallow_copy = registry->shallow_copy; | |||
| } else { | |||
| shallow_copy = intl::copy_opr_shallow_default_impl; | |||
| } | |||
| mgb_assert(inputs.size() == opr.input().size()); | |||
| auto dst_og = ctx.owner_graph(opr, inputs); | |||
| auto do_copy = [&]() { | |||
| auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph(); | |||
| auto ret = registry->shallow_copy(ctx, opr, inputs, config); | |||
| auto ret = shallow_copy(ctx, opr, inputs, config); | |||
| if (dst_og != opr.owner_graph() || | |||
| opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) { | |||
| @@ -188,18 +190,28 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl( | |||
| const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr, | |||
| const VarNodeArray& inputs, const OperatorNodeConfig& config) { | |||
| MGB_MARK_USED_VAR(ctx); | |||
| OprDumper opr_dumper = nullptr; | |||
| OprLoaderWrapper opr_loader = nullptr; | |||
| auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo()); | |||
| if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) { | |||
| opr_loader = registry->loader; | |||
| opr_dumper = registry->dumper; | |||
| } else { | |||
| auto registryv2 = OprRegistryV2::versioned_find_by_typeinfo( | |||
| opr.dyn_typeinfo(), CURRENT_VERSION); | |||
| opr_loader = registryv2->loader; | |||
| opr_dumper = registryv2->dumper; | |||
| } | |||
| mgb_assert( | |||
| registry && registry->dumper && registry->loader, | |||
| opr_dumper && opr_loader, | |||
| "can not shallow_copy operator %s{%s}: " | |||
| "no dumper/loader registered", | |||
| opr.cname(), opr.dyn_typeinfo()->name); | |||
| OprDumpContextMemory dumper; | |||
| registry->dumper(dumper, opr); | |||
| OprDumpContextMemory memory_dumper; | |||
| opr_dumper(memory_dumper, opr); | |||
| OprLoadContextMemory loader{opr.owner_graph(), dumper}; | |||
| return registry->loader(loader, inputs, config).opr(); | |||
| OprLoadContextMemory loader{opr.owner_graph(), memory_dumper}; | |||
| return opr_loader(loader, inputs, config).opr(); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -358,6 +358,11 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( | |||
| auto new_output_vars = output_vars; | |||
| if (!config.no_change_graph) { | |||
| new_output_vars = converter_all_opr_to_compatiable(output_vars); | |||
| mgb_assert(output_vars.size() == new_output_vars.size()); | |||
| for (size_t id = 0; id < output_vars.size(); id++) { | |||
| auto& new_var = new_output_vars[id]; | |||
| new_var.rename(output_vars[id].node()->name()); | |||
| } | |||
| } | |||
| auto begin_pos = m_file->tell(); | |||
| @@ -151,20 +151,22 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||
| //! call OprRegistryV2::versioned_add for new serialization which is compatiable | |||
| //! with old serialization, convert is nullptr, this registry is just only for | |||
| //! varsion 1 | |||
| #define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ | |||
| do { \ | |||
| ::mgb::serialization::OprRegistry::add( \ | |||
| {_cls::typeinfo(), \ | |||
| MGB_HASH_STR(#_cls), \ | |||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||
| _dump, \ | |||
| _load, \ | |||
| {}, \ | |||
| MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||
| ::mgb::serialization::OprRegistryV2::versioned_add( \ | |||
| {_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ | |||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ | |||
| ::mgb::VERSION_1, ::mgb::VERSION_1); \ | |||
| #define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load, _registerv2) \ | |||
| do { \ | |||
| ::mgb::serialization::OprRegistry::add( \ | |||
| {_cls::typeinfo(), \ | |||
| MGB_HASH_STR(#_cls), \ | |||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||
| _dump, \ | |||
| _load, \ | |||
| {}, \ | |||
| MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||
| if (_registerv2) { \ | |||
| ::mgb::serialization::OprRegistryV2::versioned_add( \ | |||
| {_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ | |||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ | |||
| ::mgb::VERSION_1, ::mgb::VERSION_1); \ | |||
| } \ | |||
| } while (0) | |||
| //! call OprRegistryV2::versioned_add for new serialization, in which convert the | |||
| @@ -181,23 +183,25 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||
| /*! | |||
| * \brief register opr serialization methods | |||
| */ | |||
| #define MGB_SEREG_OPR(_cls, _arity) \ | |||
| namespace { \ | |||
| namespace ser = ::mgb::serialization; \ | |||
| struct _OprReg##_cls { \ | |||
| using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ | |||
| static ser::OprWithOutputAccessor wrap_loader( \ | |||
| ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
| const mgb::cg::OperatorNodeConfig& config) { \ | |||
| return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ | |||
| } \ | |||
| static void entry() { \ | |||
| MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \ | |||
| } \ | |||
| }; \ | |||
| } \ | |||
| #define MGB_SEREG_OPR_CONDITION(_cls, _arity, _registerv2) \ | |||
| namespace { \ | |||
| namespace ser = ::mgb::serialization; \ | |||
| struct _OprReg##_cls { \ | |||
| using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ | |||
| static ser::OprWithOutputAccessor wrap_loader( \ | |||
| ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||
| const mgb::cg::OperatorNodeConfig& config) { \ | |||
| return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ | |||
| } \ | |||
| static void entry() { \ | |||
| MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader, _registerv2); \ | |||
| } \ | |||
| }; \ | |||
| } \ | |||
| MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | |||
| #define MGB_SEREG_OPR(_cls, _arity) MGB_SEREG_OPR_CONDITION(_cls, _arity, true) | |||
| //! new dump/load function should implement in OprLoadDumpImplV2, _converter is | |||
| //! optional , if not implement pass nullptr | |||
| #define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \ | |||
| @@ -1,3 +1,4 @@ | |||
| #include "megbrain/opr/nn_int.h" | |||
| #if MGB_ENABLE_FBS_SERIALIZATION | |||
| #include "megbrain/opr/basic_arith_wrapper.h" | |||
| @@ -1016,4 +1017,107 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) { | |||
| load(); | |||
| } | |||
| TEST(TestSerializer2, TestElemwiseMultiTypeLoadDump) { | |||
| auto fname = GET_OUTPUT_FILE(GraphDumpFormat::FLATBUFFERS_V2); | |||
| TensorShape shape{3}; | |||
| auto cn = CompNode::load("xpu0"); | |||
| std::shared_ptr<HostTensorND> host0 = | |||
| std::make_shared<HostTensorND>(cn, shape, dtype::Float32{}); | |||
| std::shared_ptr<HostTensorND> host1 = | |||
| std::make_shared<HostTensorND>(cn, shape, dtype::Float32{}); | |||
| HostTensorND dst_truth; | |||
| host0->ptr<float>()[0] = 2; | |||
| host0->ptr<float>()[1] = 2; | |||
| host0->ptr<float>()[2] = -1; | |||
| host1->ptr<float>()[0] = 1; | |||
| host1->ptr<float>()[1] = 2; | |||
| host1->ptr<float>()[2] = 3; | |||
| auto dump = [&](opr::ElemwiseMultiType::Param::Mode mode, size_t nr_opr) { | |||
| auto graph = ComputingGraph::make(); | |||
| OperatorNodeConfig config; | |||
| config.name("input0"); | |||
| auto h2d0 = opr::Host2DeviceCopy::make(*graph, host0, config); | |||
| config.name("input1"); | |||
| auto h2d1 = opr::Host2DeviceCopy::make(*graph, host1, config); | |||
| auto x = opr::ElemwiseMultiType::make( | |||
| {h2d0, h2d1}, {mode}, OperatorNodeConfig{dtype::Bool()}); | |||
| x.rename("out"); | |||
| auto func = graph->compile({make_callback_copy(x, dst_truth)}); | |||
| auto dumper = GraphDumper::make( | |||
| OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||
| auto rst = dumper->dump({x}); | |||
| func->execute().wait(); | |||
| ASSERT_EQ(rst.nr_opr, nr_opr); | |||
| }; | |||
| auto load = [&]() { | |||
| auto loader = GraphLoader::make( | |||
| InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||
| auto rst = loader->load(); | |||
| ASSERT_EQ(rst.tensor_map.size(), 2); | |||
| ASSERT_EQ(rst.output_var_map.count("out"), 1); | |||
| HostTensorND host_x; | |||
| auto func = | |||
| rst.graph_compile({make_callback_copy(rst.output_var_list[0], host_x)}); | |||
| for (auto& input : rst.tensor_map) { | |||
| if (input.first == "input0") { | |||
| input.second->copy_from(*host0).sync(); | |||
| } else if (input.first == "input1") { | |||
| input.second->copy_from(*host1).sync(); | |||
| } | |||
| } | |||
| func->execute().wait(); | |||
| for (int i = 0; i < 3; i++) { | |||
| EXPECT_EQ(host_x.ptr<bool>()[i], dst_truth.ptr<bool>()[i]); | |||
| } | |||
| }; | |||
| dump(opr::ElemwiseMultiType::Param::Mode::EQ, 4); | |||
| load(); | |||
| dump(opr::ElemwiseMultiType::Param::Mode::LT, 4); | |||
| load(); | |||
| dump(opr::ElemwiseMultiType::Param::Mode::LEQ, 4); | |||
| load(); | |||
| dump(opr::ElemwiseMultiType::Param::Mode::NEQ, 5); | |||
| load(); | |||
| auto dump_single_input = [&](opr::ElemwiseMultiType::Param::Mode mode, | |||
| size_t nr_opr) { | |||
| auto graph = ComputingGraph::make(); | |||
| auto h2d0 = opr::Host2DeviceCopy::make(*graph, host0); | |||
| auto x = opr::ElemwiseMultiType::make( | |||
| {h2d0}, {mode}, OperatorNodeConfig{dtype::Bool()}); | |||
| x.rename("out"); | |||
| auto func = graph->compile({make_callback_copy(x, dst_truth)}); | |||
| auto dumper = GraphDumper::make( | |||
| OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||
| auto rst = dumper->dump({x}); | |||
| func->execute().wait(); | |||
| ASSERT_EQ(rst.nr_opr, nr_opr); | |||
| }; | |||
| auto load_single_input = [&]() { | |||
| auto loader = GraphLoader::make( | |||
| InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||
| auto rst = loader->load(); | |||
| ASSERT_EQ(rst.tensor_map.size(), 1); | |||
| ASSERT_EQ(rst.output_var_map.count("out"), 1); | |||
| HostTensorND host_x; | |||
| auto func = | |||
| rst.graph_compile({make_callback_copy(rst.output_var_list[0], host_x)}); | |||
| rst.tensor_map.begin()->second->copy_from(*host0).sync(); | |||
| func->execute().wait(); | |||
| for (int i = 0; i < 3; i++) { | |||
| EXPECT_EQ(host_x.ptr<bool>()[i], dst_truth.ptr<bool>()[i]); | |||
| } | |||
| }; | |||
| host0->ptr<float>()[2] = INFINITY; | |||
| dump_single_input(opr::ElemwiseMultiType::Param::Mode::ISINF, 4); | |||
| load_single_input(); | |||
| host0->ptr<float>()[2] = NAN; | |||
| dump_single_input(opr::ElemwiseMultiType::Param::Mode::ISNAN, 4); | |||
| load_single_input(); | |||
| } | |||
| #endif | |||