Browse Source

black format

tags/v0.3.1
Frozenmad 5 years ago
parent
commit
e6d42bbc4f
2 changed files with 28 additions and 14 deletions
  1. +14
    -7
      autogl/module/train/graph_classification.py
  2. +14
    -7
      autogl/module/train/node_classification.py

+ 14
- 7
autogl/module/train/graph_classification.py View File

@@ -1,6 +1,11 @@
from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping
import torch
from torch.optim.lr_scheduler import StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau
from torch.optim.lr_scheduler import (
StepLR,
MultiStepLR,
ExponentialLR,
ReduceLROnPlateau,
)
import torch.nn.functional as F
from ..model import MODEL_DICT, BaseModel
from .evaluate import Logloss
@@ -229,14 +234,16 @@ class GraphClassificationTrainer(BaseTrainer):

# scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
lr_scheduler_type = self.lr_scheduler_type
if type(lr_scheduler_type) == str and lr_scheduler_type == 'steplr':
if type(lr_scheduler_type) == str and lr_scheduler_type == "steplr":
scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
elif type(lr_scheduler_type) == str and lr_scheduler_type == 'multisteplr':
scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
elif type(lr_scheduler_type) == str and lr_scheduler_type == 'exponentiallr':
elif type(lr_scheduler_type) == str and lr_scheduler_type == "multisteplr":
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
elif type(lr_scheduler_type) == str and lr_scheduler_type == "exponentiallr":
scheduler = ExponentialLR(optimizer, gamma=0.1)
elif type(lr_scheduler_type) == str and lr_scheduler_type == 'reducelronplateau':
scheduler = ReduceLROnPlateau(optimizer, 'min')
elif (
type(lr_scheduler_type) == str and lr_scheduler_type == "reducelronplateau"
):
scheduler = ReduceLROnPlateau(optimizer, "min")
else:
scheduler = None



+ 14
- 7
autogl/module/train/node_classification.py View File

@@ -1,6 +1,11 @@
from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping
import torch
from torch.optim.lr_scheduler import StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau
from torch.optim.lr_scheduler import (
StepLR,
MultiStepLR,
ExponentialLR,
ReduceLROnPlateau,
)
import torch.nn.functional as F
from ..model import MODEL_DICT, BaseModel
from .evaluate import Logloss, Acc, Auc
@@ -86,7 +91,7 @@ class NodeClassificationTrainer(BaseTrainer):
self.model = MODEL_DICT[model](num_features, num_classes, device, init=init)
elif isinstance(model, BaseModel):
self.model = model
self.opt_received = optimizer
if type(optimizer) == str and optimizer.lower() == "adam":
self.optimizer = torch.optim.Adam
@@ -207,14 +212,16 @@ class NodeClassificationTrainer(BaseTrainer):
)
# scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
lr_scheduler_type = self.lr_scheduler_type
if type(lr_scheduler_type) == str and lr_scheduler_type == 'steplr':
if type(lr_scheduler_type) == str and lr_scheduler_type == "steplr":
scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
elif type(lr_scheduler_type) == str and lr_scheduler_type == 'multisteplr':
elif type(lr_scheduler_type) == str and lr_scheduler_type == "multisteplr":
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
elif type(lr_scheduler_type) == str and lr_scheduler_type == 'exponentiallr':
elif type(lr_scheduler_type) == str and lr_scheduler_type == "exponentiallr":
scheduler = ExponentialLR(optimizer, gamma=0.1)
elif type(lr_scheduler_type) == str and lr_scheduler_type == 'reducelronplateau':
scheduler = ReduceLROnPlateau(optimizer, 'min')
elif (
type(lr_scheduler_type) == str and lr_scheduler_type == "reducelronplateau"
):
scheduler = ReduceLROnPlateau(optimizer, "min")
else:
scheduler = None



Loading…
Cancel
Save