| @@ -14,6 +14,7 @@ from concurrent.futures import Future, ThreadPoolExecutor | |||||
| import numpy as np | import numpy as np | ||||
| from .. import _imperative_rt | from .. import _imperative_rt | ||||
| from .._imperative_rt import GraphOptimizeOptions | |||||
| from .._imperative_rt.ops import BackwardGraph | from .._imperative_rt.ops import BackwardGraph | ||||
| from .._wrap import device as as_device | from .._wrap import device as as_device | ||||
| from ..ops.builtin import OpDef | from ..ops.builtin import OpDef | ||||
| @@ -83,6 +84,84 @@ class Graph(_imperative_rt.ComputingGraph): | |||||
| return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name)) | return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name)) | ||||
| def optimize_for_inference(dest_vars, **kwargs): | |||||
| r"""Applies optimize_for_inference pass for computing graph. | |||||
| :param dest_vars: list of output vars in the computing graph | |||||
| :Keyword Arguments: | |||||
| * enable_io16xc32 -- | |||||
| whether to use float16 for I/O between oprs and use | |||||
| float32 as internal computation precision. Note the output var would be | |||||
| changed to float16. | |||||
| * enable_ioc16 -- | |||||
| whether to use float16 for both I/O and computation | |||||
| precision. | |||||
| * enable_hwcd4 -- | |||||
| whether to use NHWCD4 data layout. This is faster on some | |||||
| OpenCL backend. | |||||
| * enable_nchw88 -- | |||||
| whether to use NCHW88 data layout, currently | |||||
| used in X86 AVX backend. | |||||
| * enable_nchw44 -- | |||||
| whether to use NCHW44 data layout, currently | |||||
| used in arm backend. | |||||
| * enable_nchw44_dot -- | |||||
| whether to use NCHW44_dot data layout, currently | |||||
| used in armv8.2+dotprod backend. | |||||
| * enable_nchw4 -- | |||||
| whether to use NCHW4 data layout, currently | |||||
| used in nvidia backend(based on cudnn). | |||||
| * enable_nchw32 -- | |||||
| whether to use NCHW32 data layout, currently | |||||
| used in nvidia backend with tensorcore(based on cudnn). | |||||
| * enable_chwn4 -- | |||||
| whether to use CHWN4 data layout, currently | |||||
| used in nvidia backend with tensorcore. | |||||
| * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||||
| into one opr. | |||||
| * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||||
| input for inference on nvidia backend(this optimization pass will | |||||
| result in mismatch of the precision of output of training and | |||||
| inference) | |||||
| """ | |||||
| inference_options = GraphOptimizeOptions() | |||||
| if optimize_for_inference: | |||||
| inference_optimize_layout_transform_map = { | |||||
| "enable_hwcd4": GraphOptimizeOptions.LayoutTransform.NHWCD4, | |||||
| "enable_nchw4": GraphOptimizeOptions.LayoutTransform.NCHW4, | |||||
| "enable_nchw88": GraphOptimizeOptions.LayoutTransform.NCHW88, | |||||
| "enable_nchw32": GraphOptimizeOptions.LayoutTransform.NCHW32, | |||||
| "enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44, | |||||
| "enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT, | |||||
| "enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4, | |||||
| } | |||||
| for k, v in inference_optimize_layout_transform_map.items(): | |||||
| if kwargs.pop(k, False): | |||||
| inference_options.layout_transform = v | |||||
| if kwargs.pop("enable_io16xc32", False): | |||||
| inference_options.f16_io_f32_comp = True | |||||
| if kwargs.pop("enable_ioc16", False): | |||||
| inference_options.f16_io_comp = True | |||||
| if kwargs.pop("enable_fuse_conv_bias_nonlinearity", False): | |||||
| inference_options.fuse_conv_bias_nonlinearity = True | |||||
| if kwargs.pop("enable_fuse_conv_bias_with_z", False): | |||||
| inference_options.fuse_conv_bias_with_z = True | |||||
| if kwargs: | |||||
| raise ValueError("unknown options: %s" % list(kwargs)) | |||||
| res_vars = _imperative_rt.optimize_for_inference( | |||||
| [i._node for i in dest_vars], inference_options | |||||
| ) | |||||
| return [VarNode(i) for i in res_vars] | |||||
| def dump(*args): | def dump(*args): | ||||
| return _imperative_rt.dump_graph([i._node for i in args]) | return _imperative_rt.dump_graph([i._node for i in args]) | ||||
| @@ -11,6 +11,7 @@ import numpy as np | |||||
| from ..core._imperative_rt import GraphProfiler | from ..core._imperative_rt import GraphProfiler | ||||
| from ..core._imperative_rt.ops import OprAttr | from ..core._imperative_rt.ops import OprAttr | ||||
| from ..core._trace_option import set_tensor_shape | |||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
| from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | ||||
| @@ -76,6 +77,22 @@ class TensorInfo: | |||||
| class trace: | class trace: | ||||
| """ | |||||
| Wraps a callable and provide: | |||||
| * tracing via :meth:`.trace` and :meth:`.dump` | |||||
| * accelerated evalutaion via :meth:`.__call__` | |||||
| :param function: the function will be traced. | |||||
| :param symbolic: whether to apply symbolic execution for tracing. Default: False | |||||
| :param capture_as_const: capture global vars or closures as const value. Default: False | |||||
| :param sublinear_memory_config: configuration for sublinear memory optimization. | |||||
| If not None, it enables sublinear memory optimization with given setting. | |||||
| :param profiling: whether to profile compiled trace. Default: False | |||||
| :param opt_level: optimization level for compiling trace. | |||||
| :param symbolic_shape: whether to use symbolic shape for tracing. Default: True | |||||
| """ | |||||
| def __new__(cls, *args, **kwargs): | def __new__(cls, *args, **kwargs): | ||||
| if not args: | if not args: | ||||
| return functools.partial(cls, **kwargs) | return functools.partial(cls, **kwargs) | ||||
| @@ -88,6 +105,8 @@ class trace: | |||||
| capture_as_const=False, | capture_as_const=False, | ||||
| sublinear_memory_config: SublinearMemoryConfig = None, | sublinear_memory_config: SublinearMemoryConfig = None, | ||||
| profiling: bool = False, | profiling: bool = False, | ||||
| opt_level: int = None, | |||||
| tensor_shape: bool = True, | |||||
| ): | ): | ||||
| self.__wrapped__ = function | self.__wrapped__ = function | ||||
| self._symbolic = symbolic | self._symbolic = symbolic | ||||
| @@ -95,6 +114,8 @@ class trace: | |||||
| self._sublinear_memory_config = sublinear_memory_config | self._sublinear_memory_config = sublinear_memory_config | ||||
| self._profiling = profiling | self._profiling = profiling | ||||
| self._profiler = None | self._profiler = None | ||||
| self._graph_opt_level = opt_level | |||||
| self._tensor_shape = tensor_shape | |||||
| self._untraced = True | self._untraced = True | ||||
| self._tinfo = [] # handle -> TensorInfo | self._tinfo = [] # handle -> TensorInfo | ||||
| @@ -112,6 +133,8 @@ class trace: | |||||
| self._output_bindings = None | self._output_bindings = None | ||||
| self._output_names = None | self._output_names = None | ||||
| set_tensor_shape(self._tensor_shape) | |||||
| def _new_handle(self): | def _new_handle(self): | ||||
| handle = len(self._tinfo) | handle = len(self._tinfo) | ||||
| info = TensorInfo() | info = TensorInfo() | ||||
| @@ -307,6 +330,9 @@ class trace: | |||||
| def _apply_graph_options(self, graph): | def _apply_graph_options(self, graph): | ||||
| graph.options.seq_opt.enable_seq_comp_node_opt = False | graph.options.seq_opt.enable_seq_comp_node_opt = False | ||||
| # graph opt level | |||||
| if self._graph_opt_level is not None: | |||||
| graph.options.graph_opt_level = self._graph_opt_level | |||||
| # sublinear | # sublinear | ||||
| if self._sublinear_memory_config is not None: | if self._sublinear_memory_config is not None: | ||||
| graph.options.enable_sublinear_memory_opt = True | graph.options.enable_sublinear_memory_opt = True | ||||
| @@ -320,6 +346,7 @@ class trace: | |||||
| ) | ) | ||||
| sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try | sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try | ||||
| sublinear_config.num_worker = self._sublinear_memory_config.num_worker | sublinear_config.num_worker = self._sublinear_memory_config.num_worker | ||||
| # profile | |||||
| if self._profiling: | if self._profiling: | ||||
| self._profiler = GraphProfiler(graph) | self._profiler = GraphProfiler(graph) | ||||
| @@ -416,7 +443,55 @@ class trace: | |||||
| self._process_outputs(outputs) | self._process_outputs(outputs) | ||||
| return outputs | return outputs | ||||
| def dump(self, file, *, arg_names=None, output_names=None): | |||||
| def dump(self, file, *, arg_names=None, output_names=None, append=False, **kwargs): | |||||
| r"""Serializes trace to file system. | |||||
| :param file: output file, could be file object or filename. | |||||
| :param arg_names: names of the input tensors in the traced function. | |||||
| :param output_names: names of the output tensors in the traced function, | |||||
| use the default name if not specified. | |||||
| :param append: whether output is appended to ``file``. | |||||
| Only works when ``file`` is str. | |||||
| :Keyword Arguments: | |||||
| * enable_io16xc32 -- | |||||
| whether to use float16 for I/O between oprs and use | |||||
| float32 as internal computation precision. Note the output var would be | |||||
| changed to float16. | |||||
| * enable_ioc16 -- | |||||
| whether to use float16 for both I/O and computation | |||||
| precision. | |||||
| * enable_hwcd4 -- | |||||
| whether to use NHWCD4 data layout. This is faster on some | |||||
| OpenCL backend. | |||||
| * enable_nchw88 -- | |||||
| whether to use NCHW88 data layout, currently | |||||
| used in X86 AVX backend. | |||||
| * enable_nchw44 -- | |||||
| whether to use NCHW44 data layout, currently | |||||
| used in arm backend. | |||||
| * enable_nchw44_dot -- | |||||
| whether to use NCHW44_dot data layout, currently | |||||
| used in armv8.2+dotprod backend. | |||||
| * enable_nchw4 -- | |||||
| whether to use NCHW4 data layout, currently | |||||
| used in nvidia backend(based on cudnn). | |||||
| * enable_nchw32 -- | |||||
| whether to use NCHW32 data layout, currently | |||||
| used in nvidia backend with tensorcore(based on cudnn). | |||||
| * enable_chwn4 -- | |||||
| whether to use CHWN4 data layout, currently | |||||
| used in nvidia backend with tensorcore. | |||||
| * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||||
| into one opr. | |||||
| * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||||
| input for inference on nvidia backend(this optimization pass will | |||||
| result in mismatch of the precision of output of training and | |||||
| inference) | |||||
| """ | |||||
| if not self._capture_as_const: | if not self._capture_as_const: | ||||
| raise ValueError( | raise ValueError( | ||||
| "you must specify capture_as_const=True at __init__ to use dump" | "you must specify capture_as_const=True at __init__ to use dump" | ||||
| @@ -482,8 +557,11 @@ class trace: | |||||
| v.name = output_names[i] | v.name = output_names[i] | ||||
| dest_vars.append(v) | dest_vars.append(v) | ||||
| dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||||
| if isinstance(file, str): | if isinstance(file, str): | ||||
| file = open(file, "wb") | |||||
| permission = "wb" if append == False else "ab" | |||||
| file = open(file, permission) | |||||
| file.write(G.dump(*dest_vars)) | file.write(G.dump(*dest_vars)) | ||||
| def _process_inputs(self, *args, **kwargs): | def _process_inputs(self, *args, **kwargs): | ||||
| @@ -20,12 +20,17 @@ | |||||
| #include "./helper.h" | #include "./helper.h" | ||||
| #include "megbrain/plugin/profiler.h" | #include "megbrain/plugin/profiler.h" | ||||
| #include "./common.h" | #include "./common.h" | ||||
| #include "megbrain/gopt/inference.h" | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace imperative; | using namespace imperative; | ||||
| using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; | |||||
| using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; | |||||
| namespace { | namespace { | ||||
| class _CompGraphProfilerImpl { | class _CompGraphProfilerImpl { | ||||
| std::shared_ptr<ComputingGraph> m_comp_graph; | std::shared_ptr<ComputingGraph> m_comp_graph; | ||||
| @@ -138,6 +143,37 @@ void init_graph_rt(py::module m) { | |||||
| return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size()); | return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size()); | ||||
| }); | }); | ||||
| auto GraphOptimizeOptions = py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions") | |||||
| .def(py::init()) | |||||
| .def_readwrite("f16_io_f32_comp", &_OptimizeForInferenceOptions::f16_io_f32_comp) | |||||
| .def_readwrite("f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp) | |||||
| .def_readwrite("fuse_conv_bias_nonlinearity", &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity) | |||||
| .def_readwrite("fuse_conv_bias_with_z", &_OptimizeForInferenceOptions::fuse_conv_bias_with_z) | |||||
| .def_readwrite("layout_transform", &_OptimizeForInferenceOptions::layout_transform) | |||||
| ; | |||||
| py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform") | |||||
| .value("DEFAULT", _LayoutTransform::DEFAULT) | |||||
| .value("NCHW4", _LayoutTransform::NCHW4) | |||||
| .value("NHWCD4", _LayoutTransform::NHWCD4) | |||||
| .value("NCHW88", _LayoutTransform::NCHW88) | |||||
| .value("NCHW44", _LayoutTransform::NCHW44) | |||||
| .value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT) | |||||
| .value("NCHW32", _LayoutTransform::NCHW32) | |||||
| .value("CHWN4", _LayoutTransform::CHWN4) | |||||
| .export_values() | |||||
| ; | |||||
| m.def("optimize_for_inference", [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) { | |||||
| SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | |||||
| auto res_symvars = mgb::gopt::optimize_for_inference(symvars, opt); | |||||
| VarNodeArray vars; | |||||
| for (auto& si: res_symvars) | |||||
| vars.push_back(si.node()); | |||||
| return vars; | |||||
| }); | |||||
| #define CURRENT_CLASS cg::ComputingGraph::Options | #define CURRENT_CLASS cg::ComputingGraph::Options | ||||
| auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options") | auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options") | ||||
| @@ -1,29 +0,0 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import pytest | |||||
| from megengine.core import Tensor | |||||
| # from megengine.core.interpreter.hints import function | |||||
| @pytest.mark.skip(reason="under rewrite") | |||||
| def test_1(): | |||||
| @function | |||||
| def f(x, p): | |||||
| x = x + 1 | |||||
| if p: | |||||
| return x * x | |||||
| return x * 2 | |||||
| x = Tensor(0) | |||||
| for _ in range(5): | |||||
| assert f(x, 0).numpy() == 2 | |||||
| assert f(x, 1).numpy() == 1 | |||||
| @@ -1,10 +1,23 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import io | import io | ||||
| from tempfile import mkstemp | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| from megengine import tensor | |||||
| from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
| from megengine.core.tensor import megbrain_graph as G | |||||
| from megengine.core.tensor.core import apply | from megengine.core.tensor.core import apply | ||||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | from megengine.core.tensor.raw_tensor import as_raw_tensor | ||||
| from megengine.functional import exp, log | |||||
| from megengine.jit import exclude_from_trace, trace | from megengine.jit import exclude_from_trace, trace | ||||
| @@ -101,3 +114,85 @@ def test_trace_profiler(): | |||||
| out = f.get_profile() | out = f.get_profile() | ||||
| assert out.get("profiler") | assert out.get("profiler") | ||||
| @pytest.mark.skip(reason="eq_to_unit failed in inplace.cpp") | |||||
| def test_goptions_div_zero(): | |||||
| @trace(symbolic=True, opt_level=0) | |||||
| def f(x): | |||||
| return x / x | |||||
| @trace(symbolic=True, opt_level=1) | |||||
| def g(x): | |||||
| return x / x | |||||
| out = f(tensor(0.0)) | |||||
| if out == out: | |||||
| raise ValueError("actual result should be nan") | |||||
| out = g(tensor(0.0)) | |||||
| if out != out: | |||||
| raise ValueError("actual result should be 1") | |||||
| @pytest.mark.skip(reason="cast to Elemwise failed in inplace.cpp") | |||||
| def test_goptions_log_exp(): | |||||
| @trace(symbolic=True, opt_level=0, capture_as_const=True) | |||||
| def f(x): | |||||
| return log(exp(x)) | |||||
| @trace(symbolic=True, opt_level=1, capture_as_const=True) | |||||
| def g(x): | |||||
| return log(exp(x)) | |||||
| f(tensor(1.0)) | |||||
| _, out = mkstemp() | |||||
| f.dump(out) | |||||
| *_, outputs = G.load_comp_graph_from_file(out) | |||||
| oprs_1 = cgtools.get_oprs_seq(outputs) | |||||
| g(tensor(1.0)) | |||||
| g.dump(out) | |||||
| *_, outputs = G.load_comp_graph_from_file(out) | |||||
| oprs_2 = cgtools.get_oprs_seq(outputs) | |||||
| assert len(oprs_1) - len(oprs_2) == 2 | |||||
| @pytest.mark.skip(reason="need cgtools to check final oprs") | |||||
| def test_goptions_log_sum_exp(): | |||||
| @trace(symbolic=True, opt_level=0, capture_as_const=True) | |||||
| def f(x, y): | |||||
| return log(exp(x) + exp(y)) | |||||
| @trace(symbolic=True, opt_level=1, capture_as_const=True) | |||||
| def g(x, y): | |||||
| return log(exp(x) + exp(y)) | |||||
| f(tensor(1.0), tensor(2.0)) | |||||
| _, out = mkstemp() | |||||
| f.dump(out) | |||||
| *_, outputs = G.load_comp_graph_from_file(out) | |||||
| oprs_1 = cgtools.get_oprs_seq(outputs) | |||||
| g(tensor(1.0), tensor(2.0)) | |||||
| g.dump(out) | |||||
| *_, outputs = G.load_comp_graph_from_file(out) | |||||
| oprs_2 = cgtools.get_oprs_seq(outputs) | |||||
| assert len(oprs_1) - len(oprs_2) == 2 | |||||
| @pytest.mark.skip(reason="need cgtools to check computing input dtype") | |||||
| def test_optimize_for_inference(): | |||||
| @trace(symbolic=True, capture_as_const=True) | |||||
| def f(x): | |||||
| return exp(x) | |||||
| _, out = mkstemp() | |||||
| f(tensor(5.0)) | |||||
| f.dump(out, optimize_for_inference=True, optimize_options={"enable_io16xc32": True}) | |||||
| res = G.load_comp_graph_from_file(out) | |||||
| computing_input = res.output_vars_list[0].owner.inputs[0] | |||||
| assert computing_input.dtype == np.float16 | |||||