diff --git a/autogl/module/train/graph_classification.py b/autogl/module/train/graph_classification.py index 87e9aac..5a10689 100644 --- a/autogl/module/train/graph_classification.py +++ b/autogl/module/train/graph_classification.py @@ -1,6 +1,11 @@ from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping import torch -from torch.optim.lr_scheduler import StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau +from torch.optim.lr_scheduler import ( + StepLR, + MultiStepLR, + ExponentialLR, + ReduceLROnPlateau, +) import torch.nn.functional as F from ..model import MODEL_DICT, BaseModel from .evaluate import Logloss @@ -229,14 +234,16 @@ class GraphClassificationTrainer(BaseTrainer): # scheduler = StepLR(optimizer, step_size=100, gamma=0.1) lr_scheduler_type = self.lr_scheduler_type - if type(lr_scheduler_type) == str and lr_scheduler_type == 'steplr': + if type(lr_scheduler_type) == str and lr_scheduler_type == "steplr": scheduler = StepLR(optimizer, step_size=100, gamma=0.1) - elif type(lr_scheduler_type) == str and lr_scheduler_type == 'multisteplr': - scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) - elif type(lr_scheduler_type) == str and lr_scheduler_type == 'exponentiallr': + elif type(lr_scheduler_type) == str and lr_scheduler_type == "multisteplr": + scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) + elif type(lr_scheduler_type) == str and lr_scheduler_type == "exponentiallr": scheduler = ExponentialLR(optimizer, gamma=0.1) - elif type(lr_scheduler_type) == str and lr_scheduler_type == 'reducelronplateau': - scheduler = ReduceLROnPlateau(optimizer, 'min') + elif ( + type(lr_scheduler_type) == str and lr_scheduler_type == "reducelronplateau" + ): + scheduler = ReduceLROnPlateau(optimizer, "min") else: scheduler = None diff --git a/autogl/module/train/node_classification.py b/autogl/module/train/node_classification.py index 061b2b7..b5a69cd 100644 --- a/autogl/module/train/node_classification.py +++ b/autogl/module/train/node_classification.py @@ -1,6 +1,11 @@ from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping import torch -from torch.optim.lr_scheduler import StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau +from torch.optim.lr_scheduler import ( + StepLR, + MultiStepLR, + ExponentialLR, + ReduceLROnPlateau, +) import torch.nn.functional as F from ..model import MODEL_DICT, BaseModel from .evaluate import Logloss, Acc, Auc @@ -86,7 +91,7 @@ class NodeClassificationTrainer(BaseTrainer): self.model = MODEL_DICT[model](num_features, num_classes, device, init=init) elif isinstance(model, BaseModel): self.model = model - + self.opt_received = optimizer if type(optimizer) == str and optimizer.lower() == "adam": self.optimizer = torch.optim.Adam @@ -207,14 +212,16 @@ class NodeClassificationTrainer(BaseTrainer): ) # scheduler = StepLR(optimizer, step_size=100, gamma=0.1) lr_scheduler_type = self.lr_scheduler_type - if type(lr_scheduler_type) == str and lr_scheduler_type == 'steplr': + if type(lr_scheduler_type) == str and lr_scheduler_type == "steplr": scheduler = StepLR(optimizer, step_size=100, gamma=0.1) - elif type(lr_scheduler_type) == str and lr_scheduler_type == 'multisteplr': + elif type(lr_scheduler_type) == str and lr_scheduler_type == "multisteplr": scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) - elif type(lr_scheduler_type) == str and lr_scheduler_type == 'exponentiallr': + elif type(lr_scheduler_type) == str and lr_scheduler_type == "exponentiallr": scheduler = ExponentialLR(optimizer, gamma=0.1) - elif type(lr_scheduler_type) == str and lr_scheduler_type == 'reducelronplateau': - scheduler = ReduceLROnPlateau(optimizer, 'min') + elif ( + type(lr_scheduler_type) == str and lr_scheduler_type == "reducelronplateau" + ): + scheduler = ReduceLROnPlateau(optimizer, "min") else: scheduler = None