GitOrigin-RevId: 65064452c9
tags/v1.10.0
| @@ -367,10 +367,12 @@ def dump_graph( | |||||
| keep_opr_name: bool = False, | keep_opr_name: bool = False, | ||||
| keep_param_name: bool = False, | keep_param_name: bool = False, | ||||
| keep_opr_priority: bool = False, | keep_opr_priority: bool = False, | ||||
| no_change_graph: bool = False, | |||||
| strip_info_file=None, | strip_info_file=None, | ||||
| append_json=False, | append_json=False, | ||||
| metadata=None, | metadata=None, | ||||
| dump_format=None | |||||
| dump_format=None, | |||||
| model_version: int = 2 | |||||
| ) -> Tuple[bytes, CompGraphDumpResult]: | ) -> Tuple[bytes, CompGraphDumpResult]: | ||||
| r"""serialize the computing graph of `output_vars` and get byte result. | r"""serialize the computing graph of `output_vars` and get byte result. | ||||
| @@ -386,12 +388,22 @@ def dump_graph( | |||||
| keep_param_name: whether to keep param names, so param values can be | keep_param_name: whether to keep param names, so param values can be | ||||
| easily manipulated after loading model | easily manipulated after loading model | ||||
| keep_opr_priority: whether to keep priority setting for operators | keep_opr_priority: whether to keep priority setting for operators | ||||
| no_change_graph: whether to change the compute graph when dump, for | |||||
| model compatibility, some operators will convert to its compatible | |||||
| format in this version. | |||||
| * if set False, some operators maybe convert to other operator for | |||||
| compatibility, all operators will ensure compatibility. | |||||
| * if set True, no operator will change in the graph when dump. | |||||
| strip_info_file: a string for path or a file handler. if is not None, | strip_info_file: a string for path or a file handler. if is not None, | ||||
| then the dump information for code strip would be written to ``strip_info_file`` | then the dump information for code strip would be written to ``strip_info_file`` | ||||
| append_json: will be check when `strip_info_file` is not None. if set | append_json: will be check when `strip_info_file` is not None. if set | ||||
| true, the information for code strip will be append to strip_info_file. | true, the information for code strip will be append to strip_info_file. | ||||
| if set false, will rewrite strip_info_file | if set false, will rewrite strip_info_file | ||||
| dump_format: using different dump formats. | dump_format: using different dump formats. | ||||
| model_version: the model version of "FBS_V2", begin with version 2, this | |||||
| works only when dump format is "FBS_V2". | |||||
| Note: | Note: | ||||
| The underlying C++ API only accepts a var list. If a dict is given, | The underlying C++ API only accepts a var list. If a dict is given, | ||||
| @@ -441,8 +453,10 @@ def dump_graph( | |||||
| keep_opr_name, | keep_opr_name, | ||||
| keep_param_name, | keep_param_name, | ||||
| keep_opr_priority, | keep_opr_priority, | ||||
| no_change_graph, | |||||
| metadata, | metadata, | ||||
| dump_format, | dump_format, | ||||
| model_version, | |||||
| stat, | stat, | ||||
| inputs, | inputs, | ||||
| outputs, | outputs, | ||||
| @@ -549,6 +549,7 @@ class trace: | |||||
| keep_opr_name: bool = False, | keep_opr_name: bool = False, | ||||
| keep_param_name: bool = False, | keep_param_name: bool = False, | ||||
| keep_opr_priority: bool = False, | keep_opr_priority: bool = False, | ||||
| no_change_graph: bool = False, | |||||
| strip_info_file=None, | strip_info_file=None, | ||||
| append_json=False, | append_json=False, | ||||
| optimize_for_inference=True, | optimize_for_inference=True, | ||||
| @@ -562,6 +563,7 @@ class trace: | |||||
| resize_input=False, | resize_input=False, | ||||
| input_transform=None, | input_transform=None, | ||||
| dump_format: str = None, | dump_format: str = None, | ||||
| model_version: int = 2, | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| r"""Serializes trace to file system. | r"""Serializes trace to file system. | ||||
| @@ -583,6 +585,14 @@ class trace: | |||||
| keep_param_name: whether to keep param names, so param values can be | keep_param_name: whether to keep param names, so param values can be | ||||
| easily manipulated after loading model | easily manipulated after loading model | ||||
| keep_opr_priority: whether to keep priority setting for operators | keep_opr_priority: whether to keep priority setting for operators | ||||
| no_change_graph: whether to change the compute graph when dump, for | |||||
| model compatibility, some operators will convert to its compatible | |||||
| format in this version. | |||||
| * if set False, some operators maybe convert to other operator for | |||||
| compatibility, all operators will ensure compatibility. | |||||
| * if set True, no operator will change in the graph when dump. | |||||
| strip_info_file: a string for path or a file handler. if is not None, | strip_info_file: a string for path or a file handler. if is not None, | ||||
| then the dump information for code strip would be written to ``strip_info_file`` | then the dump information for code strip would be written to ``strip_info_file`` | ||||
| append_json: will be check when `strip_info_file` is not None. if set | append_json: will be check when `strip_info_file` is not None. if set | ||||
| @@ -616,6 +626,9 @@ class trace: | |||||
| dump_format: using different dump formats. the open source MegEngine | dump_format: using different dump formats. the open source MegEngine | ||||
| defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose, | defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose, | ||||
| internal MegEngine have an other choice of internal proprietary formats | internal MegEngine have an other choice of internal proprietary formats | ||||
| model_version: the model version of FBS_V2, begin with version 2, this | |||||
| works only when dump format is FBS_V2. | |||||
| Keyword Arguments: | Keyword Arguments: | ||||
| @@ -762,10 +775,12 @@ class trace: | |||||
| keep_opr_name=keep_opr_name, | keep_opr_name=keep_opr_name, | ||||
| keep_param_name=keep_param_name, | keep_param_name=keep_param_name, | ||||
| keep_opr_priority=keep_opr_priority, | keep_opr_priority=keep_opr_priority, | ||||
| no_change_graph=no_change_graph, | |||||
| strip_info_file=strip_info_file, | strip_info_file=strip_info_file, | ||||
| append_json=append_json, | append_json=append_json, | ||||
| metadata=metadata, | metadata=metadata, | ||||
| dump_format=dump_format, | dump_format=dump_format, | ||||
| model_version=model_version, | |||||
| ) | ) | ||||
| file.write(dump_content) | file.write(dump_content) | ||||
| @@ -381,20 +381,26 @@ void init_graph_rt(py::module m) { | |||||
| m.def("dump_graph", | m.def("dump_graph", | ||||
| [](const std::vector<VarNode*>& dest_vars, int keep_var_name, | [](const std::vector<VarNode*>& dest_vars, int keep_var_name, | ||||
| bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, | bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, | ||||
| std::optional<_SerializationMetadata> metadata, | |||||
| std::optional<_SerializationFormat> dump_format, py::list& stat, | |||||
| py::list& inputs, py::list& outputs, py::list& params) { | |||||
| bool no_change_graph, std::optional<_SerializationMetadata> metadata, | |||||
| std::optional<_SerializationFormat> dump_format, | |||||
| std::optional<int> model_version, py::list& stat, py::list& inputs, | |||||
| py::list& outputs, py::list& params) { | |||||
| std::vector<uint8_t> buf; | std::vector<uint8_t> buf; | ||||
| ser::GraphDumpFormat format = ser::GraphDumpFormat::FLATBUFFERS_V2; | ser::GraphDumpFormat format = ser::GraphDumpFormat::FLATBUFFERS_V2; | ||||
| int version = 2; | |||||
| if (dump_format.has_value()) { | if (dump_format.has_value()) { | ||||
| format = dump_format.value(); | format = dump_format.value(); | ||||
| } | } | ||||
| if (model_version.has_value()) { | |||||
| version = model_version.value(); | |||||
| } | |||||
| auto dumper = ser::GraphDumper::make( | auto dumper = ser::GraphDumper::make( | ||||
| ser::OutputFile::make_vector_proxy(&buf), format); | |||||
| ser::OutputFile::make_vector_proxy(&buf), format, version); | |||||
| SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | ||||
| ser::GraphDumper::DumpConfig config{ | ser::GraphDumper::DumpConfig config{ | ||||
| keep_var_name, keep_param_name, keep_opr_priority, keep_opr_name}; | keep_var_name, keep_param_name, keep_opr_priority, keep_opr_name}; | ||||
| config.no_change_graph = no_change_graph; | |||||
| ser::GraphDumper::DumpResult rst; | ser::GraphDumper::DumpResult rst; | ||||
| if (metadata) | if (metadata) | ||||
| @@ -21,6 +21,13 @@ struct OprLoadDumpImplV2<opr::Softmax, 1> { | |||||
| ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param()); | ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param()); | ||||
| } | } | ||||
| /** This converter is just a example for Operator serialization compatible, | |||||
| * Just in this situation: when optimize the softmax Operator by | |||||
| * fusing the elemwise and reduce to a big Operator, but the whole softmax | |||||
| * Operator can't be recognized by old version, in order to model | |||||
| * compatibility the softmax Operator should be covert to elemwise and | |||||
| * reduce Operators when dump the model | |||||
| */ | |||||
| static cg::OperatorNodeBase* replace_opr( | static cg::OperatorNodeBase* replace_opr( | ||||
| cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | ||||
| int32_t axis = opr->cast_final_safe<Opr>().param().axis; | int32_t axis = opr->cast_final_safe<Opr>().param().axis; | ||||
| @@ -196,9 +203,11 @@ namespace opr { | |||||
| #define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ | #define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ | ||||
| MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); | MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); | ||||
| SERGE_OPR_V2_CONVERTER( | |||||
| //! this is just a example for Operator compatibility | |||||
| /*SERGE_OPR_V2_CONVERTER( | |||||
| Softmax, 1, | Softmax, 1, | ||||
| (mgb::serialization::OprLoadDumpImplV2<opr::Softmax, 1>::replace_opr)); | |||||
| (mgb::serialization::OprLoadDumpImplV2<opr::Softmax, 1>::replace_opr));*/ | |||||
| SERGE_OPR_V2_NO_CONVERTER(Softmax, 1) | |||||
| SERGE_OPR_V2_NO_CONVERTER(ConvBiasForward, 0) | SERGE_OPR_V2_NO_CONVERTER(ConvBiasForward, 0) | ||||
| SERGE_OPR_V2_NO_CONVERTER(BatchConvBiasForward, 0); | SERGE_OPR_V2_NO_CONVERTER(BatchConvBiasForward, 0); | ||||
| @@ -59,7 +59,8 @@ std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file); | |||||
| std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file); | std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file); | ||||
| std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file); | std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file); | ||||
| std::unique_ptr<GraphDumper> make_fbs_v2_dumper(std::unique_ptr<OutputFile> file); | |||||
| std::unique_ptr<GraphDumper> make_fbs_v2_dumper( | |||||
| std::unique_ptr<OutputFile> file, int version); | |||||
| bool is_fbs_file(InputFile& file); | bool is_fbs_file(InputFile& file); | ||||
| bool is_fbs_v2_file(InputFile& file); | bool is_fbs_v2_file(InputFile& file); | ||||
| @@ -72,7 +73,7 @@ bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) { | |||||
| } | } | ||||
| std::unique_ptr<GraphDumper> GraphDumper::make( | std::unique_ptr<GraphDumper> GraphDumper::make( | ||||
| std::unique_ptr<OutputFile> file, GraphDumpFormat format) { | |||||
| std::unique_ptr<OutputFile> file, GraphDumpFormat format, int version) { | |||||
| switch (format) { | switch (format) { | ||||
| case GraphDumpFormat::FLATBUFFERS: | case GraphDumpFormat::FLATBUFFERS: | ||||
| #if MGB_ENABLE_FBS_SERIALIZATION | #if MGB_ENABLE_FBS_SERIALIZATION | ||||
| @@ -81,7 +82,7 @@ std::unique_ptr<GraphDumper> GraphDumper::make( | |||||
| MGB_FALLTHRU | MGB_FALLTHRU | ||||
| case GraphDumpFormat::FLATBUFFERS_V2: | case GraphDumpFormat::FLATBUFFERS_V2: | ||||
| #if MGB_ENABLE_FBS_SERIALIZATION | #if MGB_ENABLE_FBS_SERIALIZATION | ||||
| return make_fbs_v2_dumper(std::move(file)); | |||||
| return make_fbs_v2_dumper(std::move(file), version); | |||||
| #endif | #endif | ||||
| MGB_FALLTHRU | MGB_FALLTHRU | ||||
| default: | default: | ||||
| @@ -194,7 +194,7 @@ void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) { | |||||
| } | } | ||||
| } else { | } else { | ||||
| auto registry = OprRegistryV2::versioned_find_by_typeinfo( | auto registry = OprRegistryV2::versioned_find_by_typeinfo( | ||||
| opr->dyn_typeinfo(), CURRENT_VERSION); | |||||
| opr->dyn_typeinfo(), m_version); | |||||
| if (!registry || !registry->dumper) { | if (!registry || !registry->dumper) { | ||||
| mgb_throw( | mgb_throw( | ||||
| cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>, | cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>, | ||||
| @@ -202,6 +202,9 @@ void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) { | |||||
| "operator %s", | "operator %s", | ||||
| opr->dyn_typeinfo()->name); | opr->dyn_typeinfo()->name); | ||||
| } | } | ||||
| mgb_assert( | |||||
| registry->version <= m_version, | |||||
| "The Operator version should less than model version"); | |||||
| m_oprs_to_dump.emplace_back(opr, registry); | m_oprs_to_dump.emplace_back(opr, registry); | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -352,7 +355,10 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( | |||||
| const Metadata& metadata) { | const Metadata& metadata) { | ||||
| mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph"); | mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph"); | ||||
| auto&& new_output_vars = converter_all_opr_to_compatiable(output_vars); | |||||
| auto new_output_vars = output_vars; | |||||
| if (!config.no_change_graph) { | |||||
| new_output_vars = converter_all_opr_to_compatiable(output_vars); | |||||
| } | |||||
| auto begin_pos = m_file->tell(); | auto begin_pos = m_file->tell(); | ||||
| m_config = config; | m_config = config; | ||||
| @@ -416,6 +422,7 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( | |||||
| fbs::v2::ModelBuilder model(m_builder); | fbs::v2::ModelBuilder model(m_builder); | ||||
| model.add_mge_version(MGB_VERSION); | model.add_mge_version(MGB_VERSION); | ||||
| model.add_model_version(m_version); | |||||
| model.add_oprs(fb_oprs); | model.add_oprs(fb_oprs); | ||||
| model.add_middle_tensors(fb_mid_tensor); | model.add_middle_tensors(fb_mid_tensor); | ||||
| model.add_output_vars_idx(fb_output_vars); | model.add_output_vars_idx(fb_output_vars); | ||||
| @@ -694,10 +701,8 @@ void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr( | |||||
| OprRegistryV2::versioned_find_by_id(type_id, opr_version); | OprRegistryV2::versioned_find_by_id(type_id, opr_version); | ||||
| mgb_throw_if( | mgb_throw_if( | ||||
| !registry, SerializationError, | !registry, SerializationError, | ||||
| "failed to find opr with type %s , use python env " | |||||
| "config.dump_registered_oprs() to get a dict that maps from " | |||||
| "opr id to opr name", | |||||
| fbopr->type()->str().c_str()); | |||||
| "failed to find opr with type %s and version %d.", | |||||
| fbopr->type()->str().c_str(), opr_version); | |||||
| // load inputs | // load inputs | ||||
| VarNodeArray inputs; | VarNodeArray inputs; | ||||
| @@ -811,12 +816,19 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re | |||||
| m_model = fbs::v2::GetModel(m_model_buf.data()); | m_model = fbs::v2::GetModel(m_model_buf.data()); | ||||
| m_mgb_version = m_model->mge_version(); | m_mgb_version = m_model->mge_version(); | ||||
| m_model_version = m_model->model_version(); | |||||
| if (m_model->mge_version() > MGB_VERSION) { | if (m_model->mge_version() > MGB_VERSION) { | ||||
| mgb_log_warn( | mgb_log_warn( | ||||
| "loading model from future runtime: version=%u " | "loading model from future runtime: version=%u " | ||||
| "model_version=%u", | "model_version=%u", | ||||
| MGB_VERSION, m_model->mge_version()); | MGB_VERSION, m_model->mge_version()); | ||||
| } | } | ||||
| if (m_model_version > CURRENT_VERSION) { | |||||
| mgb_log_warn( | |||||
| "The model dump in the future version %d, try to load it, maybe case " | |||||
| "load error in %d version.", | |||||
| m_model_version, CURRENT_VERSION); | |||||
| } | |||||
| if (m_shared_tensor_map.empty()) { | if (m_shared_tensor_map.empty()) { | ||||
| m_shared_tensor_map.resize(m_model->nr_shared_tensor()); | m_shared_tensor_map.resize(m_model->nr_shared_tensor()); | ||||
| @@ -845,8 +857,9 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re | |||||
| return result; | return result; | ||||
| } | } | ||||
| std::unique_ptr<GraphDumper> make_fbs_v2_dumper(std::unique_ptr<OutputFile> file) { | |||||
| return std::make_unique<GraphDumperOSSV2>(std::move(file)); | |||||
| std::unique_ptr<GraphDumper> make_fbs_v2_dumper( | |||||
| std::unique_ptr<OutputFile> file, int version) { | |||||
| return std::make_unique<GraphDumperOSSV2>(std::move(file), version); | |||||
| } | } | ||||
| std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file) { | std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file) { | ||||
| @@ -58,18 +58,25 @@ struct GraphDumpConfig { | |||||
| //! names. this list record the mapping between output node and it's name | //! names. this list record the mapping between output node and it's name | ||||
| std::vector<std::pair<std::string, SymbolVar>> alias_name_map; | std::vector<std::pair<std::string, SymbolVar>> alias_name_map; | ||||
| //! whether just to dump all the op with no change the graph, sometimes the | |||||
| //! opr maybe not compatible, if false, some opr will converter to the compatibility | |||||
| //! format and then dump | |||||
| bool no_change_graph; | |||||
| GraphDumpConfig( | GraphDumpConfig( | ||||
| int keep_var_name_ = 1, bool keep_param_name_ = false, | int keep_var_name_ = 1, bool keep_param_name_ = false, | ||||
| bool keep_opr_priority_ = false, bool keep_op_name_ = true, | bool keep_opr_priority_ = false, bool keep_op_name_ = true, | ||||
| const std::shared_ptr<UserDataContainer>& user_data_ = | const std::shared_ptr<UserDataContainer>& user_data_ = | ||||
| std::make_shared<UserDataContainer>(), | std::make_shared<UserDataContainer>(), | ||||
| const TensorValueDumper& tensor_value_dumper_ = {}) | |||||
| const TensorValueDumper& tensor_value_dumper_ = {}, | |||||
| bool no_change_graph_ = false) | |||||
| : keep_var_name{keep_var_name_}, | : keep_var_name{keep_var_name_}, | ||||
| keep_param_name{keep_param_name_}, | keep_param_name{keep_param_name_}, | ||||
| keep_opr_priority{keep_opr_priority_}, | keep_opr_priority{keep_opr_priority_}, | ||||
| keep_op_name{keep_op_name_}, | keep_op_name{keep_op_name_}, | ||||
| user_data{user_data_}, | user_data{user_data_}, | ||||
| tensor_value_dumper{tensor_value_dumper_} {} | |||||
| tensor_value_dumper{tensor_value_dumper_}, | |||||
| no_change_graph{no_change_graph_} {} | |||||
| }; | }; | ||||
| //! config for loading a whole graph; setup in GraphLoader | //! config for loading a whole graph; setup in GraphLoader | ||||
| @@ -15,6 +15,7 @@ namespace serialization { | |||||
| class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers { | class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers { | ||||
| const std::unique_ptr<OutputFile> m_file; | const std::unique_ptr<OutputFile> m_file; | ||||
| int m_version; | |||||
| flatbuffers::FlatBufferBuilder m_builder; | flatbuffers::FlatBufferBuilder m_builder; | ||||
| DumpConfig m_config; | DumpConfig m_config; | ||||
| @@ -51,7 +52,8 @@ class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers { | |||||
| flatbuffers::Offset<fbs::DType> build_dtype(DType dtype); | flatbuffers::Offset<fbs::DType> build_dtype(DType dtype); | ||||
| public: | public: | ||||
| GraphDumperOSSV2(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {} | |||||
| GraphDumperOSSV2(std::unique_ptr<OutputFile> file, int version) | |||||
| : m_file{std::move(file)}, m_version{version} {} | |||||
| DumpResult dump( | DumpResult dump( | ||||
| const SymbolVarArray& output_vars, const DumpConfig& config = {}, | const SymbolVarArray& output_vars, const DumpConfig& config = {}, | ||||
| @@ -95,6 +97,7 @@ class GraphLoaderOSSV2 final : public GraphLoader { | |||||
| const fbs::v2::Model* m_model; | const fbs::v2::Model* m_model; | ||||
| SharedTensorIDMap m_shared_tensor_map; | SharedTensorIDMap m_shared_tensor_map; | ||||
| uint32_t m_mgb_version = 0; | uint32_t m_mgb_version = 0; | ||||
| uint32_t m_model_version = CURRENT_VERSION; | |||||
| bool m_model_loaded = false; | bool m_model_loaded = false; | ||||
| void verify(); | void verify(); | ||||
| @@ -5,6 +5,7 @@ | |||||
| #include "megbrain/serialization/file.h" | #include "megbrain/serialization/file.h" | ||||
| #include "megbrain/serialization/load_dump_config.h" | #include "megbrain/serialization/load_dump_config.h" | ||||
| #include "megbrain/serialization/metadata.h" | #include "megbrain/serialization/metadata.h" | ||||
| #include "megbrain/serialization/opr_load_dump.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace serialization { | namespace serialization { | ||||
| @@ -160,7 +161,8 @@ public: | |||||
| }; | }; | ||||
| MGE_WIN_DECLSPEC_FUC static std::unique_ptr<GraphDumper> make( | MGE_WIN_DECLSPEC_FUC static std::unique_ptr<GraphDumper> make( | ||||
| std::unique_ptr<OutputFile> file, GraphDumpFormat format = {}); | |||||
| std::unique_ptr<OutputFile> file, GraphDumpFormat format = {}, | |||||
| int version = VERSION_2); | |||||
| virtual ~GraphDumper() = default; | virtual ~GraphDumper() = default; | ||||
| @@ -987,7 +987,9 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) { | |||||
| OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | ||||
| auto rst = dumper->dump({x}); | auto rst = dumper->dump({x}); | ||||
| func->execute().wait(); | func->execute().wait(); | ||||
| ASSERT_EQ(rst.nr_opr, 6); | |||||
| //! if convert to reduce and elemwise, nr_opr is 6 | |||||
| // ASSERT_EQ(rst.nr_opr, 6); | |||||
| ASSERT_EQ(rst.nr_opr, 2); | |||||
| ASSERT_EQ(rst.inputs.size(), 1); | ASSERT_EQ(rst.inputs.size(), 1); | ||||
| ASSERT_EQ(rst.outputs.size(), 1); | ASSERT_EQ(rst.outputs.size(), 1); | ||||
| ASSERT_EQ(rst.params.size(), 0); | ASSERT_EQ(rst.params.size(), 0); | ||||