GitOrigin-RevId: dbc1f27ff7
tags/v1.4.0-rc1
| @@ -65,7 +65,7 @@ def _elwise(*args, mode): | |||||
| def _matmul(inp1, inp2): | def _matmul(inp1, inp2): | ||||
| op = builtin.MatrixMul( | op = builtin.MatrixMul( | ||||
| transposeA=False, transposeB=False, compute_mode="DEFAULT", format="DEFAULT" | |||||
| transposeA=False, transposeB=False, compute_mode="default", format="default" | |||||
| ) | ) | ||||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | inp1, inp2 = utils.convert_inputs(inp1, inp2) | ||||
| (result,) = apply(op, inp1, inp2) | (result,) = apply(op, inp1, inp2) | ||||
| @@ -178,7 +178,7 @@ def _reduce(mode): | |||||
| def f(self, axis=None, keepdims: bool = False): | def f(self, axis=None, keepdims: bool = False): | ||||
| data = self | data = self | ||||
| (data,) = utils.convert_inputs(data) | (data,) = utils.convert_inputs(data) | ||||
| if mode == "MEAN": | |||||
| if mode == "mean": | |||||
| data = data.astype("float32") | data = data.astype("float32") | ||||
| elif self.dtype == np.bool_: | elif self.dtype == np.bool_: | ||||
| data = data.astype("int32") | data = data.astype("int32") | ||||
| @@ -204,7 +204,7 @@ def _reduce(mode): | |||||
| if not keepdims: | if not keepdims: | ||||
| result = _remove_axis(result, axis) | result = _remove_axis(result, axis) | ||||
| if self.dtype == np.bool_: | if self.dtype == np.bool_: | ||||
| if mode in ["MIN", "MAX"]: | |||||
| if mode in ["min", "max"]: | |||||
| result = result.astype("bool") | result = result.astype("bool") | ||||
| if axis is None or self.ndim == 1: | if axis is None or self.ndim == 1: | ||||
| setscalar(result) | setscalar(result) | ||||
| @@ -479,7 +479,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
| 10.0 | 10.0 | ||||
| """ | """ | ||||
| return _reduce("SUM")(self, axis, keepdims) | |||||
| return _reduce("sum")(self, axis, keepdims) | |||||
| def prod(self, axis=None, keepdims: bool = False): | def prod(self, axis=None, keepdims: bool = False): | ||||
| r""" | r""" | ||||
| @@ -512,7 +512,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
| 24.0 | 24.0 | ||||
| """ | """ | ||||
| return _reduce("PRODUCT")(self, axis, keepdims) | |||||
| return _reduce("product")(self, axis, keepdims) | |||||
| def min(self, axis=None, keepdims: bool = False): | def min(self, axis=None, keepdims: bool = False): | ||||
| r""" | r""" | ||||
| @@ -545,7 +545,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
| 1.0 | 1.0 | ||||
| """ | """ | ||||
| return _reduce("MIN")(self, axis, keepdims) | |||||
| return _reduce("min")(self, axis, keepdims) | |||||
| def max(self, axis=None, keepdims: bool = False): | def max(self, axis=None, keepdims: bool = False): | ||||
| r""" | r""" | ||||
| @@ -578,7 +578,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
| 4.0 | 4.0 | ||||
| """ | """ | ||||
| return _reduce("MAX")(self, axis, keepdims) | |||||
| return _reduce("max")(self, axis, keepdims) | |||||
| def mean(self, axis=None, keepdims: bool = False): | def mean(self, axis=None, keepdims: bool = False): | ||||
| r""" | r""" | ||||
| @@ -611,4 +611,4 @@ class ArrayMethodMixin(abc.ABC): | |||||
| 2.5 | 2.5 | ||||
| """ | """ | ||||
| return _reduce("MEAN")(self, axis, keepdims) | |||||
| return _reduce("mean")(self, axis, keepdims) | |||||
| @@ -267,6 +267,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | |||||
| 1.5 | 1.5 | ||||
| """ | """ | ||||
| norm = norm.upper() | |||||
| assert norm in ["L1", "L2"], "norm must be L1 or L2" | assert norm in ["L1", "L2"], "norm must be L1 or L2" | ||||
| # Converts binary labels to -1/1 labels. | # Converts binary labels to -1/1 labels. | ||||
| loss = relu(1.0 - pred * label) | loss = relu(1.0 - pred * label) | ||||
| @@ -604,9 +604,9 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor: | |||||
| """ | """ | ||||
| assert len(inp.shape) <= 2, "Input should be 1d or 2d" | assert len(inp.shape) <= 2, "Input should be 1d or 2d" | ||||
| if descending: | if descending: | ||||
| order = "DESCENDING" | |||||
| order = "descending" | |||||
| else: | else: | ||||
| order = "ASCENDING" | |||||
| order = "ascending" | |||||
| op = builtin.Argsort(order=order) | op = builtin.Argsort(order=order) | ||||
| if len(inp.shape) == 1: | if len(inp.shape) == 1: | ||||
| @@ -646,9 +646,9 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: | |||||
| """ | """ | ||||
| assert len(inp.shape) <= 2, "Input should be 1d or 2d" | assert len(inp.shape) <= 2, "Input should be 1d or 2d" | ||||
| if descending: | if descending: | ||||
| order = "DESCENDING" | |||||
| order = "descending" | |||||
| else: | else: | ||||
| order = "ASCENDING" | |||||
| order = "ascending" | |||||
| op = builtin.Argsort(order=order) | op = builtin.Argsort(order=order) | ||||
| if len(inp.shape) == 1: | if len(inp.shape) == 1: | ||||
| @@ -699,11 +699,11 @@ def topk( | |||||
| inp = -inp | inp = -inp | ||||
| if kth_only: | if kth_only: | ||||
| mode = "KTH_ONLY" | |||||
| mode = "kth_only" | |||||
| elif no_sort: | elif no_sort: | ||||
| mode = "VALUE_IDX_NOSORT" | |||||
| mode = "value_idx_nosort" | |||||
| else: | else: | ||||
| mode = "VALUE_IDX_SORTED" | |||||
| mode = "value_idx_sorted" | |||||
| op = builtin.TopK(mode=mode) | op = builtin.TopK(mode=mode) | ||||
| if not isinstance(k, Tensor): | if not isinstance(k, Tensor): | ||||
| @@ -765,8 +765,8 @@ def matmul( | |||||
| inp2: Tensor, | inp2: Tensor, | ||||
| transpose_a=False, | transpose_a=False, | ||||
| transpose_b=False, | transpose_b=False, | ||||
| compute_mode="DEFAULT", | |||||
| format="DEFAULT", | |||||
| compute_mode="default", | |||||
| format="default", | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. | Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. | ||||
| @@ -776,7 +776,9 @@ def matmul( | |||||
| - Both 1-D tensor, simply forward to ``dot``. | - Both 1-D tensor, simply forward to ``dot``. | ||||
| - Both 2-D tensor, normal matrix multiplication. | - Both 2-D tensor, normal matrix multiplication. | ||||
| - If one input tensor is 1-D, matrix vector multiplication. | - If one input tensor is 1-D, matrix vector multiplication. | ||||
| - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted. For example: | |||||
| - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, | |||||
| the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted. | |||||
| For example: | |||||
| - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` | - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` | ||||
| - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` | - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` | ||||
| @@ -52,6 +52,8 @@ __all__ = [ | |||||
| "deformable_psroi_pooling", | "deformable_psroi_pooling", | ||||
| "dropout", | "dropout", | ||||
| "embedding", | "embedding", | ||||
| "hsigmoid", | |||||
| "hswish", | |||||
| "indexing_one_hot", | "indexing_one_hot", | ||||
| "leaky_relu", | "leaky_relu", | ||||
| "linear", | "linear", | ||||
| @@ -62,17 +64,14 @@ __all__ = [ | |||||
| "max_pool2d", | "max_pool2d", | ||||
| "one_hot", | "one_hot", | ||||
| "prelu", | "prelu", | ||||
| "softmax", | |||||
| "softplus", | |||||
| "sync_batch_norm", | |||||
| "conv1d", | |||||
| "sigmoid", | |||||
| "hsigmoid", | |||||
| "relu", | "relu", | ||||
| "relu6", | "relu6", | ||||
| "hswish", | |||||
| "resize", | |||||
| "remap", | "remap", | ||||
| "resize", | |||||
| "sigmoid", | |||||
| "softmax", | |||||
| "softplus", | |||||
| "sync_batch_norm", | |||||
| "warp_affine", | "warp_affine", | ||||
| "warp_perspective", | "warp_perspective", | ||||
| ] | ] | ||||
| @@ -106,6 +105,83 @@ def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor | |||||
| return ret | return ret | ||||
| def conv1d( | |||||
| inp: Tensor, | |||||
| weight: Tensor, | |||||
| bias: Optional[Tensor] = None, | |||||
| stride: int = 1, | |||||
| padding: int = 0, | |||||
| dilation: int = 1, | |||||
| groups: int = 1, | |||||
| conv_mode="cross_correlation", | |||||
| compute_mode="default", | |||||
| ) -> Tensor: | |||||
| """1D convolution operation. | |||||
| Refer to :class:`~.Conv1d` for more information. | |||||
| :param inp: The feature map of the convolution operation | |||||
| :param weight: The convolution kernel | |||||
| :param bias: The bias added to the result of convolution (if given) | |||||
| :param stride: Stride of the 1D convolution operation. Default: 1 | |||||
| :param padding: Size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: Dilation of the 1D convolution operation. Default: 1 | |||||
| :param groups: number of groups to divide input and output channels into, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and the shape of weight should be ``(groups, out_channel // groups, | |||||
| in_channels // groups, height, width)``. | |||||
| :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` | |||||
| :param conv_mode: Supports 'cross_correlation'. Default: | |||||
| 'cross_correlation'. | |||||
| :type compute_mode: string or | |||||
| :class:`mgb.opr_param_defs.Convolution.ComputeMode` | |||||
| :param compute_mode: When set to 'default', no special requirements will be | |||||
| placed on the precision of intermediate results. When set to 'float32', | |||||
| float32 would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | |||||
| """ | |||||
| assert ( | |||||
| conv_mode.lower() == "cross_correlation" | |||||
| or conv_mode.name == "CROSS_CORRELATION" | |||||
| ) | |||||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||||
| assert inp.ndim == 3, "the input dimension of conv1d should be 3" | |||||
| assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | |||||
| inp = expand_dims(inp, 3) | |||||
| weight = expand_dims(weight, 3) | |||||
| if bias is not None: | |||||
| assert bias.ndim == 3, "the bias dimension of conv1d should be 3" | |||||
| bias = expand_dims(bias, 3) | |||||
| stride_h = stride | |||||
| pad_h = padding | |||||
| dilate_h = dilation | |||||
| sparse_type = "dense" if groups == 1 else "group" | |||||
| op = builtin.Convolution( | |||||
| stride_h=stride_h, | |||||
| stride_w=1, | |||||
| pad_h=pad_h, | |||||
| pad_w=0, | |||||
| dilate_h=dilate_h, | |||||
| dilate_w=1, | |||||
| strategy=get_execution_strategy(), | |||||
| mode=conv_mode, | |||||
| compute_mode=compute_mode, | |||||
| sparse=sparse_type, | |||||
| ) | |||||
| inp, weight = utils.convert_inputs(inp, weight) | |||||
| (output,) = apply(op, inp, weight) | |||||
| if bias is not None: | |||||
| output += bias | |||||
| output = squeeze(output, 3) | |||||
| return output | |||||
| def conv2d( | def conv2d( | ||||
| inp: Tensor, | inp: Tensor, | ||||
| weight: Tensor, | weight: Tensor, | ||||
| @@ -114,8 +190,8 @@ def conv2d( | |||||
| padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| conv_mode="CROSS_CORRELATION", | |||||
| compute_mode="DEFAULT", | |||||
| conv_mode="cross_correlation", | |||||
| compute_mode="default", | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| 2D convolution operation. | 2D convolution operation. | ||||
| @@ -135,24 +211,27 @@ def conv2d( | |||||
| and the shape of weight should be `(groups, out_channel // groups, | and the shape of weight should be `(groups, out_channel // groups, | ||||
| in_channels // groups, height, width)`. | in_channels // groups, height, width)`. | ||||
| :type conv_mode: string or :class:`Convolution.Mode` | :type conv_mode: string or :class:`Convolution.Mode` | ||||
| :param conv_mode: supports "CROSS_CORRELATION". Default: | |||||
| "CROSS_CORRELATION" | |||||
| :param conv_mode: supports "cross_correlation". Default: | |||||
| "cross_correlation" | |||||
| :type compute_mode: string or | :type compute_mode: string or | ||||
| :class:`Convolution.ComputeMode` | :class:`Convolution.ComputeMode` | ||||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||||
| "Float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of Float16 dtype. | |||||
| :param compute_mode: when set to "default", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "float32", | |||||
| "float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | |||||
| :return: output tensor. | :return: output tensor. | ||||
| """ | """ | ||||
| assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | |||||
| assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | |||||
| assert ( | |||||
| conv_mode.lower() == "cross_correlation" | |||||
| or conv_mode.name == "CROSS_CORRELATION" | |||||
| ) | |||||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||||
| stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
| pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
| dilate_h, dilate_w = expand_hw(dilation) | dilate_h, dilate_w = expand_hw(dilation) | ||||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
| sparse_type = "dense" if groups == 1 else "group" | |||||
| op = builtin.Convolution( | op = builtin.Convolution( | ||||
| stride_h=stride_h, | stride_h=stride_h, | ||||
| stride_w=stride_w, | stride_w=stride_w, | ||||
| @@ -180,7 +259,7 @@ def conv3d( | |||||
| padding: Union[int, Tuple[int, int, int]] = 0, | padding: Union[int, Tuple[int, int, int]] = 0, | ||||
| dilation: Union[int, Tuple[int, int, int]] = 1, | dilation: Union[int, Tuple[int, int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| conv_mode: str = "cross_correlation", | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| 3D convolution operation. | 3D convolution operation. | ||||
| @@ -194,15 +273,16 @@ def conv3d( | |||||
| :param padding: size of the paddings added to the input on both sides of its | :param padding: size of the paddings added to the input on both sides of its | ||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | spatial dimensions. Only zero-padding is supported. Default: 0 | ||||
| :param dilation: dilation of the 3D convolution operation. Default: 1 | :param dilation: dilation of the 3D convolution operation. Default: 1 | ||||
| :param groups: number of groups into which the input and output channels are divided, so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||||
| :param groups: number of groups into which the input and output channels are divided, | |||||
| so as to perform a ``grouped convolution``. When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ||||
| and the shape of weight should be `(groups, out_channel // groups, | and the shape of weight should be `(groups, out_channel // groups, | ||||
| in_channels // groups, t, height, width)`. | in_channels // groups, t, height, width)`. | ||||
| :param conv_mode: supports "CROSS_CORRELATION". Default: | |||||
| "CROSS_CORRELATION" | |||||
| :param conv_mode: supports "cross_correlation". Default: | |||||
| "cross_correlation" | |||||
| :return: output tensor. | :return: output tensor. | ||||
| """ | """ | ||||
| assert conv_mode == "CROSS_CORRELATION" | |||||
| assert conv_mode.lower() == "cross_correlation" | |||||
| D, H, W = 0, 1, 2 | D, H, W = 0, 1, 2 | ||||
| @@ -210,7 +290,7 @@ def conv3d( | |||||
| stride = _triple_nonzero(stride) | stride = _triple_nonzero(stride) | ||||
| dilate = _triple_nonzero(dilation) | dilate = _triple_nonzero(dilation) | ||||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
| sparse_type = "dense" if groups == 1 else "group" | |||||
| op = builtin.Convolution3D( | op = builtin.Convolution3D( | ||||
| pad_d=pad[D], | pad_d=pad[D], | ||||
| pad_h=pad[H], | pad_h=pad[H], | ||||
| @@ -240,8 +320,8 @@ def conv_transpose2d( | |||||
| padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| conv_mode="CROSS_CORRELATION", | |||||
| compute_mode="DEFAULT", | |||||
| conv_mode="cross_correlation", | |||||
| compute_mode="default", | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| 2D transposed convolution operation. | 2D transposed convolution operation. | ||||
| @@ -261,18 +341,21 @@ def conv_transpose2d( | |||||
| and the shape of weight should be `(groups, out_channel // groups, | and the shape of weight should be `(groups, out_channel // groups, | ||||
| in_channels // groups, height, width)`. Default: 1 | in_channels // groups, height, width)`. Default: 1 | ||||
| :type conv_mode: string or :class:`Convolution.Mode` | :type conv_mode: string or :class:`Convolution.Mode` | ||||
| :param conv_mode: supports "CROSS_CORRELATION". Default: | |||||
| "CROSS_CORRELATION" | |||||
| :param conv_mode: supports "cross_correlation". Default: | |||||
| "cross_correlation" | |||||
| :type compute_mode: string or | :type compute_mode: string or | ||||
| :class:`Convolution.ComputeMode` | :class:`Convolution.ComputeMode` | ||||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||||
| "Float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of Float16 dtype. | |||||
| :param compute_mode: when set to "default", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "float32", | |||||
| "float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | |||||
| :return: output tensor. | :return: output tensor. | ||||
| """ | """ | ||||
| assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | |||||
| assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | |||||
| assert ( | |||||
| conv_mode.lower() == "cross_correlation" | |||||
| or conv_mode.name == "CROSS_CORRELATION" | |||||
| ) | |||||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||||
| if groups != 1: | if groups != 1: | ||||
| raise NotImplementedError("group transposed conv2d is not supported yet.") | raise NotImplementedError("group transposed conv2d is not supported yet.") | ||||
| @@ -307,8 +390,8 @@ def deformable_conv2d( | |||||
| padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| conv_mode="CROSS_CORRELATION", | |||||
| compute_mode="DEFAULT", | |||||
| conv_mode="cross_correlation", | |||||
| compute_mode="default", | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Deformable Convolution. | Deformable Convolution. | ||||
| @@ -328,24 +411,27 @@ def deformable_conv2d( | |||||
| and the shape of weight should be `(groups, out_channel // groups, | and the shape of weight should be `(groups, out_channel // groups, | ||||
| in_channels // groups, height, width)`. Default: 1 | in_channels // groups, height, width)`. Default: 1 | ||||
| :type conv_mode: string or :class:`Convolution.Mode` | :type conv_mode: string or :class:`Convolution.Mode` | ||||
| :param conv_mode: supports "CROSS_CORRELATION". Default: | |||||
| "CROSS_CORRELATION" | |||||
| :param conv_mode: supports "cross_correlation". Default: | |||||
| "cross_correlation" | |||||
| :type compute_mode: string or | :type compute_mode: string or | ||||
| :class:`Convolution.ComputeMode` | :class:`Convolution.ComputeMode` | ||||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||||
| "Float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of Float16 dtype. | |||||
| :param compute_mode: when set to "default", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "float32", | |||||
| "float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | |||||
| :return: output tensor. | :return: output tensor. | ||||
| """ | """ | ||||
| assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | |||||
| assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | |||||
| assert ( | |||||
| conv_mode.lower() == "cross_correlation" | |||||
| or conv_mode.name == "CROSS_CORRELATION" | |||||
| ) | |||||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | |||||
| stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
| pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
| dilate_h, dilate_w = expand_hw(dilation) | dilate_h, dilate_w = expand_hw(dilation) | ||||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
| sparse_type = "dense" if groups == 1 else "group" | |||||
| op = builtin.DeformableConv( | op = builtin.DeformableConv( | ||||
| stride_h=stride_h, | stride_h=stride_h, | ||||
| stride_w=stride_w, | stride_w=stride_w, | ||||
| @@ -372,10 +458,13 @@ def local_conv2d( | |||||
| stride: Union[int, Tuple[int, int]] = 1, | stride: Union[int, Tuple[int, int]] = 1, | ||||
| padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| conv_mode="CROSS_CORRELATION", | |||||
| conv_mode="cross_correlation", | |||||
| ): | ): | ||||
| """Applies spatial 2D convolution over an groupped channeled image with untied kernels.""" | """Applies spatial 2D convolution over an groupped channeled image with untied kernels.""" | ||||
| assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | |||||
| assert ( | |||||
| conv_mode.lower() == "cross_correlation" | |||||
| or conv_mode.name == "CROSS_CORRELATION" | |||||
| ) | |||||
| stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
| pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
| @@ -389,8 +478,8 @@ def local_conv2d( | |||||
| dilate_h=dilate_h, | dilate_h=dilate_h, | ||||
| dilate_w=dilate_w, | dilate_w=dilate_w, | ||||
| mode=conv_mode, | mode=conv_mode, | ||||
| compute_mode="DEFAULT", | |||||
| sparse="DENSE", | |||||
| compute_mode="default", | |||||
| sparse="dense", | |||||
| ) | ) | ||||
| inp, weight = utils.convert_inputs(inp, weight) | inp, weight = utils.convert_inputs(inp, weight) | ||||
| (output,) = apply(op, inp, weight) | (output,) = apply(op, inp, weight) | ||||
| @@ -430,7 +519,7 @@ def max_pool2d( | |||||
| stride_w=stride_w, | stride_w=stride_w, | ||||
| pad_h=padding_h, | pad_h=padding_h, | ||||
| pad_w=padding_w, | pad_w=padding_w, | ||||
| mode="MAX", | |||||
| mode="max", | |||||
| ) | ) | ||||
| (output,) = apply(op, inp) | (output,) = apply(op, inp) | ||||
| return output | return output | ||||
| @@ -441,7 +530,7 @@ def avg_pool2d( | |||||
| kernel_size: Union[int, Tuple[int, int]], | kernel_size: Union[int, Tuple[int, int]], | ||||
| stride: Optional[Union[int, Tuple[int, int]]] = None, | stride: Optional[Union[int, Tuple[int, int]]] = None, | ||||
| padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
| mode: str = "AVERAGE_COUNT_EXCLUDE_PADDING", | |||||
| mode: str = "average_count_exclude_padding", | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Applies 2D average pooling over an input tensor. | Applies 2D average pooling over an input tensor. | ||||
| @@ -453,7 +542,8 @@ def avg_pool2d( | |||||
| :param stride: stride of the window. If not provided, its value is set to ``kernel_size``. | :param stride: stride of the window. If not provided, its value is set to ``kernel_size``. | ||||
| Default: None | Default: None | ||||
| :param padding: implicit zero padding added on both sides. Default: 0 | :param padding: implicit zero padding added on both sides. Default: 0 | ||||
| :param mode: whether to count padding values. Default: "AVERAGE_COUNT_EXCLUDE_PADDING" | |||||
| :param mode: whether to count padding values, set to "average" will do counting. | |||||
| Default: "average_count_exclude_padding" | |||||
| :return: output tensor. | :return: output tensor. | ||||
| """ | """ | ||||
| if stride is None: | if stride is None: | ||||
| @@ -490,7 +580,7 @@ def adaptive_max_pool2d( | |||||
| if isinstance(oshp, int): | if isinstance(oshp, int): | ||||
| oshp = (oshp, oshp) | oshp = (oshp, oshp) | ||||
| op = builtin.AdaptivePooling(mode="MAX", format="NCHW",) | |||||
| op = builtin.AdaptivePooling(mode="max", format="NCHW",) | |||||
| oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) | oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) | ||||
| (output,) = apply(op, inp, oshp) | (output,) = apply(op, inp, oshp) | ||||
| return output | return output | ||||
| @@ -511,7 +601,7 @@ def adaptive_avg_pool2d( | |||||
| if isinstance(oshp, int): | if isinstance(oshp, int): | ||||
| oshp = (oshp, oshp) | oshp = (oshp, oshp) | ||||
| op = builtin.AdaptivePooling(mode="AVERAGE", format="NCHW",) | |||||
| op = builtin.AdaptivePooling(mode="average", format="NCHW",) | |||||
| oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) | oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) | ||||
| (output,) = apply(op, inp, oshp) | (output,) = apply(op, inp, oshp) | ||||
| return output | return output | ||||
| @@ -556,6 +646,53 @@ def deformable_psroi_pooling( | |||||
| return output | return output | ||||
| def hswish(x): | |||||
| """ | |||||
| Element-wise `x * relu6(x + 3) / 6`. | |||||
| :param x: input tensor. | |||||
| :return: computed tensor. | |||||
| Example: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.arange(5).astype(np.float32)) | |||||
| out = F.hswish(x) | |||||
| print(out.numpy().round(decimals=4)) | |||||
| .. testoutput:: | |||||
| [0. 0.6667 1.6667 3. 4. ] | |||||
| """ | |||||
| return _elwise(x, mode=Elemwise.Mode.H_SWISH) | |||||
| def sigmoid(x): | |||||
| """Element-wise `1 / ( 1 + exp( -x ) )`.""" | |||||
| return _elwise(x, mode=Elemwise.Mode.SIGMOID) | |||||
| def hsigmoid(x): | |||||
| """Element-wise `relu6(x + 3) / 6`.""" | |||||
| return relu6(x + 3) / 6 | |||||
| def relu(x): | |||||
| """Element-wise `max(x, 0)`.""" | |||||
| return _elwise(x, mode=Elemwise.Mode.RELU) | |||||
| def relu6(x): | |||||
| """Element-wise `min(max(x, 0), 6)`.""" | |||||
| return minimum(maximum(x, 0), 6) | |||||
| def prelu(inp: Tensor, weight: Tensor) -> Tensor: | def prelu(inp: Tensor, weight: Tensor) -> Tensor: | ||||
| r""" | r""" | ||||
| Applies the element-wise PReLU function. | Applies the element-wise PReLU function. | ||||
| @@ -872,14 +1009,14 @@ def batch_norm( | |||||
| if not training: | if not training: | ||||
| op = builtin.BatchNorm( | op = builtin.BatchNorm( | ||||
| fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="DIM_1C11" | |||||
| fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11" | |||||
| ) | ) | ||||
| ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | ||||
| return ret | return ret | ||||
| else: | else: | ||||
| op = builtin.BatchNorm( | op = builtin.BatchNorm( | ||||
| avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11" | |||||
| avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11" | |||||
| ) | ) | ||||
| if has_mean or has_var: | if has_mean or has_var: | ||||
| running_mean = make_full_if_none(running_mean, 0) | running_mean = make_full_if_none(running_mean, 0) | ||||
| @@ -915,7 +1052,7 @@ def sync_batch_norm( | |||||
| training: bool = False, | training: bool = False, | ||||
| momentum: Union[float, Tensor] = 0.9, | momentum: Union[float, Tensor] = 0.9, | ||||
| eps: float = 1e-5, | eps: float = 1e-5, | ||||
| eps_mode="ADDITIVE", | |||||
| eps_mode="additive", | |||||
| group=WORLD, | group=WORLD, | ||||
| ) -> Tensor: | ) -> Tensor: | ||||
| r""" | r""" | ||||
| @@ -939,7 +1076,9 @@ def sync_batch_norm( | |||||
| Default: 1e-5 | Default: 1e-5 | ||||
| :return: output tensor. | :return: output tensor. | ||||
| """ | """ | ||||
| assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) | |||||
| assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format( | |||||
| eps_mode | |||||
| ) | |||||
| _channels = inp.shape[1] | _channels = inp.shape[1] | ||||
| _ndim = inp.ndim | _ndim = inp.ndim | ||||
| _device = inp.device | _device = inp.device | ||||
| @@ -979,7 +1118,7 @@ def sync_batch_norm( | |||||
| channel_mean = running_mean.reshape(*_param_shape) | channel_mean = running_mean.reshape(*_param_shape) | ||||
| invsqrt_channel_variance = ( | invsqrt_channel_variance = ( | ||||
| maximum(channel_variance, eps) if eps_mode == "MAX" else channel_variance + eps | |||||
| maximum(channel_variance, eps) if eps_mode == "max" else channel_variance + eps | |||||
| ) ** -0.5 | ) ** -0.5 | ||||
| if weight is not None: | if weight is not None: | ||||
| @@ -1019,13 +1158,16 @@ def sync_batch_norm( | |||||
| return outvar | return outvar | ||||
| def one_hot(inp: Tensor, num_classes: int) -> Tensor: | |||||
| r""" | |||||
| Performs one-hot encoding for the input tensor. | |||||
| def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: | |||||
| """ | |||||
| Returns a new tensor where each of the elements are randomly set to zero | |||||
| with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True. | |||||
| :param inp: input tensor. | :param inp: input tensor. | ||||
| :param num_classes: number of classes denotes the last dimension of the output tensor. | |||||
| :return: output tensor. | |||||
| :param drop_prob: probability to drop (set to zero) a single element. | |||||
| :param training: the default behavior of ``dropout`` during training is to rescale the output, | |||||
| then it can be replaced by an :class:`~.Identity` during inference. Default: True | |||||
| :return: the output tensor | |||||
| Examples: | Examples: | ||||
| @@ -1035,51 +1177,33 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: | |||||
| from megengine import tensor | from megengine import tensor | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| x = tensor(np.arange(1, 4, dtype=np.int32)) | |||||
| out = F.one_hot(x, num_classes=4) | |||||
| x = tensor(np.ones(10, dtype=np.float32)) | |||||
| out = F.dropout(x, 1./3.) | |||||
| print(out.numpy()) | print(out.numpy()) | ||||
| Outputs: | Outputs: | ||||
| .. testoutput:: | .. testoutput:: | ||||
| :options: +SKIP | |||||
| [[0 1 0 0] | |||||
| [0 0 1 0] | |||||
| [0 0 0 1]] | |||||
| """ | |||||
| zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device) | |||||
| ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device) | |||||
| op = builtin.IndexingSetOneHot(axis=inp.ndim) | |||||
| (result,) = apply(op, zeros_tensor, inp, ones_tensor) | |||||
| return result | |||||
| [1.5 1.5 0. 1.5 1.5 1.5 1.5 1.5 1.5 1.5] | |||||
| def matmul( | |||||
| inp1: Tensor, | |||||
| inp2: Tensor, | |||||
| transpose_a=False, | |||||
| transpose_b=False, | |||||
| compute_mode="DEFAULT", | |||||
| format="DEFAULT", | |||||
| ) -> Tensor: | |||||
| """ | """ | ||||
| Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. | |||||
| assert 0 <= drop_prob < 1 | |||||
| rv = uniform(size=inp.shape) | |||||
| mask = rv > drop_prob | |||||
| inp *= mask.astype(inp.dtype) | |||||
| if training: | |||||
| inp *= 1 / (1 - drop_prob) | |||||
| return inp | |||||
| With different inputs dim, this function behaves differently: | |||||
| - Both 1-D tensor, simply forward to ``dot``. | |||||
| - Both 2-D tensor, normal matrix multiplication. | |||||
| - If one input tensor is 1-D, matrix vector multiplication. | |||||
| - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will | |||||
| be broadcasted. For example: | |||||
| - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` | |||||
| - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` | |||||
| - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)` | |||||
| def one_hot(inp: Tensor, num_classes: int) -> Tensor: | |||||
| r""" | |||||
| Performs one-hot encoding for the input tensor. | |||||
| :param inp1: first matrix to be multiplied. | |||||
| :param inp2: second matrix to be multiplied. | |||||
| :param inp: input tensor. | |||||
| :param num_classes: number of classes denotes the last dimension of the output tensor. | |||||
| :return: output tensor. | :return: output tensor. | ||||
| Examples: | Examples: | ||||
| @@ -1090,182 +1214,27 @@ def matmul( | |||||
| from megengine import tensor | from megengine import tensor | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||||
| data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2)) | |||||
| out = F.matmul(data1, data2) | |||||
| x = tensor(np.arange(1, 4, dtype=np.int32)) | |||||
| out = F.one_hot(x, num_classes=4) | |||||
| print(out.numpy()) | print(out.numpy()) | ||||
| Outputs: | Outputs: | ||||
| .. testoutput:: | .. testoutput:: | ||||
| [[10. 13.] | |||||
| [28. 40.]] | |||||
| """ | |||||
| remove_row, remove_col = False, False | |||||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||||
| dim1, dim2 = inp1.ndim, inp2.ndim | |||||
| # handle dim=1 cases, dot and matrix-vector multiplication | |||||
| if dim1 == 1 and dim2 == 1: | |||||
| return dot(inp1, inp2) | |||||
| # the underlying matmul op requires input dims to be at least 2 | |||||
| if dim1 == 1: | |||||
| inp1 = expand_dims(inp1, 0) | |||||
| dim1 = 2 | |||||
| remove_row = True | |||||
| if dim2 == 1: | |||||
| inp2 = expand_dims(inp2, 1) | |||||
| dim2 = 2 | |||||
| remove_col = True | |||||
| batch_shape = None | |||||
| shape1 = inp1.shape | |||||
| shape2 = inp2.shape | |||||
| maxdim = dim1 if dim1 > dim2 else dim2 | |||||
| if dim1 >= 3 or dim2 >= 3: | |||||
| if use_symbolic_shape(): | |||||
| if dim1 > dim2: | |||||
| shape2 = concat([shape1[:-2], shape2[-2:]]) | |||||
| inp2 = broadcast_to(inp2, shape2) | |||||
| if dim1 < dim2: | |||||
| shape1 = concat([shape2[:-2], shape1[-2:]]) | |||||
| inp1 = broadcast_to(inp1, shape1) | |||||
| if maxdim > 3: | |||||
| batch_shape = shape1[:-2] | |||||
| # compress inputs to 3d | |||||
| (inp1,) = apply( | |||||
| builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) | |||||
| ) | |||||
| (inp2,) = apply( | |||||
| builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) | |||||
| ) | |||||
| else: | |||||
| if dim1 > dim2: | |||||
| shape2 = shape1[:-2] + shape2[-2:] | |||||
| inp2 = broadcast_to(inp2, shape2) | |||||
| if dim1 < dim2: | |||||
| shape1 = shape2[:-2] + shape1[-2:] | |||||
| inp1 = broadcast_to(inp1, shape1) | |||||
| if maxdim > 3: | |||||
| batch_shape = shape1[:-2] | |||||
| # compress inputs to 3d | |||||
| inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) | |||||
| inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) | |||||
| op = builtin.BatchedMatrixMul( | |||||
| transposeA=transpose_a, | |||||
| transposeB=transpose_b, | |||||
| compute_mode=compute_mode, | |||||
| format=format, | |||||
| strategy=get_execution_strategy(), | |||||
| ) | |||||
| else: | |||||
| op = builtin.MatrixMul( | |||||
| transposeA=transpose_a, | |||||
| transposeB=transpose_b, | |||||
| compute_mode=compute_mode, | |||||
| format=format, | |||||
| strategy=get_execution_strategy(), | |||||
| ) | |||||
| (result,) = apply(op, inp1, inp2) | |||||
| if maxdim > 3: | |||||
| if use_symbolic_shape(): | |||||
| (result,) = apply( | |||||
| builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) | |||||
| ) | |||||
| else: | |||||
| result = result.reshape(batch_shape + result.shape[-2:]) | |||||
| if remove_row: | |||||
| result = squeeze(result, axis=-2) | |||||
| if remove_col: | |||||
| result = squeeze(result, axis=-1) | |||||
| return result | |||||
| [[0 1 0 0] | |||||
| [0 0 1 0] | |||||
| [0 0 0 1]] | |||||
| def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||||
| """ | """ | ||||
| Computes dot-product of two vectors ``inp1`` and ``inp2``. | |||||
| inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted. | |||||
| Refer to :func:`~.matmul` for more general usage. | |||||
| :param inp1: first vector. | |||||
| :param inp2: second vector. | |||||
| :return: output value. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data1 = tensor(np.arange(0, 6, dtype=np.float32)) | |||||
| data2 = tensor(np.arange(0, 6, dtype=np.float32)) | |||||
| out = F.dot(data1, data2) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| 55. | |||||
| zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device) | |||||
| ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device) | |||||
| """ | |||||
| op = builtin.Dot() | |||||
| inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||||
| assert ( | |||||
| inp1.ndim <= 1 and inp2.ndim <= 1 | |||||
| ), "Input tensors for dot must be 1-dimensional or scalar" | |||||
| (result,) = apply(op, inp1, inp2) | |||||
| setscalar(result) | |||||
| op = builtin.IndexingSetOneHot(axis=inp.ndim) | |||||
| (result,) = apply(op, zeros_tensor, inp, ones_tensor) | |||||
| return result | return result | ||||
| def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: | |||||
| """ | |||||
| Returns a new tensor where each of the elements are randomly set to zero | |||||
| with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True. | |||||
| :param inp: input tensor. | |||||
| :param drop_prob: probability to drop (set to zero) a single element. | |||||
| :param training: the default behavior of ``dropout`` during training is to rescale the output, | |||||
| then it can be replaced by an :class:`~.Identity` during inference. Default: True | |||||
| :return: the output tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.ones(10, dtype=np.float32)) | |||||
| out = F.dropout(x, 1./3.) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| :options: +SKIP | |||||
| [1.5 1.5 0. 1.5 1.5 1.5 1.5 1.5 1.5 1.5] | |||||
| """ | |||||
| assert 0 <= drop_prob < 1 | |||||
| rv = uniform(size=inp.shape) | |||||
| mask = rv > drop_prob | |||||
| inp *= mask.astype(inp.dtype) | |||||
| if training: | |||||
| inp *= 1 / (1 - drop_prob) | |||||
| return inp | |||||
| def embedding( | def embedding( | ||||
| inp: Tensor, | inp: Tensor, | ||||
| weight: Tensor, | weight: Tensor, | ||||
| @@ -1334,128 +1303,6 @@ def indexing_one_hot( | |||||
| return result | return result | ||||
| def conv1d( | |||||
| inp: Tensor, | |||||
| weight: Tensor, | |||||
| bias: Optional[Tensor] = None, | |||||
| stride: int = 1, | |||||
| padding: int = 0, | |||||
| dilation: int = 1, | |||||
| groups: int = 1, | |||||
| conv_mode="CROSS_CORRELATION", | |||||
| compute_mode="DEFAULT", | |||||
| ) -> Tensor: | |||||
| """1D convolution operation. | |||||
| Refer to :class:`~.Conv1d` for more information. | |||||
| :param inp: The feature map of the convolution operation | |||||
| :param weight: The convolution kernel | |||||
| :param bias: The bias added to the result of convolution (if given) | |||||
| :param stride: Stride of the 1D convolution operation. Default: 1 | |||||
| :param padding: Size of the paddings added to the input on both sides of its | |||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: Dilation of the 1D convolution operation. Default: 1 | |||||
| :param groups: number of groups to divide input and output channels into, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and the shape of weight should be ``(groups, out_channel // groups, | |||||
| in_channels // groups, height, width)``. | |||||
| :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` | |||||
| :param conv_mode: Supports 'CROSS_CORRELATION'. Default: | |||||
| 'CROSS_CORRELATION'. | |||||
| :type compute_mode: string or | |||||
| :class:`mgb.opr_param_defs.Convolution.ComputeMode` | |||||
| :param compute_mode: When set to 'DEFAULT', no special requirements will be | |||||
| placed on the precision of intermediate results. When set to 'FLOAT32', | |||||
| Float32 would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of Float16 dtype. | |||||
| """ | |||||
| assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | |||||
| assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | |||||
| assert inp.ndim == 3, "the input dimension of conv1d should be 3" | |||||
| assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | |||||
| inp = expand_dims(inp, 3) | |||||
| weight = expand_dims(weight, 3) | |||||
| if bias is not None: | |||||
| assert bias.ndim == 3, "the bias dimension of conv1d should be 3" | |||||
| bias = expand_dims(bias, 3) | |||||
| stride_h = stride | |||||
| pad_h = padding | |||||
| dilate_h = dilation | |||||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
| op = builtin.Convolution( | |||||
| stride_h=stride_h, | |||||
| stride_w=1, | |||||
| pad_h=pad_h, | |||||
| pad_w=0, | |||||
| dilate_h=dilate_h, | |||||
| dilate_w=1, | |||||
| strategy=get_execution_strategy(), | |||||
| mode=conv_mode, | |||||
| compute_mode=compute_mode, | |||||
| sparse=sparse_type, | |||||
| ) | |||||
| inp, weight = utils.convert_inputs(inp, weight) | |||||
| (output,) = apply(op, inp, weight) | |||||
| if bias is not None: | |||||
| output += bias | |||||
| output = squeeze(output, 3) | |||||
| return output | |||||
| def hswish(x): | |||||
| """ | |||||
| Element-wise `x * relu6(x + 3) / 6`. | |||||
| :param x: input tensor. | |||||
| :return: computed tensor. | |||||
| Example: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor(np.arange(5).astype(np.float32)) | |||||
| out = F.hswish(x) | |||||
| print(out.numpy().round(decimals=4)) | |||||
| .. testoutput:: | |||||
| [0. 0.6667 1.6667 3. 4. ] | |||||
| """ | |||||
| return _elwise(x, mode=Elemwise.Mode.H_SWISH) | |||||
| def sigmoid(x): | |||||
| """Element-wise `1 / ( 1 + exp( -x ) )`.""" | |||||
| return _elwise(x, mode=Elemwise.Mode.SIGMOID) | |||||
| def hsigmoid(x): | |||||
| """Element-wise `relu6(x + 3) / 6`.""" | |||||
| return relu6(x + 3) / 6 | |||||
| def relu(x): | |||||
| """Element-wise `max(x, 0)`.""" | |||||
| return _elwise(x, mode=Elemwise.Mode.RELU) | |||||
| def relu6(x): | |||||
| """Element-wise `min(max(x, 0), 6)`.""" | |||||
| return minimum(maximum(x, 0), 6) | |||||
| interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True) | interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True) | ||||
| roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True) | roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True) | ||||
| roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True) | roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True) | ||||
| @@ -24,9 +24,9 @@ def conv_bias_activation( | |||||
| padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| nonlinear_mode="IDENTITY", | |||||
| conv_mode="CROSS_CORRELATION", | |||||
| compute_mode="DEFAULT", | |||||
| nonlinear_mode="identity", | |||||
| conv_mode="cross_correlation", | |||||
| compute_mode="default", | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Convolution bias with activation operation, only for inference. | Convolution bias with activation operation, only for inference. | ||||
| @@ -35,27 +35,30 @@ def conv_bias_activation( | |||||
| :param weight: convolution kernel. | :param weight: convolution kernel. | ||||
| :param bias: bias added to the result of convolution | :param bias: bias added to the result of convolution | ||||
| :param stride: stride of the 2D convolution operation. Default: 1 | :param stride: stride of the 2D convolution operation. Default: 1 | ||||
| :param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param padding: size of the paddings added to the input on both sides | |||||
| of its spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | :param dilation: dilation of the 2D convolution operation. Default: 1 | ||||
| :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| :param groups: number of groups into which the input and output channels are divided, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ||||
| and the shape of weight should be `(groups, out_channel // groups, | and the shape of weight should be `(groups, out_channel // groups, | ||||
| in_channels // groups, height, width)`. | in_channels // groups, height, width)`. | ||||
| :type conv_mode: string or :class:`Convolution.Mode`. | :type conv_mode: string or :class:`Convolution.Mode`. | ||||
| :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||||
| 'CROSS_CORRELATION' | |||||
| :param conv_mode: supports 'cross_correlation' or 'convolution'. Default: | |||||
| 'cross_correlation' | |||||
| :param dtype: support for ``np.dtype``, Default: np.int8 | :param dtype: support for ``np.dtype``, Default: np.int8 | ||||
| :type compute_mode: string or | :type compute_mode: string or | ||||
| :class:`Convolution.ComputeMode`. | :class:`Convolution.ComputeMode`. | ||||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||||
| "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. | |||||
| :param compute_mode: when set to "default", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "float32", | |||||
| "float32" would be used for accumulator and intermediate result, | |||||
| but only effective when input and output are of float16 dtype. | |||||
| """ | """ | ||||
| ph, pw = _pair(padding) | ph, pw = _pair(padding) | ||||
| sh, sw = _pair_nonzero(stride) | sh, sw = _pair_nonzero(stride) | ||||
| dh, dw = _pair_nonzero(dilation) | dh, dw = _pair_nonzero(dilation) | ||||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
| sparse_type = "dense" if groups == 1 else "group" | |||||
| op = builtin.ConvBias( | op = builtin.ConvBias( | ||||
| stride_h=sh, | stride_h=sh, | ||||
| stride_w=sw, | stride_w=sw, | ||||
| @@ -84,9 +87,9 @@ def batch_conv_bias_activation( | |||||
| padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| nonlinear_mode="IDENTITY", | |||||
| conv_mode="CROSS_CORRELATION", | |||||
| compute_mode="DEFAULT", | |||||
| nonlinear_mode="identity", | |||||
| conv_mode="cross_correlation", | |||||
| compute_mode="default", | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Batch convolution bias with activation operation, only for inference. | Batch convolution bias with activation operation, only for inference. | ||||
| @@ -95,27 +98,30 @@ def batch_conv_bias_activation( | |||||
| :param weight: convolution kernel in batched way. | :param weight: convolution kernel in batched way. | ||||
| :param bias: bias added to the result of convolution | :param bias: bias added to the result of convolution | ||||
| :param stride: stride of the 2D convolution operation. Default: 1 | :param stride: stride of the 2D convolution operation. Default: 1 | ||||
| :param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param padding: size of the paddings added to the input on both sides | |||||
| of its spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | :param dilation: dilation of the 2D convolution operation. Default: 1 | ||||
| :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| :param groups: number of groups into which the input and output channels are divided, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ||||
| and the shape of weight should be `(groups, out_channel // groups, | and the shape of weight should be `(groups, out_channel // groups, | ||||
| in_channels // groups, height, width)`. | in_channels // groups, height, width)`. | ||||
| :type conv_mode: string or :class:`Convolution.Mode`. | :type conv_mode: string or :class:`Convolution.Mode`. | ||||
| :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||||
| 'CROSS_CORRELATION' | |||||
| :param conv_mode: supports 'cross_correlation' or 'convolution'. Default: | |||||
| 'cross_correlation' | |||||
| :param dtype: support for ``np.dtype``, Default: np.int8 | :param dtype: support for ``np.dtype``, Default: np.int8 | ||||
| :type compute_mode: string or | :type compute_mode: string or | ||||
| :class:`Convolution.ComputeMode`. | :class:`Convolution.ComputeMode`. | ||||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||||
| "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. | |||||
| :param compute_mode: when set to "default", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "float32", | |||||
| "float32" would be used for accumulator and intermediate result, | |||||
| but only effective when input and output are of float16 dtype. | |||||
| """ | """ | ||||
| ph, pw = _pair(padding) | ph, pw = _pair(padding) | ||||
| sh, sw = _pair_nonzero(stride) | sh, sw = _pair_nonzero(stride) | ||||
| dh, dw = _pair_nonzero(dilation) | dh, dw = _pair_nonzero(dilation) | ||||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
| sparse_type = "dense" if groups == 1 else "group" | |||||
| op = builtin.BatchConvBias( | op = builtin.BatchConvBias( | ||||
| stride_h=sh, | stride_h=sh, | ||||
| stride_w=sw, | stride_w=sw, | ||||
| @@ -335,12 +335,8 @@ def split(inp, nsplits_or_sections, axis=0): | |||||
| y = F.split(x, 3) | y = F.split(x, 3) | ||||
| z = F.split(x, [6, 17], axis=1) | z = F.split(x, [6, 17], axis=1) | ||||
| if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"): | |||||
| print([tuple(i.shape.numpy().tolist()) for i in y]) | |||||
| print([tuple(i.shape.numpy().tolist()) for i in z]) | |||||
| else: | |||||
| print([i.shape for i in y]) | |||||
| print([i.shape for i in z]) | |||||
| print([i.numpy().shape for i in y]) | |||||
| print([i.numpy().shape for i in z]) | |||||
| Outputs: | Outputs: | ||||
| @@ -46,6 +46,7 @@ def cvt_color(inp: Tensor, mode: str = ""): | |||||
| [[[[0.86555195]]]] | [[[[0.86555195]]]] | ||||
| """ | """ | ||||
| mode = mode.upper() | |||||
| assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color" | assert mode in builtin.CvtColor.Mode.__dict__, "unspport mode for cvt_color" | ||||
| mode = getattr(builtin.CvtColor.Mode, mode) | mode = getattr(builtin.CvtColor.Mode, mode) | ||||
| assert isinstance(mode, builtin.CvtColor.Mode) | assert isinstance(mode, builtin.CvtColor.Mode) | ||||
| @@ -92,9 +93,8 @@ def roi_pooling( | |||||
| [[[-0.1383 -0.1383] | [[[-0.1383 -0.1383] | ||||
| [-0.5035 -0.5035]]] | [-0.5035 -0.5035]]] | ||||
| """ | """ | ||||
| assert mode in ["max", "average"], "only max/average mode is supported" | |||||
| assert mode.lower() in ["max", "average"], "only max/average mode is supported" | |||||
| if isinstance(output_shape, int): | if isinstance(output_shape, int): | ||||
| output_shape = (output_shape, output_shape) | output_shape = (output_shape, output_shape) | ||||
| @@ -151,6 +151,7 @@ def roi_align( | |||||
| [0.1359 0.1359]]] | [0.1359 0.1359]]] | ||||
| """ | """ | ||||
| mode = mode.lower() | |||||
| assert mode in ["max", "average"], "only max/average mode is supported" | assert mode in ["max", "average"], "only max/average mode is supported" | ||||
| if isinstance(output_shape, int): | if isinstance(output_shape, int): | ||||
| output_shape = (output_shape, output_shape) | output_shape = (output_shape, output_shape) | ||||
| @@ -244,9 +245,9 @@ def nms( | |||||
| def remap( | def remap( | ||||
| inp: Tensor, | inp: Tensor, | ||||
| map_xy: Tensor, | map_xy: Tensor, | ||||
| border_mode: str = "REPLICATE", | |||||
| border_mode: str = "replicate", | |||||
| scalar: float = 0.0, | scalar: float = 0.0, | ||||
| interp_mode: str = "LINEAR", | |||||
| interp_mode: str = "linear", | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| r""" | r""" | ||||
| Applies remap transformation to batched 2D images. | Applies remap transformation to batched 2D images. | ||||
| @@ -257,11 +258,11 @@ def remap( | |||||
| :param inp: input image | :param inp: input image | ||||
| :param map_xy: (batch, oh, ow, 2) transformation matrix | :param map_xy: (batch, oh, ow, 2) transformation matrix | ||||
| :param border_mode: pixel extrapolation method. | :param border_mode: pixel extrapolation method. | ||||
| Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT", | |||||
| "REFLECT_101", "WRAP". | |||||
| Default: "replicate". Currently also support "constant", "reflect", | |||||
| "reflect_101", "wrap". | |||||
| :param scalar: value used in case of a constant border. Default: 0 | :param scalar: value used in case of a constant border. Default: 0 | ||||
| :param interp_mode: interpolation methods. | :param interp_mode: interpolation methods. | ||||
| Default: "LINEAR". Currently only support "LINEAR" mode. | |||||
| Default: "linear". Currently only support "linear" mode. | |||||
| :return: output tensor. | :return: output tensor. | ||||
| Examples: | Examples: | ||||
| @@ -301,10 +302,10 @@ def warp_affine( | |||||
| inp: Tensor, | inp: Tensor, | ||||
| weight: Tensor, | weight: Tensor, | ||||
| out_shape, | out_shape, | ||||
| border_mode="REPLICATE", | |||||
| border_mode="replicate", | |||||
| border_val=0, | border_val=0, | ||||
| format="NHWC", | format="NHWC", | ||||
| imode="LINEAR", | |||||
| imode="linear", | |||||
| ): | ): | ||||
| """ | """ | ||||
| Batched affine transform on 2D images. | Batched affine transform on 2D images. | ||||
| @@ -313,13 +314,13 @@ def warp_affine( | |||||
| :param weight: weight tensor. | :param weight: weight tensor. | ||||
| :param out_shape: output tensor shape. | :param out_shape: output tensor shape. | ||||
| :param border_mode: pixel extrapolation method. | :param border_mode: pixel extrapolation method. | ||||
| Default: "WRAP". Currently "CONSTANT", "REFLECT", | |||||
| "REFLECT_101", "ISOLATED", "WRAP", "REPLICATE", "TRANSPARENT" are supported. | |||||
| Default: "wrap". Currently "constant", "reflect", | |||||
| "reflect_101", "isolated", "wrap", "replicate", "transparent" are supported. | |||||
| :param border_val: value used in case of a constant border. Default: 0 | :param border_val: value used in case of a constant border. Default: 0 | ||||
| :param format: "NHWC" as default based on historical concerns, | :param format: "NHWC" as default based on historical concerns, | ||||
| "NCHW" is also supported. Default: "NCHW". | |||||
| :param imode: interpolation methods. Could be "LINEAR", "NEAREST", "CUBIC", "AREA". | |||||
| Default: "LINEAR". | |||||
| "NCHW" is also supported. Default: "NHWC". | |||||
| :param imode: interpolation methods. Could be "linear", "nearest", "cubic", "area". | |||||
| Default: "linear". | |||||
| :return: output tensor. | :return: output tensor. | ||||
| .. note:: | .. note:: | ||||
| @@ -340,9 +341,9 @@ def warp_perspective( | |||||
| inp: Tensor, | inp: Tensor, | ||||
| M: Tensor, | M: Tensor, | ||||
| dsize: Union[Tuple[int, int], int, Tensor], | dsize: Union[Tuple[int, int], int, Tensor], | ||||
| border_mode: str = "REPLICATE", | |||||
| border_mode: str = "replicate", | |||||
| border_val: float = 0.0, | border_val: float = 0.0, | ||||
| interp_mode: str = "LINEAR", | |||||
| interp_mode: str = "linear", | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| r""" | r""" | ||||
| Applies perspective transformation to batched 2D images. | Applies perspective transformation to batched 2D images. | ||||
| @@ -359,11 +360,11 @@ def warp_perspective( | |||||
| :param M: `(batch, 3, 3)` transformation matrix. | :param M: `(batch, 3, 3)` transformation matrix. | ||||
| :param dsize: `(h, w)` size of the output image. | :param dsize: `(h, w)` size of the output image. | ||||
| :param border_mode: pixel extrapolation method. | :param border_mode: pixel extrapolation method. | ||||
| Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT", | |||||
| "REFLECT_101", "WRAP". | |||||
| Default: "replicate". Currently also support "constant", "reflect", | |||||
| "reflect_101", "wrap". | |||||
| :param border_val: value used in case of a constant border. Default: 0 | :param border_val: value used in case of a constant border. Default: 0 | ||||
| :param interp_mode: interpolation methods. | :param interp_mode: interpolation methods. | ||||
| Default: "LINEAR". Currently only support "LINEAR" mode. | |||||
| Default: "linear". Currently only support "linear" mode. | |||||
| :return: output tensor. | :return: output tensor. | ||||
| Note: | Note: | ||||
| @@ -409,7 +410,7 @@ def interpolate( | |||||
| inp: Tensor, | inp: Tensor, | ||||
| size: Optional[Union[int, Tuple[int, int]]] = None, | size: Optional[Union[int, Tuple[int, int]]] = None, | ||||
| scale_factor: Optional[Union[float, Tuple[float, float]]] = None, | scale_factor: Optional[Union[float, Tuple[float, float]]] = None, | ||||
| mode: str = "BILINEAR", | |||||
| mode: str = "bilinear", | |||||
| align_corners: Optional[bool] = None, | align_corners: Optional[bool] = None, | ||||
| ) -> Tensor: | ) -> Tensor: | ||||
| r""" | r""" | ||||
| @@ -419,9 +420,9 @@ def interpolate( | |||||
| :param size: size of the output tensor. Default: None | :param size: size of the output tensor. Default: None | ||||
| :param scale_factor: scaling factor of the output tensor. Default: None | :param scale_factor: scaling factor of the output tensor. Default: None | ||||
| :param mode: interpolation methods, acceptable values are: | :param mode: interpolation methods, acceptable values are: | ||||
| "BILINEAR", "LINEAR". Default: "BILINEAR" | |||||
| "bilinear", "linear". Default: "bilinear" | |||||
| :param align_corners: This only has an effect when `mode` | :param align_corners: This only has an effect when `mode` | ||||
| is "BILINEAR" or "LINEAR". Geometrically, we consider the pixels of the input | |||||
| is "bilinear" or "linear". Geometrically, we consider the pixels of the input | |||||
| and output as squares rather than points. If set to ``True``, the input | and output as squares rather than points. If set to ``True``, the input | ||||
| and output tensors are aligned by the center points of their corner | and output tensors are aligned by the center points of their corner | ||||
| pixels, preserving the values at the corner pixels. If set to ``False``, | pixels, preserving the values at the corner pixels. If set to ``False``, | ||||
| @@ -455,10 +456,10 @@ def interpolate( | |||||
| [3. 3.25 3.75 4. ]]]] | [3. 3.25 3.75 4. ]]]] | ||||
| """ | """ | ||||
| mode = mode.upper() | |||||
| if mode not in ["BILINEAR", "LINEAR"]: | |||||
| mode = mode.lower() | |||||
| if mode not in ["bilinear", "linear"]: | |||||
| raise ValueError("interpolate only support linear or bilinear mode") | raise ValueError("interpolate only support linear or bilinear mode") | ||||
| if mode not in ["BILINEAR", "LINEAR"]: | |||||
| if mode not in ["bilinear", "linear"]: | |||||
| if align_corners is not None: | if align_corners is not None: | ||||
| raise ValueError( | raise ValueError( | ||||
| "align_corners option can only be set in the bilinear/linear interpolating mode" | "align_corners option can only be set in the bilinear/linear interpolating mode" | ||||
| @@ -471,16 +472,16 @@ def interpolate( | |||||
| size is not None | size is not None | ||||
| and scale_factor is None | and scale_factor is None | ||||
| and not align_corners | and not align_corners | ||||
| and mode == "BILINEAR" | |||||
| and mode == "bilinear" | |||||
| and inp.ndim in [4, 5] | and inp.ndim in [4, 5] | ||||
| ): | ): | ||||
| # fastpath for interpolate | # fastpath for interpolate | ||||
| op = builtin.Resize(imode="LINEAR", format="NCHW") | |||||
| op = builtin.Resize(imode="linear", format="NCHW") | |||||
| shape = astensor1d(size, inp, dtype="int32", device=inp.device) | shape = astensor1d(size, inp, dtype="int32", device=inp.device) | ||||
| (result,) = apply(op, inp, shape) | (result,) = apply(op, inp, shape) | ||||
| return result | return result | ||||
| if mode == "LINEAR": | |||||
| if mode == "linear": | |||||
| inp = expand_dims(inp, 3) | inp = expand_dims(inp, 3) | ||||
| if inp.ndim != 4: | if inp.ndim != 4: | ||||
| @@ -492,14 +493,14 @@ def interpolate( | |||||
| if isinstance(scale_factor, (float, int)): | if isinstance(scale_factor, (float, int)): | ||||
| scale_factor = float(scale_factor) | scale_factor = float(scale_factor) | ||||
| if mode == "LINEAR": | |||||
| if mode == "linear": | |||||
| scale_factor = (scale_factor, float(1)) | scale_factor = (scale_factor, float(1)) | ||||
| else: | else: | ||||
| scale_factor = (scale_factor, scale_factor) | scale_factor = (scale_factor, scale_factor) | ||||
| else: | else: | ||||
| if mode == "LINEAR": | |||||
| if mode == "linear": | |||||
| raise ValueError( | raise ValueError( | ||||
| "under LINEAR mode, scale_factor can only be single value" | |||||
| "under linear mode, scale_factor can only be single value" | |||||
| ) | ) | ||||
| assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )" | assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )" | ||||
| @@ -524,8 +525,8 @@ def interpolate( | |||||
| if isinstance(size, int): | if isinstance(size, int): | ||||
| size = (size, 1) | size = (size, 1) | ||||
| else: | else: | ||||
| if mode == "LINEAR": | |||||
| raise ValueError("under LINEAR mode, size can only be single value") | |||||
| if mode == "linear": | |||||
| raise ValueError("under linear mode, size can only be single value") | |||||
| dsize = size | dsize = size | ||||
| oh, ow = dsize[0], dsize[1] | oh, ow = dsize[0], dsize[1] | ||||
| @@ -534,7 +535,7 @@ def interpolate( | |||||
| if align_corners: | if align_corners: | ||||
| hscale = (ih - 1.0) / (oh - 1.0) | hscale = (ih - 1.0) / (oh - 1.0) | ||||
| wscale = 1.0 * iw / ow | wscale = 1.0 * iw / ow | ||||
| if mode != "LINEAR": | |||||
| if mode != "linear": | |||||
| wscale = (iw - 1.0) / (ow - 1.0) | wscale = (iw - 1.0) / (ow - 1.0) | ||||
| row0 = concat( | row0 = concat( | ||||
| [wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0 | [wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0 | ||||
| @@ -570,8 +571,8 @@ def interpolate( | |||||
| weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | ||||
| weight = weight.astype("float32") | weight = weight.astype("float32") | ||||
| ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR") | |||||
| if mode == "LINEAR": | |||||
| ret = warp_perspective(inp, weight, dsize, interp_mode="linear") | |||||
| if mode == "linear": | |||||
| ret = reshape(ret, ret.shape[0:3]) | ret = reshape(ret, ret.shape[0:3]) | ||||
| return ret | return ret | ||||
| @@ -24,7 +24,7 @@ class BatchMatMulActivation(Module): | |||||
| in_features: int, | in_features: int, | ||||
| out_features: int, | out_features: int, | ||||
| bias: bool = True, | bias: bool = True, | ||||
| nonlinear_mode="IDENTITY", | |||||
| nonlinear_mode="identity", | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
| @@ -37,7 +37,7 @@ class BatchMatMulActivation(Module): | |||||
| if bias: | if bias: | ||||
| b_shape = (out_features,) | b_shape = (out_features,) | ||||
| self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | ||||
| self.nonlinear_mode = nonlinear_mode | |||||
| self.nonlinear_mode = nonlinear_mode.lower() | |||||
| self.reset_parameters() | self.reset_parameters() | ||||
| def _get_fanin(self): | def _get_fanin(self): | ||||
| @@ -54,7 +54,7 @@ class BatchMatMulActivation(Module): | |||||
| res = matmul(weight, x) | res = matmul(weight, x) | ||||
| if self.bias is not None: | if self.bias is not None: | ||||
| res += bias | res += bias | ||||
| if self.nonlinear_mode == "RELU": | |||||
| if self.nonlinear_mode == "relu": | |||||
| res = relu(res) | res = relu(res) | ||||
| return res | return res | ||||
| @@ -138,11 +138,11 @@ class Conv1d(_ConvNd): | |||||
| out_channel // groups, in_channels // groups, *kernel_size)`. | out_channel // groups, in_channels // groups, *kernel_size)`. | ||||
| :param bias: whether to add a bias onto the result of convolution. Default: | :param bias: whether to add a bias onto the result of convolution. Default: | ||||
| True | True | ||||
| :param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||||
| `CROSS_CORRELATION` | |||||
| :param compute_mode: When set to "DEFAULT", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||||
| "Float32" would be used for accumulator and intermediate result, but only | |||||
| :param conv_mode: Supports `cross_correlation`. Default: | |||||
| `cross_correlation` | |||||
| :param compute_mode: When set to "default", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "float32", | |||||
| "float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | effective when input and output are of float16 dtype. | ||||
| Examples: | Examples: | ||||
| @@ -176,8 +176,8 @@ class Conv1d(_ConvNd): | |||||
| dilation: int = 1, | dilation: int = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| bias: bool = True, | bias: bool = True, | ||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| compute_mode: str = "DEFAULT", | |||||
| conv_mode: str = "cross_correlation", | |||||
| compute_mode: str = "default", | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| kernel_size = kernel_size | kernel_size = kernel_size | ||||
| @@ -298,11 +298,11 @@ class Conv2d(_ConvNd): | |||||
| out_channel // groups, in_channels // groups, *kernel_size)`. | out_channel // groups, in_channels // groups, *kernel_size)`. | ||||
| :param bias: whether to add a bias onto the result of convolution. Default: | :param bias: whether to add a bias onto the result of convolution. Default: | ||||
| True | True | ||||
| :param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||||
| `CROSS_CORRELATION` | |||||
| :param compute_mode: When set to "DEFAULT", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||||
| "Float32" would be used for accumulator and intermediate result, but only | |||||
| :param conv_mode: Supports `cross_correlation`. Default: | |||||
| `cross_correlation` | |||||
| :param compute_mode: When set to "default", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "float32", | |||||
| "float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | effective when input and output are of float16 dtype. | ||||
| Examples: | Examples: | ||||
| @@ -336,8 +336,8 @@ class Conv2d(_ConvNd): | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| bias: bool = True, | bias: bool = True, | ||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| compute_mode: str = "DEFAULT", | |||||
| conv_mode: str = "cross_correlation", | |||||
| compute_mode: str = "default", | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| kernel_size = _pair_nonzero(kernel_size) | kernel_size = _pair_nonzero(kernel_size) | ||||
| @@ -436,15 +436,16 @@ class Conv3d(_ConvNd): | |||||
| :param padding: size of the paddings added to the input on both sides of its | :param padding: size of the paddings added to the input on both sides of its | ||||
| spatial dimensions. Only zero-padding is supported. Default: 0 | spatial dimensions. Only zero-padding is supported. Default: 0 | ||||
| :param dilation: dilation of the 3D convolution operation. Default: 1 | :param dilation: dilation of the 3D convolution operation. Default: 1 | ||||
| :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| :param groups: number of groups into which the input and output channels are divided, | |||||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | ||||
| and there would be an extra dimension at the beginning of the weight's | and there would be an extra dimension at the beginning of the weight's | ||||
| shape. Specifically, the shape of weight would be `(groups, | shape. Specifically, the shape of weight would be `(groups, | ||||
| out_channel // groups, in_channels // groups, *kernel_size)`. | out_channel // groups, in_channels // groups, *kernel_size)`. | ||||
| :param bias: whether to add a bias onto the result of convolution. Default: | :param bias: whether to add a bias onto the result of convolution. Default: | ||||
| True | True | ||||
| :param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||||
| `CROSS_CORRELATION` | |||||
| :param conv_mode: Supports `cross_correlation`. Default: | |||||
| `cross_correlation` | |||||
| Examples: | Examples: | ||||
| @@ -477,7 +478,7 @@ class Conv3d(_ConvNd): | |||||
| dilation: Union[int, Tuple[int, int, int]] = 1, | dilation: Union[int, Tuple[int, int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| bias: bool = True, | bias: bool = True, | ||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| conv_mode: str = "cross_correlation", | |||||
| ): | ): | ||||
| kernel_size = _triple_nonzero(kernel_size) | kernel_size = _triple_nonzero(kernel_size) | ||||
| stride = _triple_nonzero(stride) | stride = _triple_nonzero(stride) | ||||
| @@ -566,11 +567,11 @@ class ConvTranspose2d(_ConvNd): | |||||
| out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1 | ||||
| :param bias: wether to add a bias onto the result of convolution. Default: | :param bias: wether to add a bias onto the result of convolution. Default: | ||||
| True | True | ||||
| :param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||||
| `CROSS_CORRELATION` | |||||
| :param compute_mode: When set to "DEFAULT", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||||
| "Float32" would be used for accumulator and intermediate result, but only | |||||
| :param conv_mode: Supports `cross_correlation`. Default: | |||||
| `cross_correlation` | |||||
| :param compute_mode: When set to "default", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "float32", | |||||
| "float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | effective when input and output are of float16 dtype. | ||||
| """ | """ | ||||
| @@ -584,8 +585,8 @@ class ConvTranspose2d(_ConvNd): | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| bias: bool = True, | bias: bool = True, | ||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| compute_mode: str = "DEFAULT", | |||||
| conv_mode: str = "cross_correlation", | |||||
| compute_mode: str = "default", | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| kernel_size = _pair_nonzero(kernel_size) | kernel_size = _pair_nonzero(kernel_size) | ||||
| @@ -679,7 +680,7 @@ class LocalConv2d(Conv2d): | |||||
| padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| conv_mode: str = "cross_correlation", | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| self.input_height = input_height | self.input_height = input_height | ||||
| @@ -758,11 +759,11 @@ class DeformableConv2d(_ConvNd): | |||||
| out_channel // groups, in_channels // groups, *kernel_size)`. | out_channel // groups, in_channels // groups, *kernel_size)`. | ||||
| :param bias: whether to add a bias onto the result of convolution. Default: | :param bias: whether to add a bias onto the result of convolution. Default: | ||||
| True | True | ||||
| :param conv_mode: Supports `CROSS_CORRELATION`. Default: | |||||
| `CROSS_CORRELATION` | |||||
| :param compute_mode: When set to "DEFAULT", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||||
| "Float32" would be used for accumulator and intermediate result, but only | |||||
| :param conv_mode: Supports `cross_correlation`. Default: | |||||
| `cross_correlation` | |||||
| :param compute_mode: When set to "default", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "float32", | |||||
| "float32" would be used for accumulator and intermediate result, but only | |||||
| effective when input and output are of float16 dtype. | effective when input and output are of float16 dtype. | ||||
| """ | """ | ||||
| @@ -776,8 +777,8 @@ class DeformableConv2d(_ConvNd): | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| bias: bool = True, | bias: bool = True, | ||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| compute_mode: str = "DEFAULT", | |||||
| conv_mode: str = "cross_correlation", | |||||
| compute_mode: str = "default", | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| kernel_size = _pair_nonzero(kernel_size) | kernel_size = _pair_nonzero(kernel_size) | ||||
| @@ -24,8 +24,8 @@ class _ConvBnActivation2d(Module): | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| bias: bool = True, | bias: bool = True, | ||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| compute_mode: str = "DEFAULT", | |||||
| conv_mode: str = "cross_correlation", | |||||
| compute_mode: str = "default", | |||||
| eps=1e-5, | eps=1e-5, | ||||
| momentum=0.9, | momentum=0.9, | ||||
| affine=True, | affine=True, | ||||
| @@ -18,58 +18,58 @@ class Elemwise(Module): | |||||
| :param method: the elemwise method, support the following string. | :param method: the elemwise method, support the following string. | ||||
| It will do the normal elemwise operator for float. | It will do the normal elemwise operator for float. | ||||
| * "ADD": a + b | |||||
| * "FUSE_ADD_RELU": max(x+y, 0) | |||||
| * "MUL": x * y | |||||
| * "MIN": min(x, y) | |||||
| * "MAX": max(x, y) | |||||
| * "SUB": x - y | |||||
| * "TRUE_DIV": x / y | |||||
| * "FUSE_ADD_SIGMOID": sigmoid(x + y) | |||||
| * "FUSE_ADD_TANH": tanh(x + y) | |||||
| * "RELU": x > 0 ? x : 0 | |||||
| * "ABS": x > 0 ? x : -x | |||||
| * "SIGMOID": sigmoid(x) | |||||
| * "EXP": exp(x) | |||||
| * "TANH": tanh(x) | |||||
| * "FUSE_MUL_ADD3": x * y + z | |||||
| * "FAST_TANH": x * (27. + x * x) / (27. + 9. * x * x) | |||||
| * "NEGATE": -x | |||||
| * "ACOS": acos(x) | |||||
| * "ASIN": asin(x) | |||||
| * "CEIL": ceil(x) | |||||
| * "COS": cos(x) | |||||
| * "EXPM1": expm1(x) | |||||
| * "FLOOR": floor(x) | |||||
| * "LOG": log(x) | |||||
| * "LOG1P": log1p(x) | |||||
| * "SIN": sin(x) | |||||
| * "ROUND": round(x) | |||||
| * "ERF": erf(x) | |||||
| * "ERFINV": erfinv(x) | |||||
| * "ERFC": erfc(x) | |||||
| * "ERFCINV": erfcinv(x) | |||||
| * "ABS_GRAD": abs_grad | |||||
| * "FLOOR_DIV": floor_div | |||||
| * "MOD": mod | |||||
| * "SIGMOID_GRAD": sigmoid_grad | |||||
| * "SWITCH_GT0": switch_gt0 | |||||
| * "TANH_GRAD": tanh_grad | |||||
| * "LT": less | |||||
| * "LEQ": leq | |||||
| * "EQ": equal | |||||
| * "POW": pow | |||||
| * "LOG_SUM_EXP": log_sum_exp | |||||
| * "FAST_TANH_GRAD": fast_tanh_grad | |||||
| * "ATAN2": atan2 | |||||
| * "COND_LEQ_MOV": cond_leq_mov | |||||
| * "H_SWISH": h_swish | |||||
| * "FUSE_ADD_H_SWISH": h_swish(x+y) | |||||
| * "H_SWISH_GRAD": h_swish_grad | |||||
| * "AND": bool binary: x && y | |||||
| * "OR": bool binary: x || y | |||||
| * "XOR": bool binary: x ^ y | |||||
| * "NOT": bool unary: ~x | |||||
| * "add": a + b | |||||
| * "fuse_add_relu": max(x+y, 0) | |||||
| * "mul": x * y | |||||
| * "min": min(x, y) | |||||
| * "max": max(x, y) | |||||
| * "sub": x - y | |||||
| * "true_div": x / y | |||||
| * "fuse_add_sigmoid": sigmoid(x + y) | |||||
| * "fuse_add_tanh": tanh(x + y) | |||||
| * "relu": x > 0 ? x : 0 | |||||
| * "abs": x > 0 ? x : -x | |||||
| * "sigmoid": sigmoid(x) | |||||
| * "exp": exp(x) | |||||
| * "tanh": tanh(x) | |||||
| * "fuse_mul_add3": x * y + z | |||||
| * "fast_tanh": x * (27. + x * x) / (27. + 9. * x * x) | |||||
| * "negate": -x | |||||
| * "acos": acos(x) | |||||
| * "asin": asin(x) | |||||
| * "ceil": ceil(x) | |||||
| * "cos": cos(x) | |||||
| * "expm1": expm1(x) | |||||
| * "floor": floor(x) | |||||
| * "log": log(x) | |||||
| * "log1p": log1p(x) | |||||
| * "sin": sin(x) | |||||
| * "round": round(x) | |||||
| * "erf": erf(x) | |||||
| * "erfinv": erfinv(x) | |||||
| * "erfc": erfc(x) | |||||
| * "erfcinv": erfcinv(x) | |||||
| * "abs_grad": abs_grad | |||||
| * "floor_div": floor_div | |||||
| * "mod": mod | |||||
| * "sigmoid_grad": sigmoid_grad | |||||
| * "switch_gt0": switch_gt0 | |||||
| * "tanh_grad": tanh_grad | |||||
| * "lt": less | |||||
| * "leq": leq | |||||
| * "eq": equal | |||||
| * "pow": pow | |||||
| * "log_sum_exp": log_sum_exp | |||||
| * "fast_tanh_grad": fast_tanh_grad | |||||
| * "atan2": atan2 | |||||
| * "cond_leq_mov": cond_leq_mov | |||||
| * "h_swish": h_swish | |||||
| * "fuse_add_h_swish": h_swish(x+y) | |||||
| * "h_swish_grad": h_swish_grad | |||||
| * "and": bool binary: x && y | |||||
| * "or": bool binary: x || y | |||||
| * "xor": bool binary: x ^ y | |||||
| * "not": bool unary: ~x | |||||
| """ | """ | ||||
| def __init__(self, method, **kwargs): | def __init__(self, method, **kwargs): | ||||
| @@ -27,7 +27,7 @@ class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule): | |||||
| in_features: int, | in_features: int, | ||||
| out_features: int, | out_features: int, | ||||
| bias: bool = True, | bias: bool = True, | ||||
| nonlinear_mode="IDENTITY", | |||||
| nonlinear_mode="identity", | |||||
| dtype=None, | dtype=None, | ||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| @@ -34,8 +34,8 @@ class Conv2d(Float.Conv2d, QuantizedModule): | |||||
| padding: Union[int, Tuple[int, int]] = 0, | padding: Union[int, Tuple[int, int]] = 0, | ||||
| dilation: Union[int, Tuple[int, int]] = 1, | dilation: Union[int, Tuple[int, int]] = 1, | ||||
| groups: int = 1, | groups: int = 1, | ||||
| conv_mode: str = "CROSS_CORRELATION", | |||||
| compute_mode: str = "DEFAULT", | |||||
| conv_mode: str = "cross_correlation", | |||||
| compute_mode: str = "default", | |||||
| dtype=None, | dtype=None, | ||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| @@ -53,7 +53,7 @@ class Conv2d(Float.Conv2d, QuantizedModule): | |||||
| ) | ) | ||||
| self.output_dtype = dtype | self.output_dtype = dtype | ||||
| def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): | |||||
| def calc_conv_quantized(self, inp, nonlinear_mode="identity"): | |||||
| inp_scale = dtype.get_scale(inp.dtype) | inp_scale = dtype.get_scale(inp.dtype) | ||||
| w_scale = dtype.get_scale(self.weight.dtype) | w_scale = dtype.get_scale(self.weight.dtype) | ||||
| bias_scale = inp_scale * w_scale | bias_scale = inp_scale * w_scale | ||||
| @@ -100,11 +100,11 @@ class Conv2d(Float.Conv2d, QuantizedModule): | |||||
| return qconv | return qconv | ||||
| def forward(self, inp): | def forward(self, inp): | ||||
| return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | |||||
| return self.calc_conv_quantized(inp, nonlinear_mode="identity") | |||||
| class ConvRelu2d(Conv2d): | class ConvRelu2d(Conv2d): | ||||
| r"""Quantized version of :class:`~.qat.ConvRelu2d`.""" | r"""Quantized version of :class:`~.qat.ConvRelu2d`.""" | ||||
| def forward(self, inp): | def forward(self, inp): | ||||
| return self.calc_conv_quantized(inp, nonlinear_mode="RELU") | |||||
| return self.calc_conv_quantized(inp, nonlinear_mode="relu") | |||||
| @@ -50,11 +50,11 @@ class ConvBn2d(_ConvBnActivation2d): | |||||
| r"""Quantized version of :class:`~.qat.ConvBn2d`.""" | r"""Quantized version of :class:`~.qat.ConvBn2d`.""" | ||||
| def forward(self, inp): | def forward(self, inp): | ||||
| return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | |||||
| return self.calc_conv_quantized(inp, nonlinear_mode="identity") | |||||
| class ConvBnRelu2d(_ConvBnActivation2d): | class ConvBnRelu2d(_ConvBnActivation2d): | ||||
| r"""Quantized version of :class:`~.qat.ConvBnRelu2d`.""" | r"""Quantized version of :class:`~.qat.ConvBnRelu2d`.""" | ||||
| def forward(self, inp): | def forward(self, inp): | ||||
| return self.calc_conv_quantized(inp, nonlinear_mode="RELU") | |||||
| return self.calc_conv_quantized(inp, nonlinear_mode="relu") | |||||
| @@ -16,7 +16,7 @@ class Elemwise(QuantizedModule): | |||||
| def __init__(self, method, dtype=None, **kwargs): | def __init__(self, method, dtype=None, **kwargs): | ||||
| super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
| self.method = "Q" + method | |||||
| self.method = "q" + method | |||||
| self.output_dtype = dtype | self.output_dtype = dtype | ||||
| def forward(self, *inps): | def forward(self, *inps): | ||||
| @@ -16,9 +16,9 @@ fi | |||||
| export MEGENGINE_LOGGING_LEVEL="ERROR" | export MEGENGINE_LOGGING_LEVEL="ERROR" | ||||
| pushd $(dirname "${BASH_SOURCE[0]}")/.. >/dev/null | pushd $(dirname "${BASH_SOURCE[0]}")/.. >/dev/null | ||||
| PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest $test_dirs -m 'not isolated_distributed' | |||||
| PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'not isolated_distributed' | |||||
| if [[ "$TEST_PLAT" == cuda ]]; then | if [[ "$TEST_PLAT" == cuda ]]; then | ||||
| echo "test GPU pytest now" | echo "test GPU pytest now" | ||||
| PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest $test_dirs -m 'isolated_distributed' | |||||
| PYTHONPATH="." PY_IGNORE_IMPORTMISMATCH=1 python3 -m pytest -v $test_dirs -m 'isolated_distributed' | |||||
| fi | fi | ||||
| popd >/dev/null | popd >/dev/null | ||||
| @@ -372,7 +372,7 @@ def test_interpolate_fastpath(): | |||||
| x = mge.Tensor(x_np) | x = mge.Tensor(x_np) | ||||
| grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
| y = F.vision.interpolate(x, size=(16, 16), mode="BILINEAR") | |||||
| y = F.vision.interpolate(x, size=(16, 16), mode="bilinear") | |||||
| grad(y, F.ones_like(y)) | grad(y, F.ones_like(y)) | ||||
| np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy()) | np.testing.assert_equal(np.ones(x_np.shape, dtype=np.float32) / 4, x.grad.numpy()) | ||||
| @@ -162,7 +162,7 @@ def test_qadd(): | |||||
| x = tensor(x, dtype=dtype.qint8(inp_scale)) | x = tensor(x, dtype=dtype.qint8(inp_scale)) | ||||
| y = tensor(y, dtype=dtype.qint8(inp_scale)) | y = tensor(y, dtype=dtype.qint8(inp_scale)) | ||||
| result_mge = F.elemwise._elemwise_multi_type( | result_mge = F.elemwise._elemwise_multi_type( | ||||
| x, y, mode="QADD", dtype=dtype.qint8(outp_scale) | |||||
| x, y, mode="qadd", dtype=dtype.qint8(outp_scale) | |||||
| ) | ) | ||||
| result_mge = result_mge.astype("float32").numpy() | result_mge = result_mge.astype("float32").numpy() | ||||
| result_expect = x.astype("float32").numpy() + y.astype("float32").numpy() | result_expect = x.astype("float32").numpy() + y.astype("float32").numpy() | ||||
| @@ -140,8 +140,8 @@ def test_interpolate(): | |||||
| def linear_interpolate(): | def linear_interpolate(): | ||||
| inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) | inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) | ||||
| out = F.vision.interpolate(inp, scale_factor=2.0, mode="LINEAR") | |||||
| out2 = F.vision.interpolate(inp, 4, mode="LINEAR") | |||||
| out = F.vision.interpolate(inp, scale_factor=2.0, mode="linear") | |||||
| out2 = F.vision.interpolate(inp, 4, mode="linear") | |||||
| np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
| out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32) | out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32) | ||||
| @@ -170,13 +170,13 @@ def test_interpolate(): | |||||
| inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) | inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| F.vision.interpolate(inp, scale_factor=2.0, mode="LINEAR") | |||||
| F.vision.interpolate(inp, scale_factor=2.0, mode="linear") | |||||
| def inappropriate_scale_linear_interpolate(): | def inappropriate_scale_linear_interpolate(): | ||||
| inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) | inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="LINEAR") | |||||
| F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="linear") | |||||
| linear_interpolate() | linear_interpolate() | ||||
| many_batch_interpolate() | many_batch_interpolate() | ||||
| @@ -339,18 +339,18 @@ def test_interpolate_fastpath(): | |||||
| ] | ] | ||||
| for inp_shape, target_shape in test_cases: | for inp_shape, target_shape in test_cases: | ||||
| x = tensor(np.random.randn(*inp_shape), dtype=np.float32) | x = tensor(np.random.randn(*inp_shape), dtype=np.float32) | ||||
| out = F.vision.interpolate(x, target_shape, mode="BILINEAR") | |||||
| out = F.vision.interpolate(x, target_shape, mode="bilinear") | |||||
| assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1] | assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1] | ||||
| assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1] | assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1] | ||||
| # check value | # check value | ||||
| x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32) | x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32) | ||||
| out = F.vision.interpolate(x, (15, 5), mode="BILINEAR") | |||||
| out = F.vision.interpolate(x, (15, 5), mode="bilinear") | |||||
| np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32)) | np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32)) | ||||
| np_x = np.arange(32) | np_x = np.arange(32) | ||||
| x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1) | x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1) | ||||
| out = F.vision.interpolate(x, (1, 1), mode="BILINEAR") | |||||
| out = F.vision.interpolate(x, (1, 1), mode="bilinear") | |||||
| np.testing.assert_equal(out.item(), np_x.mean()) | np.testing.assert_equal(out.item(), np_x.mean()) | ||||
| @@ -374,7 +374,7 @@ def test_warp_affine(): | |||||
| inp_shape = (1, 3, 3, 3) | inp_shape = (1, 3, 3, 3) | ||||
| x = tensor(np.arange(27, dtype=np.float32).reshape(inp_shape)) | x = tensor(np.arange(27, dtype=np.float32).reshape(inp_shape)) | ||||
| weightv = [[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]] | weightv = [[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]] | ||||
| outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="WRAP") | |||||
| outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="wrap") | |||||
| res = np.array( | res = np.array( | ||||
| [ | [ | ||||
| [ | [ | ||||
| @@ -509,7 +509,7 @@ def test_conv_bias(): | |||||
| SH, | SH, | ||||
| SW, | SW, | ||||
| has_bias=True, | has_bias=True, | ||||
| nonlinear_mode="IDENTITY", | |||||
| nonlinear_mode="identity", | |||||
| ): | ): | ||||
| inp_v = np.random.normal(size=(N, IC, IH, IW)) | inp_v = np.random.normal(size=(N, IC, IH, IW)) | ||||
| w_v = np.random.normal(size=(OC, IC, KH, KW)) | w_v = np.random.normal(size=(OC, IC, KH, KW)) | ||||
| @@ -541,7 +541,7 @@ def test_conv_bias(): | |||||
| O = F.conv2d( | O = F.conv2d( | ||||
| inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), | inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), | ||||
| ) | ) | ||||
| if nonlinear_mode == "RELU": | |||||
| if nonlinear_mode == "relu": | |||||
| return F.relu(O) | return F.relu(O) | ||||
| else: | else: | ||||
| return O | return O | ||||
| @@ -583,8 +583,8 @@ def test_conv_bias(): | |||||
| run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) | run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) | ||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) | run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) | ||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU") | |||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") | |||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu") | |||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| @@ -23,8 +23,8 @@ def test_module_elemwise(): | |||||
| y = np.random.rand(100).astype("float32") | y = np.random.rand(100).astype("float32") | ||||
| x, y = tensor(x), tensor(y) | x, y = tensor(x), tensor(y) | ||||
| np.testing.assert_almost_equal( | np.testing.assert_almost_equal( | ||||
| test_func("H_SWISH", x), F.hswish(x).numpy(), decimal=6 | |||||
| test_func("h_swish", x), F.hswish(x).numpy(), decimal=6 | |||||
| ) | ) | ||||
| np.testing.assert_almost_equal( | np.testing.assert_almost_equal( | ||||
| test_func("ADD", x, y), F.add(x, y).numpy(), decimal=6 | |||||
| test_func("add", x, y), F.add(x, y).numpy(), decimal=6 | |||||
| ) | ) | ||||
| @@ -133,7 +133,7 @@ def test_dequant_stub(): | |||||
| np.testing.assert_allclose(q, fake_quant_normal.numpy()) | np.testing.assert_allclose(q, fake_quant_normal.numpy()) | ||||
| @pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"]) | |||||
| @pytest.mark.parametrize("kind", ["cos", "relu", "add", "mul", "fuse_add_relu"]) | |||||
| def test_elemwise(kind): | def test_elemwise(kind): | ||||
| normal_net = Float.Elemwise(kind) | normal_net = Float.Elemwise(kind) | ||||
| normal_net.eval() | normal_net.eval() | ||||
| @@ -167,7 +167,7 @@ def test_elemwise(kind): | |||||
| x2_int8 = quant(x2, x2_scale) | x2_int8 = quant(x2, x2_scale) | ||||
| # test correctness of `Float`, `QAT` and `Quantized` | # test correctness of `Float`, `QAT` and `Quantized` | ||||
| if kind in ("ADD", "MUL", "FUSE_ADD_RELU"): | |||||
| if kind in ("add", "mul", "fuse_add_relu"): | |||||
| normal = normal_net(x1, x2) | normal = normal_net(x1, x2) | ||||
| qat_without_fakequant = qat_from_float(x1, x2) | qat_without_fakequant = qat_from_float(x1, x2) | ||||
| fake_quant_normal = fake_quant_act(normal_net(x1, x2), act_scale) | fake_quant_normal = fake_quant_act(normal_net(x1, x2), act_scale) | ||||
| @@ -22,7 +22,7 @@ def fake_quant(x, scale): | |||||
| return x | return x | ||||
| @pytest.mark.parametrize("kind", ["ABS", "SIN", "SUB", "MUL", "FUSE_ADD_TANH"]) | |||||
| @pytest.mark.parametrize("kind", ["abs", "sin", "sub", "mul", "fuse_add_tanh"]) | |||||
| def test_elemwise(kind): | def test_elemwise(kind): | ||||
| x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | ||||
| x1_scale = np.float32(np.random.rand() + 1) | x1_scale = np.float32(np.random.rand() + 1) | ||||
| @@ -39,8 +39,8 @@ def test_elemwise(kind): | |||||
| output_scale = np.float32(np.random.rand() + 1) | output_scale = np.float32(np.random.rand() + 1) | ||||
| output_dtype = dtype.qint8(output_scale) | output_dtype = dtype.qint8(output_scale) | ||||
| quantized_kind = "Q" + kind | |||||
| if kind in ("ABS", "SIN"): | |||||
| quantized_kind = "q" + kind | |||||
| if kind in ("abs", "sin"): | |||||
| desired_out = fake_quant(_elwise(x1, mode=kind), output_scale) | desired_out = fake_quant(_elwise(x1, mode=kind), output_scale) | ||||
| actual_out = ( | actual_out = ( | ||||
| _elemwise_multi_type( | _elemwise_multi_type( | ||||
| @@ -84,7 +84,7 @@ def test_conv_bias(): | |||||
| SH, | SH, | ||||
| SW, | SW, | ||||
| has_bias=True, | has_bias=True, | ||||
| nonlinear_mode="IDENTITY", | |||||
| nonlinear_mode="identity", | |||||
| ): | ): | ||||
| inp_v = np.random.normal(size=(N, IC, IH, IW)) | inp_v = np.random.normal(size=(N, IC, IH, IW)) | ||||
| w_v = np.random.normal(size=(OC, IC, KH, KW)) | w_v = np.random.normal(size=(OC, IC, KH, KW)) | ||||
| @@ -116,7 +116,7 @@ def test_conv_bias(): | |||||
| O = F.conv2d( | O = F.conv2d( | ||||
| inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), | inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), | ||||
| ) | ) | ||||
| if nonlinear_mode == "RELU": | |||||
| if nonlinear_mode == "relu": | |||||
| return F.relu(O) | return F.relu(O) | ||||
| else: | else: | ||||
| return O | return O | ||||
| @@ -158,5 +158,5 @@ def test_conv_bias(): | |||||
| run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) | run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) | ||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) | run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) | ||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU") | |||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") | |||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu") | |||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | |||||
| @@ -280,7 +280,7 @@ def test_convbias(): | |||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def fwd(inp, weight, bias): | def fwd(inp, weight, bias): | ||||
| return F.quantized.conv_bias_activation( | return F.quantized.conv_bias_activation( | ||||
| inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" | |||||
| inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="relu" | |||||
| ) | ) | ||||
| inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) | inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) | ||||
| @@ -297,7 +297,7 @@ def test_batch_convbias(): | |||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def fwd(inp, weight, bias): | def fwd(inp, weight, bias): | ||||
| return F.quantized.batch_conv_bias_activation( | return F.quantized.batch_conv_bias_activation( | ||||
| inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" | |||||
| inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="relu" | |||||
| ) | ) | ||||
| inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) | inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) | ||||
| @@ -358,7 +358,7 @@ def test_warpaffine(): | |||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def fwd(x, weightv): | def fwd(x, weightv): | ||||
| return F.vision.warp_affine(x, weightv, (2, 2), border_mode="WRAP") | |||||
| return F.vision.warp_affine(x, weightv, (2, 2), border_mode="wrap") | |||||
| outp = fwd(x, weightv) | outp = fwd(x, weightv) | ||||
| check_pygraph_dump(fwd, [x, weightv], [outp]) | check_pygraph_dump(fwd, [x, weightv], [outp]) | ||||
| @@ -387,7 +387,7 @@ def test_resize(): | |||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def fwd(x): | def fwd(x): | ||||
| return F.vision.interpolate(x, size=(16, 16), mode="BILINEAR") | |||||
| return F.vision.interpolate(x, size=(16, 16), mode="bilinear") | |||||
| out = fwd(x) | out = fwd(x) | ||||
| check_pygraph_dump(fwd, [x], [out]) | check_pygraph_dump(fwd, [x], [out]) | ||||
| @@ -697,7 +697,7 @@ def test_assert_equal(): | |||||
| def test_elemwise_multitype(): | def test_elemwise_multitype(): | ||||
| op = builtin.ElemwiseMultiType(mode="QADD", dtype=dtype.qint32(2.0)) | |||||
| op = builtin.ElemwiseMultiType(mode="qadd", dtype=dtype.qint32(2.0)) | |||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def fwd(x, y): | def fwd(x, y): | ||||