GitOrigin-RevId: 62eb3bfb10
tags/v1.9.0
| @@ -138,11 +138,7 @@ class Module(metaclass=ABCMeta): | |||||
| return HookHandler(self._forward_hooks, hook) | return HookHandler(self._forward_hooks, hook) | ||||
| def __call__(self, *inputs, **kwargs): | def __call__(self, *inputs, **kwargs): | ||||
| AutoNaming.push_scope( | |||||
| self.name | |||||
| if self.name is not None | |||||
| else (self._short_name if hasattr(self, "_short_name") else self._name) | |||||
| ) | |||||
| AutoNaming.push_scope(self.name if self.name is not None else self._short_name) | |||||
| for hook in self._forward_pre_hooks.values(): | for hook in self._forward_pre_hooks.values(): | ||||
| modified_inputs = hook(self, inputs) | modified_inputs = hook(self, inputs) | ||||
| if modified_inputs is not None: | if modified_inputs is not None: | ||||
| @@ -685,6 +681,12 @@ class Module(metaclass=ABCMeta): | |||||
| set_name(self, prefix, k, v) | set_name(self, prefix, k, v) | ||||
| super().__setattr__(name, value) | super().__setattr__(name, value) | ||||
| def __setstate__(self, state): | |||||
| if "_short_name" not in state: | |||||
| state["_short_name"] = state["_name"] | |||||
| state["_name"] = None | |||||
| self.__dict__.update(state) | |||||
| def __delattr__(self, name: str): | def __delattr__(self, name: str): | ||||
| if name in self.__dict__ and _is_module(self.__dict__[name]): | if name in self.__dict__ and _is_module(self.__dict__[name]): | ||||
| modules = self.__dict__.get("_modules") | modules = self.__dict__.get("_modules") | ||||
| @@ -50,7 +50,7 @@ class _ModuleState: | |||||
| if self.obj is None: | if self.obj is None: | ||||
| typem = getattr(import_module(self.module[0]), self.module[1]) | typem = getattr(import_module(self.module[0]), self.module[1]) | ||||
| m_obj = typem.__new__(typem) | m_obj = typem.__new__(typem) | ||||
| m_obj.__dict__.update(self.state) | |||||
| m_obj.__setstate__(self.state) | |||||
| self.obj = m_obj | self.obj = m_obj | ||||
| return self.obj | return self.obj | ||||
| @@ -1681,11 +1681,13 @@ class TracedModuleBuilder(NodeMixin): | |||||
| if isinstance(wrapped, TracedModuleBuilder): | if isinstance(wrapped, TracedModuleBuilder): | ||||
| if not isinstance(mod_attr, (List, Dict, QATModule)): | if not isinstance(mod_attr, (List, Dict, QATModule)): | ||||
| assert mod_attr is wrapped._mod | |||||
| else: | |||||
| assert ( | |||||
| mod_attr is wrapped._mod | |||||
| ), "TracedModule do not support modify module attributes, please check your code." | |||||
| if isinstance(wrapped, RawTensor): | |||||
| assert ( | assert ( | ||||
| mod_attr is wrapped | mod_attr is wrapped | ||||
| ), "TracedModule do not support modify attributes, please check your code." | |||||
| ), "TracedModule do not support modify tensor attributes, please check your code." | |||||
| if isinstance(wrapped, (NodeMixin, RawTensor)): | if isinstance(wrapped, (NodeMixin, RawTensor)): | ||||
| NodeMixin.wrap( | NodeMixin.wrap( | ||||
| @@ -2296,7 +2298,7 @@ class TracedModule(Module): | |||||
| for k, v in state.items(): | for k, v in state.items(): | ||||
| if isinstance(v, _ModuleState): | if isinstance(v, _ModuleState): | ||||
| state[k] = v.to_module() | state[k] = v.to_module() | ||||
| self.__dict__.update(state) | |||||
| super().__setstate__(state) | |||||
| self._update_ref() | self._update_ref() | ||||
| for _, graph in self.argdef_graph_map.items(): | for _, graph in self.argdef_graph_map.items(): | ||||
| @@ -87,3 +87,17 @@ def test_compatibility(): | |||||
| test_old_tensor("tensor_v1_1.mge") | test_old_tensor("tensor_v1_1.mge") | ||||
| test_old_tensor("tensor_v1_2.mge") | test_old_tensor("tensor_v1_2.mge") | ||||
| t = mge.tensor([1]) | |||||
| getattr(t, "qparams") | |||||
| new_args = t.__getnewargs__() | |||||
| assert ( | |||||
| len(new_args) == 3 | |||||
| and isinstance(new_args[0], np.ndarray) | |||||
| and new_args[1] == np.int32 | |||||
| and isinstance(new_args[2], str) | |||||
| ), "Modify Tensor __getnewargs__ may break pickle serialization compatible" | |||||
| state = t.__getstate__() | |||||
| assert set(state.keys()) == set( | |||||
| ["qparams"] | |||||
| ), "Modify Tensor __getstate__ may break pickle serialization compatible" | |||||
| @@ -681,3 +681,27 @@ def test_repr_module_reset_attr(): | |||||
| m1 = ResetAttrModule(False) | m1 = ResetAttrModule(False) | ||||
| output = [m0.__repr__(), m1.__repr__()] | output = [m0.__repr__(), m1.__repr__()] | ||||
| assert output == ground_truth | assert output == ground_truth | ||||
| def test_module_compatible(): | |||||
| class Empty(Module): | |||||
| def forward(self): | |||||
| pass | |||||
| empty_module = Empty() | |||||
| old_attributes = set( | |||||
| [ | |||||
| "_modules", | |||||
| "name", | |||||
| "training", | |||||
| "quantize_disabled", | |||||
| "_forward_pre_hooks", | |||||
| "_forward_hooks", | |||||
| "_name", | |||||
| "_short_name", | |||||
| ] | |||||
| ) | |||||
| current_attributes = set(empty_module.__dict__.keys()) | |||||
| assert ( | |||||
| old_attributes == current_attributes | |||||
| ), "Add or delete attributes in Module class may break compatibility of pickle serialization" | |||||