GitOrigin-RevId: 4793adf02b
tags/v1.0.0-rc1
| @@ -725,6 +725,12 @@ class trace: | |||||
| raise RuntimeError("trace is not set with profiling=True") | raise RuntimeError("trace is not set with profiling=True") | ||||
| return json.loads(self._profiler.get()) | return json.loads(self._profiler.get()) | ||||
| def trace(self, *args, **kwargs): | |||||
| raise NotImplementedError( | |||||
| "trace is deemed unbeneficial with the new " | |||||
| "tracing mechanism. You should alwasy use __call__." | |||||
| ) | |||||
| class CompiledTensorProxy(RawTensor): | class CompiledTensorProxy(RawTensor): | ||||
| """ | """ | ||||
| @@ -174,7 +174,11 @@ class Module(metaclass=ABCMeta): | |||||
| if "requires_grad" in kwargs: | if "requires_grad" in kwargs: | ||||
| del kwargs["requires_grad"] | del kwargs["requires_grad"] | ||||
| warnings.warn("passing requires_grad has no effect currently") | |||||
| warnings.warn( | |||||
| "Tensor currently has no requires_grad attribute " | |||||
| "so requires_grad argument is ignored here", | |||||
| DeprecationWarning, | |||||
| ) | |||||
| def predicate(obj) -> bool: | def predicate(obj) -> bool: | ||||
| return _is_parameter(obj) | return _is_parameter(obj) | ||||
| @@ -197,7 +201,11 @@ class Module(metaclass=ABCMeta): | |||||
| if "requires_grad" in kwargs: | if "requires_grad" in kwargs: | ||||
| del kwargs["requires_grad"] | del kwargs["requires_grad"] | ||||
| warnings.warn("passing requires_grad has no effect currently") | |||||
| warnings.warn( | |||||
| "Tensor currently has no requires_grad attribute " | |||||
| "so requires_grad argument is ignored here", | |||||
| DeprecationWarning, | |||||
| ) | |||||
| def predicate(obj) -> bool: | def predicate(obj) -> bool: | ||||
| return _is_parameter(obj) | return _is_parameter(obj) | ||||
| @@ -339,6 +347,7 @@ class Module(metaclass=ABCMeta): | |||||
| self.apply(fn) | self.apply(fn) | ||||
| @deprecated(version="1.0") | |||||
| def replace_param( | def replace_param( | ||||
| self, params: dict, start_pos: int, seen: Optional[Set[int]] = None | self, params: dict, start_pos: int, seen: Optional[Set[int]] = None | ||||
| ): | ): | ||||
| @@ -16,6 +16,7 @@ from typing import Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
| from ..utils.deprecation import deprecated | |||||
| class _RequiredParameter: | class _RequiredParameter: | ||||
| @@ -149,8 +150,15 @@ class Optimizer(metaclass=ABCMeta): | |||||
| self._updates(group) | self._updates(group) | ||||
| return self | return self | ||||
| @deprecated(version="1.0", reason="use clear_grad instead") | |||||
| def zero_grad(self): | |||||
| for param_group in self.param_groups: | |||||
| for param in param_group["params"]: | |||||
| if param.grad is not None: | |||||
| param.grad.reset_zero() | |||||
| def clear_grad(self): | def clear_grad(self): | ||||
| r"""Clear the grad buffer. | |||||
| r"""Set the grad attribute to None for all parameters. | |||||
| """ | """ | ||||
| for param_group in self.param_groups: | for param_group in self.param_groups: | ||||
| @@ -224,3 +232,9 @@ class Optimizer(metaclass=ABCMeta): | |||||
| "loaded state dict contains a state that doesn't match " | "loaded state dict contains a state that doesn't match " | ||||
| "the size of optimizer's state" | "the size of optimizer's state" | ||||
| ) | ) | ||||
| def backward(self, loss): | |||||
| raise NotImplementedError("use autodiff.GradManager instead") | |||||
| def bcast_param(self): | |||||
| raise NotImplementedError("use distributed.bcast_list_ instead") | |||||
| @@ -13,6 +13,7 @@ import collections | |||||
| from .core import Tensor as _Tensor | from .core import Tensor as _Tensor | ||||
| from .core.ops.builtin import Copy | from .core.ops.builtin import Copy | ||||
| from .core.tensor.core import apply | from .core.tensor.core import apply | ||||
| from .core.tensor.raw_tensor import as_device | |||||
| from .device import get_default_device | from .device import get_default_device | ||||
| from .utils.deprecation import deprecated | from .utils.deprecation import deprecated | ||||
| @@ -35,7 +36,8 @@ class Tensor(_Tensor): | |||||
| def reset_zero(self): | def reset_zero(self): | ||||
| self *= 0 | self *= 0 | ||||
| def to(self, cn): | |||||
| def to(self, device): | |||||
| cn = as_device(device).to_c() | |||||
| return apply(Copy(comp_node=cn), self)[0] | return apply(Copy(comp_node=cn), self)[0] | ||||
| @property | @property | ||||