import torch from typing import Union def get_device(device: Union[str, torch.device]): """ Get device of passed argument. Will return a torch.device based on passed arguments. Can parse auto, cpu, gpu, cpu:x, gpu:x, etc. If auto is given, will automatically find available devices. Parameters ---------- device: ``str`` or ``torch.device`` The device to parse. If ``auto`` if given, will determine automatically. Returns ------- device: ``torch.device`` The parsed device. """ assert isinstance( device, (str, torch.device) ), "Only support device of str or torch.device, get {} instead".format(device) if device == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(device)