| @@ -11,13 +11,12 @@ import json | |||
| import os | |||
| import weakref | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from typing import Dict, List, Optional, Tuple, Union | |||
| from typing import Dict, List, Tuple, Union | |||
| import numpy as np | |||
| from .. import _imperative_rt | |||
| from .._imperative_rt import GraphOptimizeOptions | |||
| from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | |||
| from .._imperative_rt import GraphOptimizeOptions, SerializationFormat | |||
| from .._wrap import as_device | |||
| from ..ops.builtin import OpDef | |||
| @@ -377,7 +376,8 @@ def dump_graph( | |||
| keep_opr_priority: bool = False, | |||
| strip_info_file=None, | |||
| append_json=False, | |||
| metadata=None | |||
| metadata=None, | |||
| dump_format=None | |||
| ) -> Tuple[bytes, CompGraphDumpResult]: | |||
| r"""serialize the computing graph of `output_vars` and get byte result. | |||
| @@ -398,6 +398,7 @@ def dump_graph( | |||
| append_json: will be check when `strip_info_file` is not None. if set | |||
| true, the information for code strip will be append to strip_info_file. | |||
| if set false, will rewrite strip_info_file | |||
| dump_format: using different dump formats. | |||
| Note: | |||
| The underlying C++ API only accepts a var list. If a dict is given, | |||
| @@ -434,6 +435,12 @@ def dump_graph( | |||
| outputs = [] | |||
| params = [] | |||
| dump_format_map = { | |||
| None: None, | |||
| "FBS": SerializationFormat.FBS, | |||
| } | |||
| dump_format = dump_format_map[dump_format] | |||
| dump_content = _imperative_rt.dump_graph( | |||
| ov, | |||
| keep_var_name, | |||
| @@ -441,6 +448,7 @@ def dump_graph( | |||
| keep_param_name, | |||
| keep_opr_priority, | |||
| metadata, | |||
| dump_format, | |||
| stat, | |||
| inputs, | |||
| outputs, | |||
| @@ -1008,6 +1008,7 @@ class trace: | |||
| maxerr=1e-4, | |||
| resize_input=False, | |||
| input_transform=None, | |||
| dump_format: str = None, | |||
| **kwargs | |||
| ): | |||
| r"""Serializes trace to file system. | |||
| @@ -1059,6 +1060,7 @@ class trace: | |||
| resize_input: whether resize input image to fit input var shape. | |||
| input_transform: a python expression to transform the input data. | |||
| Example: data / np.std(data) | |||
| dump_format: using different dump formats. | |||
| Keyword Arguments: | |||
| @@ -1265,6 +1267,7 @@ class trace: | |||
| strip_info_file=strip_info_file, | |||
| append_json=append_json, | |||
| metadata=metadata, | |||
| dump_format=dump_format, | |||
| ) | |||
| file.write(dump_content) | |||
| @@ -35,6 +35,7 @@ using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; | |||
| using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; | |||
| using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
| using _SerializationMetadata = mgb::serialization::Metadata; | |||
| using _SerializationFormat = mgb::serialization::GraphDumpFormat; | |||
| namespace { | |||
| class _CompGraphProfilerImpl { | |||
| @@ -310,6 +311,10 @@ void init_graph_rt(py::module m) { | |||
| .value("NCHW64", _LayoutTransform::NCHW64) | |||
| .export_values(); | |||
| py::enum_<_SerializationFormat>(m, "SerializationFormat") | |||
| .value("FBS", _SerializationFormat::FLATBUFFERS) | |||
| .export_values(); | |||
| m.def("optimize_for_inference", | |||
| [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) { | |||
| SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | |||
| @@ -380,11 +385,18 @@ void init_graph_rt(py::module m) { | |||
| m.def("dump_graph", | |||
| [](const std::vector<VarNode*>& dest_vars, int keep_var_name, | |||
| bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, | |||
| std::optional<_SerializationMetadata> metadata, py::list& stat, | |||
| std::optional<_SerializationMetadata> metadata, | |||
| std::optional<_SerializationFormat> dump_format, py::list& stat, | |||
| py::list& inputs, py::list& outputs, py::list& params) { | |||
| std::vector<uint8_t> buf; | |||
| auto dumper = | |||
| ser::GraphDumper::make(ser::OutputFile::make_vector_proxy(&buf)); | |||
| ser::GraphDumpFormat format; | |||
| if (dump_format.has_value()) { | |||
| format = dump_format.value(); | |||
| } else { | |||
| format = {}; | |||
| } | |||
| auto dumper = ser::GraphDumper::make( | |||
| ser::OutputFile::make_vector_proxy(&buf), format); | |||
| SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | |||
| ser::GraphDumper::DumpConfig config{ | |||
| @@ -190,7 +190,13 @@ def test_print_in_trace(): | |||
| np.testing.assert_equal(z, buf) | |||
| def test_dump(): | |||
| @pytest.mark.parametrize( | |||
| "dump_format", | |||
| [ | |||
| "FBS", | |||
| ], | |||
| ) | |||
| def test_dump(dump_format): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def f(a, b): | |||
| return a + b | |||
| @@ -205,7 +211,7 @@ def test_dump(): | |||
| np.testing.assert_equal(f(a, b).numpy(), y) | |||
| file = io.BytesIO() | |||
| dump_info = f.dump(file) | |||
| dump_info = f.dump(file, dump_format=dump_format) | |||
| assert dump_info.nr_opr == 3 | |||
| np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | |||
| np.testing.assert_equal(dump_info.outputs, ["ADD"]) | |||