| @@ -9,17 +9,7 @@ | |||||
| # pylint: disable=redefined-builtin | # pylint: disable=redefined-builtin | ||||
| from .elemwise import * | from .elemwise import * | ||||
| from .graph import add_update | from .graph import add_update | ||||
| from .loss import ( | |||||
| binary_cross_entropy, | |||||
| cross_entropy, | |||||
| cross_entropy_with_softmax, | |||||
| hinge_loss, | |||||
| l1_loss, | |||||
| nll_loss, | |||||
| smooth_l1_loss, | |||||
| square_loss, | |||||
| triplet_margin_loss, | |||||
| ) | |||||
| from .loss import * | |||||
| from .math import * | from .math import * | ||||
| from .nn import * | from .nn import * | ||||
| from .quantized import conv_bias_activation | from .quantized import conv_bias_activation | ||||
| @@ -25,10 +25,6 @@ __all__ = [ | |||||
| "asinh", | "asinh", | ||||
| "acosh", | "acosh", | ||||
| "atanh", | "atanh", | ||||
| "bitwise_and", # TODO | |||||
| "bitwise_not", # TODO | |||||
| "bitwise_or", # TODO | |||||
| "bitwise_xor", # TODO | |||||
| "ceil", | "ceil", | ||||
| "clamp", | "clamp", | ||||
| "cos", | "cos", | ||||
| @@ -339,22 +335,6 @@ def right_shift(x, y): | |||||
| return _elwise(x, y, mode="shl") | return _elwise(x, y, mode="shl") | ||||
| def bitwise_and(x, y): | |||||
| raise NotImplementedError | |||||
| def bitwise_not(x): | |||||
| raise NotImplementedError | |||||
| def bitwise_or(x, y): | |||||
| raise NotImplementedError | |||||
| def bitwise_xor(x, y): | |||||
| raise NotImplementedError | |||||
| # logical functions | # logical functions | ||||
| @@ -15,6 +15,14 @@ from .nn import assert_equal, indexing_one_hot | |||||
| from .tensor import where | from .tensor import where | ||||
| from .utils import zero_grad | from .utils import zero_grad | ||||
| __all__ = [ | |||||
| "l1_loss", | |||||
| "square_loss", | |||||
| "cross_entropy_with_softmax", | |||||
| "binary_cross_entropy", | |||||
| "hinge_loss", | |||||
| ] | |||||
| def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | ||||
| r""" | r""" | ||||
| @@ -93,59 +101,6 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor: | |||||
| return (diff ** 2).mean() | return (diff ** 2).mean() | ||||
| def cross_entropy( | |||||
| inp: Tensor, target: Tensor, axis: int = 1, ignore_index: int = -1 | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| Returns the cross entropy loss in a classification problem. | |||||
| .. math:: \textrm{CrossEntropy}(x, y) = - \sum_{i} y_i\log(x_i) | |||||
| :param inp: The input tensor representing the predicted probability. | |||||
| :param label: The input tensor representing the classification label. | |||||
| :param axis: An axis along which cross_entropy will be applied. Default: 1 | |||||
| :param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. Default: -1 | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data_shape = (1, 2) | |||||
| label_shape = (1, ) | |||||
| pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape)) | |||||
| label = tensor(np.ones(label_shape, dtype=np.int32)) | |||||
| loss = F.cross_entropy(pred, label) | |||||
| print(loss.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [0.69] | |||||
| """ | |||||
| raise NotImplementedError | |||||
| # n0 = inp.ndim | |||||
| # n1 = target.ndim | |||||
| # assert n0 == n1 + 1, ( | |||||
| # "target ndim must be one less than input ndim; input_ndim={} " | |||||
| # "target_ndim={}".format(n0, n1) | |||||
| # ) | |||||
| # if ignore_index != -1: | |||||
| # mask = 1 - equal(target, ignore_index) | |||||
| # target = target * mask | |||||
| # loss = -log(indexing_one_hot(inp, target, axis)) * mask | |||||
| # return loss.sum() / maximum(mask.sum(), 1.0) | |||||
| # else: | |||||
| # return -log(indexing_one_hot(inp, target, axis)).mean() | |||||
| def cross_entropy_with_softmax( | def cross_entropy_with_softmax( | ||||
| pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0 | pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0 | ||||
| ) -> Tensor: | ) -> Tensor: | ||||
| @@ -189,49 +144,6 @@ def cross_entropy_with_softmax( | |||||
| return (log(down) - up).mean() | return (log(down) - up).mean() | ||||
| def triplet_margin_loss( | |||||
| anchor: Tensor, positive: Tensor, negative: Tensor, margin: float = 1.0, p: int = 2 | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| Creates a criterion that measures the triplet loss given an input tensors. | |||||
| .. math:: | |||||
| L(a, p, n) = max\left\{d\left(a_{i},p_{i}\right)-d\left(a_{i}, n_{i}\right)+margin, 0\right\},\ | |||||
| d\left(x_{i},y_{i}\right)=\left\|x_{i}-y_{i}\right\|_{p} | |||||
| :param anchor: The input tensor representing the anchor samples. | |||||
| :param positive: The input tensor representing the positive samples. | |||||
| :param negative: The input tensor representing the negative samples. | |||||
| :param margin: Default: 1.0 | |||||
| :param p: The norm degree for pairwise distance. Default: 2.0 | |||||
| """ | |||||
| s0 = anchor.shapeof() | |||||
| s1 = positive.shapeof() | |||||
| s2 = negative.shapeof() | |||||
| assert_equal(s0, s1) | |||||
| assert_equal(s1, s2) | |||||
| n0 = anchor.ndim | |||||
| n1 = positive.ndim | |||||
| n2 = negative.ndim | |||||
| assert n0 == 2 and n1 == 2 and n2 == 2, ( | |||||
| "anchor ndim, positive ndim, and negative ndim must be 2; " | |||||
| "anchor_ndim={} positive_ndim={} negative_ndim={}".format(n0, n1, n2) | |||||
| ) | |||||
| assert p > 0, "a margin with a value greater than 0; p={}".format(p) | |||||
| diff0 = abs(anchor - positive) | |||||
| diff1 = abs(anchor - negative) | |||||
| d1 = power(power(diff0, p).sum(axis=1, keepdims=True), 1 / p) | |||||
| d2 = power(power(diff1, p).sum(axis=1, keepdims=True), 1 / p) | |||||
| loss = maximum(d1 - d2 + margin, 0) | |||||
| return loss.mean() | |||||
| def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: | def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: | ||||
| r"""Function that measures the Binary Cross Entropy between the target and the prediction. | r"""Function that measures the Binary Cross Entropy between the target and the prediction. | ||||
| @@ -244,59 +156,6 @@ def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: | |||||
| return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean() | return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean() | ||||
| def nll_loss( | |||||
| pred: Tensor, label: Tensor, axis: int = 1, ignore_index: int = -1 | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| The negative log likelihood loss. | |||||
| :param pred: The predicted result from model. | |||||
| :param label: The ground truth to compare. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| data_shape = (2, 2) | |||||
| label_shape = (2, ) | |||||
| data = tensor( | |||||
| np.array([[1, 0.5], [0.3, 1.2]], dtype=np.float32).reshape(data_shape), | |||||
| ) | |||||
| label = tensor( | |||||
| np.ones(label_shape, dtype=np.int32) | |||||
| ) | |||||
| pred = F.log(F.softmax(data)) | |||||
| loss1 = F.nll_loss(pred, label) | |||||
| loss2 = F.cross_entropy_with_softmax(data, label) | |||||
| print(loss1.numpy(), loss2.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [0.6576154] [0.6576154] | |||||
| """ | |||||
| raise NotImplementedError | |||||
| # n0 = pred.ndim | |||||
| # n1 = label.ndim | |||||
| # assert n0 == n1 + 1, ( | |||||
| # "target ndim must be one less than input ndim; input_ndim={} " | |||||
| # "target_ndim={}".format(n0, n1) | |||||
| # ) | |||||
| # mask = 1.0 - equal(label, ignore_index) | |||||
| # label = label * mask | |||||
| # loss = indexing_one_hot(pred, label, axis) * mask | |||||
| # return -1.0 * loss.sum() / maximum(mask.sum(), 1.0) | |||||
| def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | ||||
| r""" | r""" | ||||
| Caculate the hinge loss which is often used in SVMs. | Caculate the hinge loss which is often used in SVMs. | ||||
| @@ -337,53 +196,3 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | |||||
| return loss.sum(axis=1).mean() | return loss.sum(axis=1).mean() | ||||
| else: | else: | ||||
| return (loss ** 2).sum(axis=1).mean() | return (loss ** 2).sum(axis=1).mean() | ||||
| def smooth_l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||||
| r""" | |||||
| Caculate the smooth l1 loss proposed in `Fast R-CNN paper by Ross Girshick`. | |||||
| The smooth l1 loss can be described as: | |||||
| .. math:: | |||||
| \text{loss}(x, y) = \frac{1}{n} \sum_{i} l_{i} | |||||
| where :math:`l_{i}` is given by: | |||||
| .. math:: | |||||
| l_{i} = | |||||
| \begin{cases} | |||||
| 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\ | |||||
| |x_i - y_i| - 0.5, & \text{otherwise } | |||||
| \end{cases} | |||||
| :param pred: The predicted result from model. | |||||
| :param label: The ground truth to compare. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]]) | |||||
| label = tensor([[0.4, 1.5, 1.2], [0., 0.1, 2.2]]) | |||||
| loss = F.smooth_l1_loss(pred, label) | |||||
| print(loss.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [0.5608334] | |||||
| """ | |||||
| raise NotImplementedError | |||||
| # diff = abs(pred - label) | |||||
| # l2_loss = 0.5 * (diff ** 2) | |||||
| # l1_loss = diff - 0.5 | |||||
| # mask = diff < 1 | |||||
| # loss = where(mask, l2_loss, l1_loss) | |||||
| # return loss.mean() | |||||
| @@ -21,47 +21,26 @@ from .elemwise import clamp, exp, log, log1p | |||||
| from .tensor import remove_axis, reshape | from .tensor import remove_axis, reshape | ||||
| __all__ = [ | __all__ = [ | ||||
| "all", # TODO | |||||
| "all_close", # TODO | |||||
| "any", # TODO | |||||
| "argmax", | "argmax", | ||||
| "argmin", | "argmin", | ||||
| "argsort", | "argsort", | ||||
| "isinf", | "isinf", | ||||
| "isnan", # TODO | |||||
| "isnan", | |||||
| "max", | "max", | ||||
| "mean", | "mean", | ||||
| "median", # TODO | |||||
| "min", | "min", | ||||
| "norm", | "norm", | ||||
| "normalize", | "normalize", | ||||
| "prod", | "prod", | ||||
| "sign", # TODO | |||||
| "sign", | |||||
| "sort", | "sort", | ||||
| "std", | "std", | ||||
| "sum", | "sum", | ||||
| "topk", | "topk", | ||||
| "unique", # TODO | |||||
| "var", | "var", | ||||
| ] | ] | ||||
| def all(inp): | |||||
| raise NotImplementedError | |||||
| def all_close(inp): | |||||
| raise NotImplementedError | |||||
| def any(inp): | |||||
| raise NotImplementedError | |||||
| def unique(inp): | |||||
| raise NotImplementedError | |||||
| def isnan(inp: Tensor) -> Tensor: | def isnan(inp: Tensor) -> Tensor: | ||||
| r"""Returns a new tensor representing if each element is NaN or not. | r"""Returns a new tensor representing if each element is NaN or not. | ||||
| @@ -77,15 +56,14 @@ def isnan(inp: Tensor) -> Tensor: | |||||
| x = tensor([1, float("nan"), 0]) | x = tensor([1, float("nan"), 0]) | ||||
| print(F.isnan(x)) | |||||
| print(F.isnan(x).numpy()) | |||||
| .. testoutput:: | .. testoutput:: | ||||
| Tensor([0 1 0], dtype=uint8) | |||||
| [False True False] | |||||
| """ | """ | ||||
| raise NotImplementedError | |||||
| # return (inp != inp).astype("uint8") | |||||
| return inp != inp | |||||
| def isinf(inp: Tensor) -> Tensor: | def isinf(inp: Tensor) -> Tensor: | ||||
| @@ -103,18 +81,39 @@ def isinf(inp: Tensor) -> Tensor: | |||||
| x = tensor([1, float("inf"), 0]) | x = tensor([1, float("inf"), 0]) | ||||
| print(F.isinf(x)) | |||||
| print(F.isinf(x).numpy()) | |||||
| .. testoutput:: | .. testoutput:: | ||||
| Tensor([0 1 0], dtype=uint8) | |||||
| [False True False] | |||||
| """ | """ | ||||
| return (abs(inp).astype("float32") == float("inf")).astype("uint8") | |||||
| return abs(inp).astype("float32") == float("inf") | |||||
| def sign(inp: Tensor): | def sign(inp: Tensor): | ||||
| raise NotImplementedError | |||||
| r"""Returns sign of each element in the input tensor. | |||||
| :param: inp | |||||
| :return: a sign tensor. | |||||
| Examples: | |||||
| .. testcode:: | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor([1, -1, 0]) | |||||
| print(F.sign(x).numpy()) | |||||
| .. testoutput:: | |||||
| [ 1 -1 0] | |||||
| """ | |||||
| return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype) | |||||
| def sum( | def sum( | ||||
| @@ -623,7 +623,7 @@ def batch_norm2d( | |||||
| Default: True | Default: True | ||||
| """ | """ | ||||
| from .tensor import expand_dims, squeeze, broadcast | |||||
| from .tensor import add_axis, remove_axis, broadcast | |||||
| def full(value): | def full(value): | ||||
| C = data.shape[1] | C = data.shape[1] | ||||
| @@ -633,7 +633,7 @@ def batch_norm2d( | |||||
| def expand_or_full(x, value): | def expand_or_full(x, value): | ||||
| if x is None: | if x is None: | ||||
| return full(value) | return full(value) | ||||
| return expand_dims(x, [0, 2, 3]) | |||||
| return add_axis(x, [0, 2, 3]) | |||||
| def make_full_if_none(x, value): | def make_full_if_none(x, value): | ||||
| if x is None: | if x is None: | ||||
| @@ -1229,14 +1229,14 @@ def interpolate( | |||||
| return ret | return ret | ||||
| def dropout(inp: Tensor, drop_prob: float, rescale: bool = True) -> Tensor: | |||||
| def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: | |||||
| """ | """ | ||||
| Returns a new tensor where each of the elements are randomly set to zero | Returns a new tensor where each of the elements are randomly set to zero | ||||
| with probability P = ``drop_prob``. Optionally rescale the output tensor. | with probability P = ``drop_prob``. Optionally rescale the output tensor. | ||||
| :param inp: The input tensor | :param inp: The input tensor | ||||
| :param drop_prob: The probability to drop (set to zero) a single element | :param drop_prob: The probability to drop (set to zero) a single element | ||||
| :param rescale: The default behavior of ``dropout`` during training is to rescale the output, | |||||
| :param training: The default behavior of ``dropout`` during training is to rescale the output, | |||||
| then it can be replaced by an :class:`~.Identity` during inference, default to True. | then it can be replaced by an :class:`~.Identity` during inference, default to True. | ||||
| :return: The output tensor | :return: The output tensor | ||||
| @@ -1266,7 +1266,7 @@ def dropout(inp: Tensor, drop_prob: float, rescale: bool = True) -> Tensor: | |||||
| rv = uniform(inp.shape) | rv = uniform(inp.shape) | ||||
| mask = rv > drop_prob | mask = rv > drop_prob | ||||
| inp *= mask.astype(inp.dtype) | inp *= mask.astype(inp.dtype) | ||||
| if rescale: | |||||
| if training: | |||||
| inp *= 1 / (1 - drop_prob) | inp *= 1 / (1 - drop_prob) | ||||
| return inp | return inp | ||||
| @@ -14,6 +14,7 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..core._imperative_rt import CompNode | from ..core._imperative_rt import CompNode | ||||
| from ..core._wrap import device as as_device | |||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops._internal import param_defs as P | from ..core.ops._internal import param_defs as P | ||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| @@ -30,31 +31,32 @@ from ..tensor import Tensor | |||||
| from .elemwise import ceil | from .elemwise import ceil | ||||
| __all__ = [ | __all__ = [ | ||||
| "add_axis", # expand_dims | |||||
| "add_axis", | |||||
| "arange", | "arange", | ||||
| "broadcast", | "broadcast", | ||||
| "concat", | "concat", | ||||
| "cond_take", | "cond_take", | ||||
| "dimshuffle", # transpose, permute | |||||
| "dimshuffle", | |||||
| "expand_dims", | "expand_dims", | ||||
| "eye", | |||||
| "full", | "full", | ||||
| "full_like", | "full_like", | ||||
| "gather", | "gather", | ||||
| "eye", | |||||
| "linspace", | "linspace", | ||||
| "ones", | "ones", | ||||
| "ones_like", | "ones_like", | ||||
| "remove_axis", # squeeze | |||||
| "param_pack_concat", | |||||
| "param_pack_split", | |||||
| "reshape", | |||||
| "remove_axis", | |||||
| "split", | "split", | ||||
| "squeeze", | "squeeze", | ||||
| "stack", | "stack", | ||||
| "reshape", | |||||
| "scatter", | "scatter", | ||||
| "transpose", | |||||
| "where", | "where", | ||||
| "zeros", | "zeros", | ||||
| "zeros_like", | "zeros_like", | ||||
| "param_pack_split", | |||||
| "param_pack_concat", | |||||
| ] | ] | ||||
| @@ -97,6 +99,8 @@ def eye(n: int, *, dtype=None, device: Optional[CompNode] = None) -> Tensor: | |||||
| def full(shape, value, dtype="float32", device=None): | def full(shape, value, dtype="float32", device=None): | ||||
| if isinstance(shape, int): | |||||
| shape = (shape,) | |||||
| if device is None: | if device is None: | ||||
| device = get_default_device() | device = get_default_device() | ||||
| (x,) = Const(value, dtype=dtype, device=device)( | (x,) = Const(value, dtype=dtype, device=device)( | ||||
| @@ -196,16 +200,13 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||||
| return result | return result | ||||
| def concat( | |||||
| inps: Iterable[Tensor], axis: int = 0, device: Optional[CompNode] = None, | |||||
| ) -> Tensor: | |||||
| def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||||
| r""" | r""" | ||||
| Concat some tensors | Concat some tensors | ||||
| :param inps: Input tensors to concat | :param inps: Input tensors to concat | ||||
| :param axis: the dimension over which the tensors are concatenated. Default: 0 | :param axis: the dimension over which the tensors are concatenated. Default: 0 | ||||
| :param device: The comp node output on. Default: None | :param device: The comp node output on. Default: None | ||||
| :param comp_graph: The graph in which output is. Default: None | |||||
| :return: The output tensor | :return: The output tensor | ||||
| Examples: | Examples: | ||||
| @@ -235,7 +236,9 @@ def concat( | |||||
| return inps[0] | return inps[0] | ||||
| dtype = dtype_promotion(inps) | dtype = dtype_promotion(inps) | ||||
| device = get_device(inps) | |||||
| if device is None: | |||||
| device = get_device(inps) | |||||
| device = as_device(device) | |||||
| def convert(x): | def convert(x): | ||||
| return convert_single_value(x, inps, dtype=dtype) | return convert_single_value(x, inps, dtype=dtype) | ||||
| @@ -245,12 +248,13 @@ def concat( | |||||
| return result | return result | ||||
| def stack(inps, axis=0): | |||||
| def stack(inps, axis=0, device=None): | |||||
| """Concats a sequence of tensors along a new axis. | """Concats a sequence of tensors along a new axis. | ||||
| The input tensors must have the same shape. | The input tensors must have the same shape. | ||||
| :param inps: The input tensors. | :param inps: The input tensors. | ||||
| :param axis: Which axis will be concatenated. | :param axis: Which axis will be concatenated. | ||||
| :param device: The comp node output on. Default: None | |||||
| :return: The output concatenated tensor. | :return: The output concatenated tensor. | ||||
| Examples: | Examples: | ||||
| @@ -283,7 +287,7 @@ def stack(inps, axis=0): | |||||
| raise ValueError("All input tensors must have the same shape") | raise ValueError("All input tensors must have the same shape") | ||||
| inps = [add_axis(inp, axis=axis) for inp in inps] | inps = [add_axis(inp, axis=axis) for inp in inps] | ||||
| return concat(inps, axis=axis) | |||||
| return concat(inps, axis=axis, device=device) | |||||
| def split(inp, nsplits_or_sections, axis=0): | def split(inp, nsplits_or_sections, axis=0): | ||||
| @@ -609,7 +613,10 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||||
| def cond_take(mask: Tensor, x: Tensor) -> Tensor: | def cond_take(mask: Tensor, x: Tensor) -> Tensor: | ||||
| r""" | r""" | ||||
| Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened. | |||||
| Take elements from data if specific condition is satisfied on mask. | |||||
| This operator has two outputs: the first is the elements taken, | |||||
| and the second is the indices corresponding to those elements; | |||||
| they are both 1-dimensional. High-dimension input would first be flattened. | |||||
| :param mask: condition param; must be the same shape with data | :param mask: condition param; must be the same shape with data | ||||
| :param x: input tensor from which to take elements | :param x: input tensor from which to take elements | ||||
| @@ -692,6 +699,9 @@ def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||||
| return result | return result | ||||
| transpose = dimshuffle | |||||
| def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | ||||
| r""" | r""" | ||||
| Reshape a tensor to given target shape; total number of logical elements must | Reshape a tensor to given target shape; total number of logical elements must | ||||
| @@ -748,9 +758,6 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | |||||
| return x | return x | ||||
| transpose = dimshuffle | |||||
| AxisAddRemove = builtin.AxisAddRemove | AxisAddRemove = builtin.AxisAddRemove | ||||
| AxisDesc = AxisAddRemove.AxisDesc | AxisDesc = AxisAddRemove.AxisDesc | ||||
| @@ -803,12 +810,14 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
| expand_dims = add_axis | expand_dims = add_axis | ||||
| def remove_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
| def remove_axis( | |||||
| inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None | |||||
| ) -> Tensor: | |||||
| r""" | r""" | ||||
| Remove dimension of shape 1. | Remove dimension of shape 1. | ||||
| :param inp: Input tensor | :param inp: Input tensor | ||||
| :param axis: Place of axis to be removed | |||||
| :param axis: Place of axis to be removed, if None, all axis=1 will be removed. Default: None | |||||
| :return: The output tensor | :return: The output tensor | ||||
| Examples: | Examples: | ||||
| @@ -897,8 +906,8 @@ def linspace( | |||||
| def arange( | def arange( | ||||
| start: Union[int, float, Tensor], | |||||
| end: Union[int, float, Tensor], | |||||
| start: Union[int, float, Tensor] = 0, | |||||
| end: Optional[Union[int, float, Tensor]] = None, | |||||
| step: Union[int, float, Tensor] = 1, | step: Union[int, float, Tensor] = 1, | ||||
| dtype="float32", | dtype="float32", | ||||
| device: Optional[CompNode] = None, | device: Optional[CompNode] = None, | ||||
| @@ -919,7 +928,7 @@ def arange( | |||||
| import numpy as np | import numpy as np | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| a = F.arange(1, 5, 1) | |||||
| a = F.arange(5) | |||||
| print(a.numpy()) | print(a.numpy()) | ||||
| .. testoutput:: | .. testoutput:: | ||||
| @@ -927,6 +936,9 @@ def arange( | |||||
| [1. 2. 3. 4.] | [1. 2. 3. 4.] | ||||
| """ | """ | ||||
| if end is None: | |||||
| start, end = 0, start | |||||
| if isinstance(start, Tensor): | if isinstance(start, Tensor): | ||||
| start = start.astype("float32") | start = start.astype("float32") | ||||
| if isinstance(end, Tensor): | if isinstance(end, Tensor): | ||||
| @@ -1,2 +1,10 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from .sublinear_memory_config import SublinearMemoryConfig | from .sublinear_memory_config import SublinearMemoryConfig | ||||
| from .tracing import exclude_from_trace, trace | from .tracing import exclude_from_trace, trace | ||||
| @@ -6,7 +6,6 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from ..device import get_device_count | from ..device import get_device_count | ||||
| @@ -1,3 +1,11 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import collections | import collections | ||||
| import contextlib | import contextlib | ||||
| import functools | import functools | ||||
| @@ -13,7 +13,7 @@ from typing import Optional, Tuple, Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..functional import full | from ..functional import full | ||||
| from ..random import gaussian, uniform | |||||
| from ..random import normal, uniform | |||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| @@ -50,7 +50,7 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: | |||||
| :param a: Lower bound of the sampling interval | :param a: Lower bound of the sampling interval | ||||
| :param b: Upper bound of the sampling interval | :param b: Upper bound of the sampling interval | ||||
| """ | """ | ||||
| tensor._reset(uniform(tensor.shape, low=a, high=b).astype(tensor.dtype)) | |||||
| tensor._reset(uniform(size=tensor.shape, low=a, high=b).astype(tensor.dtype)) | |||||
| def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: | def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: | ||||
| @@ -61,7 +61,7 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: | |||||
| :param mean: The mean of the normal distribution | :param mean: The mean of the normal distribution | ||||
| :param std: The standard deviation of the normal distribution | :param std: The standard deviation of the normal distribution | ||||
| """ | """ | ||||
| tensor._reset(gaussian(tensor.shape, mean=mean, std=std).astype(tensor.dtype)) | |||||
| tensor._reset(normal(size=tensor.shape, mean=mean, std=std).astype(tensor.dtype)) | |||||
| def calculate_gain( | def calculate_gain( | ||||
| @@ -6,8 +6,8 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from .distribution import gaussian, uniform | |||||
| from .rng import manual_seed | |||||
| from .distribution import normal, uniform | |||||
| from .rng import seed | |||||
| # pylint: disable=undefined-variable | # pylint: disable=undefined-variable | ||||
| del distribution, rng # type: ignore[name-defined] | del distribution, rng # type: ignore[name-defined] | ||||
| @@ -15,13 +15,15 @@ from ..core.tensor import utils | |||||
| from ..core.tensor.core import apply | from ..core.tensor.core import apply | ||||
| from .rng import _random_seed_generator | from .rng import _random_seed_generator | ||||
| __all__ = ["gaussian", "uniform"] | |||||
| __all__ = ["normal", "uniform"] | |||||
| def gaussian(shape: Iterable[int], mean: float = 0, std: float = 1,) -> Tensor: | |||||
| def normal( | |||||
| mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None | |||||
| ) -> Tensor: | |||||
| r"""Random variable with Gaussian distribution $N(\mu, \sigma)$ | r"""Random variable with Gaussian distribution $N(\mu, \sigma)$ | ||||
| :param shape: Output tensor shape | |||||
| :param size: Output tensor size | |||||
| :param mean: The mean or expectation of the distribution | :param mean: The mean or expectation of the distribution | ||||
| :param std: The standard deviation of the distribution (variance = $\sigma ^ 2$) | :param std: The standard deviation of the distribution (variance = $\sigma ^ 2$) | ||||
| :return: The output tensor | :return: The output tensor | ||||
| @@ -33,7 +35,7 @@ def gaussian(shape: Iterable[int], mean: float = 0, std: float = 1,) -> Tensor: | |||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.random as rand | import megengine.random as rand | ||||
| x = rand.gaussian((2, 2), mean=0, std=1) | |||||
| x = rand.normal(mean=0, std=1, size=(2, 2)) | |||||
| print(x.numpy()) | print(x.numpy()) | ||||
| .. testoutput:: | .. testoutput:: | ||||
| @@ -43,17 +45,21 @@ def gaussian(shape: Iterable[int], mean: float = 0, std: float = 1,) -> Tensor: | |||||
| [-1.4939808 -1.5824696 ]] | [-1.4939808 -1.5824696 ]] | ||||
| """ | """ | ||||
| if size is None: | |||||
| size = (1,) | |||||
| seed = _random_seed_generator().__next__() | seed = _random_seed_generator().__next__() | ||||
| op = GaussianRNG(seed=seed, mean=mean, std=std) | op = GaussianRNG(seed=seed, mean=mean, std=std) | ||||
| shape = Tensor(shape, dtype="int32") | |||||
| (output,) = apply(op, shape) | |||||
| size = Tensor(size, dtype="int32") | |||||
| (output,) = apply(op, size) | |||||
| return output | return output | ||||
| def uniform(shape: Iterable[int], low: float = 0, high: float = 1,) -> Tensor: | |||||
| def uniform( | |||||
| low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None | |||||
| ) -> Tensor: | |||||
| r"""Random variable with uniform distribution $U(0, 1)$ | r"""Random variable with uniform distribution $U(0, 1)$ | ||||
| :param shape: Output tensor shape | |||||
| :param size: Output tensor size | |||||
| :param low: Lower range | :param low: Lower range | ||||
| :param high: Upper range | :param high: Upper range | ||||
| :return: The output tensor | :return: The output tensor | ||||
| @@ -65,7 +71,7 @@ def uniform(shape: Iterable[int], low: float = 0, high: float = 1,) -> Tensor: | |||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.random as rand | import megengine.random as rand | ||||
| x = rand.uniform((2, 2)) | |||||
| x = rand.uniform(size=(2, 2)) | |||||
| print(x.numpy()) | print(x.numpy()) | ||||
| .. testoutput:: | .. testoutput:: | ||||
| @@ -77,9 +83,11 @@ def uniform(shape: Iterable[int], low: float = 0, high: float = 1,) -> Tensor: | |||||
| """ | """ | ||||
| assert low < high, "Uniform is not defined when low >= high" | assert low < high, "Uniform is not defined when low >= high" | ||||
| if size is None: | |||||
| size = (1,) | |||||
| seed = _random_seed_generator().__next__() | seed = _random_seed_generator().__next__() | ||||
| op = UniformRNG(seed=seed) | op = UniformRNG(seed=seed) | ||||
| shape = Tensor(shape, dtype="int32") | |||||
| (output,) = apply(op, shape) | |||||
| size = Tensor(size, dtype="int32") | |||||
| (output,) = apply(op, size) | |||||
| return low + (high - low) * output | return low + (high - low) * output | ||||
| @@ -17,11 +17,11 @@ def _random_seed_generator(): | |||||
| if _rng is None: | if _rng is None: | ||||
| from ..distributed.group import get_rank | from ..distributed.group import get_rank | ||||
| manual_seed(seed=int(time.time()) + get_rank()) | |||||
| seed(seed=int(time.time()) + get_rank()) | |||||
| while True: | while True: | ||||
| yield _rng.random_raw() | yield _rng.random_raw() | ||||
| def manual_seed(seed: int): | |||||
| def seed(seed: int): | |||||
| global _rng # pylint: disable=global-statement | global _rng # pylint: disable=global-statement | ||||
| _rng = MT19937(seed=seed) | _rng = MT19937(seed=seed) | ||||
| @@ -55,14 +55,20 @@ def test_clamp(): | |||||
| assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0)) | assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0)) | ||||
| # def test_isnan(): | |||||
| # for case in [[1, float("nan"), 0]]: | |||||
| # assertTensorClose(F.isnan(tensor(case)), np.isnan(case).astype("uint8")) | |||||
| def test_isnan(): | |||||
| for case in [[1, float("nan"), 0]]: | |||||
| assertTensorClose(F.isnan(tensor(case)).numpy(), np.isnan(case)) | |||||
| def test_isinf(): | def test_isinf(): | ||||
| for case in [[1, float("inf"), 0]]: | for case in [[1, float("inf"), 0]]: | ||||
| assertTensorClose(F.isinf(tensor(case)).numpy(), np.isinf(case).astype("uint8")) | |||||
| assertTensorClose(F.isinf(tensor(case)).numpy(), np.isinf(case)) | |||||
| def test_sign(): | |||||
| for case in [[1, -1, 0]]: | |||||
| x = tensor(case) | |||||
| assertTensorClose(F.sign(x).numpy(), np.sign(case).astype(x.dtype)) | |||||
| def test_cosh(): | def test_cosh(): | ||||
| @@ -110,6 +110,14 @@ def test_concat(): | |||||
| opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y])) | opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y])) | ||||
| def test_concat_device(): | |||||
| data1 = tensor(np.random.random((3, 2, 2)).astype("float32"), device="cpu0") | |||||
| data2 = tensor(np.random.random((2, 2, 2)).astype("float32"), device="cpu1") | |||||
| out = F.concat([data1, data2], device="cpu0") | |||||
| assert str(out.device).split(":")[0] == "cpu0" | |||||
| def test_stack(): | def test_stack(): | ||||
| data1 = np.random.random((3, 2, 2)).astype("float32") | data1 = np.random.random((3, 2, 2)).astype("float32") | ||||
| data2 = np.random.random((3, 2, 2)).astype("float32") | data2 = np.random.random((3, 2, 2)).astype("float32") | ||||