| @@ -6,6 +6,7 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import collections | import collections | ||||
| import copy | |||||
| import functools | import functools | ||||
| import itertools | import itertools | ||||
| import weakref | import weakref | ||||
| @@ -674,6 +675,22 @@ class Tensor: | |||||
| snd = mgb.make_shared(device, value=data, dtype=dtype) | snd = mgb.make_shared(device, value=data, dtype=dtype) | ||||
| self._reset(snd, requires_grad=requires_grad) | self._reset(snd, requires_grad=requires_grad) | ||||
| def __deepcopy__(self, memo): | |||||
| """ | |||||
| Since Tensor have __getstate__ and __setstate__ method, | |||||
| deepcopy only process the that and ignore the attribute of Parameter. | |||||
| So we need to add __deepcopy__ method to deepcopy correct attribute. | |||||
| """ | |||||
| assert (self.__val is not None) and ( | |||||
| self.__sym is None | |||||
| ), "Only SharedND initialized Tensor can be serialized or deep copied" | |||||
| cls = self.__class__ | |||||
| result = cls.__new__(cls) | |||||
| memo[id(self)] = result | |||||
| for k, v in self.__dict__.items(): | |||||
| setattr(result, k, copy.deepcopy(v, memo)) | |||||
| return result | |||||
| def tensor( | def tensor( | ||||
| data: Union[list, np.ndarray] = None, | data: Union[list, np.ndarray] = None, | ||||