|
|
|
@@ -80,8 +80,8 @@ class HeteroMap(nn.Module): |
|
|
|
""" |
|
|
|
super(HeteroMap, self).__init__() |
|
|
|
|
|
|
|
self.cuda_idx = allocate_cuda_idx() if cuda_idx is None else cuda_idx |
|
|
|
self.device = choose_device(self.cuda_idx) |
|
|
|
cuda_idx = allocate_cuda_idx() if cuda_idx is None else cuda_idx |
|
|
|
device = choose_device(cuda_idx) |
|
|
|
self.model_args = { |
|
|
|
"num_partition": num_partition, |
|
|
|
"overlap_ratio": overlap_ratio, |
|
|
|
@@ -105,7 +105,7 @@ class HeteroMap(nn.Module): |
|
|
|
pad_token_id=feature_tokenizer.pad_token_id, |
|
|
|
hidden_dim=hidden_dim, |
|
|
|
hidden_dropout_prob=hidden_dropout_prob, |
|
|
|
device=self.device, |
|
|
|
device=device, |
|
|
|
) |
|
|
|
|
|
|
|
self.encoder = TransformerMultiLayer( |
|
|
|
@@ -127,7 +127,7 @@ class HeteroMap(nn.Module): |
|
|
|
self.base_temperature = base_temperature |
|
|
|
self.num_partition = num_partition |
|
|
|
self.overlap_ratio = overlap_ratio |
|
|
|
self.to(self.device) |
|
|
|
self.to(device) |
|
|
|
|
|
|
|
def to(self, device: Union[str, torch.device]): |
|
|
|
"""Moves the model and all its submodules to the specified device |
|
|
|
|