|
|
|
@@ -29,14 +29,20 @@ from ...core._imperative_rt.core2 import ( |
|
|
|
from ...core._trace_option import set_symbolic_shape |
|
|
|
from ...core.tensor.array_method import ArrayMethodMixin |
|
|
|
from ...module import Module |
|
|
|
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize |
|
|
|
from ...module.qat import QATModule |
|
|
|
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize |
|
|
|
from ...quantization.observer import ( |
|
|
|
ExponentialMovingAverageObserver, |
|
|
|
HistogramObserver, |
|
|
|
MinMaxObserver, |
|
|
|
Observer, |
|
|
|
PassiveObserver, |
|
|
|
SyncExponentialMovingAverageObserver, |
|
|
|
SyncMinMaxObserver, |
|
|
|
) |
|
|
|
from ...tensor import Tensor |
|
|
|
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input |
|
|
|
from .fake_quant import FakeQuantize as TM_FakeQuant |
|
|
|
from .module_tracer import ( |
|
|
|
Patcher, |
|
|
|
active_module_tracer, |
|
|
|
@@ -613,7 +619,8 @@ def _wrapped_function(orig_func): |
|
|
|
if isinstance(i, (RawTensor, NodeMixin)): |
|
|
|
NodeMixin.wrap_safe(i, Constant.make(i)) |
|
|
|
meth_name = _get_meth_name(args[0], wrapped_fn) if args else None |
|
|
|
if meth_name: |
|
|
|
arg_type = args[0] if isinstance(args[0], type) else type(args[0]) |
|
|
|
if meth_name and issubclass(arg_type, RawTensor): |
|
|
|
self = inputs[0] |
|
|
|
if meth_name == "__new__": |
|
|
|
if all([not isinstance(i, RawTensor) for i in inputs]): |
|
|
|
@@ -680,7 +687,15 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
self._mod = mod |
|
|
|
self._body = None |
|
|
|
self._is_top = is_top_module |
|
|
|
self._is_builtin = module_tracer.is_builtin(mod) |
|
|
|
self._is_builtin = ( |
|
|
|
True |
|
|
|
if isinstance(mod, (Observer, _FakeQuantize)) |
|
|
|
else module_tracer.is_builtin(mod) |
|
|
|
) |
|
|
|
if isinstance(self._mod, QATModule): |
|
|
|
unset_module_tracing() |
|
|
|
self._check_qat_module(self._mod) |
|
|
|
set_module_tracing() |
|
|
|
self._argdef_graph_map = {} |
|
|
|
self._argdef_outdef_map = {} |
|
|
|
|
|
|
|
@@ -693,15 +708,65 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
dict(TracedModuleBuilder.__dict__), |
|
|
|
) |
|
|
|
|
|
|
|
def _check_qat_module(self, qat_module): |
|
|
|
def isbuiltin(m): |
|
|
|
return m is None or module_tracer.is_builtin(m) |
|
|
|
|
|
|
|
if qat_module.with_act: |
|
|
|
act_observer = qat_module.act_observer |
|
|
|
act_fakequant = qat_module.act_fake_quant |
|
|
|
if not isbuiltin(act_observer) or not isbuiltin(act_fakequant): |
|
|
|
qparams = ( |
|
|
|
act_observer.get_qparams() |
|
|
|
if hasattr(act_observer, "get_qparams") |
|
|
|
else act_fakequant.get_qparams() |
|
|
|
) |
|
|
|
dtype = ( |
|
|
|
act_observer.dtype |
|
|
|
if hasattr(act_observer, "dtype") |
|
|
|
else act_fakequant.dtype |
|
|
|
) |
|
|
|
qat_module.act_observer = None |
|
|
|
qat_module.act_fake_quant = TM_FakeQuant(dtype) |
|
|
|
qat_module.act_fake_quant.set_qparams(qparams) |
|
|
|
|
|
|
|
if qat_module.with_weight: |
|
|
|
weight_observer = qat_module.weight_observer |
|
|
|
weight_fakequant = qat_module.weight_fake_quant |
|
|
|
if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant): |
|
|
|
qparams = ( |
|
|
|
weight_observer.get_qparams() |
|
|
|
if hasattr(weight_observer, "get_qparams") |
|
|
|
else weight_fakequant.get_qparams() |
|
|
|
) |
|
|
|
dtype = ( |
|
|
|
weight_observer.dtype |
|
|
|
if hasattr(weight_observer, "dtype") |
|
|
|
else weight_fakequant.dtype |
|
|
|
) |
|
|
|
qat_module.weight_observer = None |
|
|
|
qat_module.weight_fake_quant = TM_FakeQuant(dtype) |
|
|
|
qat_module.weight_fake_quant.set_qparams(qparams) |
|
|
|
|
|
|
|
def build(self): |
|
|
|
if self._is_builtin or isinstance(self._mod, TracedModule): |
|
|
|
if module_tracer.is_builtin(self._mod) or isinstance( |
|
|
|
self._mod, TracedModule |
|
|
|
): |
|
|
|
mod_type = type(self._mod) |
|
|
|
else: |
|
|
|
assert isinstance(self._mod, (Observer, _FakeQuantize)) |
|
|
|
mod_type = ( |
|
|
|
Observer if isinstance(self._mod, Observer) else _FakeQuantize |
|
|
|
) |
|
|
|
for node in self.nodes: |
|
|
|
node.module_type = type(self._mod) |
|
|
|
# node._owner = weakref.ref(self._mod) |
|
|
|
node.module_type = mod_type |
|
|
|
|
|
|
|
return self._mod |
|
|
|
else: |
|
|
|
is_qat = isinstance(self._mod, QATModule) |
|
|
|
traced_module = TracedModule( |
|
|
|
self._is_top, self._argdef_graph_map, self._argdef_outdef_map |
|
|
|
self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat |
|
|
|
) |
|
|
|
for _, g in self._argdef_graph_map.items(): |
|
|
|
g.compile() |
|
|
|
@@ -712,6 +777,20 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
v = v.build() |
|
|
|
setattr(traced_module, k, v) |
|
|
|
|
|
|
|
if isinstance(self._mod, QATModule): |
|
|
|
unset_module_tracing() |
|
|
|
traced_module.with_act = self._mod.with_act |
|
|
|
traced_module.with_weight = self._mod.with_weight |
|
|
|
if not hasattr(traced_module, "act_fake_quant"): |
|
|
|
traced_module.act_fakequant = None |
|
|
|
if not hasattr(traced_module, "act_observer"): |
|
|
|
traced_module.act_observer = None |
|
|
|
if not hasattr(traced_module, "weight_fake_quant"): |
|
|
|
traced_module.weight_fakequant = None |
|
|
|
if not hasattr(traced_module, "weight_observer"): |
|
|
|
traced_module.weight_observer = None |
|
|
|
set_module_tracing() |
|
|
|
|
|
|
|
return traced_module |
|
|
|
|
|
|
|
def _record_wrapped_nodes(self, node): |
|
|
|
@@ -846,7 +925,8 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
attr = getattr(self._mod, name) |
|
|
|
if isinstance(attr, Module): |
|
|
|
attr = TracedModuleBuilder(attr) |
|
|
|
setattr(self, name, attr) |
|
|
|
if isinstance(attr, (Module, RawTensor)): |
|
|
|
setattr(self, name, attr) |
|
|
|
NodeMixin.wrap( |
|
|
|
attr, |
|
|
|
lambda: GetAttr.make( |
|
|
|
@@ -1066,7 +1146,7 @@ class TracedModule(Module): |
|
|
|
argdef_graph_map = None |
|
|
|
argdef_outdef_map = None |
|
|
|
|
|
|
|
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map): |
|
|
|
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False): |
|
|
|
super(TracedModule, self).__init__() |
|
|
|
self.argdef_graph_map = argdef_graph_map |
|
|
|
self.argdef_outdef_map = argdef_outdef_map |
|
|
|
@@ -1074,6 +1154,7 @@ class TracedModule(Module): |
|
|
|
self.watch_points = [] |
|
|
|
self.watch_node_value = {} |
|
|
|
self.end_points = [] |
|
|
|
self.is_qat = is_qat |
|
|
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
inputs, treedef = tree_flatten(((self, *args), kwargs)) |
|
|
|
@@ -1195,8 +1276,8 @@ class TracedModule(Module): |
|
|
|
): |
|
|
|
if graph is not None and prefix_name and prefix_name[-1] != "_": |
|
|
|
prefix_name += "_" |
|
|
|
if graph is None: |
|
|
|
assert not isinstance(module, TracedModule) |
|
|
|
if graph is None or module.is_qat: |
|
|
|
assert not isinstance(module, TracedModule) or module.is_qat |
|
|
|
const = Constant(module, "self.%s" % module2name[id(module)]) |
|
|
|
m_node = call.inputs[0] |
|
|
|
if m_node.top_graph != active_module_tracer().current_scope(): |
|
|
|
@@ -1326,9 +1407,23 @@ def _register_all_builtin_module(): |
|
|
|
isclass(m[1]) |
|
|
|
and issubclass(m[1], M.Module) |
|
|
|
and m[1] is not M.Sequential |
|
|
|
and m[1] is not M.ModuleList |
|
|
|
): |
|
|
|
module_tracer.register_as_builtin(m[1]) |
|
|
|
|
|
|
|
module_tracer.register_as_builtin(Observer) |
|
|
|
module_tracer.register_as_builtin(MinMaxObserver) |
|
|
|
module_tracer.register_as_builtin(SyncMinMaxObserver) |
|
|
|
module_tracer.register_as_builtin(ExponentialMovingAverageObserver) |
|
|
|
module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver) |
|
|
|
module_tracer.register_as_builtin(HistogramObserver) |
|
|
|
module_tracer.register_as_builtin(PassiveObserver) |
|
|
|
|
|
|
|
module_tracer.register_as_builtin(LSQ) |
|
|
|
module_tracer.register_as_builtin(TQT) |
|
|
|
module_tracer.register_as_builtin(FakeQuantize) |
|
|
|
module_tracer.register_as_builtin(TM_FakeQuant) |
|
|
|
|
|
|
|
|
|
|
|
def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: |
|
|
|
""" |
|
|
|
|