|
|
|
@@ -3,16 +3,17 @@ import typing as _typing |
|
|
|
|
|
|
|
import torch |
|
|
|
import pickle |
|
|
|
|
|
|
|
from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer |
|
|
|
|
|
|
|
from ..model import ( |
|
|
|
EncoderRegistry, |
|
|
|
DecoderRegistry, |
|
|
|
BaseAutoModel, |
|
|
|
BaseAutoEncoder, |
|
|
|
BaseAutoDecoder, |
|
|
|
AutoClassifierDecoder, |
|
|
|
AutoHomogeneousEncoder |
|
|
|
BaseAutoEncoderMaintainer, |
|
|
|
BaseAutoDecoderMaintainer, |
|
|
|
BaseAutoModel |
|
|
|
) |
|
|
|
from autogl.utils.autobase import AutoModule |
|
|
|
from ..hpo import AutoModule |
|
|
|
import logging |
|
|
|
from .evaluation import Evaluation, get_feval, Acc |
|
|
|
from ...utils import get_logger |
|
|
|
@@ -94,8 +95,8 @@ class EarlyStopping: |
|
|
|
class BaseTrainer(AutoModule): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
encoder: _typing.Union[BaseAutoModel, BaseAutoEncoder, None], |
|
|
|
decoder: _typing.Union[BaseAutoDecoder, None], |
|
|
|
encoder: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, None], |
|
|
|
decoder: _typing.Union[BaseAutoDecoderMaintainer, None], |
|
|
|
device: _typing.Union[torch.device, str], |
|
|
|
init: bool = True, |
|
|
|
feval: _typing.Union[ |
|
|
|
@@ -116,39 +117,11 @@ class BaseTrainer(AutoModule): |
|
|
|
init: `bool` |
|
|
|
If True(False), the model will (not) be initialized. |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
self.init: bool = init |
|
|
|
self.encoder = encoder |
|
|
|
self.decoder = None if isinstance(encoder, BaseAutoModel) else decoder |
|
|
|
if type(device) == torch.device or ( |
|
|
|
type(device) == str and device.lower() != "auto" |
|
|
|
): |
|
|
|
self.__device: torch.device = torch.device(device) |
|
|
|
else: |
|
|
|
self.__device: torch.device = torch.device( |
|
|
|
"cuda" |
|
|
|
if torch.cuda.is_available() and torch.cuda.device_count() > 0 |
|
|
|
else "cpu" |
|
|
|
) |
|
|
|
self.__feval: _typing.Sequence[_typing.Type[Evaluation]] = get_feval(feval) |
|
|
|
self.loss: str = loss |
|
|
|
|
|
|
|
@property |
|
|
|
def device(self) -> torch.device: |
|
|
|
return self.__device |
|
|
|
|
|
|
|
@device.setter |
|
|
|
def device(self, __device: _typing.Union[torch.device, str]): |
|
|
|
if type(__device) == torch.device or ( |
|
|
|
type(__device) == str and __device.lower() != "auto" |
|
|
|
): |
|
|
|
self.__device: torch.device = torch.device(__device) |
|
|
|
else: |
|
|
|
self.__device: torch.device = torch.device( |
|
|
|
"cuda" |
|
|
|
if torch.cuda.is_available() and torch.cuda.device_count() > 0 |
|
|
|
else "cpu" |
|
|
|
) |
|
|
|
self.feval = feval |
|
|
|
self.loss = loss |
|
|
|
super().__init__(init, device) |
|
|
|
|
|
|
|
@property |
|
|
|
def feval(self) -> _typing.Sequence[_typing.Type[Evaluation]]: |
|
|
|
@@ -174,13 +147,9 @@ class BaseTrainer(AutoModule): |
|
|
|
""" |
|
|
|
self.device = device |
|
|
|
if self.encoder is not None: |
|
|
|
self.encoder.to(self.device) |
|
|
|
self.encoder.to_device(self.device) |
|
|
|
if self.decoder is not None: |
|
|
|
self.decoder.to(self.device) |
|
|
|
|
|
|
|
def initialize(self): |
|
|
|
"""Initialize the auto model in trainer.""" |
|
|
|
pass |
|
|
|
self.decoder.to_device(self.device) |
|
|
|
|
|
|
|
def get_feval( |
|
|
|
self, return_major: bool = False |
|
|
|
@@ -335,8 +304,8 @@ class _BaseClassificationTrainer(BaseTrainer): |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
encoder: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None], |
|
|
|
decoder: _typing.Union[BaseAutoDecoder, str, None], |
|
|
|
encoder: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None], |
|
|
|
decoder: _typing.Union[BaseAutoDecoderMaintainer, str, None], |
|
|
|
num_features: int, |
|
|
|
num_classes: int, |
|
|
|
last_dim: _typing.Union[int, str] = "auto", |
|
|
|
@@ -371,12 +340,12 @@ class _BaseClassificationTrainer(BaseTrainer): |
|
|
|
return self.__encoder |
|
|
|
|
|
|
|
@encoder.setter |
|
|
|
def encoder(self, enc: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None]): |
|
|
|
def encoder(self, enc: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None]): |
|
|
|
if isinstance(enc, str): |
|
|
|
self.__encoder = EncoderRegistry.get_model(enc)( |
|
|
|
self.num_features, last_dim=self.last_dim, device=self.device, init=self.init |
|
|
|
self.num_features, last_dim=self.last_dim, device=self.device, init=self.initialized |
|
|
|
) |
|
|
|
elif isinstance(enc, BaseAutoEncoder): |
|
|
|
elif isinstance(enc, BaseAutoEncoderMaintainer): |
|
|
|
self.__encoder = enc |
|
|
|
elif isinstance(enc, BaseAutoModel): |
|
|
|
self.__encoder = enc |
|
|
|
@@ -393,16 +362,16 @@ class _BaseClassificationTrainer(BaseTrainer): |
|
|
|
return self.__decoder |
|
|
|
|
|
|
|
@decoder.setter |
|
|
|
def decoder(self, dec: _typing.Union[BaseAutoDecoder, str, None]): |
|
|
|
def decoder(self, dec: _typing.Union[BaseAutoDecoderMaintainer, str, None]): |
|
|
|
if isinstance(self.encoder, BaseAutoModel): |
|
|
|
logging.warn("Ignore passed dec since enc is a whole model") |
|
|
|
self.__decoder = None |
|
|
|
return |
|
|
|
if isinstance(dec, str): |
|
|
|
self.__decoder = DecoderRegistry.get_model(dec)( |
|
|
|
self.num_classes, input_dim=self.last_dim, device=self.device, init=self.init |
|
|
|
self.num_classes, input_dim=self.last_dim, device=self.device, init=self.initialized |
|
|
|
) |
|
|
|
elif isinstance(dec, BaseAutoDecoder): |
|
|
|
elif isinstance(dec, BaseAutoDecoderMaintainer): |
|
|
|
self.__decoder = dec |
|
|
|
elif dec is None: |
|
|
|
self.__decoder = None |
|
|
|
@@ -417,9 +386,19 @@ class _BaseClassificationTrainer(BaseTrainer): |
|
|
|
def num_classes(self, num_classes): |
|
|
|
self.__num_classes = num_classes |
|
|
|
if isinstance(self.encoder, BaseAutoModel): |
|
|
|
self.encoder.num_classes = num_classes |
|
|
|
elif isinstance(self.decoder, AutoClassifierDecoder): |
|
|
|
self.decoder.num_classes = num_classes |
|
|
|
self.encoder.output_dimension = num_classes |
|
|
|
elif isinstance(self.decoder, BaseAutoDecoderMaintainer): |
|
|
|
self.decoder.output_dimension = num_classes |
|
|
|
|
|
|
|
@property |
|
|
|
def last_dim(self): |
|
|
|
return self.__last_dim |
|
|
|
|
|
|
|
@last_dim.setter |
|
|
|
def last_dim(self, dim): |
|
|
|
self.__last_dim = dim |
|
|
|
if isinstance(self.encoder, AutoHomogeneousEncoderMaintainer): |
|
|
|
self.encoder.final_dimension = self.__last_dim |
|
|
|
|
|
|
|
@property |
|
|
|
def num_features(self): |
|
|
|
@@ -429,13 +408,13 @@ class _BaseClassificationTrainer(BaseTrainer): |
|
|
|
def num_features(self, num_features): |
|
|
|
self.__num_features = num_features |
|
|
|
if self.encoder is not None: |
|
|
|
self.encoder.num_features = num_features |
|
|
|
self.encoder.input_dimension = num_features |
|
|
|
|
|
|
|
class BaseNodeClassificationTrainer(_BaseClassificationTrainer): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
encoder: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None], |
|
|
|
decoder: _typing.Union[BaseAutoDecoder, str, None], |
|
|
|
encoder: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None], |
|
|
|
decoder: _typing.Union[BaseAutoDecoderMaintainer, str, None], |
|
|
|
num_features: int, |
|
|
|
num_classes: int, |
|
|
|
device: _typing.Union[torch.device, str, None] = None, |
|
|
|
@@ -448,13 +427,27 @@ class BaseNodeClassificationTrainer(_BaseClassificationTrainer): |
|
|
|
super(BaseNodeClassificationTrainer, self).__init__( |
|
|
|
encoder, decoder, num_features, num_classes, num_classes, device, init, feval, loss |
|
|
|
) |
|
|
|
|
|
|
|
# override num_classes property to support last_dim setting |
|
|
|
@property |
|
|
|
def num_classes(self): |
|
|
|
return self.__num_classes |
|
|
|
|
|
|
|
@num_classes.setter |
|
|
|
def num_classes(self, num_classes): |
|
|
|
self.__num_classes = num_classes |
|
|
|
if isinstance(self.encoder, BaseAutoModel): |
|
|
|
self.encoder.output_dimension = num_classes |
|
|
|
elif isinstance(self.decoder, BaseAutoDecoderMaintainer): |
|
|
|
self.decoder.output_dimension = num_classes |
|
|
|
self.last_dim = num_classes |
|
|
|
|
|
|
|
|
|
|
|
class BaseGraphClassificationTrainer(_BaseClassificationTrainer): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
encoder: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None] = None, |
|
|
|
decoder: _typing.Union[BaseAutoDecoder, str, None] = None, |
|
|
|
encoder: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None] = None, |
|
|
|
decoder: _typing.Union[BaseAutoDecoderMaintainer, str, None] = None, |
|
|
|
num_features: _typing.Optional[int] = None, |
|
|
|
num_classes: _typing.Optional[int] = None, |
|
|
|
num_graph_features: int = 0, |
|
|
|
@@ -477,16 +470,16 @@ class BaseGraphClassificationTrainer(_BaseClassificationTrainer): |
|
|
|
return self.__encoder |
|
|
|
|
|
|
|
@encoder.setter |
|
|
|
def encoder(self, enc: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None]): |
|
|
|
def encoder(self, enc: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None]): |
|
|
|
if isinstance(enc, str): |
|
|
|
self.__encoder = EncoderRegistry.get_model(enc)( |
|
|
|
self.num_features, |
|
|
|
last_dim=self.last_dim, |
|
|
|
num_graph_features=self.num_graph_features, |
|
|
|
device=self.device, |
|
|
|
init=self.init |
|
|
|
init=self.initialized |
|
|
|
) |
|
|
|
elif isinstance(enc, (BaseAutoModel, BaseAutoEncoder)): |
|
|
|
elif isinstance(enc, (BaseAutoModel, BaseAutoEncoderMaintainer)): |
|
|
|
self.__encoder = enc |
|
|
|
if isinstance(enc, BaseAutoModel) and self.decoder is not None: |
|
|
|
logging.warn("will disable decoder since a whole model is passed") |
|
|
|
@@ -502,7 +495,7 @@ class BaseGraphClassificationTrainer(_BaseClassificationTrainer): |
|
|
|
return self.__decoder |
|
|
|
|
|
|
|
@decoder.setter |
|
|
|
def decoder(self, dec: _typing.Union[BaseAutoDecoder, str, None]): |
|
|
|
def decoder(self, dec: _typing.Union[BaseAutoDecoderMaintainer, str, None]): |
|
|
|
if isinstance(self.encoder, BaseAutoModel): |
|
|
|
logging.warn("Ignore passed dec since enc is a whole model") |
|
|
|
self.__decoder = None |
|
|
|
@@ -513,19 +506,34 @@ class BaseGraphClassificationTrainer(_BaseClassificationTrainer): |
|
|
|
input_dim=self.last_dim, |
|
|
|
num_graph_features=self.num_graph_features, |
|
|
|
device=self.device, |
|
|
|
init=self.init |
|
|
|
init=self.initialized |
|
|
|
) |
|
|
|
elif isinstance(dec, (BaseAutoDecoder, None)): |
|
|
|
elif isinstance(dec, (BaseAutoDecoderMaintainer, None)): |
|
|
|
self.__decoder = dec |
|
|
|
else: |
|
|
|
raise ValueError("Invalid decoder setting") |
|
|
|
|
|
|
|
# override num_classes property to support last_dim setting |
|
|
|
@property |
|
|
|
def num_classes(self): |
|
|
|
return self.__num_classes |
|
|
|
|
|
|
|
@num_classes.setter |
|
|
|
def num_classes(self, num_classes): |
|
|
|
self.__num_classes = num_classes |
|
|
|
if isinstance(self.encoder, BaseAutoModel): |
|
|
|
self.encoder.output_dimension = num_classes |
|
|
|
elif isinstance(self.decoder, BaseAutoDecoderMaintainer): |
|
|
|
self.decoder.output_dimension = num_classes |
|
|
|
self.last_dim = num_classes |
|
|
|
|
|
|
|
|
|
|
|
# TODO: according to discussion, link prediction may not belong to classification tasks |
|
|
|
class BaseLinkPredictionTrainer(_BaseClassificationTrainer): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
encoder: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None] = None, |
|
|
|
decoder: _typing.Union[BaseAutoDecoder, str, None] = None, |
|
|
|
encoder: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None] = None, |
|
|
|
decoder: _typing.Union[BaseAutoDecoderMaintainer, str, None] = None, |
|
|
|
num_features: _typing.Optional[int] = None, |
|
|
|
last_dim: _typing.Union[int, str] = "auto", |
|
|
|
device: _typing.Union[torch.device, str, None] = None, |
|
|
|
@@ -546,7 +554,7 @@ class BaseLinkPredictionTrainer(_BaseClassificationTrainer): |
|
|
|
return self.__decoder |
|
|
|
|
|
|
|
@decoder.setter |
|
|
|
def decoder(self, dec: _typing.Union[BaseAutoDecoder, str, None]): |
|
|
|
def decoder(self, dec: _typing.Union[BaseAutoDecoderMaintainer, str, None]): |
|
|
|
if isinstance(self.encoder, BaseAutoModel): |
|
|
|
logging.warn("Ignore passed dec since enc is a whole model") |
|
|
|
self.__decoder = None |
|
|
|
@@ -555,9 +563,9 @@ class BaseLinkPredictionTrainer(_BaseClassificationTrainer): |
|
|
|
self.__decoder = DecoderRegistry.get_model(dec)( |
|
|
|
input_dim=self.last_dim, |
|
|
|
device=self.device, |
|
|
|
init=self.init |
|
|
|
init=self.initialized |
|
|
|
) |
|
|
|
elif isinstance(dec, (BaseAutoDecoder, None)): |
|
|
|
elif isinstance(dec, (BaseAutoDecoderMaintainer, None)): |
|
|
|
self.__decoder = dec |
|
|
|
else: |
|
|
|
raise ValueError("Invalid decoder setting") |