| @@ -72,7 +72,7 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| init=True, | |||
| feval=[Logloss], | |||
| loss="nll_loss", | |||
| lr_scheduler_type='steplr', | |||
| lr_scheduler_type=None, | |||
| *args, | |||
| **kwargs | |||
| ): | |||
| @@ -93,6 +93,7 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| elif isinstance(model, BaseModel): | |||
| self.model = model | |||
| self.opt_received = optimizer | |||
| if type(optimizer) == str and optimizer.lower() == "adam": | |||
| self.optimizer = torch.optim.Adam | |||
| elif type(optimizer) == str and optimizer.lower() == "sgd": | |||
| @@ -270,8 +271,8 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| self.early_stopping(val_loss, self.model.model) | |||
| if self.early_stopping.early_stop: | |||
| LOGGER.debug("Early stopping at", epoch) | |||
| self.early_stopping.load_checkpoint(self.model.model) | |||
| break | |||
| self.early_stopping.load_checkpoint(self.model.model) | |||
| def predict_only(self, loader): | |||
| """ | |||
| @@ -551,7 +552,7 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| num_features=self.num_features, | |||
| num_classes=self.num_classes, | |||
| num_graph_features=self.num_graph_features, | |||
| optimizer=self.optimizer, | |||
| optimizer=self.opt_received, | |||
| lr=hp["lr"], | |||
| max_epoch=hp["max_epoch"], | |||
| batch_size=hp["batch_size"], | |||
| @@ -559,6 +560,8 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| weight_decay=hp["weight_decay"], | |||
| device=self.device, | |||
| feval=self.feval, | |||
| loss=self.loss_type, | |||
| lr_scheduler_type=self.lr_scheduler_type, | |||
| init=True, | |||
| *self.args, | |||
| **self.kwargs | |||
| @@ -69,7 +69,7 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| init=True, | |||
| feval=[Logloss], | |||
| loss="nll_loss", | |||
| lr_scheduler_type='steplr', | |||
| lr_scheduler_type=None, | |||
| *args, | |||
| **kwargs | |||
| ): | |||
| @@ -86,7 +86,8 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| self.model = MODEL_DICT[model](num_features, num_classes, device, init=init) | |||
| elif isinstance(model, BaseModel): | |||
| self.model = model | |||
| self.opt_received = optimizer | |||
| if type(optimizer) == str and optimizer.lower() == "adam": | |||
| self.optimizer = torch.optim.Adam | |||
| elif type(optimizer) == str and optimizer.lower() == "sgd": | |||
| @@ -243,8 +244,8 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| self.early_stopping(val_loss, self.model.model) | |||
| if self.early_stopping.early_stop: | |||
| LOGGER.debug("Early stopping at %d", epoch) | |||
| self.early_stopping.load_checkpoint(self.model.model) | |||
| break | |||
| self.early_stopping.load_checkpoint(self.model.model) | |||
| def predict_only(self, data, test_mask=None): | |||
| """ | |||
| @@ -499,13 +500,15 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| model=model, | |||
| num_features=self.num_features, | |||
| num_classes=self.num_classes, | |||
| optimizer=self.optimizer, | |||
| optimizer=self.opt_received, | |||
| lr=hp["lr"], | |||
| max_epoch=hp["max_epoch"], | |||
| early_stopping_round=hp["early_stopping_round"], | |||
| device=self.device, | |||
| weight_decay=hp["weight_decay"], | |||
| feval=self.feval, | |||
| loss=self.loss_type, | |||
| lr_scheduler_type=self.lr_scheduler_type, | |||
| init=True, | |||
| *self.args, | |||
| **self.kwargs | |||