| @@ -158,10 +158,12 @@ def set_log_level(level, update_existing=True): | |||||
| update_existing: whether to update existing loggers | update_existing: whether to update existing loggers | ||||
| """ | """ | ||||
| global _default_level # pylint: disable=global-statement | global _default_level # pylint: disable=global-statement | ||||
| origin_level = _default_level | |||||
| _default_level = level | _default_level = level | ||||
| if update_existing: | if update_existing: | ||||
| for i in _all_loggers: | for i in _all_loggers: | ||||
| i.setLevel(level) | i.setLevel(level) | ||||
| return origin_level | |||||
| _logger = get_logger(__name__) | _logger = get_logger(__name__) | ||||
| @@ -32,7 +32,7 @@ class _ConvBnActivation2d(Module): | |||||
| track_running_stats=True, | track_running_stats=True, | ||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| super().__init__() | |||||
| super().__init__(**kwargs) | |||||
| self.conv = Conv2d( | self.conv = Conv2d( | ||||
| in_channels, | in_channels, | ||||
| out_channels, | out_channels, | ||||
| @@ -49,17 +49,13 @@ def _access_structure(obj, key, callback=None): | |||||
| parent = None | parent = None | ||||
| for k in key_list: | for k in key_list: | ||||
| parent = cur | parent = cur | ||||
| if isinstance(cur, (Tensor, Module)): | |||||
| cur = getattr(cur, k) | |||||
| elif isinstance(cur, (list, tuple)): | |||||
| if isinstance(cur, (list, tuple)): | |||||
| k = int(k) | k = int(k) | ||||
| cur = cur[k] | cur = cur[k] | ||||
| elif isinstance(cur, dict): | elif isinstance(cur, dict): | ||||
| cur = cur[k] | cur = cur[k] | ||||
| else: | else: | ||||
| raise ValueError( | |||||
| "Unsupport value type {} to access attribute".format(type(cur)) | |||||
| ) | |||||
| cur = getattr(cur, k) | |||||
| return callback(parent, k, cur) | return callback(parent, k, cur) | ||||
| @@ -650,8 +646,8 @@ class Module(metaclass=ABCMeta): | |||||
| v._name = k | v._name = k | ||||
| elif v._name != k: | elif v._name != k: | ||||
| logger.warning( | logger.warning( | ||||
| "try setting the submodule `{}` to a new attribute `{}`, its name `{}` will remain unchanged".format( | |||||
| v._name, k, v._name | |||||
| "try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format( | |||||
| type(v), type(self), k, v._name | |||||
| ) | ) | ||||
| ) | ) | ||||
| super().__setattr__(name, value) | super().__setattr__(name, value) | ||||
| @@ -111,10 +111,8 @@ class QParams: | |||||
| return "QParams({})".format(content) | return "QParams({})".format(content) | ||||
| class LSQParams: | |||||
| r"""To standardize LSQ's qparams format. If custom | |||||
| qparams is needed, inherit this class and add custom ``__slots__``. | |||||
| """ | |||||
| class LSQParams(QParams): | |||||
| r"""LSQ qparams with extra grad_scale slot.""" | |||||
| __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" | __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" | ||||
| @@ -126,30 +124,9 @@ class LSQParams: | |||||
| zero_point: Tensor, | zero_point: Tensor, | ||||
| grad_scale: Tensor, | grad_scale: Tensor, | ||||
| ): | ): | ||||
| self.mode = mode | |||||
| self.dtype_meta = dtype_meta | |||||
| self.scale = scale | |||||
| self.zero_point = zero_point | |||||
| super().__init__(mode, dtype_meta, scale, zero_point) | |||||
| self.grad_scale = grad_scale | self.grad_scale = grad_scale | ||||
| def update(self, lsqparams: "LSQParams"): | |||||
| for key in self.__slots__: | |||||
| setattr(self, key, getattr(lsqparams, key)) | |||||
| def __eq__(self, other): | |||||
| if len(self.__slots__) != len(other.__slots__): | |||||
| return False | |||||
| for key in self.__slots__: | |||||
| if not hasattr(other, key) or getattr(self, key) != getattr(other, key): | |||||
| return False | |||||
| return True | |||||
| def __repr__(self): | |||||
| content = ", ".join( | |||||
| ["{}={}".format(key, getattr(self, key)) for key in self.__slots__] | |||||
| ) | |||||
| return "LSQParams({})".format(content) | |||||
| class QParamsModuleMixin(abc.ABC): | class QParamsModuleMixin(abc.ABC): | ||||
| def get_quantized_dtype(self): | def get_quantized_dtype(self): | ||||
| @@ -642,7 +642,6 @@ class InternalGraph: | |||||
| Returns: | Returns: | ||||
| A :class:`~.TracedModule.NodeFilterType`. | A :class:`~.TracedModule.NodeFilterType`. | ||||
| """ | """ | ||||
| assert issubclass(module_cls, Module) | |||||
| return self.nodes(recursive).type(module_cls) | return self.nodes(recursive).type(module_cls) | ||||
| def get_node_by_id(self, node_id: List[int] = None, recursive=True): | def get_node_by_id(self, node_id: List[int] = None, recursive=True): | ||||
| @@ -96,6 +96,12 @@ class _ModuleList(Module, MutableSequence): | |||||
| raise IndexError("list index out of range") | raise IndexError("list index out of range") | ||||
| return rst if len(rst) > 1 else rst[0] | return rst if len(rst) > 1 else rst[0] | ||||
| def __setattr__(self, key, value): | |||||
| # clear mod name to avoid warning in Module's setattr | |||||
| if isinstance(value, Module): | |||||
| value._name = None | |||||
| super().__setattr__(key, value) | |||||
| def __setitem__(self, idx: int, mod: Module): | def __setitem__(self, idx: int, mod: Module): | ||||
| if not isinstance(mod, Module): | if not isinstance(mod, Module): | ||||
| raise ValueError("invalid sub-module") | raise ValueError("invalid sub-module") | ||||
| @@ -159,6 +165,12 @@ class _ModuleDict(Module, MutableMapping): | |||||
| def __getitem__(self, key): | def __getitem__(self, key): | ||||
| return getattr(self, key) | return getattr(self, key) | ||||
| def __setattr__(self, key, value): | |||||
| # clear mod name to avoid warning in Module's setattr | |||||
| if isinstance(value, Module): | |||||
| value._name = None | |||||
| super().__setattr__(key, value) | |||||
| def __setitem__(self, key, value): | def __setitem__(self, key, value): | ||||
| if not isinstance(value, Module): | if not isinstance(value, Module): | ||||
| raise ValueError("invalid sub-module") | raise ValueError("invalid sub-module") | ||||