|
|
|
@@ -11,6 +11,7 @@ import json |
|
|
|
import threading |
|
|
|
import weakref |
|
|
|
from concurrent.futures import Future, ThreadPoolExecutor |
|
|
|
from typing import Dict, List, Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
@@ -85,6 +86,97 @@ class Graph(_imperative_rt.ComputingGraph): |
|
|
|
return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name)) |
|
|
|
|
|
|
|
|
|
|
|
class VarNode(TensorBase): |
|
|
|
def __init__(self, node: _imperative_rt.VarNode): |
|
|
|
self._node = node |
|
|
|
if hasattr(self.graph, "_var_cache"): |
|
|
|
self.graph._var_cache[node] = self |
|
|
|
|
|
|
|
@property |
|
|
|
def graph(self) -> Graph: |
|
|
|
return self._node.graph |
|
|
|
|
|
|
|
@property |
|
|
|
def op(self): |
|
|
|
if hasattr(self.graph, "_wrap"): |
|
|
|
return self.graph._wrap(self._node.owner) |
|
|
|
else: |
|
|
|
return self._node.owner |
|
|
|
|
|
|
|
@property |
|
|
|
def name(self): |
|
|
|
return self._node.name |
|
|
|
|
|
|
|
@property |
|
|
|
def id(self): |
|
|
|
return self._node.id |
|
|
|
|
|
|
|
@name.setter |
|
|
|
def name(self, name): |
|
|
|
self._node.name = name |
|
|
|
|
|
|
|
@property |
|
|
|
def dtype(self): |
|
|
|
return self._node.dtype |
|
|
|
|
|
|
|
@property |
|
|
|
def device(self): |
|
|
|
return as_device(self._node.comp_node) |
|
|
|
|
|
|
|
@property |
|
|
|
def shape(self): |
|
|
|
return self._node.shape |
|
|
|
|
|
|
|
@property |
|
|
|
def value(self): |
|
|
|
return self._node.value |
|
|
|
|
|
|
|
|
|
|
|
class OpNode: |
|
|
|
def __init__(self, node: _imperative_rt.OperatorNode): |
|
|
|
self._node = node |
|
|
|
if hasattr(self.graph, "_op_cache"): |
|
|
|
self.graph._op_cache[node] = self |
|
|
|
|
|
|
|
@property |
|
|
|
def graph(self) -> Graph: |
|
|
|
return self._node.graph |
|
|
|
|
|
|
|
@property |
|
|
|
def name(self): |
|
|
|
return self._node.name |
|
|
|
|
|
|
|
@property |
|
|
|
def id(self): |
|
|
|
return self._node.id |
|
|
|
|
|
|
|
@name.setter |
|
|
|
def name(self, name): |
|
|
|
self._node.name = name |
|
|
|
|
|
|
|
@property |
|
|
|
def inputs(self): |
|
|
|
if hasattr(self.graph, "_wrap"): |
|
|
|
return tuple(map(self.graph._wrap, self._node.inputs)) |
|
|
|
else: |
|
|
|
return self._node.inputs |
|
|
|
|
|
|
|
@property |
|
|
|
def outputs(self): |
|
|
|
if hasattr(self.graph, "_wrap"): |
|
|
|
return tuple(map(self.graph._wrap, self._node.outputs)) |
|
|
|
else: |
|
|
|
return self._node.outputs |
|
|
|
|
|
|
|
@property |
|
|
|
def params(self): |
|
|
|
return json.loads(self._node.params) |
|
|
|
|
|
|
|
@property |
|
|
|
def type(self): |
|
|
|
return self._node.type |
|
|
|
|
|
|
|
|
|
|
|
def optimize_for_inference(dest_vars, **kwargs): |
|
|
|
r"""Applies optimize_for_inference pass for computing graph. |
|
|
|
|
|
|
|
@@ -162,8 +254,100 @@ def optimize_for_inference(dest_vars, **kwargs): |
|
|
|
return [VarNode(i) for i in res_vars] |
|
|
|
|
|
|
|
|
|
|
|
def dump_graph(*args): |
|
|
|
return _imperative_rt.dump_graph([i._node for i in args]) |
|
|
|
CompGraphDumpResult = collections.namedtuple( |
|
|
|
"CompGraphDumpResult", |
|
|
|
[ |
|
|
|
"nr_opr", |
|
|
|
"tot_bytes", |
|
|
|
"tensor_value_bytes", |
|
|
|
"content_hash", |
|
|
|
"inputs", |
|
|
|
"outputs", |
|
|
|
"params", |
|
|
|
], |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def dump_graph( |
|
|
|
output_vars: Union[Dict[str, VarNode], List[VarNode]], |
|
|
|
*, |
|
|
|
keep_var_name: int = 1, |
|
|
|
keep_param_name: bool = False, |
|
|
|
keep_opr_priority: bool = False, |
|
|
|
strip_info_file=None, |
|
|
|
): |
|
|
|
"""serialize the computing graph of `output_vars` and get byte result. |
|
|
|
|
|
|
|
:param output_vars: output variables which are the graph's end point. |
|
|
|
|
|
|
|
.. note:: |
|
|
|
|
|
|
|
The underlying C++ API only accepts a var list. If a dict is given, |
|
|
|
the vars would be renamed to the given names. |
|
|
|
|
|
|
|
:param keep_var_name: level for keeping variable names: |
|
|
|
|
|
|
|
* 0: none of the names are kept |
|
|
|
* 1: (default)keep names of output vars |
|
|
|
* 2: keep names of all (output and internal) vars |
|
|
|
:param keep_param_name: whether to keep param names, so param values can be |
|
|
|
easily manipulated after loading model |
|
|
|
:param keep_opr_priority: whether to keep priority setting for operators |
|
|
|
:param strip_info_file: a string for path or a file handler. if is not None, |
|
|
|
then the dump information for code strip would be written to ``strip_info_file`` |
|
|
|
:return: dump result as byte string, and an instance of namedtuple |
|
|
|
:class:`CompGraphDumpResult`, whose fields are: |
|
|
|
|
|
|
|
* ``nr_opr`` number of operators dumped |
|
|
|
* ``tot_bytes`` total bytes for the whole graph |
|
|
|
* ``tensor_value_bytes`` bytes consumed for dumping tensor values |
|
|
|
* ``inputs`` names of input tensors |
|
|
|
* ``params`` list of names of dumped params |
|
|
|
* ``outputs`` names of output vars |
|
|
|
""" |
|
|
|
ov = [] |
|
|
|
if isinstance(output_vars, dict): |
|
|
|
used_vars = set() |
|
|
|
for name, var in output_vars.items(): |
|
|
|
assert isinstance(var, VarNode), "bad output var: {!r}".format(var) |
|
|
|
assert var.id not in used_vars, ( |
|
|
|
"var name is associated with a var object, so we can not have " |
|
|
|
"two names given to the same var: {}".format(var) |
|
|
|
) |
|
|
|
used_vars.add(var.id) |
|
|
|
var.name = name |
|
|
|
ov.append(var._node) |
|
|
|
else: |
|
|
|
for var in output_vars: |
|
|
|
assert isinstance(var, VarNode), "bad output var: {!r}".format(var) |
|
|
|
ov.append(var._node) |
|
|
|
|
|
|
|
stat = [] |
|
|
|
inputs = [] |
|
|
|
outputs = [] |
|
|
|
params = [] |
|
|
|
|
|
|
|
dump_content = _imperative_rt.dump_graph( |
|
|
|
ov, |
|
|
|
keep_var_name, |
|
|
|
keep_param_name, |
|
|
|
keep_opr_priority, |
|
|
|
stat, |
|
|
|
inputs, |
|
|
|
outputs, |
|
|
|
params, |
|
|
|
) |
|
|
|
|
|
|
|
dump_info = CompGraphDumpResult(*stat, inputs, outputs, params) |
|
|
|
|
|
|
|
if strip_info_file is not None: |
|
|
|
if isinstance(strip_info_file, str): |
|
|
|
strip_info_file = open(strip_info_file, "w") |
|
|
|
strip_info = json.loads(_imperative_rt.get_info_for_strip(ov)) |
|
|
|
strip_info["hash"] = dump_info.content_hash |
|
|
|
json.dump(strip_info, strip_info_file) |
|
|
|
|
|
|
|
return dump_content, dump_info |
|
|
|
|
|
|
|
|
|
|
|
CompGraphLoadResult = collections.namedtuple( |
|
|
|
@@ -193,97 +377,6 @@ def load_graph(fpath): |
|
|
|
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list) |
|
|
|
|
|
|
|
|
|
|
|
class VarNode(TensorBase): |
|
|
|
def __init__(self, node: _imperative_rt.VarNode): |
|
|
|
self._node = node |
|
|
|
if hasattr(self.graph, "_var_cache"): |
|
|
|
self.graph._var_cache[node] = self |
|
|
|
|
|
|
|
@property |
|
|
|
def graph(self) -> Graph: |
|
|
|
return self._node.graph |
|
|
|
|
|
|
|
@property |
|
|
|
def op(self): |
|
|
|
if hasattr(self.graph, "_wrap"): |
|
|
|
return self.graph._wrap(self._node.owner) |
|
|
|
else: |
|
|
|
return self._node.owner |
|
|
|
|
|
|
|
@property |
|
|
|
def name(self): |
|
|
|
return self._node.name |
|
|
|
|
|
|
|
@property |
|
|
|
def id(self): |
|
|
|
return self._node.id |
|
|
|
|
|
|
|
@name.setter |
|
|
|
def name(self, name): |
|
|
|
self._node.name = name |
|
|
|
|
|
|
|
@property |
|
|
|
def dtype(self): |
|
|
|
return self._node.dtype |
|
|
|
|
|
|
|
@property |
|
|
|
def device(self): |
|
|
|
return as_device(self._node.comp_node) |
|
|
|
|
|
|
|
@property |
|
|
|
def shape(self): |
|
|
|
return self._node.shape |
|
|
|
|
|
|
|
@property |
|
|
|
def value(self): |
|
|
|
return self._node.value |
|
|
|
|
|
|
|
|
|
|
|
class OpNode: |
|
|
|
def __init__(self, node: _imperative_rt.OperatorNode): |
|
|
|
self._node = node |
|
|
|
if hasattr(self.graph, "_op_cache"): |
|
|
|
self.graph._op_cache[node] = self |
|
|
|
|
|
|
|
@property |
|
|
|
def graph(self) -> Graph: |
|
|
|
return self._node.graph |
|
|
|
|
|
|
|
@property |
|
|
|
def name(self): |
|
|
|
return self._node.name |
|
|
|
|
|
|
|
@property |
|
|
|
def id(self): |
|
|
|
return self._node.id |
|
|
|
|
|
|
|
@name.setter |
|
|
|
def name(self, name): |
|
|
|
self._node.name = name |
|
|
|
|
|
|
|
@property |
|
|
|
def inputs(self): |
|
|
|
if hasattr(self.graph, "_wrap"): |
|
|
|
return tuple(map(self.graph._wrap, self._node.inputs)) |
|
|
|
else: |
|
|
|
return self._node.inputs |
|
|
|
|
|
|
|
@property |
|
|
|
def outputs(self): |
|
|
|
if hasattr(self.graph, "_wrap"): |
|
|
|
return tuple(map(self.graph._wrap, self._node.outputs)) |
|
|
|
else: |
|
|
|
return self._node.outputs |
|
|
|
|
|
|
|
@property |
|
|
|
def params(self): |
|
|
|
return json.loads(self._node.params) |
|
|
|
|
|
|
|
@property |
|
|
|
def type(self): |
|
|
|
return self._node.type |
|
|
|
|
|
|
|
|
|
|
|
def _wrap(x): |
|
|
|
if isinstance(x, collections.abc.Sequence): |
|
|
|
return type(x)(map(_wrap, x)) |
|
|
|
|