Browse Source

Merge pull request #18 from Frozenmad/cora_gat

Fix bugs in trainer
tags/v0.3.1
Frozenmad GitHub 5 years ago
parent
commit
683f4db8e6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 7 deletions
  1. +6
    -3
      autogl/module/train/graph_classification.py
  2. +7
    -4
      autogl/module/train/node_classification.py

+ 6
- 3
autogl/module/train/graph_classification.py View File

@@ -72,7 +72,7 @@ class GraphClassificationTrainer(BaseTrainer):
init=True,
feval=[Logloss],
loss="nll_loss",
lr_scheduler_type='steplr',
lr_scheduler_type=None,
*args,
**kwargs
):
@@ -93,6 +93,7 @@ class GraphClassificationTrainer(BaseTrainer):
elif isinstance(model, BaseModel):
self.model = model

self.opt_received = optimizer
if type(optimizer) == str and optimizer.lower() == "adam":
self.optimizer = torch.optim.Adam
elif type(optimizer) == str and optimizer.lower() == "sgd":
@@ -270,8 +271,8 @@ class GraphClassificationTrainer(BaseTrainer):
self.early_stopping(val_loss, self.model.model)
if self.early_stopping.early_stop:
LOGGER.debug("Early stopping at", epoch)
self.early_stopping.load_checkpoint(self.model.model)
break
self.early_stopping.load_checkpoint(self.model.model)

def predict_only(self, loader):
"""
@@ -551,7 +552,7 @@ class GraphClassificationTrainer(BaseTrainer):
num_features=self.num_features,
num_classes=self.num_classes,
num_graph_features=self.num_graph_features,
optimizer=self.optimizer,
optimizer=self.opt_received,
lr=hp["lr"],
max_epoch=hp["max_epoch"],
batch_size=hp["batch_size"],
@@ -559,6 +560,8 @@ class GraphClassificationTrainer(BaseTrainer):
weight_decay=hp["weight_decay"],
device=self.device,
feval=self.feval,
loss=self.loss_type,
lr_scheduler_type=self.lr_scheduler_type,
init=True,
*self.args,
**self.kwargs


+ 7
- 4
autogl/module/train/node_classification.py View File

@@ -69,7 +69,7 @@ class NodeClassificationTrainer(BaseTrainer):
init=True,
feval=[Logloss],
loss="nll_loss",
lr_scheduler_type='steplr',
lr_scheduler_type=None,
*args,
**kwargs
):
@@ -86,7 +86,8 @@ class NodeClassificationTrainer(BaseTrainer):
self.model = MODEL_DICT[model](num_features, num_classes, device, init=init)
elif isinstance(model, BaseModel):
self.model = model

self.opt_received = optimizer
if type(optimizer) == str and optimizer.lower() == "adam":
self.optimizer = torch.optim.Adam
elif type(optimizer) == str and optimizer.lower() == "sgd":
@@ -243,8 +244,8 @@ class NodeClassificationTrainer(BaseTrainer):
self.early_stopping(val_loss, self.model.model)
if self.early_stopping.early_stop:
LOGGER.debug("Early stopping at %d", epoch)
self.early_stopping.load_checkpoint(self.model.model)
break
self.early_stopping.load_checkpoint(self.model.model)

def predict_only(self, data, test_mask=None):
"""
@@ -499,13 +500,15 @@ class NodeClassificationTrainer(BaseTrainer):
model=model,
num_features=self.num_features,
num_classes=self.num_classes,
optimizer=self.optimizer,
optimizer=self.opt_received,
lr=hp["lr"],
max_epoch=hp["max_epoch"],
early_stopping_round=hp["early_stopping_round"],
device=self.device,
weight_decay=hp["weight_decay"],
feval=self.feval,
loss=self.loss_type,
lr_scheduler_type=self.lr_scheduler_type,
init=True,
*self.args,
**self.kwargs


Loading…
Cancel
Save