GitOrigin-RevId: ee58271276
tags/v1.11.0
| @@ -40,16 +40,16 @@ mgb::cg::OperatorNodeBase* custom_loader( | |||||
| } // namespace serialization | } // namespace serialization | ||||
| } // namespace mgb | } // namespace mgb | ||||
| #define CUSTOM_OP_SEREG_REG(cls) \ | |||||
| namespace { \ | |||||
| struct _OprReg##cls { \ | |||||
| static void entry() { \ | |||||
| MGB_SEREG_OPR_INTL_CALL_ADD( \ | |||||
| cls, ::mgb::serialization::custom_dumper, \ | |||||
| ::mgb::serialization::custom_loader); \ | |||||
| } \ | |||||
| }; \ | |||||
| } \ | |||||
| #define CUSTOM_OP_SEREG_REG(cls) \ | |||||
| namespace { \ | |||||
| struct _OprReg##cls { \ | |||||
| static void entry() { \ | |||||
| MGB_SEREG_OPR_INTL_CALL_ADD( \ | |||||
| cls, ::mgb::serialization::custom_dumper, \ | |||||
| ::mgb::serialization::custom_loader, true); \ | |||||
| } \ | |||||
| }; \ | |||||
| } \ | |||||
| MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) | MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) | ||||
| #define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \ | #define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \ | ||||
| @@ -131,10 +131,10 @@ cg::OperatorNodeBase* serialization::opr_shallow_copy_loop( | |||||
| } | } | ||||
| void LoopSerializer::reg_all() { | void LoopSerializer::reg_all() { | ||||
| MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop); | |||||
| MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker); | |||||
| MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop, true); | |||||
| MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker, true); | |||||
| MGB_SEREG_OPR_INTL_CALL_ADD( | MGB_SEREG_OPR_INTL_CALL_ADD( | ||||
| CounterProvider, dump_counter_provider, load_counter_provider); | |||||
| CounterProvider, dump_counter_provider, load_counter_provider, true); | |||||
| MGB_SEREG_OPR_INTL_CALL_ADD_V2( | MGB_SEREG_OPR_INTL_CALL_ADD_V2( | ||||
| opr::Loop, dump_loop, load_loop, nullptr, 2, CURRENT_VERSION); | opr::Loop, dump_loop, load_loop, nullptr, 2, CURRENT_VERSION); | ||||
| @@ -1,3 +1,4 @@ | |||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/opr/nn_int.h" | #include "megbrain/opr/nn_int.h" | ||||
| #include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
| @@ -7,10 +8,74 @@ template <> | |||||
| struct OprMaker<opr::ElemwiseMultiType, 0> | struct OprMaker<opr::ElemwiseMultiType, 0> | ||||
| : public OprMakerVariadic<opr::ElemwiseMultiType> {}; | : public OprMakerVariadic<opr::ElemwiseMultiType> {}; | ||||
| template <> | |||||
| struct OprLoadDumpImplV2<opr::ElemwiseMultiType, 0> { | |||||
| using Opr = opr::ElemwiseMultiType; | |||||
| using PersisParam = opr::ElemwiseMultiType::Param; | |||||
| using PersisElemwseiParam = opr::Elemwise::Param; | |||||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||||
| ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param()); | |||||
| } | |||||
| static cg::OperatorNodeBase* replace_opr( | |||||
| cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | |||||
| auto mode = opr->cast_final_safe<Opr>().param().mode; | |||||
| auto change_to_elemwise_mode = [&](PersisParam::Mode multitype_mode) { | |||||
| if (multitype_mode == PersisParam::Mode::EQ) { | |||||
| return PersisElemwseiParam::Mode::EQ; | |||||
| } else if (multitype_mode == PersisParam::Mode::LT) { | |||||
| return PersisElemwseiParam::Mode::LT; | |||||
| } else if (multitype_mode == PersisParam::Mode::LEQ) { | |||||
| return PersisElemwseiParam::Mode::LEQ; | |||||
| } | |||||
| mgb_assert(0, "no supported model."); | |||||
| }; | |||||
| if (PersisParam::Mode::EQ == mode || PersisParam::Mode::LT == mode || | |||||
| PersisParam::Mode::LEQ == mode) { | |||||
| auto elemwise_mode = change_to_elemwise_mode(mode); | |||||
| auto elemiwse_out = opr::Elemwise::make(inputs, {elemwise_mode}); | |||||
| return opr::TypeCvt::make(elemiwse_out, dtype::Bool()).node()->owner_opr(); | |||||
| } else if (PersisParam::Mode::NEQ == mode) { | |||||
| auto elemiwse_out = | |||||
| opr::Elemwise::make(inputs, {PersisElemwseiParam::Mode::EQ}); | |||||
| auto bool_out = opr::TypeCvt::make(elemiwse_out, dtype::Bool()); | |||||
| return opr::Elemwise::make({bool_out}, {PersisElemwseiParam::Mode::NOT}) | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } else if (PersisParam::Mode::ISNAN == mode) { | |||||
| auto elemiwse_out = opr::Elemwise::make( | |||||
| {inputs[0], inputs[0]}, {PersisElemwseiParam::Mode::EQ}); | |||||
| auto bool_out = opr::TypeCvt::make(elemiwse_out, dtype::Bool()); | |||||
| return opr::Elemwise::make({bool_out}, {PersisElemwseiParam::Mode::NOT}) | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } else if (PersisParam::Mode::ISINF == mode) { | |||||
| auto input_var = SymbolVar{inputs[0]}; | |||||
| auto inf_var = input_var.make_scalar(INFINITY); | |||||
| auto float_out = opr::TypeCvt::make(inputs[0], dtype::Float32()); | |||||
| auto elemiwse_out = opr::Elemwise::make( | |||||
| {float_out, inf_var}, {PersisElemwseiParam::Mode::EQ}); | |||||
| return opr::TypeCvt::make(elemiwse_out, dtype::Bool()).node()->owner_opr(); | |||||
| } | |||||
| return opr; | |||||
| } | |||||
| static cg::OperatorNodeBase* load( | |||||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
| const OperatorNodeConfig& config) { | |||||
| return OprMaker<opr::ElemwiseMultiType, 0>::make( | |||||
| ctx.read_param<PersisParam>(), inputs, ctx.graph(), config); | |||||
| } | |||||
| }; | |||||
| } // namespace serialization | } // namespace serialization | ||||
| namespace opr { | namespace opr { | ||||
| MGB_SEREG_OPR(ElemwiseMultiType, 0); | |||||
| MGB_SEREG_OPR_CONDITION(ElemwiseMultiType, 0, false); | |||||
| MGB_SEREG_OPR_V2( | |||||
| ElemwiseMultiType, 0, | |||||
| (mgb::serialization::OprLoadDumpImplV2<opr::ElemwiseMultiType, 0>::replace_opr), | |||||
| VERSION_1, VERSION_1); | |||||
| MGB_SEREG_OPR(AffineInt, 3); | MGB_SEREG_OPR(AffineInt, 3); | ||||
| } // namespace opr | } // namespace opr | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -125,16 +125,18 @@ ComputingGraph* serialization::OprShallowCopyContext::owner_graph( | |||||
| cg::OperatorNodeBase* serialization::copy_opr_shallow( | cg::OperatorNodeBase* serialization::copy_opr_shallow( | ||||
| const cg::OperatorNodeBase& opr, const VarNodeArray& inputs, | const cg::OperatorNodeBase& opr, const VarNodeArray& inputs, | ||||
| const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) { | const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) { | ||||
| auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo()); | |||||
| mgb_assert( | |||||
| registry, "could not find OprReceiver to copy opr %s{%s}", opr.cname(), | |||||
| opr.dyn_typeinfo()->name); | |||||
| OprShallowCopy shallow_copy = nullptr; | |||||
| if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) { | |||||
| shallow_copy = registry->shallow_copy; | |||||
| } else { | |||||
| shallow_copy = intl::copy_opr_shallow_default_impl; | |||||
| } | |||||
| mgb_assert(inputs.size() == opr.input().size()); | mgb_assert(inputs.size() == opr.input().size()); | ||||
| auto dst_og = ctx.owner_graph(opr, inputs); | auto dst_og = ctx.owner_graph(opr, inputs); | ||||
| auto do_copy = [&]() { | auto do_copy = [&]() { | ||||
| auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph(); | auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph(); | ||||
| auto ret = registry->shallow_copy(ctx, opr, inputs, config); | |||||
| auto ret = shallow_copy(ctx, opr, inputs, config); | |||||
| if (dst_og != opr.owner_graph() || | if (dst_og != opr.owner_graph() || | ||||
| opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) { | opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) { | ||||
| @@ -188,18 +190,28 @@ cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl( | |||||
| const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr, | const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr, | ||||
| const VarNodeArray& inputs, const OperatorNodeConfig& config) { | const VarNodeArray& inputs, const OperatorNodeConfig& config) { | ||||
| MGB_MARK_USED_VAR(ctx); | MGB_MARK_USED_VAR(ctx); | ||||
| OprDumper opr_dumper = nullptr; | |||||
| OprLoaderWrapper opr_loader = nullptr; | |||||
| auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo()); | |||||
| if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) { | |||||
| opr_loader = registry->loader; | |||||
| opr_dumper = registry->dumper; | |||||
| } else { | |||||
| auto registryv2 = OprRegistryV2::versioned_find_by_typeinfo( | |||||
| opr.dyn_typeinfo(), CURRENT_VERSION); | |||||
| opr_loader = registryv2->loader; | |||||
| opr_dumper = registryv2->dumper; | |||||
| } | |||||
| mgb_assert( | mgb_assert( | ||||
| registry && registry->dumper && registry->loader, | |||||
| opr_dumper && opr_loader, | |||||
| "can not shallow_copy operator %s{%s}: " | "can not shallow_copy operator %s{%s}: " | ||||
| "no dumper/loader registered", | "no dumper/loader registered", | ||||
| opr.cname(), opr.dyn_typeinfo()->name); | opr.cname(), opr.dyn_typeinfo()->name); | ||||
| OprDumpContextMemory dumper; | |||||
| registry->dumper(dumper, opr); | |||||
| OprDumpContextMemory memory_dumper; | |||||
| opr_dumper(memory_dumper, opr); | |||||
| OprLoadContextMemory loader{opr.owner_graph(), dumper}; | |||||
| return registry->loader(loader, inputs, config).opr(); | |||||
| OprLoadContextMemory loader{opr.owner_graph(), memory_dumper}; | |||||
| return opr_loader(loader, inputs, config).opr(); | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -358,6 +358,11 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( | |||||
| auto new_output_vars = output_vars; | auto new_output_vars = output_vars; | ||||
| if (!config.no_change_graph) { | if (!config.no_change_graph) { | ||||
| new_output_vars = converter_all_opr_to_compatiable(output_vars); | new_output_vars = converter_all_opr_to_compatiable(output_vars); | ||||
| mgb_assert(output_vars.size() == new_output_vars.size()); | |||||
| for (size_t id = 0; id < output_vars.size(); id++) { | |||||
| auto& new_var = new_output_vars[id]; | |||||
| new_var.rename(output_vars[id].node()->name()); | |||||
| } | |||||
| } | } | ||||
| auto begin_pos = m_file->tell(); | auto begin_pos = m_file->tell(); | ||||
| @@ -151,20 +151,22 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||||
| //! call OprRegistryV2::versioned_add for new serialization which is compatiable | //! call OprRegistryV2::versioned_add for new serialization which is compatiable | ||||
| //! with old serialization, convert is nullptr, this registry is just only for | //! with old serialization, convert is nullptr, this registry is just only for | ||||
| //! varsion 1 | //! varsion 1 | ||||
| #define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load) \ | |||||
| do { \ | |||||
| ::mgb::serialization::OprRegistry::add( \ | |||||
| {_cls::typeinfo(), \ | |||||
| MGB_HASH_STR(#_cls), \ | |||||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||||
| _dump, \ | |||||
| _load, \ | |||||
| {}, \ | |||||
| MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||||
| ::mgb::serialization::OprRegistryV2::versioned_add( \ | |||||
| {_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ | |||||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ | |||||
| ::mgb::VERSION_1, ::mgb::VERSION_1); \ | |||||
| #define MGB_SEREG_OPR_INTL_CALL_ADD(_cls, _dump, _load, _registerv2) \ | |||||
| do { \ | |||||
| ::mgb::serialization::OprRegistry::add( \ | |||||
| {_cls::typeinfo(), \ | |||||
| MGB_HASH_STR(#_cls), \ | |||||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), \ | |||||
| _dump, \ | |||||
| _load, \ | |||||
| {}, \ | |||||
| MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls)}); \ | |||||
| if (_registerv2) { \ | |||||
| ::mgb::serialization::OprRegistryV2::versioned_add( \ | |||||
| {_cls::typeinfo(), MGB_HASH_STR_WITHOUT_TAIL_0_AND_VERSION(#_cls), \ | |||||
| _MGB_SEREG_OPR_NAME_FROM_CLS(_cls), _dump, _load, nullptr}, \ | |||||
| ::mgb::VERSION_1, ::mgb::VERSION_1); \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| //! call OprRegistryV2::versioned_add for new serialization, in which convert the | //! call OprRegistryV2::versioned_add for new serialization, in which convert the | ||||
| @@ -181,23 +183,25 @@ struct OprRegistryCaller : public OprRegistryCallerDefaultImpl<Callee> {}; | |||||
| /*! | /*! | ||||
| * \brief register opr serialization methods | * \brief register opr serialization methods | ||||
| */ | */ | ||||
| #define MGB_SEREG_OPR(_cls, _arity) \ | |||||
| namespace { \ | |||||
| namespace ser = ::mgb::serialization; \ | |||||
| struct _OprReg##_cls { \ | |||||
| using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ | |||||
| static ser::OprWithOutputAccessor wrap_loader( \ | |||||
| ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||||
| const mgb::cg::OperatorNodeConfig& config) { \ | |||||
| return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ | |||||
| } \ | |||||
| static void entry() { \ | |||||
| MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader); \ | |||||
| } \ | |||||
| }; \ | |||||
| } \ | |||||
| #define MGB_SEREG_OPR_CONDITION(_cls, _arity, _registerv2) \ | |||||
| namespace { \ | |||||
| namespace ser = ::mgb::serialization; \ | |||||
| struct _OprReg##_cls { \ | |||||
| using Impl = ser::OprLoadDumpImpl<_cls, _arity>; \ | |||||
| static ser::OprWithOutputAccessor wrap_loader( \ | |||||
| ser::OprLoadContext& ctx, const mgb::cg::VarNodeArray& inputs, \ | |||||
| const mgb::cg::OperatorNodeConfig& config) { \ | |||||
| return ser::OprWithOutputAccessor(Impl::load(ctx, inputs, config)); \ | |||||
| } \ | |||||
| static void entry() { \ | |||||
| MGB_SEREG_OPR_INTL_CALL_ADD(_cls, Impl::dump, wrap_loader, _registerv2); \ | |||||
| } \ | |||||
| }; \ | |||||
| } \ | |||||
| MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | MGB_SEREG_OPR_INTL_CALL_ENTRY(_cls, _OprReg##_cls) | ||||
| #define MGB_SEREG_OPR(_cls, _arity) MGB_SEREG_OPR_CONDITION(_cls, _arity, true) | |||||
| //! new dump/load function should implement in OprLoadDumpImplV2, _converter is | //! new dump/load function should implement in OprLoadDumpImplV2, _converter is | ||||
| //! optional , if not implement pass nullptr | //! optional , if not implement pass nullptr | ||||
| #define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \ | #define MGB_SEREG_OPR_V2(_cls, _arity, _converter, _version_min, _version_max) \ | ||||
| @@ -1,3 +1,4 @@ | |||||
| #include "megbrain/opr/nn_int.h" | |||||
| #if MGB_ENABLE_FBS_SERIALIZATION | #if MGB_ENABLE_FBS_SERIALIZATION | ||||
| #include "megbrain/opr/basic_arith_wrapper.h" | #include "megbrain/opr/basic_arith_wrapper.h" | ||||
| @@ -1016,4 +1017,107 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) { | |||||
| load(); | load(); | ||||
| } | } | ||||
| TEST(TestSerializer2, TestElemwiseMultiTypeLoadDump) { | |||||
| auto fname = GET_OUTPUT_FILE(GraphDumpFormat::FLATBUFFERS_V2); | |||||
| TensorShape shape{3}; | |||||
| auto cn = CompNode::load("xpu0"); | |||||
| std::shared_ptr<HostTensorND> host0 = | |||||
| std::make_shared<HostTensorND>(cn, shape, dtype::Float32{}); | |||||
| std::shared_ptr<HostTensorND> host1 = | |||||
| std::make_shared<HostTensorND>(cn, shape, dtype::Float32{}); | |||||
| HostTensorND dst_truth; | |||||
| host0->ptr<float>()[0] = 2; | |||||
| host0->ptr<float>()[1] = 2; | |||||
| host0->ptr<float>()[2] = -1; | |||||
| host1->ptr<float>()[0] = 1; | |||||
| host1->ptr<float>()[1] = 2; | |||||
| host1->ptr<float>()[2] = 3; | |||||
| auto dump = [&](opr::ElemwiseMultiType::Param::Mode mode, size_t nr_opr) { | |||||
| auto graph = ComputingGraph::make(); | |||||
| OperatorNodeConfig config; | |||||
| config.name("input0"); | |||||
| auto h2d0 = opr::Host2DeviceCopy::make(*graph, host0, config); | |||||
| config.name("input1"); | |||||
| auto h2d1 = opr::Host2DeviceCopy::make(*graph, host1, config); | |||||
| auto x = opr::ElemwiseMultiType::make( | |||||
| {h2d0, h2d1}, {mode}, OperatorNodeConfig{dtype::Bool()}); | |||||
| x.rename("out"); | |||||
| auto func = graph->compile({make_callback_copy(x, dst_truth)}); | |||||
| auto dumper = GraphDumper::make( | |||||
| OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||||
| auto rst = dumper->dump({x}); | |||||
| func->execute().wait(); | |||||
| ASSERT_EQ(rst.nr_opr, nr_opr); | |||||
| }; | |||||
| auto load = [&]() { | |||||
| auto loader = GraphLoader::make( | |||||
| InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||||
| auto rst = loader->load(); | |||||
| ASSERT_EQ(rst.tensor_map.size(), 2); | |||||
| ASSERT_EQ(rst.output_var_map.count("out"), 1); | |||||
| HostTensorND host_x; | |||||
| auto func = | |||||
| rst.graph_compile({make_callback_copy(rst.output_var_list[0], host_x)}); | |||||
| for (auto& input : rst.tensor_map) { | |||||
| if (input.first == "input0") { | |||||
| input.second->copy_from(*host0).sync(); | |||||
| } else if (input.first == "input1") { | |||||
| input.second->copy_from(*host1).sync(); | |||||
| } | |||||
| } | |||||
| func->execute().wait(); | |||||
| for (int i = 0; i < 3; i++) { | |||||
| EXPECT_EQ(host_x.ptr<bool>()[i], dst_truth.ptr<bool>()[i]); | |||||
| } | |||||
| }; | |||||
| dump(opr::ElemwiseMultiType::Param::Mode::EQ, 4); | |||||
| load(); | |||||
| dump(opr::ElemwiseMultiType::Param::Mode::LT, 4); | |||||
| load(); | |||||
| dump(opr::ElemwiseMultiType::Param::Mode::LEQ, 4); | |||||
| load(); | |||||
| dump(opr::ElemwiseMultiType::Param::Mode::NEQ, 5); | |||||
| load(); | |||||
| auto dump_single_input = [&](opr::ElemwiseMultiType::Param::Mode mode, | |||||
| size_t nr_opr) { | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto h2d0 = opr::Host2DeviceCopy::make(*graph, host0); | |||||
| auto x = opr::ElemwiseMultiType::make( | |||||
| {h2d0}, {mode}, OperatorNodeConfig{dtype::Bool()}); | |||||
| x.rename("out"); | |||||
| auto func = graph->compile({make_callback_copy(x, dst_truth)}); | |||||
| auto dumper = GraphDumper::make( | |||||
| OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||||
| auto rst = dumper->dump({x}); | |||||
| func->execute().wait(); | |||||
| ASSERT_EQ(rst.nr_opr, nr_opr); | |||||
| }; | |||||
| auto load_single_input = [&]() { | |||||
| auto loader = GraphLoader::make( | |||||
| InputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||||
| auto rst = loader->load(); | |||||
| ASSERT_EQ(rst.tensor_map.size(), 1); | |||||
| ASSERT_EQ(rst.output_var_map.count("out"), 1); | |||||
| HostTensorND host_x; | |||||
| auto func = | |||||
| rst.graph_compile({make_callback_copy(rst.output_var_list[0], host_x)}); | |||||
| rst.tensor_map.begin()->second->copy_from(*host0).sync(); | |||||
| func->execute().wait(); | |||||
| for (int i = 0; i < 3; i++) { | |||||
| EXPECT_EQ(host_x.ptr<bool>()[i], dst_truth.ptr<bool>()[i]); | |||||
| } | |||||
| }; | |||||
| host0->ptr<float>()[2] = INFINITY; | |||||
| dump_single_input(opr::ElemwiseMultiType::Param::Mode::ISINF, 4); | |||||
| load_single_input(); | |||||
| host0->ptr<float>()[2] = NAN; | |||||
| dump_single_input(opr::ElemwiseMultiType::Param::Mode::ISNAN, 4); | |||||
| load_single_input(); | |||||
| } | |||||
| #endif | #endif | ||||