GitOrigin-RevId: b29e374c60
tags/v1.6.0
| @@ -1,26 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import functools | |||
| import inspect | |||
| import sys | |||
| import typing | |||
| from abc import ABC | |||
| class OpBase: | |||
| pass | |||
| class TensorBase: | |||
| pass | |||
| class TensorWrapperBase: | |||
| pass | |||
| @@ -20,7 +20,6 @@ from .._imperative_rt import GraphOptimizeOptions | |||
| from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | |||
| from .._wrap import as_device | |||
| from ..ops.builtin import OpDef | |||
| from .core import TensorBase | |||
| def set_priority_to_id(dest_vars): | |||
| @@ -127,7 +126,7 @@ class Graph(_imperative_rt.ComputingGraph): | |||
| print("this function should be called after compilation.") | |||
| class VarNode(TensorBase): | |||
| class VarNode: | |||
| def __init__(self, node: _imperative_rt.VarNode, isscalar=False): | |||
| self._node = node | |||
| self._isscalar = isscalar | |||
| @@ -7,12 +7,31 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # pylint: disable=redefined-builtin | |||
| from typing import Sequence | |||
| from typing import Iterable, List, Sequence | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.ops import builtin | |||
| def extern_opr_subgraph( | |||
| inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, output_dtypes | |||
| ): | |||
| r"""Load a serialized extern opr subgraph and fake execute the operator. | |||
| Args: | |||
| inputs: list of input tensors. | |||
| output_shapes: The output shapes. | |||
| dump_name: The serialized subgraph name. | |||
| dump_data: The serialized subgraph. | |||
| """ | |||
| if not isinstance(inputs, Iterable): | |||
| inputs = (inputs,) | |||
| op = builtin.ExternOpr( | |||
| output_shapes, dump_name, dump_data, len(dump_data), output_dtypes | |||
| ) | |||
| return apply(op, *inputs) | |||
| def tensorrt_runtime_opr(inputs, *, data: bytes = None): | |||
| # empty model will give None result | |||
| if data is None: | |||
| @@ -29,13 +29,13 @@ from ..core._imperative_rt.core2 import ( | |||
| from ..core._imperative_rt.ops import ( | |||
| AssertEqual, | |||
| CollectiveComm, | |||
| ExternOpr, | |||
| RemoteRecv, | |||
| RemoteSend, | |||
| ) | |||
| from ..core._trace_option import set_symbolic_shape | |||
| from ..core._wrap import as_device | |||
| from ..core.ops.builtin import BatchNorm, OpDef | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor import megbrain_graph as G | |||
| from ..core.tensor.utils import setscalar | |||
| from ..utils.naming import AutoNaming | |||
| @@ -129,6 +129,7 @@ class trace: | |||
| function: the function will be traced. | |||
| symbolic: whether to apply symbolic execution for tracing. Default: False | |||
| capture_as_const: capture global vars or closures as const value. Default: False | |||
| record_only: if True, won't run even if call the function. Default: False | |||
| sublinear_memory_config: configuration for sublinear memory optimization. | |||
| If not None, it enables sublinear memory optimization with given setting. | |||
| profiling: whether to profile compiled trace. Default: False | |||
| @@ -147,6 +148,7 @@ class trace: | |||
| function, | |||
| symbolic=False, | |||
| capture_as_const=False, | |||
| record_only=False, | |||
| sublinear_memory_config: SublinearMemoryConfig = None, | |||
| dtr_config: DTRConfig = None, | |||
| profiling: bool = False, | |||
| @@ -155,8 +157,9 @@ class trace: | |||
| symbolic_shape: bool = True, | |||
| ): | |||
| self.__wrapped__ = function | |||
| self._symbolic = symbolic | |||
| self._capture_as_const = capture_as_const | |||
| self._symbolic = symbolic or record_only | |||
| self._capture_as_const = capture_as_const or record_only | |||
| self._record_only = record_only | |||
| self._sublinear_memory_config = sublinear_memory_config | |||
| self._dtr_config = dtr_config | |||
| self._profiling = profiling | |||
| @@ -418,35 +421,40 @@ class trace: | |||
| def do_finalize(): | |||
| escaped_tensors = self._take_escaped_tensors() | |||
| if self._untraced: | |||
| for x in escaped_tensors: | |||
| if x(): | |||
| info = self._tinfo[x()._mixin_handle] | |||
| info.data_read = True | |||
| x()._mixin_handle = -1 | |||
| x()._recording = False | |||
| if self._inputs_to_restore: | |||
| for x in self._inputs_to_restore: | |||
| x._mixin_handle = -1 | |||
| x._recording = False | |||
| if self._symbolic and ( | |||
| self._lazy_eval_tensors or self._lazy_eval_links | |||
| ): | |||
| # eval lazy eval tensors | |||
| self._lazy_eval( | |||
| self._lazy_eval_graph, | |||
| self._lazy_eval_tensors, | |||
| self._lazy_eval_links, | |||
| ) | |||
| if self._record_only: | |||
| self._lazy_eval_graph = None | |||
| self._lazy_eval_tensors = None | |||
| self._lazy_eval_links = None | |||
| self._untraced = False | |||
| else: | |||
| for x in escaped_tensors: | |||
| if x(): | |||
| info = self._tinfo[x()._mixin_handle] | |||
| info.data_read = True | |||
| x()._mixin_handle = -1 | |||
| x()._recording = False | |||
| if self._inputs_to_restore: | |||
| for x in self._inputs_to_restore: | |||
| x._mixin_handle = -1 | |||
| x._recording = False | |||
| if self._symbolic and ( | |||
| self._lazy_eval_tensors or self._lazy_eval_links | |||
| ): | |||
| # eval lazy eval tensors | |||
| self._lazy_eval( | |||
| self._lazy_eval_graph, | |||
| self._lazy_eval_tensors, | |||
| self._lazy_eval_links, | |||
| ) | |||
| self._lazy_eval_graph = None | |||
| self._lazy_eval_tensors = None | |||
| self._lazy_eval_links = None | |||
| self._untraced = False | |||
| else: | |||
| # compiled_tensor leaks | |||
| if self._pc == len(self._seq): | |||
| for x in escaped_tensors: | |||
| try: | |||
| assign_raw_tensor(x(), RawTensor(x()._dev_tensor())) | |||
| x().__init__(RawTensor(x()._dev_tensor())) | |||
| except RuntimeError: | |||
| # TraceMismatchError thrown in do_exit | |||
| pass | |||
| @@ -769,8 +777,8 @@ class trace: | |||
| raise ValueError( | |||
| "you must specify capture_as_const=True at __init__ to use dump" | |||
| ) | |||
| if self._untraced: | |||
| raise RuntimeError("should run at least once before calling dump") | |||
| if self._untraced and len(self._seq) == 0: | |||
| raise RuntimeError("should do record first before dump") | |||
| if self._output_names and output_names: | |||
| raise TypeError( | |||
| "cannot specify output_names when output is already in dict format" | |||
| @@ -1104,10 +1112,6 @@ class CompiledTensorProxy: | |||
| self.__info.data_reader.drop_value() | |||
| def assign_raw_tensor(lhs, rhs): | |||
| lhs.__init__(rhs) | |||
| def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
| graph = active_trace._lazy_eval_graph | |||
| ivars = [] | |||
| @@ -12,11 +12,55 @@ import numpy as np | |||
| from ..functional.external import ( | |||
| atlas_runtime_opr, | |||
| cambricon_runtime_opr, | |||
| extern_opr_subgraph, | |||
| tensorrt_runtime_opr, | |||
| ) | |||
| from .module import Module | |||
| class ExternOprSubgraph(Module): | |||
| r"""Load a serialized ExternOpr subgraph. | |||
| See :func:`~.extern_opr` for more details. | |||
| """ | |||
| def __init__( | |||
| self, output_shapes, dump_name, dump_data, output_dtypes=None, **kwargs | |||
| ): | |||
| super(ExternOprSubgraph, self).__init__(**kwargs) | |||
| self._output_shapes = output_shapes | |||
| self._dump_name = dump_name | |||
| self._dump_data = dump_data | |||
| self._output_dtypes = output_dtypes | |||
| if self._output_dtypes is None: | |||
| self._output_dtypes = [np.float32] * len(output_shapes) | |||
| @property | |||
| def data(self): | |||
| return self._dump_data | |||
| @data.setter | |||
| def data(self, val): | |||
| self._dump_data = np.frombuffer(val, dtype=np.uint8) | |||
| @property | |||
| def name(self): | |||
| return self._dump_name | |||
| @name.setter | |||
| def name(self, val): | |||
| self._dump_name = val | |||
| def forward(self, *inputs): | |||
| return extern_opr_subgraph( | |||
| inputs, | |||
| output_shapes=self._output_shapes, | |||
| dump_name=self._dump_name, | |||
| dump_data=self._dump_data, | |||
| output_dtypes=self._output_dtypes, | |||
| ) | |||
| class TensorrtRuntimeSubgraph(Module): | |||
| r"""Load a serialized TensorrtRuntime subgraph. | |||
| @@ -76,7 +76,7 @@ class XORNet(Module): | |||
| @pytest.mark.parametrize("test_traced_module", [True, False]) | |||
| def test_training_converge(test_traced_module): | |||
| net = XORNet() | |||
| if test_training_converge: | |||
| if test_traced_module: | |||
| inp = Tensor(np.random.random((14, 2))) | |||
| net = trace_module(net, inp) | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * \file imperative/src/impl/ops/extern_opr.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "../op_trait.h" | |||
| #include "megbrain/serialization/extern_c_opr_io.h" | |||
| namespace mgb::imperative { | |||
| namespace { namespace externopr { | |||
| TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) { | |||
| TensorShapeArray ret; | |||
| for (auto&& i:shapes) { | |||
| SmallVector<size_t> shape(i.begin(), i.end()); | |||
| TensorShape shp(shape); | |||
| ret.push_back(shp); | |||
| } | |||
| return ret; | |||
| } | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const ExternOpr&>(def); | |||
| SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end()); | |||
| SmallVector<DType> output_dtypes(op.output_dtypes.begin(), op.output_dtypes.end()); | |||
| auto&& output_shapes = get_shapes(op.output_shapes); | |||
| cg::OperatorNodeBase* opr = opr::ExternCOprRunner::make_placeholder( | |||
| symbol_var_inputs, output_shapes, op.name.c_str(), op.data.c_str(), | |||
| op.data_len, {}, output_dtypes); | |||
| return opr; | |||
| } | |||
| OP_TRAIT_REG(ExternOpr, ExternOpr, opr::ExternCOprRunner) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // externopr | |||
| } // namespace mgb::imperative | |||
| @@ -10,7 +10,7 @@ import argparse | |||
| import numpy as np | |||
| import yaml | |||
| from megengine import jit | |||
| from megengine import jit, tensor | |||
| from megengine.module.external import ExternOprSubgraph | |||
| @@ -27,12 +27,12 @@ def main(): | |||
| description="load a .pb model and convert to corresponding " | |||
| "load-and-run model" | |||
| ) | |||
| parser.add_argument("input", help="mace model file") | |||
| parser.add_argument("param", help="mace param file") | |||
| parser.add_argument("--input", help="mace model file") | |||
| parser.add_argument("--param", help="mace param file") | |||
| parser.add_argument( | |||
| "output", help="converted model that can be fed to dump_with_testcase_mge.py" | |||
| "--output", help="converted model that can be fed to dump_with_testcase_mge.py" | |||
| ) | |||
| parser.add_argument("config", help="config file with yaml format") | |||
| parser.add_argument("--config", help="config file with yaml format") | |||
| args = parser.parse_args() | |||
| with open(args.config, "r") as f: | |||
| @@ -90,17 +90,17 @@ def main(): | |||
| + raw_param | |||
| ) | |||
| net = ExternOprSubgraph(wk_raw_content, "mace", osizes) | |||
| net = ExternOprSubgraph(osizes, "mace", wk_raw_content) | |||
| net.eval() | |||
| @jit.trace(symbolic=True) | |||
| @jit.trace(record_only=True) | |||
| def inference(inputs): | |||
| return net(inputs) | |||
| inputs = [ | |||
| np.random.random(isizes[i]).astype(np.float32) for i in range(len(isizes)) | |||
| tensor(np.random.random(isizes[i]).astype(np.float32)) for i in range(len(isizes)) | |||
| ] | |||
| inference.trace(*inputs) | |||
| inference(*inputs) | |||
| inference.dump(args.output) | |||
| @@ -381,6 +381,24 @@ def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>; | |||
| def FastpathCopy: MgbHashableOp<"FastpathCopy">; | |||
| def ExternOpr: MgbHashableOp<"ExternOpr"> { | |||
| let extraArguments = (ins | |||
| MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$output_shapes, | |||
| MgbStringAttr:$name, | |||
| MgbStringAttr:$data, | |||
| MgbSizeTAddr:$data_len, | |||
| MgbArrayAttr<MgbDTypeAttr>:$output_dtypes | |||
| ); | |||
| let hashFunction = [{ | |||
| return mgb::hash_pair_combine( | |||
| mgb::hash($_self.dyn_typeinfo()), | |||
| mgb::hash_pair_combine( | |||
| mgb::hash($_self.name), | |||
| mgb::hash($_self.data)) | |||
| ); | |||
| }]; | |||
| } | |||
| def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>; | |||
| def Split: MgbHashableOp<"Split", [EmptyParam]> { | |||