|
- import torch
- from typing import Union
-
- def get_device(device: Union[str, torch.device]):
- """
- Get device of passed argument. Will return a torch.device based on passed arguments.
- Can parse auto, cpu, gpu, cpu:x, gpu:x, etc. If auto is given, will automatically find
- available devices.
-
-
- Parameters
- ----------
- device: ``str`` or ``torch.device``
- The device to parse. If ``auto`` if given, will determine automatically.
-
- Returns
- -------
- device: ``torch.device``
- The parsed device.
- """
- assert isinstance(device, (str, torch.device)), "Only support device of str or torch.device, get {} instead".format(device)
- if device == 'auto':
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- return torch.device(device)
|