|
- """
- Helpful functions.
- """
-
- import numpy as np
- import torch
- import torch.nn.functional as F
-
-
- def unsqueeze(input, dims):
- """ Implement multi-dimension unsqueeze function. """
- if isinstance(dims, (list, tuple)):
- dims = [
- dim if dim >= 0 else dim + len(input.shape) + 1 for dim in dims
- ]
- dims = sorted(dims, reverse=True)
- shape = list(input.shape)
- for dim in dims:
- shape.insert(dim, 1)
- return torch.reshape(input, shape)
- elif isinstance(dims, int):
- return input.unsqueeze(dims)
- else:
- raise ValueError('Warning: type(dims) must in (list, tuple, int)!')
-
-
- def gumbel_softmax(input, tau=1, eps=1e-10):
- """ Basic implement of gumbel_softmax. """
- U = torch.tensor(np.random.rand(*input.shape))
- gumbel = 0.0 - torch.log(eps - torch.log(U + eps))
- y = input + gumbel
- return F.softmax(y / tau)
-
-
- def equal(x, y, dtype=None):
- """ Implement equal in dygraph mode. (paddle) """
- if dtype is None:
- dtype = 'float32'
- if isinstance(x, torch.Tensor):
- x = x.numpy()
- if isinstance(y, torch.Tensor):
- y = y.numpy()
- out = np.equal(x, y).astype(dtype)
- return torch.tensor(out)
-
-
- def not_equal(x, y, dtype=None):
- """ Implement not_equal in dygraph mode. (paddle) """
- return 1 - equal(x, y, dtype)
-
-
- if __name__ == '__main__':
- a = torch.tensor([[1, 1], [3, 4]])
- b = torch.tensor([[1, 1], [3, 4]])
- c = torch.equal(a, a)
- c1 = equal(a, 3)
- d = 1 - torch.not_equal(a, 3).float()
- print(c)
- print(c1)
- print(d)
- e = F.gumbel_softmax(a)
- f = a.unsqueeze(a)
- g = unsqueeze(a, dims=[0, 0, 1])
- print(g, g.shape)
|