GitOrigin-RevId: 5ac525f010
tags/v1.5.0
| @@ -609,14 +609,6 @@ class Module(metaclass=ABCMeta): | |||||
| return set(loaded), set(skipped) | return set(loaded), set(skipped) | ||||
| def __getattribute__(self, name: str): | |||||
| value = super().__getattribute__(name) | |||||
| if name == "__dict__": | |||||
| return value | |||||
| for prefix, variable in _expand_structure(name, value): | |||||
| variable._name = prefix | |||||
| return value | |||||
| def __setattr__(self, name: str, value): | def __setattr__(self, name: str, value): | ||||
| is_module_like = _is_module(value) or isinstance(value, (list, tuple, dict)) | is_module_like = _is_module(value) or isinstance(value, (list, tuple, dict)) | ||||
| if name != "_modules": | if name != "_modules": | ||||
| @@ -631,6 +623,15 @@ class Module(metaclass=ABCMeta): | |||||
| else: | else: | ||||
| if modules is not None and name in modules: | if modules is not None and name in modules: | ||||
| modules.remove(name) | modules.remove(name) | ||||
| for k, v in _expand_structure(name, value): | |||||
| if not v._name: | |||||
| v._name = k | |||||
| else: | |||||
| logger.warning( | |||||
| "try setting the submodule `{}` to a new attribute `{}`, its name `{}` will remain unchanged".format( | |||||
| v._name, k, v._name | |||||
| ) | |||||
| ) | |||||
| super().__setattr__(name, value) | super().__setattr__(name, value) | ||||
| def __delattr__(self, name: str): | def __delattr__(self, name: str): | ||||
| @@ -368,10 +368,10 @@ class AssertModule(Module): | |||||
| def test_assert_message(): | def test_assert_message(): | ||||
| m = AssertModule() | |||||
| with pytest.raises( | with pytest.raises( | ||||
| AssertionError, match="keys for Tensor and Module must be str, error key: True" | AssertionError, match="keys for Tensor and Module must be str, error key: True" | ||||
| ): | ): | ||||
| m = AssertModule() | |||||
| list(m._flatten()) | list(m._flatten()) | ||||
| @@ -155,13 +155,13 @@ def test_with_submodule_in_container(symbolic): | |||||
| m = Simple("simple") | m = Simple("simple") | ||||
| ops = _dump_and_load(m, symbolic) | ops = _dump_and_load(m, symbolic) | ||||
| assert ops[-1].outputs[0].name == "simple.l2.l2-1.ADD" | |||||
| assert ops[-1].name == "simple.l2.l2-1.ADD" | |||||
| assert ops[-2].name == "simple.l2.l2-1.MatrixMul" | |||||
| assert ops[-3].name == "simple.l1.1.ADD" | |||||
| assert ops[-4].name == "simple.l1.1.MatrixMul" | |||||
| assert ops[-5].name == "simple.l0.1.ADD" | |||||
| assert ops[-6].name == "simple.l0.1.MatrixMul" | |||||
| assert ops[-1].outputs[0].name == "simple.l0.1.ADD[2]" | |||||
| assert ops[-1].name == "simple.l0.1.ADD[2]" | |||||
| assert ops[-2].name == "simple.l0.1.MatrixMul[2]" | |||||
| assert ops[-3].name == "simple.l0.1.ADD[1]" | |||||
| assert ops[-4].name == "simple.l0.1.MatrixMul[1]" | |||||
| assert ops[-5].name == "simple.l0.1.ADD[0]" | |||||
| assert ops[-6].name == "simple.l0.1.MatrixMul[0]" | |||||
| @pytest.mark.parametrize("symbolic", [False, True]) | @pytest.mark.parametrize("symbolic", [False, True]) | ||||