Browse Source

[FIX] fix bugs

tags/v0.3.2
bxdd 2 years ago
parent
commit
995db27e86
3 changed files with 6 additions and 5 deletions
  1. +1
    -0
      learnware/config.py
  2. +4
    -4
      learnware/market/heterogeneous/organizer/hetero_map/__init__.py
  3. +1
    -1
      learnware/utils/gpu.py

+ 1
- 0
learnware/config.py View File

@@ -119,6 +119,7 @@ semantic_config = {
_DEFAULT_CONFIG = {
"root_path": ROOT_DIRPATH,
"package_path": PACKAGE_DIRPATH,
"database_path": DATABASE_PATH,
"stdout_path": STDOUT_PATH,
"cache_path": CACHE_PATH,
"logging_level": logging.INFO,


+ 4
- 4
learnware/market/heterogeneous/organizer/hetero_map/__init__.py View File

@@ -80,8 +80,8 @@ class HeteroMap(nn.Module):
"""
super(HeteroMap, self).__init__()

self.cuda_idx = allocate_cuda_idx() if cuda_idx is None else cuda_idx
self.device = choose_device(self.cuda_idx)
cuda_idx = allocate_cuda_idx() if cuda_idx is None else cuda_idx
device = choose_device(cuda_idx)
self.model_args = {
"num_partition": num_partition,
"overlap_ratio": overlap_ratio,
@@ -105,7 +105,7 @@ class HeteroMap(nn.Module):
pad_token_id=feature_tokenizer.pad_token_id,
hidden_dim=hidden_dim,
hidden_dropout_prob=hidden_dropout_prob,
device=self.device,
device=device,
)

self.encoder = TransformerMultiLayer(
@@ -127,7 +127,7 @@ class HeteroMap(nn.Module):
self.base_temperature = base_temperature
self.num_partition = num_partition
self.overlap_ratio = overlap_ratio
self.to(self.device)
self.to(device)

def to(self, device: Union[str, torch.device]):
"""Moves the model and all its submodules to the specified device


+ 1
- 1
learnware/utils/gpu.py View File

@@ -47,7 +47,7 @@ def choose_device(cuda_idx=-1):
return device


def allocate_cuda_idx(self):
def allocate_cuda_idx():
if is_torch_available(verbose=False):
import torch



Loading…
Cancel
Save