|
|
|
@@ -1,6 +1,11 @@ |
|
|
|
from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping |
|
|
|
import torch |
|
|
|
from torch.optim.lr_scheduler import StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau |
|
|
|
from torch.optim.lr_scheduler import ( |
|
|
|
StepLR, |
|
|
|
MultiStepLR, |
|
|
|
ExponentialLR, |
|
|
|
ReduceLROnPlateau, |
|
|
|
) |
|
|
|
import torch.nn.functional as F |
|
|
|
from ..model import MODEL_DICT, BaseModel |
|
|
|
from .evaluate import Logloss, Acc, Auc |
|
|
|
@@ -86,7 +91,7 @@ class NodeClassificationTrainer(BaseTrainer): |
|
|
|
self.model = MODEL_DICT[model](num_features, num_classes, device, init=init) |
|
|
|
elif isinstance(model, BaseModel): |
|
|
|
self.model = model |
|
|
|
|
|
|
|
|
|
|
|
self.opt_received = optimizer |
|
|
|
if type(optimizer) == str and optimizer.lower() == "adam": |
|
|
|
self.optimizer = torch.optim.Adam |
|
|
|
@@ -207,14 +212,16 @@ class NodeClassificationTrainer(BaseTrainer): |
|
|
|
) |
|
|
|
# scheduler = StepLR(optimizer, step_size=100, gamma=0.1) |
|
|
|
lr_scheduler_type = self.lr_scheduler_type |
|
|
|
if type(lr_scheduler_type) == str and lr_scheduler_type == 'steplr': |
|
|
|
if type(lr_scheduler_type) == str and lr_scheduler_type == "steplr": |
|
|
|
scheduler = StepLR(optimizer, step_size=100, gamma=0.1) |
|
|
|
elif type(lr_scheduler_type) == str and lr_scheduler_type == 'multisteplr': |
|
|
|
elif type(lr_scheduler_type) == str and lr_scheduler_type == "multisteplr": |
|
|
|
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) |
|
|
|
elif type(lr_scheduler_type) == str and lr_scheduler_type == 'exponentiallr': |
|
|
|
elif type(lr_scheduler_type) == str and lr_scheduler_type == "exponentiallr": |
|
|
|
scheduler = ExponentialLR(optimizer, gamma=0.1) |
|
|
|
elif type(lr_scheduler_type) == str and lr_scheduler_type == 'reducelronplateau': |
|
|
|
scheduler = ReduceLROnPlateau(optimizer, 'min') |
|
|
|
elif ( |
|
|
|
type(lr_scheduler_type) == str and lr_scheduler_type == "reducelronplateau" |
|
|
|
): |
|
|
|
scheduler = ReduceLROnPlateau(optimizer, "min") |
|
|
|
else: |
|
|
|
scheduler = None |
|
|
|
|
|
|
|
|