GitOrigin-RevId: b563c94451
tags/v1.5.0
| @@ -11,7 +11,7 @@ import json | |||
| import os | |||
| import weakref | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from typing import Dict, List, Tuple, Union | |||
| from typing import Dict, List, Optional, Tuple, Union | |||
| import numpy as np | |||
| @@ -256,6 +256,9 @@ def optimize_for_inference(dest_vars, **kwargs): | |||
| * enable_chwn4 -- | |||
| whether to use CHWN4 data layout, currently | |||
| used in nvidia backend with tensorcore. | |||
| * enable_nchw64 -- | |||
| whether to use NCHW64 data layout, used for fast int4 | |||
| support on Nvidia GPU. | |||
| * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
| into one opr. | |||
| @@ -273,6 +276,7 @@ def optimize_for_inference(dest_vars, **kwargs): | |||
| "enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, | |||
| "enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, | |||
| "enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, | |||
| "enable_nchw64": GraphOptimizeOptions.LayoutTransform.NCHW64, | |||
| } | |||
| for k, v in inference_optimize_layout_transform_map.items(): | |||
| @@ -293,7 +297,46 @@ def optimize_for_inference(dest_vars, **kwargs): | |||
| dest_vars = _unwrap(dest_vars) | |||
| res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) | |||
| return _wrap(res_vars) | |||
| return _wrap(res_vars), inference_options.serialize() | |||
| def deserialize_infer_option(x: int) -> Dict[str, bool]: | |||
| r""" | |||
| Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``. | |||
| :param x: inference options represented by int. | |||
| :return: inference options represented by dict. | |||
| """ | |||
| inference_options = GraphOptimizeOptions.deserialize(x) | |||
| inference_optimize_layout_transform_map = { | |||
| GraphOptimizeOptions.LayoutTransform.NHWCD4: "enable_hwcd4", | |||
| GraphOptimizeOptions.LayoutTransform.NCHW4: "enable_nchw4", | |||
| GraphOptimizeOptions.LayoutTransform.NCHW88: "enable_nchw88", | |||
| GraphOptimizeOptions.LayoutTransform.NCHW32: "enable_nchw32", | |||
| GraphOptimizeOptions.LayoutTransform.NCHW44: "enable_nchw44", | |||
| GraphOptimizeOptions.LayoutTransform.NCHW44_DOT: "enable_nchw44_dot", | |||
| GraphOptimizeOptions.LayoutTransform.CHWN4: "enable_chwn4", | |||
| GraphOptimizeOptions.LayoutTransform.NCHW64: "enable_nchw64", | |||
| } | |||
| ret = dict() | |||
| layout = inference_options.layout_transform | |||
| if layout != GraphOptimizeOptions.LayoutTransform.DEFAULT: | |||
| ret[inference_optimize_layout_transform_map[layout]] = True | |||
| if inference_options.f16_io_f32_comp: | |||
| ret["enable_io16xc32"] = True | |||
| if inference_options.f16_io_comp: | |||
| ret["enable_ioc16"] = True | |||
| if inference_options.fuse_conv_bias_nonlinearity: | |||
| ret["enable_fuse_conv_bias_nonlinearity"] = True | |||
| if inference_options.fuse_conv_bias_with_z: | |||
| ret["enable_fuse_conv_bias_with_z"] = True | |||
| return ret | |||
| def modify_opr_algo_strategy_inplace(dest_vars, strategy: str): | |||
| @@ -331,7 +374,8 @@ def dump_graph( | |||
| keep_param_name: bool = False, | |||
| keep_opr_priority: bool = False, | |||
| strip_info_file=None, | |||
| append_json=False | |||
| append_json=False, | |||
| metadata=None | |||
| ) -> Tuple[bytes, CompGraphDumpResult]: | |||
| """ | |||
| serialize the computing graph of `output_vars` and get byte result. | |||
| @@ -393,6 +437,7 @@ def dump_graph( | |||
| keep_opr_name, | |||
| keep_param_name, | |||
| keep_opr_priority, | |||
| metadata, | |||
| stat, | |||
| inputs, | |||
| outputs, | |||
| @@ -427,7 +472,7 @@ def dump_graph( | |||
| CompGraphLoadResult = collections.namedtuple( | |||
| "CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"] | |||
| "CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list", "metadata"] | |||
| ) | |||
| @@ -450,8 +495,8 @@ def load_graph(fpath) -> CompGraphLoadResult: | |||
| buf = open(fpath, "rb").read() | |||
| else: | |||
| buf = fpath.read() | |||
| cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list) | |||
| return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list) | |||
| cg, metadata = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list) | |||
| return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list, metadata) | |||
| def _wrap(x): | |||
| @@ -12,10 +12,12 @@ import functools | |||
| import itertools | |||
| import json | |||
| import os | |||
| import pickle | |||
| from typing import Any | |||
| import numpy as np | |||
| from ..core._imperative_rt import GraphProfiler | |||
| from ..core._imperative_rt import GraphProfiler, SerializationMetadata | |||
| from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
| from ..core._imperative_rt.core2 import ( | |||
| TensorWeakRef, | |||
| @@ -670,6 +672,8 @@ class trace: | |||
| strip_info_file=None, | |||
| append_json=False, | |||
| optimize_for_inference=True, | |||
| user_info: Any = None, | |||
| enable_metadata: bool = True, | |||
| **kwargs | |||
| ): | |||
| r""" | |||
| @@ -697,6 +701,8 @@ class trace: | |||
| if set false, will rewrite strip_info_file | |||
| :param optimize_for_inference: enbale optmizations, | |||
| will skip all optimize options if this is False. Default: True | |||
| :param user_info: any type object, which will be pickled to bytes. | |||
| :param enable_metadata: whether to save metadata into output file. | |||
| :Keyword Arguments: | |||
| @@ -729,6 +735,9 @@ class trace: | |||
| * enable_chwn4 -- | |||
| whether to use CHWN4 data layout, currently | |||
| used in nvidia backend with tensorcore. | |||
| * enable_nchw64 -- | |||
| whether to use NCHW64 data layout, used for fast int4 | |||
| support on Nvidia GPU. | |||
| * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
| into one opr. | |||
| @@ -851,7 +860,15 @@ class trace: | |||
| dest_vars.append(v) | |||
| if optimize_for_inference: | |||
| dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||
| dest_vars, optimize_options = G.optimize_for_inference(dest_vars, **kwargs) | |||
| metadata = SerializationMetadata() | |||
| if enable_metadata: | |||
| metadata.user_info = pickle.dumps(user_info) | |||
| metadata.is_valid = True | |||
| metadata.graph_modified = False | |||
| if optimize_for_inference: | |||
| metadata.optimize_options = optimize_options | |||
| if isinstance(file, str): | |||
| permission = "wb" if append == False else "ab" | |||
| @@ -864,6 +881,7 @@ class trace: | |||
| keep_opr_priority=keep_opr_priority, | |||
| strip_info_file=strip_info_file, | |||
| append_json=append_json, | |||
| metadata=metadata, | |||
| ) | |||
| file.write(dump_content) | |||
| return dump_info | |||
| @@ -411,7 +411,8 @@ def main(): | |||
| args.embed_input = True | |||
| logger.info("loading model ...") | |||
| graph, _, output_vars = G.load_graph(args.net) | |||
| ret = G.load_graph(args.net) | |||
| graph, output_vars = ret.graph, ret.output_vars_list | |||
| input_vars = tools.get_dep_vars(output_vars, "Host2DeviceCopy") | |||
| if args.output_name is not None: | |||
| @@ -391,7 +391,8 @@ class GraphInference: | |||
| optimize_for_inference: bool = False, | |||
| **kwargs | |||
| ): | |||
| self._graph, _, output_nodes = G.load_graph(file) | |||
| ret = G.load_graph(file) | |||
| self._graph, output_nodes = ret.graph, ret.output_vars_list | |||
| if outputs is not None: | |||
| output_nodes = find_vars_by_name(output_nodes, outputs) | |||
| self._origin_outputs = output_nodes | |||
| @@ -9,14 +9,12 @@ | |||
| import collections | |||
| import fnmatch | |||
| import itertools | |||
| import pickle | |||
| import re | |||
| from collections import OrderedDict | |||
| from typing import Dict, List, Sequence | |||
| from typing import Any, Dict, List, Sequence | |||
| import numpy as np | |||
| from ..core._imperative_rt import ComputingGraph | |||
| from ..core._imperative_rt.core2 import SymbolVar | |||
| from ..core._imperative_rt import ComputingGraph, SerializationMetadata | |||
| from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape | |||
| from ..core.tensor import megbrain_graph as G | |||
| from ..logger import get_logger | |||
| @@ -42,6 +40,30 @@ class Network: | |||
| self.all_oprs_map = OrderedDict() | |||
| self.all_vars_map = OrderedDict() | |||
| self.graph = ComputingGraph() | |||
| self._metadata = None | |||
| @property | |||
| def metadata(self): | |||
| r""" | |||
| Load metadata as a dict. | |||
| """ | |||
| if not self._metadata.is_valid: | |||
| logger.info("metadata is not valid!") | |||
| return None | |||
| ret = dict() | |||
| try: | |||
| user_info = pickle.loads(self._metadata.user_info) | |||
| except: # pylint: disable=bare-except | |||
| logger.warning( | |||
| "can't parse user info by pickle, so return the original bytes object!" | |||
| ) | |||
| user_info = self._metadata.user_info | |||
| ret["user_info"] = user_info | |||
| ret["graph_modified"] = self._metadata.graph_modified | |||
| ret["optimized_for_inference"] = self._metadata.optimized_for_inference | |||
| if ret["optimized_for_inference"]: | |||
| ret.update(G.deserialize_infer_option(self._metadata.optimize_options)) | |||
| return ret | |||
| @classmethod | |||
| def load(cls, model_path: str, outspec: List[str] = None): | |||
| @@ -51,7 +73,8 @@ class Network: | |||
| :param outspec: only load the subgraph with outspec as its endpoints. | |||
| """ | |||
| self = cls() | |||
| _, _, outputs = G.load_graph(model_path) | |||
| ret = G.load_graph(model_path) | |||
| outputs, self._metadata = ret.output_vars_list, ret.metadata | |||
| if outspec is not None: | |||
| output_spec = outspec.copy() | |||
| all_vars = get_dep_vars(outputs) + outputs | |||
| @@ -125,6 +148,9 @@ class Network: | |||
| * enable_chwn4 -- | |||
| whether to use CHWN4 data layout, currently | |||
| used in nvidia backend with tensorcore. | |||
| * enable_nchw64 -- | |||
| whether to use NCHW64 data layout, used for fast int4 | |||
| support on Nvidia GPU. | |||
| * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
| into one opr. | |||
| @@ -152,6 +178,8 @@ class Network: | |||
| append_json=False, | |||
| optimize_for_inference=True, | |||
| append=False, | |||
| user_info: Any = None, | |||
| enable_metadata=True, | |||
| **kwargs | |||
| ): | |||
| """ | |||
| @@ -176,6 +204,8 @@ class Network: | |||
| if set false, will rewrite strip_info_file | |||
| :param optimize_for_inference: enbale optmizations, | |||
| will skip all optimize options if this is False. Default: True | |||
| :param user_info: any type object, which will be pickled to bytes. | |||
| :param enable_metadata: whether to save metadata into output file. | |||
| :Keyword Arguments: | |||
| @@ -201,7 +231,15 @@ class Network: | |||
| ) | |||
| if optimize_for_inference: | |||
| out = G.optimize_for_inference(out, **kwargs) | |||
| out, optimize_options = G.optimize_for_inference(out, **kwargs) | |||
| metadata = SerializationMetadata() | |||
| if enable_metadata: | |||
| metadata.is_valid = True | |||
| metadata.graph_modified = True | |||
| metadata.user_info = pickle.dumps(user_info) | |||
| if optimize_for_inference: | |||
| metadata.optimize_options = optimize_options | |||
| dump_content, _ = G.dump_graph( | |||
| out, | |||
| @@ -211,6 +249,7 @@ class Network: | |||
| keep_opr_priority=keep_opr_priority, | |||
| strip_info_file=strip_info_file, | |||
| append_json=append_json, | |||
| metadata=metadata, | |||
| ) | |||
| if isinstance(file, str): | |||
| permission = "wb" if append == False else "ab" | |||
| @@ -34,6 +34,7 @@ namespace ser = mgb::serialization; | |||
| using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; | |||
| using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; | |||
| using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
| using _SerializationMetadata = mgb::serialization::Metadata; | |||
| namespace { | |||
| class _CompGraphProfilerImpl { | |||
| @@ -240,6 +241,8 @@ void init_graph_rt(py::module m) { | |||
| auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions") | |||
| .def(py::init()) | |||
| .def("serialize", &_OptimizeForInferenceOptions::serialize) | |||
| .def_static("deserialize", &_OptimizeForInferenceOptions::deserialize) | |||
| .def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp) | |||
| .def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp) | |||
| .def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity) | |||
| @@ -256,6 +259,7 @@ void init_graph_rt(py::module m) { | |||
| .value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT) | |||
| .value("NCHW32", _LayoutTransform::NCHW32) | |||
| .value("CHWN4", _LayoutTransform::CHWN4) | |||
| .value("NCHW64", _LayoutTransform::NCHW64) | |||
| .export_values() | |||
| ; | |||
| @@ -307,12 +311,24 @@ void init_graph_rt(py::module m) { | |||
| })->to_string(); | |||
| }); | |||
| py::class_<_SerializationMetadata>(m, "SerializationMetadata") | |||
| .def(py::init()) | |||
| .def_property("user_info", [](const _SerializationMetadata& meta){return py::bytes(meta.get_user_info()); }, | |||
| &_SerializationMetadata::set_user_info) | |||
| .def_readonly("optimized_for_inference", &_SerializationMetadata::optimized_for_inference) | |||
| .def_property("optimize_options", &_SerializationMetadata::get_optimize_options, | |||
| &_SerializationMetadata::set_optimize_options) | |||
| .def_readwrite("graph_modified", &_SerializationMetadata::graph_modified) | |||
| .def_readwrite("is_valid", &_SerializationMetadata::is_valid) | |||
| ; | |||
| m.def("dump_graph", []( | |||
| const std::vector<VarNode*>& dest_vars, | |||
| int keep_var_name, | |||
| bool keep_opr_name, | |||
| bool keep_param_name, | |||
| bool keep_opr_priority, | |||
| std::optional<_SerializationMetadata> metadata, | |||
| py::list& stat, | |||
| py::list& inputs, | |||
| py::list& outputs, | |||
| @@ -325,7 +341,12 @@ void init_graph_rt(py::module m) { | |||
| ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | |||
| keep_opr_priority, keep_opr_name}; | |||
| auto rst = dumper->dump(symvars, config); | |||
| ser::GraphDumper::DumpResult rst; | |||
| if (metadata) | |||
| rst = dumper->dump(symvars, config, *metadata); | |||
| else | |||
| rst = dumper->dump(symvars, config); | |||
| for (auto i : rst.inputs) { | |||
| inputs.append(py::cast(i)); | |||
| } | |||
| @@ -377,8 +398,10 @@ void init_graph_rt(py::module m) { | |||
| for (const auto& var : rst.output_var_list) { | |||
| iter.add(var); | |||
| } | |||
| return rst.graph; | |||
| auto ret = py::tuple(2); | |||
| ret[0] = py::cast(rst.graph); | |||
| ret[1] = py::cast(rst.metadata); | |||
| return ret; | |||
| }); | |||
| #define CURRENT_CLASS cg::ComputingGraph::Options | |||
| @@ -239,8 +239,7 @@ def test_dump_volatile(): | |||
| file = io.BytesIO() | |||
| f.dump(file, optimize_for_inference=False) | |||
| file.seek(0) | |||
| cg, _, outputs = G.load_graph(file) | |||
| (out,) = outputs | |||
| (out,) = G.load_graph(file).output_vars_list | |||
| assert ( | |||
| cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) | |||
| == "ImmutableTensor" | |||
| @@ -337,12 +336,12 @@ def test_goptions_log_exp(): | |||
| f(tensor(1.0)) | |||
| _, out = mkstemp() | |||
| f.dump(out, optimize_for_inference=False) | |||
| *_, outputs = G.load_graph(out) | |||
| outputs = G.load_graph(out).output_vars_list | |||
| oprs_1 = cgtools.get_oprs_seq(outputs) | |||
| g(tensor(1.0)) | |||
| g.dump(out, optimize_for_inference=False) | |||
| *_, outputs = G.load_graph(out) | |||
| outputs = G.load_graph(out).output_vars_list | |||
| oprs_2 = cgtools.get_oprs_seq(outputs) | |||
| assert len(oprs_1) - len(oprs_2) == 2 | |||
| @@ -88,7 +88,7 @@ def test_graph_traversal(): | |||
| file = io.BytesIO() | |||
| fun.dump(file, optimize_for_inference=False) | |||
| file.seek(0) | |||
| cg, _, outputs = mgb_graph.load_graph(file) | |||
| outputs = mgb_graph.load_graph(file).output_vars_list | |||
| _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs) | |||
| input_var = map_vars[1] | |||
| @@ -101,7 +101,9 @@ def test_load_refcnt(): | |||
| graph = mgb_graph.Graph() | |||
| varnode = graph.make_const(0) | |||
| buf, _ = mgb_graph.dump_graph([varnode]) | |||
| graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf)) | |||
| ret = mgb_graph.load_graph(io.BytesIO(buf)) | |||
| graph, (varnode,) = ret.graph, ret.output_vars_list | |||
| del ret | |||
| del graph | |||
| varnode.owner | |||
| @@ -132,7 +134,7 @@ def test_get_opr_seq(): | |||
| file = io.BytesIO() | |||
| func.dump(file, optimize_for_inference=False) | |||
| file.seek(0) | |||
| *_, outputs = mgb_graph.load_graph(file) | |||
| outputs = mgb_graph.load_graph(file).output_vars_list | |||
| seq_1 = cgtools.get_oprs_seq(outputs, True) | |||
| assert len(seq_1) == 5 | |||
| @@ -35,7 +35,7 @@ def _dump_and_load(func, symbolic, keep_opr_name=True): | |||
| keep_var_name=2, | |||
| ) | |||
| file.seek(0) | |||
| *_, outputs = G.load_graph(file) | |||
| outputs = G.load_graph(file).output_vars_list | |||
| ops = cgtools.get_oprs_seq(outputs) | |||
| return ops | |||
| @@ -223,7 +223,7 @@ def test_catch_input_name(tensor_name, var_name): | |||
| file = io.BytesIO() | |||
| func.dump(file, optimize_for_inference=False, keep_opr_name=True, keep_var_name=2) | |||
| file.seek(0) | |||
| *_, outputs = G.load_graph(file) | |||
| outputs = G.load_graph(file).output_vars_list | |||
| op = cgtools.get_oprs_seq(outputs)[-1] | |||
| assert op.inputs[0].name == var_name | |||
| @@ -14,6 +14,50 @@ from megengine.utils.network import as_oprnode, set_symbolic_shape | |||
| from megengine.utils.network_node import Host2DeviceCopy, VarNode | |||
| def test_metadata(): | |||
| x = Tensor(0) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(x): | |||
| return x * 2 | |||
| fwd(x) | |||
| orig_model = io.BytesIO() | |||
| fwd.dump(orig_model, user_info="test", optimize_for_inference=False) | |||
| orig_model.seek(0) | |||
| graph = Net.load(orig_model) | |||
| assert graph.metadata == { | |||
| "user_info": "test", | |||
| "graph_modified": False, # False: tracing.dump | |||
| "optimized_for_inference": False, | |||
| } | |||
| orig_model.seek(0) | |||
| graph.dump( | |||
| orig_model, | |||
| user_info={"str": "x", "tensor": x, "module": M.Module, "none": None}, | |||
| optimize_for_inference=True, | |||
| enable_nchw4=True, | |||
| enable_ioc16=True, | |||
| ) | |||
| orig_model.seek(0) | |||
| graph = Net.load(orig_model) | |||
| assert graph.metadata == { | |||
| "user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None}, | |||
| "graph_modified": True, # True: Network.dump | |||
| "optimized_for_inference": True, | |||
| "enable_nchw4": True, | |||
| "enable_ioc16": True, | |||
| } | |||
| orig_model.seek(0) | |||
| fwd.dump(orig_model, enable_metadata=False) | |||
| orig_model.seek(0) | |||
| graph = Net.load(orig_model) | |||
| assert graph.metadata is None | |||
| def test_replace_var(): | |||
| a = Tensor([1, 2]) | |||
| @@ -170,7 +170,8 @@ def gen_one_testcase(args, inputs, spec): | |||
| def make_feeds(args): | |||
| cg_rt, _, outputs = G.load_graph(args.input) | |||
| ret = G.load_graph(args.input) | |||
| cg_rt, outputs = ret.graph, ret.output_vars_list | |||
| inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") | |||
| inputs = {i.name: i for i in inputs} | |||
| @@ -322,7 +322,31 @@ namespace gopt { | |||
| static std::unique_ptr<EnableNchw44DotPass> make_nchw44_dot_converter(); | |||
| }; | |||
| struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {}; | |||
| struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { | |||
| uint64_t serialize() { | |||
| uint64_t ret = 0; | |||
| ret |= (uint64_t)layout_transform << 32; | |||
| if (f16_io_f32_comp) ret |= 1u; | |||
| if (f16_io_comp) ret |= 1u << 1; | |||
| if (fuse_conv_bias_nonlinearity) ret |= 1u << 2; | |||
| if (fuse_conv_bias_with_z) ret |= 1u << 3; | |||
| if (weight_preprocess) ret |= 1u << 4; | |||
| if (fuse_preprocess) ret |= 1u << 5; | |||
| return ret; | |||
| } | |||
| static OptimizeForInferenceOptions deserialize(uint64_t buf) { | |||
| OptimizeForInferenceOptions ret; | |||
| ret.f16_io_f32_comp = buf & 1u; | |||
| ret.f16_io_comp = buf & 1u << 1; | |||
| ret.fuse_conv_bias_nonlinearity = buf & 1u << 2; | |||
| ret.fuse_conv_bias_with_z = buf & 1u << 3; | |||
| ret.weight_preprocess = buf & 1u << 4; | |||
| ret.fuse_preprocess = buf & 1u << 5; | |||
| ret.layout_transform = (LayoutTransform)(buf >> 32); | |||
| return ret; | |||
| } | |||
| }; | |||
| /*! | |||
| * \brief optimize a computing graph for inference | |||
| @@ -128,6 +128,13 @@ table Operator { | |||
| name:string; | |||
| } | |||
| table Metadata { | |||
| is_valid:bool; | |||
| graph_modified:bool; | |||
| user_info:string; | |||
| optimize_options:ulong; | |||
| } | |||
| struct OutputVar { | |||
| compact_id:uint; | |||
| original_id:uint; | |||
| @@ -141,6 +148,7 @@ table Graph { | |||
| nr_shared_tensor:uint; | |||
| oprs:[Operator]; | |||
| output_vars_idx:[OutputVar]; | |||
| metadata:Metadata; | |||
| } | |||
| root_type Graph; | |||
| @@ -30,6 +30,7 @@ | |||
| #include "megbrain/serialization/internal/flatbuffers_helper.h" | |||
| #include "megbrain/serialization/internal/schema_generated.h" | |||
| #include "megbrain/serialization/opr_load_dump.h" | |||
| #include "megbrain/serialization/metadata.h" | |||
| #include "megbrain/serialization/serializer.h" | |||
| #include "megbrain/version.h" | |||
| @@ -115,6 +116,7 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers { | |||
| std::vector<flatbuffers::Offset<void>> m_cur_opr_param; | |||
| void init_oprs_to_dump(const SymbolVarArray& endpoints); | |||
| flatbuffers::Offset<fbs::Metadata> build_metadata(const Metadata& metadata); | |||
| flatbuffers::Offset<fbs::Operator> build_single_opr( | |||
| cg::OperatorNodeBase* opr, const OprRegistry* registry); | |||
| @@ -123,7 +125,8 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers { | |||
| public: | |||
| GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {} | |||
| DumpResult dump(const SymbolVarArray& output_vars, | |||
| const DumpConfig& config = {}) override; | |||
| const DumpConfig& config = {}, | |||
| const Metadata& metadata = {}) override; | |||
| const GraphDumpConfig& config() const override { return m_config; } | |||
| void dump_tensor(const std::string& name, const HostTensorND& tensor, | |||
| TensorWriteMethod method) override; | |||
| @@ -185,6 +188,17 @@ void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) { | |||
| } | |||
| } | |||
| flatbuffers::Offset<fbs::Metadata> GraphDumperOSS::build_metadata( | |||
| const Metadata& metadata) { | |||
| auto user_info = m_builder.CreateSharedString(metadata.user_info); | |||
| fbs::MetadataBuilder builder(m_builder); | |||
| builder.add_is_valid(metadata.is_valid); | |||
| builder.add_graph_modified(metadata.graph_modified); | |||
| builder.add_user_info(user_info); | |||
| builder.add_optimize_options(metadata.optimize_options); | |||
| return builder.Finish(); | |||
| } | |||
| flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||
| cg::OperatorNodeBase* opr, const OprRegistry* registry) { | |||
| m_cur_opr = opr; | |||
| @@ -282,7 +296,8 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||
| } | |||
| GraphDumper::DumpResult GraphDumperOSS::dump( | |||
| const SymbolVarArray& output_vars, const DumpConfig& config) { | |||
| const SymbolVarArray& output_vars, | |||
| const DumpConfig& config, const Metadata& metadata) { | |||
| mgb_throw_if(output_vars.empty(), SerializationError, | |||
| "Can't dump empty graph"); | |||
| @@ -323,6 +338,9 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||
| uint64_t offset_to_fbs = 0; | |||
| m_file->write(&offset_to_fbs, sizeof(offset_to_fbs)); | |||
| // Dump metadata | |||
| auto fbmeta = build_metadata(metadata); | |||
| // Dump operators | |||
| init_oprs_to_dump(output_vars); | |||
| std::vector<flatbuffers::Offset<fbs::Operator>> oprs; | |||
| @@ -350,6 +368,7 @@ GraphDumper::DumpResult GraphDumperOSS::dump( | |||
| graph.add_oprs(fb_oprs); | |||
| graph.add_output_vars_idx(fb_output_vars); | |||
| graph.add_nr_shared_tensor(m_nr_shared_tensor); | |||
| graph.add_metadata(fbmeta); | |||
| m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier()); | |||
| // Write actual offset_to_fbs | |||
| @@ -531,6 +550,7 @@ public: | |||
| mgb_assert(nr == 1); | |||
| } | |||
| Metadata load_metadata(); | |||
| LoadResult load_oprs(); | |||
| CompNode load_comp_node(const fbs::CompNode* comp_node); | |||
| @@ -700,6 +720,22 @@ GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() { | |||
| return sh_ptr_ref; | |||
| } | |||
| Metadata GraphLoaderOSS::OprLoadContextImpl::load_metadata() { | |||
| const auto* fbmeta = m_loader->m_graph->metadata(); | |||
| Metadata ret; | |||
| ret.is_valid = fbmeta->is_valid(); | |||
| ret.graph_modified = fbmeta->graph_modified(); | |||
| if (fbmeta->user_info()) { | |||
| ret.user_info = fbmeta->user_info()->str(); | |||
| ret.has_user_info = true; | |||
| } | |||
| if (fbmeta->optimize_options()) { | |||
| ret.optimize_options = fbmeta->optimize_options(); | |||
| ret.optimized_for_inference = true; | |||
| } | |||
| return ret; | |||
| } | |||
| void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||
| const fbs::Operator* fbopr) { | |||
| m_cur_opr_tensor_cnt = 0; | |||
| @@ -872,7 +908,9 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, | |||
| } | |||
| OprLoadContextImpl ctx{this, m_graph->mgb_version()}; | |||
| auto metadata = ctx.load_metadata(); | |||
| auto result = ctx.load_oprs(); | |||
| result.metadata = metadata; | |||
| auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size; | |||
| auto cur = m_file->tell(); | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * \file src/serialization/include/megbrain/serialization/metadata.h | |||
| * | |||
| * This file is part of MegBrain, a deep learning framework developed by Megvii. | |||
| * | |||
| * \brief MegEngine model's metadata | |||
| * | |||
| * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| */ | |||
| #pragma once | |||
| #include <string> | |||
| namespace mgb { | |||
| namespace serialization { | |||
| struct Metadata { | |||
| bool is_valid = false; | |||
| bool graph_modified = false; | |||
| bool has_user_info = false; | |||
| std::string user_info; | |||
| bool optimized_for_inference = false; | |||
| uint64_t optimize_options; | |||
| #define ADD_PROPERTY(type, name) \ | |||
| type get_##name() const { return name; } \ | |||
| void set_##name(type x) { \ | |||
| name = x; \ | |||
| has_##name = true; \ | |||
| } | |||
| ADD_PROPERTY(std::string, user_info) | |||
| #undef ADD_PROPERTY | |||
| uint64_t get_optimize_options() { return optimize_options; } | |||
| void set_optimize_options(uint64_t value) { | |||
| optimized_for_inference = true; | |||
| optimize_options = value; | |||
| } | |||
| }; | |||
| } // namespace serialization | |||
| } // namespace mgb | |||
| @@ -15,6 +15,7 @@ | |||
| #include "megbrain/serialization/dump_format.h" | |||
| #include "megbrain/serialization/file.h" | |||
| #include "megbrain/serialization/load_dump_config.h" | |||
| #include "megbrain/serialization/metadata.h" | |||
| namespace mgb { | |||
| namespace serialization { | |||
| @@ -32,6 +33,9 @@ namespace serialization { | |||
| //! expliit dtor decl to reduce binary size | |||
| ~LoadResult() noexcept; | |||
| //! metadata | |||
| Metadata metadata; | |||
| using TensorMap = std::unordered_map< | |||
| std::string, std::shared_ptr<HostTensorND>>; | |||
| @@ -178,7 +182,8 @@ namespace serialization { | |||
| virtual DumpResult dump( | |||
| const SymbolVarArray &output_vars, | |||
| const DumpConfig &config = {}) = 0; | |||
| const DumpConfig &config = {}, | |||
| const Metadata &metadata = {}) = 0; | |||
| virtual GraphDumpFormat format() const = 0; | |||
| }; | |||
| @@ -92,6 +92,43 @@ TEST(TestSerializer2, MultiGraphDumpLoad) { | |||
| load(); | |||
| } | |||
| TEST(TestSerializer2, Metadata) { | |||
| auto fname = GET_OUTPUT_FILE(); | |||
| TensorShape shape{2, 3}; | |||
| auto dump = [&]() { | |||
| auto cn = CompNode::load("xpu0"); | |||
| auto host_x = std::make_shared<HostTensorND>(cn, shape), | |||
| host_y = std::make_shared<HostTensorND>(cn, shape); | |||
| auto graph = ComputingGraph::make(); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}), | |||
| y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"}); | |||
| using Mode = opr::Elemwise::Mode; | |||
| auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"}); | |||
| Metadata metadata; | |||
| metadata.user_info = "TEST_METADATA"; | |||
| metadata.has_user_info = true; | |||
| auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), | |||
| GraphDumpFormat::FLATBUFFERS); | |||
| auto rst = dumper->dump({z.rename("z")}, {}, metadata); | |||
| }; | |||
| auto load = [&]() { | |||
| HostTensorGenerator<> gen; | |||
| auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), | |||
| GraphDumpFormat::FLATBUFFERS); | |||
| auto rst = loader->load(); | |||
| auto metadata = rst.metadata; | |||
| int cmp = strcmp(metadata.user_info.c_str(), "TEST_METADATA"); | |||
| EXPECT_EQ(cmp, 0); | |||
| }; | |||
| dump(); | |||
| load(); | |||
| } | |||
| TEST(TestSerializer2, APlusB) { | |||
| auto fname = GET_OUTPUT_FILE(); | |||
| TensorShape shape{2, 3}; | |||