GitOrigin-RevId: 5d2d047d2f
tags/v1.7.0
| @@ -159,13 +159,13 @@ class Expr: | |||||
| @property | @property | ||||
| def kwargs(self): | def kwargs(self): | ||||
| r"""Get the the keyword arguments of the operation corresponding to this Expr.""" | |||||
| r"""Get the keyword arguments of the operation corresponding to this Expr.""" | |||||
| _, kwargs = self.unflatten_args(self.inputs) | _, kwargs = self.unflatten_args(self.inputs) | ||||
| return kwargs | return kwargs | ||||
| @property | @property | ||||
| def args(self): | def args(self): | ||||
| r"""Get the the positional arguments of the operation corresponding to this Expr.""" | |||||
| r"""Get the positional arguments of the operation corresponding to this Expr.""" | |||||
| args, _ = self.unflatten_args(self.inputs) | args, _ = self.unflatten_args(self.inputs) | ||||
| return args | return args | ||||
| @@ -33,7 +33,7 @@ class Node: | |||||
| _orig_name = None # type: str | _orig_name = None # type: str | ||||
| _format_spec = "" # type: str | _format_spec = "" # type: str | ||||
| def __init__(self, expr: "Expr", name: str, orig_name: str): | |||||
| def __init__(self, expr, name: str, orig_name: str): | |||||
| self.expr = expr | self.expr = expr | ||||
| self.users = [] # List[Expr] | self.users = [] # List[Expr] | ||||
| self._id = Node.__total_id | self._id = Node.__total_id | ||||
| @@ -120,7 +120,7 @@ class ModuleNode(Node): | |||||
| r"""The type of the Module correspending to the ModuleNode.""" | r"""The type of the Module correspending to the ModuleNode.""" | ||||
| _owner = None # type: weakref.ReferenceType | _owner = None # type: weakref.ReferenceType | ||||
| def __init__(self, expr: "Expr", name: str = None, orig_name: str = None): | |||||
| def __init__(self, expr, name: str = None, orig_name: str = None): | |||||
| super().__init__(expr, name, orig_name) | super().__init__(expr, name, orig_name) | ||||
| def __getstate__(self): | def __getstate__(self): | ||||
| @@ -136,9 +136,6 @@ class ModuleNode(Node): | |||||
| @property | @property | ||||
| def owner(self): | def owner(self): | ||||
| r"""Get the ``Module`` corresponding to this ``ModuleNode``. | r"""Get the ``Module`` corresponding to this ``ModuleNode``. | ||||
| Returns: | |||||
| An :calss:`~.Module`. | |||||
| """ | """ | ||||
| if self._owner: | if self._owner: | ||||
| return self._owner() | return self._owner() | ||||
| @@ -196,7 +193,7 @@ class TensorNode(Node): | |||||
| @property | @property | ||||
| def qparams(self): | def qparams(self): | ||||
| r"""Get the :calss:`QParams` of this Node.""" | |||||
| r"""Get the :class:`QParams` of this Node.""" | |||||
| return self._qparams | return self._qparams | ||||
| @qparams.setter | @qparams.setter | ||||
| @@ -210,11 +207,7 @@ class TensorNode(Node): | |||||
| @value.setter | @value.setter | ||||
| def value(self, value): | def value(self, value): | ||||
| r"""Bind a Tensor to this Node. | |||||
| Args: | |||||
| value: A :class:`Tensor`. | |||||
| """ | |||||
| r"""Bind a :class:`Tensor` to this Node.""" | |||||
| if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None: | if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None: | ||||
| setattr(value, "_NodeMixin__node", None) | setattr(value, "_NodeMixin__node", None) | ||||
| self._value = value | self._value = value | ||||
| @@ -150,8 +150,8 @@ def tree_flatten( | |||||
| is_leaf: Callable = _is_leaf, | is_leaf: Callable = _is_leaf, | ||||
| is_const_leaf: Callable = _is_const_leaf, | is_const_leaf: Callable = _is_const_leaf, | ||||
| ): | ): | ||||
| r"""Flattens a object into a list of values and a :calss:`TreeDef` that can be used | |||||
| to reconstruct the object. | |||||
| r"""Flattens a pytree into a list of values and a :class:`TreeDef` that can be used | |||||
| to reconstruct the pytree. | |||||
| """ | """ | ||||
| if type(values) not in SUPPORTED_TYPE: | if type(values) not in SUPPORTED_TYPE: | ||||
| assert is_leaf(values), values | assert is_leaf(values), values | ||||
| @@ -188,7 +188,7 @@ class TreeDef: | |||||
| self.num_leaves = sum(ch.num_leaves for ch in children_defs) | self.num_leaves = sum(ch.num_leaves for ch in children_defs) | ||||
| def unflatten(self, leaves): | def unflatten(self, leaves): | ||||
| r"""Given a list of values and a ``TreeDef``, builds a object. | |||||
| r"""Given a list of values and a ``TreeDef``, builds a pytree. | |||||
| This is the inverse operation of ``tree_flatten``. | This is the inverse operation of ``tree_flatten``. | ||||
| """ | """ | ||||
| assert len(leaves) == self.num_leaves | assert len(leaves) == self.num_leaves | ||||
| @@ -453,7 +453,7 @@ class InternalGraph: | |||||
| r"""Get the list of output Nodes of this graph. | r"""Get the list of output Nodes of this graph. | ||||
| Returns: | Returns: | ||||
| A list of Node. | |||||
| A list of ``Node``. | |||||
| """ | """ | ||||
| return self._outputs | return self._outputs | ||||
| @@ -1937,7 +1937,7 @@ class TracedModule(Module): | |||||
| @property | @property | ||||
| def graph(self) -> InternalGraph: | def graph(self) -> InternalGraph: | ||||
| """Return the ``InternalGraph`` of this ``TracedModule`` | |||||
| """Return the ``InternalGraph`` of this ``TracedModule``. | |||||
| """ | """ | ||||
| if self._is_top: | if self._is_top: | ||||
| self._update_ref() | self._update_ref() | ||||