Browse Source

Merge branch 'feature/hetero' of https://github.com/Learnware-LAMDA/Learnware into feature/hetero

tags/v0.3.2
bxdd 2 years ago
parent
commit
a5dcda4b0e
1 changed files with 21 additions and 2 deletions
  1. +21
    -2
      learnware/market/heterogeneous/organizer/hetero_map/__init__.py

+ 21
- 2
learnware/market/heterogeneous/organizer/hetero_map/__init__.py View File

@@ -5,6 +5,7 @@ import pandas as pd
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from loguru import logger

from .....specification import HeteroMapTableSpecification, RKMETableSpecification
from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer
@@ -73,7 +74,7 @@ class HeteroMap(nn.Module):
activation : Union[str, Callable], optional
Activation function for transformer layer, by default "relu"
device : Union[str, torch.device], optional
Device to run the model on, by default "cuda:0"
Device to run the model on, by default "cpu"
kwargs:
Additional arguments to be passed to the feature tokenizer
"""
@@ -124,8 +125,26 @@ class HeteroMap(nn.Module):
self.base_temperature = base_temperature
self.num_partition = num_partition
self.overlap_ratio = overlap_ratio
self.device = device
self.to(device)
def to(self, device: Union[str, torch.device]):
"""Moves the model and all its submodules to the specified device

Parameters
----------
device : Union[str, torch.device]
The target device to which the model and its components should be moved.

Returns
-------
HeteroMap
The instance of HeteroMap after moving to the specified device.
"""
super(HeteroMap, self).to(device)
if hasattr(self, 'feature_processor'):
self.feature_processor.device = device
self.device = device
return self

@staticmethod
def load(checkpoint: str = None):


Loading…
Cancel
Save