|
- import torch.nn.functional
- import typing as _typing
-
-
- def activation_func(
- tensor: torch.Tensor, function_name: _typing.Optional[str]
- ) -> torch.Tensor:
- if not isinstance(function_name, str):
- return tensor
- elif function_name == 'linear':
- return tensor
- elif function_name == 'tanh':
- return torch.tanh(tensor)
- elif hasattr(torch.nn.functional, function_name):
- return getattr(torch.nn.functional, function_name)(tensor)
- else:
- raise TypeError(f"PyTorch does not support activation function {function_name}")
|