GitOrigin-RevId: b29e374c60
tags/v1.6.0
| @@ -1,26 +0,0 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | |||||
| import functools | |||||
| import inspect | |||||
| import sys | |||||
| import typing | |||||
| from abc import ABC | |||||
| class OpBase: | |||||
| pass | |||||
| class TensorBase: | |||||
| pass | |||||
| class TensorWrapperBase: | |||||
| pass | |||||
| @@ -20,7 +20,6 @@ from .._imperative_rt import GraphOptimizeOptions | |||||
| from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | ||||
| from .._wrap import as_device | from .._wrap import as_device | ||||
| from ..ops.builtin import OpDef | from ..ops.builtin import OpDef | ||||
| from .core import TensorBase | |||||
| def set_priority_to_id(dest_vars): | def set_priority_to_id(dest_vars): | ||||
| @@ -127,7 +126,7 @@ class Graph(_imperative_rt.ComputingGraph): | |||||
| print("this function should be called after compilation.") | print("this function should be called after compilation.") | ||||
| class VarNode(TensorBase): | |||||
| class VarNode: | |||||
| def __init__(self, node: _imperative_rt.VarNode, isscalar=False): | def __init__(self, node: _imperative_rt.VarNode, isscalar=False): | ||||
| self._node = node | self._node = node | ||||
| self._isscalar = isscalar | self._isscalar = isscalar | ||||
| @@ -7,12 +7,31 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # pylint: disable=redefined-builtin | # pylint: disable=redefined-builtin | ||||
| from typing import Sequence | |||||
| from typing import Iterable, List, Sequence | |||||
| from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| def extern_opr_subgraph( | |||||
| inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, output_dtypes | |||||
| ): | |||||
| r"""Load a serialized extern opr subgraph and fake execute the operator. | |||||
| Args: | |||||
| inputs: list of input tensors. | |||||
| output_shapes: The output shapes. | |||||
| dump_name: The serialized subgraph name. | |||||
| dump_data: The serialized subgraph. | |||||
| """ | |||||
| if not isinstance(inputs, Iterable): | |||||
| inputs = (inputs,) | |||||
| op = builtin.ExternOpr( | |||||
| output_shapes, dump_name, dump_data, len(dump_data), output_dtypes | |||||
| ) | |||||
| return apply(op, *inputs) | |||||
| def tensorrt_runtime_opr(inputs, *, data: bytes = None): | def tensorrt_runtime_opr(inputs, *, data: bytes = None): | ||||
| # empty model will give None result | # empty model will give None result | ||||
| if data is None: | if data is None: | ||||
| @@ -29,13 +29,13 @@ from ..core._imperative_rt.core2 import ( | |||||
| from ..core._imperative_rt.ops import ( | from ..core._imperative_rt.ops import ( | ||||
| AssertEqual, | AssertEqual, | ||||
| CollectiveComm, | CollectiveComm, | ||||
| ExternOpr, | |||||
| RemoteRecv, | RemoteRecv, | ||||
| RemoteSend, | RemoteSend, | ||||
| ) | ) | ||||
| from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
| from ..core._wrap import as_device | from ..core._wrap import as_device | ||||
| from ..core.ops.builtin import BatchNorm, OpDef | from ..core.ops.builtin import BatchNorm, OpDef | ||||
| from ..core.ops.special import Const | |||||
| from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
| from ..core.tensor.utils import setscalar | from ..core.tensor.utils import setscalar | ||||
| from ..utils.naming import AutoNaming | from ..utils.naming import AutoNaming | ||||
| @@ -129,6 +129,7 @@ class trace: | |||||
| function: the function will be traced. | function: the function will be traced. | ||||
| symbolic: whether to apply symbolic execution for tracing. Default: False | symbolic: whether to apply symbolic execution for tracing. Default: False | ||||
| capture_as_const: capture global vars or closures as const value. Default: False | capture_as_const: capture global vars or closures as const value. Default: False | ||||
| record_only: if True, won't run even if call the function. Default: False | |||||
| sublinear_memory_config: configuration for sublinear memory optimization. | sublinear_memory_config: configuration for sublinear memory optimization. | ||||
| If not None, it enables sublinear memory optimization with given setting. | If not None, it enables sublinear memory optimization with given setting. | ||||
| profiling: whether to profile compiled trace. Default: False | profiling: whether to profile compiled trace. Default: False | ||||
| @@ -147,6 +148,7 @@ class trace: | |||||
| function, | function, | ||||
| symbolic=False, | symbolic=False, | ||||
| capture_as_const=False, | capture_as_const=False, | ||||
| record_only=False, | |||||
| sublinear_memory_config: SublinearMemoryConfig = None, | sublinear_memory_config: SublinearMemoryConfig = None, | ||||
| dtr_config: DTRConfig = None, | dtr_config: DTRConfig = None, | ||||
| profiling: bool = False, | profiling: bool = False, | ||||
| @@ -155,8 +157,9 @@ class trace: | |||||
| symbolic_shape: bool = True, | symbolic_shape: bool = True, | ||||
| ): | ): | ||||
| self.__wrapped__ = function | self.__wrapped__ = function | ||||
| self._symbolic = symbolic | |||||
| self._capture_as_const = capture_as_const | |||||
| self._symbolic = symbolic or record_only | |||||
| self._capture_as_const = capture_as_const or record_only | |||||
| self._record_only = record_only | |||||
| self._sublinear_memory_config = sublinear_memory_config | self._sublinear_memory_config = sublinear_memory_config | ||||
| self._dtr_config = dtr_config | self._dtr_config = dtr_config | ||||
| self._profiling = profiling | self._profiling = profiling | ||||
| @@ -418,35 +421,40 @@ class trace: | |||||
| def do_finalize(): | def do_finalize(): | ||||
| escaped_tensors = self._take_escaped_tensors() | escaped_tensors = self._take_escaped_tensors() | ||||
| if self._untraced: | if self._untraced: | ||||
| for x in escaped_tensors: | |||||
| if x(): | |||||
| info = self._tinfo[x()._mixin_handle] | |||||
| info.data_read = True | |||||
| x()._mixin_handle = -1 | |||||
| x()._recording = False | |||||
| if self._inputs_to_restore: | |||||
| for x in self._inputs_to_restore: | |||||
| x._mixin_handle = -1 | |||||
| x._recording = False | |||||
| if self._symbolic and ( | |||||
| self._lazy_eval_tensors or self._lazy_eval_links | |||||
| ): | |||||
| # eval lazy eval tensors | |||||
| self._lazy_eval( | |||||
| self._lazy_eval_graph, | |||||
| self._lazy_eval_tensors, | |||||
| self._lazy_eval_links, | |||||
| ) | |||||
| if self._record_only: | |||||
| self._lazy_eval_graph = None | self._lazy_eval_graph = None | ||||
| self._lazy_eval_tensors = None | self._lazy_eval_tensors = None | ||||
| self._lazy_eval_links = None | self._lazy_eval_links = None | ||||
| self._untraced = False | |||||
| else: | |||||
| for x in escaped_tensors: | |||||
| if x(): | |||||
| info = self._tinfo[x()._mixin_handle] | |||||
| info.data_read = True | |||||
| x()._mixin_handle = -1 | |||||
| x()._recording = False | |||||
| if self._inputs_to_restore: | |||||
| for x in self._inputs_to_restore: | |||||
| x._mixin_handle = -1 | |||||
| x._recording = False | |||||
| if self._symbolic and ( | |||||
| self._lazy_eval_tensors or self._lazy_eval_links | |||||
| ): | |||||
| # eval lazy eval tensors | |||||
| self._lazy_eval( | |||||
| self._lazy_eval_graph, | |||||
| self._lazy_eval_tensors, | |||||
| self._lazy_eval_links, | |||||
| ) | |||||
| self._lazy_eval_graph = None | |||||
| self._lazy_eval_tensors = None | |||||
| self._lazy_eval_links = None | |||||
| self._untraced = False | |||||
| else: | else: | ||||
| # compiled_tensor leaks | # compiled_tensor leaks | ||||
| if self._pc == len(self._seq): | if self._pc == len(self._seq): | ||||
| for x in escaped_tensors: | for x in escaped_tensors: | ||||
| try: | try: | ||||
| assign_raw_tensor(x(), RawTensor(x()._dev_tensor())) | |||||
| x().__init__(RawTensor(x()._dev_tensor())) | |||||
| except RuntimeError: | except RuntimeError: | ||||
| # TraceMismatchError thrown in do_exit | # TraceMismatchError thrown in do_exit | ||||
| pass | pass | ||||
| @@ -769,8 +777,8 @@ class trace: | |||||
| raise ValueError( | raise ValueError( | ||||
| "you must specify capture_as_const=True at __init__ to use dump" | "you must specify capture_as_const=True at __init__ to use dump" | ||||
| ) | ) | ||||
| if self._untraced: | |||||
| raise RuntimeError("should run at least once before calling dump") | |||||
| if self._untraced and len(self._seq) == 0: | |||||
| raise RuntimeError("should do record first before dump") | |||||
| if self._output_names and output_names: | if self._output_names and output_names: | ||||
| raise TypeError( | raise TypeError( | ||||
| "cannot specify output_names when output is already in dict format" | "cannot specify output_names when output is already in dict format" | ||||
| @@ -1104,10 +1112,6 @@ class CompiledTensorProxy: | |||||
| self.__info.data_reader.drop_value() | self.__info.data_reader.drop_value() | ||||
| def assign_raw_tensor(lhs, rhs): | |||||
| lhs.__init__(rhs) | |||||
| def apply_symbolic_mode(op: OpDef, *args: RawTensor): | def apply_symbolic_mode(op: OpDef, *args: RawTensor): | ||||
| graph = active_trace._lazy_eval_graph | graph = active_trace._lazy_eval_graph | ||||
| ivars = [] | ivars = [] | ||||
| @@ -12,11 +12,55 @@ import numpy as np | |||||
| from ..functional.external import ( | from ..functional.external import ( | ||||
| atlas_runtime_opr, | atlas_runtime_opr, | ||||
| cambricon_runtime_opr, | cambricon_runtime_opr, | ||||
| extern_opr_subgraph, | |||||
| tensorrt_runtime_opr, | tensorrt_runtime_opr, | ||||
| ) | ) | ||||
| from .module import Module | from .module import Module | ||||
| class ExternOprSubgraph(Module): | |||||
| r"""Load a serialized ExternOpr subgraph. | |||||
| See :func:`~.extern_opr` for more details. | |||||
| """ | |||||
| def __init__( | |||||
| self, output_shapes, dump_name, dump_data, output_dtypes=None, **kwargs | |||||
| ): | |||||
| super(ExternOprSubgraph, self).__init__(**kwargs) | |||||
| self._output_shapes = output_shapes | |||||
| self._dump_name = dump_name | |||||
| self._dump_data = dump_data | |||||
| self._output_dtypes = output_dtypes | |||||
| if self._output_dtypes is None: | |||||
| self._output_dtypes = [np.float32] * len(output_shapes) | |||||
| @property | |||||
| def data(self): | |||||
| return self._dump_data | |||||
| @data.setter | |||||
| def data(self, val): | |||||
| self._dump_data = np.frombuffer(val, dtype=np.uint8) | |||||
| @property | |||||
| def name(self): | |||||
| return self._dump_name | |||||
| @name.setter | |||||
| def name(self, val): | |||||
| self._dump_name = val | |||||
| def forward(self, *inputs): | |||||
| return extern_opr_subgraph( | |||||
| inputs, | |||||
| output_shapes=self._output_shapes, | |||||
| dump_name=self._dump_name, | |||||
| dump_data=self._dump_data, | |||||
| output_dtypes=self._output_dtypes, | |||||
| ) | |||||
| class TensorrtRuntimeSubgraph(Module): | class TensorrtRuntimeSubgraph(Module): | ||||
| r"""Load a serialized TensorrtRuntime subgraph. | r"""Load a serialized TensorrtRuntime subgraph. | ||||
| @@ -76,7 +76,7 @@ class XORNet(Module): | |||||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | @pytest.mark.parametrize("test_traced_module", [True, False]) | ||||
| def test_training_converge(test_traced_module): | def test_training_converge(test_traced_module): | ||||
| net = XORNet() | net = XORNet() | ||||
| if test_training_converge: | |||||
| if test_traced_module: | |||||
| inp = Tensor(np.random.random((14, 2))) | inp = Tensor(np.random.random((14, 2))) | ||||
| net = trace_module(net, inp) | net = trace_module(net, inp) | ||||
| @@ -0,0 +1,52 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/ops/extern_opr.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "../op_trait.h" | |||||
| #include "megbrain/serialization/extern_c_opr_io.h" | |||||
| namespace mgb::imperative { | |||||
| namespace { namespace externopr { | |||||
| TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) { | |||||
| TensorShapeArray ret; | |||||
| for (auto&& i:shapes) { | |||||
| SmallVector<size_t> shape(i.begin(), i.end()); | |||||
| TensorShape shp(shape); | |||||
| ret.push_back(shp); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| auto apply_on_var_node( | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& op = static_cast<const ExternOpr&>(def); | |||||
| SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end()); | |||||
| SmallVector<DType> output_dtypes(op.output_dtypes.begin(), op.output_dtypes.end()); | |||||
| auto&& output_shapes = get_shapes(op.output_shapes); | |||||
| cg::OperatorNodeBase* opr = opr::ExternCOprRunner::make_placeholder( | |||||
| symbol_var_inputs, output_shapes, op.name.c_str(), op.data.c_str(), | |||||
| op.data_len, {}, output_dtypes); | |||||
| return opr; | |||||
| } | |||||
| OP_TRAIT_REG(ExternOpr, ExternOpr, opr::ExternCOprRunner) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| }} // externopr | |||||
| } // namespace mgb::imperative | |||||
| @@ -10,7 +10,7 @@ import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import yaml | import yaml | ||||
| from megengine import jit | |||||
| from megengine import jit, tensor | |||||
| from megengine.module.external import ExternOprSubgraph | from megengine.module.external import ExternOprSubgraph | ||||
| @@ -27,12 +27,12 @@ def main(): | |||||
| description="load a .pb model and convert to corresponding " | description="load a .pb model and convert to corresponding " | ||||
| "load-and-run model" | "load-and-run model" | ||||
| ) | ) | ||||
| parser.add_argument("input", help="mace model file") | |||||
| parser.add_argument("param", help="mace param file") | |||||
| parser.add_argument("--input", help="mace model file") | |||||
| parser.add_argument("--param", help="mace param file") | |||||
| parser.add_argument( | parser.add_argument( | ||||
| "output", help="converted model that can be fed to dump_with_testcase_mge.py" | |||||
| "--output", help="converted model that can be fed to dump_with_testcase_mge.py" | |||||
| ) | ) | ||||
| parser.add_argument("config", help="config file with yaml format") | |||||
| parser.add_argument("--config", help="config file with yaml format") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| with open(args.config, "r") as f: | with open(args.config, "r") as f: | ||||
| @@ -90,17 +90,17 @@ def main(): | |||||
| + raw_param | + raw_param | ||||
| ) | ) | ||||
| net = ExternOprSubgraph(wk_raw_content, "mace", osizes) | |||||
| net = ExternOprSubgraph(osizes, "mace", wk_raw_content) | |||||
| net.eval() | net.eval() | ||||
| @jit.trace(symbolic=True) | |||||
| @jit.trace(record_only=True) | |||||
| def inference(inputs): | def inference(inputs): | ||||
| return net(inputs) | return net(inputs) | ||||
| inputs = [ | inputs = [ | ||||
| np.random.random(isizes[i]).astype(np.float32) for i in range(len(isizes)) | |||||
| tensor(np.random.random(isizes[i]).astype(np.float32)) for i in range(len(isizes)) | |||||
| ] | ] | ||||
| inference.trace(*inputs) | |||||
| inference(*inputs) | |||||
| inference.dump(args.output) | inference.dump(args.output) | ||||
| @@ -381,6 +381,24 @@ def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>; | |||||
| def FastpathCopy: MgbHashableOp<"FastpathCopy">; | def FastpathCopy: MgbHashableOp<"FastpathCopy">; | ||||
| def ExternOpr: MgbHashableOp<"ExternOpr"> { | |||||
| let extraArguments = (ins | |||||
| MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$output_shapes, | |||||
| MgbStringAttr:$name, | |||||
| MgbStringAttr:$data, | |||||
| MgbSizeTAddr:$data_len, | |||||
| MgbArrayAttr<MgbDTypeAttr>:$output_dtypes | |||||
| ); | |||||
| let hashFunction = [{ | |||||
| return mgb::hash_pair_combine( | |||||
| mgb::hash($_self.dyn_typeinfo()), | |||||
| mgb::hash_pair_combine( | |||||
| mgb::hash($_self.name), | |||||
| mgb::hash($_self.data)) | |||||
| ); | |||||
| }]; | |||||
| } | |||||
| def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>; | def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>; | ||||
| def Split: MgbHashableOp<"Split", [EmptyParam]> { | def Split: MgbHashableOp<"Split", [EmptyParam]> { | ||||