diff --git a/autogl/module/train/graph_classification.py b/autogl/module/train/graph_classification.py index 3dc16e7..87e9aac 100644 --- a/autogl/module/train/graph_classification.py +++ b/autogl/module/train/graph_classification.py @@ -72,7 +72,7 @@ class GraphClassificationTrainer(BaseTrainer): init=True, feval=[Logloss], loss="nll_loss", - lr_scheduler_type='steplr', + lr_scheduler_type=None, *args, **kwargs ): @@ -93,6 +93,7 @@ class GraphClassificationTrainer(BaseTrainer): elif isinstance(model, BaseModel): self.model = model + self.opt_received = optimizer if type(optimizer) == str and optimizer.lower() == "adam": self.optimizer = torch.optim.Adam elif type(optimizer) == str and optimizer.lower() == "sgd": @@ -270,8 +271,8 @@ class GraphClassificationTrainer(BaseTrainer): self.early_stopping(val_loss, self.model.model) if self.early_stopping.early_stop: LOGGER.debug("Early stopping at", epoch) - self.early_stopping.load_checkpoint(self.model.model) break + self.early_stopping.load_checkpoint(self.model.model) def predict_only(self, loader): """ @@ -551,7 +552,7 @@ class GraphClassificationTrainer(BaseTrainer): num_features=self.num_features, num_classes=self.num_classes, num_graph_features=self.num_graph_features, - optimizer=self.optimizer, + optimizer=self.opt_received, lr=hp["lr"], max_epoch=hp["max_epoch"], batch_size=hp["batch_size"], @@ -559,6 +560,8 @@ class GraphClassificationTrainer(BaseTrainer): weight_decay=hp["weight_decay"], device=self.device, feval=self.feval, + loss=self.loss_type, + lr_scheduler_type=self.lr_scheduler_type, init=True, *self.args, **self.kwargs diff --git a/autogl/module/train/node_classification.py b/autogl/module/train/node_classification.py index de50ca5..061b2b7 100644 --- a/autogl/module/train/node_classification.py +++ b/autogl/module/train/node_classification.py @@ -69,7 +69,7 @@ class NodeClassificationTrainer(BaseTrainer): init=True, feval=[Logloss], loss="nll_loss", - lr_scheduler_type='steplr', + lr_scheduler_type=None, *args, **kwargs ): @@ -86,7 +86,8 @@ class NodeClassificationTrainer(BaseTrainer): self.model = MODEL_DICT[model](num_features, num_classes, device, init=init) elif isinstance(model, BaseModel): self.model = model - + + self.opt_received = optimizer if type(optimizer) == str and optimizer.lower() == "adam": self.optimizer = torch.optim.Adam elif type(optimizer) == str and optimizer.lower() == "sgd": @@ -243,8 +244,8 @@ class NodeClassificationTrainer(BaseTrainer): self.early_stopping(val_loss, self.model.model) if self.early_stopping.early_stop: LOGGER.debug("Early stopping at %d", epoch) - self.early_stopping.load_checkpoint(self.model.model) break + self.early_stopping.load_checkpoint(self.model.model) def predict_only(self, data, test_mask=None): """ @@ -499,13 +500,15 @@ class NodeClassificationTrainer(BaseTrainer): model=model, num_features=self.num_features, num_classes=self.num_classes, - optimizer=self.optimizer, + optimizer=self.opt_received, lr=hp["lr"], max_epoch=hp["max_epoch"], early_stopping_round=hp["early_stopping_round"], device=self.device, weight_decay=hp["weight_decay"], feval=self.feval, + loss=self.loss_type, + lr_scheduler_type=self.lr_scheduler_type, init=True, *self.args, **self.kwargs