GitOrigin-RevId: 7245e669a7
tags/v1.3.0
| @@ -48,8 +48,8 @@ class Softmax(Module): | |||||
| """ | """ | ||||
| def __init__(self, axis=None): | |||||
| super().__init__() | |||||
| def __init__(self, axis=None, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.axis = axis | self.axis = axis | ||||
| def forward(self, inputs): | def forward(self, inputs): | ||||
| @@ -167,8 +167,8 @@ class PReLU(Module): | |||||
| """ | """ | ||||
| def __init__(self, num_parameters: int = 1, init: float = 0.25): | |||||
| super().__init__() | |||||
| def __init__(self, num_parameters: int = 1, init: float = 0.25, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.num_parameters = num_parameters | self.num_parameters = num_parameters | ||||
| if num_parameters > 1: | if num_parameters > 1: | ||||
| # Assume format is NCHW | # Assume format is NCHW | ||||
| @@ -225,8 +225,8 @@ class LeakyReLU(Module): | |||||
| """ | """ | ||||
| def __init__(self, negative_slope: float = 0.01): | |||||
| super().__init__() | |||||
| def __init__(self, negative_slope: float = 0.01, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.negative_slope = negative_slope | self.negative_slope = negative_slope | ||||
| def forward(self, inputs): | def forward(self, inputs): | ||||
| @@ -15,10 +15,8 @@ from .module import Module | |||||
| class _AdaptivePoolNd(Module): | class _AdaptivePoolNd(Module): | ||||
| def __init__( | |||||
| self, oshp: Union[Tuple[int, int], int, Tensor], | |||||
| ): | |||||
| super(_AdaptivePoolNd, self).__init__() | |||||
| def __init__(self, oshp: Union[Tuple[int, int], int, Tensor], **kwargs): | |||||
| super(_AdaptivePoolNd, self).__init__(**kwargs) | |||||
| self.oshp = oshp | self.oshp = oshp | ||||
| @abstractmethod | @abstractmethod | ||||
| @@ -26,8 +26,9 @@ class _BatchNorm(Module): | |||||
| affine=True, | affine=True, | ||||
| track_running_stats=True, | track_running_stats=True, | ||||
| freeze=False, | freeze=False, | ||||
| **kwargs | |||||
| ): | ): | ||||
| super(_BatchNorm, self).__init__() | |||||
| super(_BatchNorm, self).__init__(**kwargs) | |||||
| self.num_features = num_features | self.num_features = num_features | ||||
| self.eps = eps | self.eps = eps | ||||
| self.momentum = momentum | self.momentum = momentum | ||||
| @@ -151,9 +152,10 @@ class SyncBatchNorm(_BatchNorm): | |||||
| track_running_stats=True, | track_running_stats=True, | ||||
| freeze=False, | freeze=False, | ||||
| group: Optional[Group] = WORLD, | group: Optional[Group] = WORLD, | ||||
| **kwargs | |||||
| ) -> None: | ) -> None: | ||||
| super().__init__( | super().__init__( | ||||
| num_features, eps, momentum, affine, track_running_stats, freeze | |||||
| num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs | |||||
| ) | ) | ||||
| self.group = group | self.group = group | ||||
| @@ -37,8 +37,9 @@ class _ConvNd(Module): | |||||
| dilation: Union[int, Tuple[int, int]], | dilation: Union[int, Tuple[int, int]], | ||||
| groups: int, | groups: int, | ||||
| bias: bool = True, | bias: bool = True, | ||||
| **kwargs | |||||
| ): | ): | ||||
| super().__init__() | |||||
| super().__init__(**kwargs) | |||||
| if in_channels % groups != 0: | if in_channels % groups != 0: | ||||
| raise ValueError("in_channels must be divisible by groups") | raise ValueError("in_channels must be divisible by groups") | ||||
| if out_channels % groups != 0: | if out_channels % groups != 0: | ||||
| @@ -176,6 +177,7 @@ class Conv1d(_ConvNd): | |||||
| bias: bool = True, | bias: bool = True, | ||||
| conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
| compute_mode: str = "DEFAULT", | compute_mode: str = "DEFAULT", | ||||
| **kwargs | |||||
| ): | ): | ||||
| kernel_size = kernel_size | kernel_size = kernel_size | ||||
| stride = stride | stride = stride | ||||
| @@ -192,6 +194,7 @@ class Conv1d(_ConvNd): | |||||
| dilation, | dilation, | ||||
| groups, | groups, | ||||
| bias, | bias, | ||||
| **kwargs, | |||||
| ) | ) | ||||
| def _get_fanin(self): | def _get_fanin(self): | ||||
| @@ -334,6 +337,7 @@ class Conv2d(_ConvNd): | |||||
| bias: bool = True, | bias: bool = True, | ||||
| conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
| compute_mode: str = "DEFAULT", | compute_mode: str = "DEFAULT", | ||||
| **kwargs | |||||
| ): | ): | ||||
| kernel_size = _pair_nonzero(kernel_size) | kernel_size = _pair_nonzero(kernel_size) | ||||
| stride = _pair_nonzero(stride) | stride = _pair_nonzero(stride) | ||||
| @@ -350,6 +354,7 @@ class Conv2d(_ConvNd): | |||||
| dilation, | dilation, | ||||
| groups, | groups, | ||||
| bias, | bias, | ||||
| **kwargs, | |||||
| ) | ) | ||||
| def _get_fanin(self): | def _get_fanin(self): | ||||
| @@ -444,6 +449,7 @@ class ConvTranspose2d(_ConvNd): | |||||
| bias: bool = True, | bias: bool = True, | ||||
| conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
| compute_mode: str = "DEFAULT", | compute_mode: str = "DEFAULT", | ||||
| **kwargs | |||||
| ): | ): | ||||
| kernel_size = _pair_nonzero(kernel_size) | kernel_size = _pair_nonzero(kernel_size) | ||||
| stride = _pair_nonzero(stride) | stride = _pair_nonzero(stride) | ||||
| @@ -460,6 +466,7 @@ class ConvTranspose2d(_ConvNd): | |||||
| dilation, | dilation, | ||||
| groups, | groups, | ||||
| bias, | bias, | ||||
| **kwargs, | |||||
| ) | ) | ||||
| def _get_fanin(self): | def _get_fanin(self): | ||||
| @@ -536,6 +543,7 @@ class LocalConv2d(Conv2d): | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
| **kwargs | |||||
| ): | ): | ||||
| self.input_height = input_height | self.input_height = input_height | ||||
| self.input_width = input_width | self.input_width = input_width | ||||
| @@ -548,6 +556,7 @@ class LocalConv2d(Conv2d): | |||||
| dilation, | dilation, | ||||
| groups, | groups, | ||||
| bias=False, | bias=False, | ||||
| **kwargs, | |||||
| ) | ) | ||||
| def _infer_weight_shape(self): | def _infer_weight_shape(self): | ||||
| @@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module): | |||||
| momentum=0.9, | momentum=0.9, | ||||
| affine=True, | affine=True, | ||||
| track_running_stats=True, | track_running_stats=True, | ||||
| **kwargs | |||||
| ): | ): | ||||
| super().__init__() | super().__init__() | ||||
| self.conv = Conv2d( | self.conv = Conv2d( | ||||
| @@ -43,6 +44,7 @@ class _ConvBnActivation2d(Module): | |||||
| bias, | bias, | ||||
| conv_mode, | conv_mode, | ||||
| compute_mode, | compute_mode, | ||||
| **kwargs, | |||||
| ) | ) | ||||
| self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | ||||
| @@ -20,8 +20,8 @@ class Dropout(Module): | |||||
| :param drop_prob: The probability to drop (set to zero) each single element | :param drop_prob: The probability to drop (set to zero) each single element | ||||
| """ | """ | ||||
| def __init__(self, drop_prob=0.0): | |||||
| super().__init__() | |||||
| def __init__(self, drop_prob=0.0, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.drop_prob = drop_prob | self.drop_prob = drop_prob | ||||
| def forward(self, inputs): | def forward(self, inputs): | ||||
| @@ -72,8 +72,8 @@ class Elemwise(Module): | |||||
| * "NOT": bool unary: ~x | * "NOT": bool unary: ~x | ||||
| """ | """ | ||||
| def __init__(self, method): | |||||
| super().__init__() | |||||
| def __init__(self, method, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.method = method | self.method = method | ||||
| def forward(self, *inps): | def forward(self, *inps): | ||||
| @@ -64,8 +64,9 @@ class Embedding(Module): | |||||
| norm_type: Optional[float] = None, | norm_type: Optional[float] = None, | ||||
| initial_weight: Parameter = None, | initial_weight: Parameter = None, | ||||
| freeze: bool = False, | freeze: bool = False, | ||||
| **kwargs | |||||
| ): | ): | ||||
| super().__init__() | |||||
| super().__init__(**kwargs) | |||||
| if padding_idx is not None: | if padding_idx is not None: | ||||
| raise ValueError("Not support padding index now.") | raise ValueError("Not support padding index now.") | ||||
| if max_norm is not None or norm_type is not None: | if max_norm is not None or norm_type is not None: | ||||
| @@ -19,10 +19,8 @@ class TensorrtRuntimeSubgraph(Module): | |||||
| See :func:`~.tensorrt_runtime_opr` for more details. | See :func:`~.tensorrt_runtime_opr` for more details. | ||||
| """ | """ | ||||
| def __init__( | |||||
| self, data, | |||||
| ): | |||||
| super(TensorrtRuntimeSubgraph, self).__init__() | |||||
| def __init__(self, data, **kwargs): | |||||
| super(TensorrtRuntimeSubgraph, self).__init__(**kwargs) | |||||
| self._data = data | self._data = data | ||||
| @property | @property | ||||
| @@ -20,8 +20,8 @@ class GroupNorm(Module): | |||||
| Reference: https://arxiv.org/pdf/1803.08494.pdf. | Reference: https://arxiv.org/pdf/1803.08494.pdf. | ||||
| """ | """ | ||||
| def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): | |||||
| super().__init__() | |||||
| def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| assert num_channels % num_groups == 0 | assert num_channels % num_groups == 0 | ||||
| self.num_groups = num_groups | self.num_groups = num_groups | ||||
| self.num_channels = num_channels | self.num_channels = num_channels | ||||
| @@ -70,8 +70,8 @@ class InstanceNorm(Module): | |||||
| Note that InstanceNorm equals using GroupNome with num_groups=num_channels. | Note that InstanceNorm equals using GroupNome with num_groups=num_channels. | ||||
| """ | """ | ||||
| def __init__(self, num_channels, eps=1e-05, affine=True): | |||||
| super().__init__() | |||||
| def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.num_channels = num_channels | self.num_channels = num_channels | ||||
| self.eps = eps | self.eps = eps | ||||
| self.affine = affine | self.affine = affine | ||||
| @@ -114,8 +114,8 @@ class LayerNorm(Module): | |||||
| Note that LayerNorm equals using GroupNorm with num_groups=1. | Note that LayerNorm equals using GroupNorm with num_groups=1. | ||||
| """ | """ | ||||
| def __init__(self, num_channels, eps=1e-05, affine=True): | |||||
| super().__init__() | |||||
| def __init__(self, num_channels, eps=1e-05, affine=True, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.num_channels = num_channels | self.num_channels = num_channels | ||||
| self.eps = eps | self.eps = eps | ||||
| self.affine = affine | self.affine = affine | ||||
| @@ -19,8 +19,9 @@ class _PoolNd(Module): | |||||
| kernel_size: Union[int, Tuple[int, int]], | kernel_size: Union[int, Tuple[int, int]], | ||||
| stride: Union[int, Tuple[int, int]] = None, | stride: Union[int, Tuple[int, int]] = None, | ||||
| padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
| **kwargs | |||||
| ): | ): | ||||
| super(_PoolNd, self).__init__() | |||||
| super(_PoolNd, self).__init__(**kwargs) | |||||
| self.kernel_size = kernel_size | self.kernel_size = kernel_size | ||||
| self.stride = stride or kernel_size | self.stride = stride or kernel_size | ||||
| self.padding = padding | self.padding = padding | ||||
| @@ -46,8 +46,8 @@ class Sequential(Module): | |||||
| pred1 = net1(data) | pred1 = net1(data) | ||||
| """ | """ | ||||
| def __init__(self, *args): | |||||
| super().__init__() | |||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.layer_keys = [] | self.layer_keys = [] | ||||
| if len(args) == 1 and isinstance(args[0], OrderedDict): | if len(args) == 1 and isinstance(args[0], OrderedDict): | ||||
| for key, module in args[0].items(): | for key, module in args[0].items(): | ||||