|
|
|
@@ -29,7 +29,12 @@ from .graph_classification_full import GraphClassificationFullTrainer |
|
|
|
from .node_classification_full import NodeClassificationFullTrainer |
|
|
|
from .link_prediction_full import LinkPredictionTrainer |
|
|
|
from .node_classification_het import NodeClassificationHetTrainer |
|
|
|
from .node_classification_trainer import * |
|
|
|
if DependentBackend.is_pyg(): |
|
|
|
from .node_classification_trainer import ( |
|
|
|
NodeClassificationGraphSAINTTrainer, |
|
|
|
NodeClassificationLayerDependentImportanceSamplingTrainer, |
|
|
|
NodeClassificationNeighborSamplingTrainer |
|
|
|
) |
|
|
|
from .evaluation import get_feval, Acc, Auc, Logloss, Mrr, MicroF1 |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
@@ -42,9 +47,6 @@ __all__ = [ |
|
|
|
"GraphClassificationFullTrainer", |
|
|
|
"NodeClassificationFullTrainer", |
|
|
|
"NodeClassificationHetTrainer", |
|
|
|
"NodeClassificationGraphSAINTTrainer", |
|
|
|
"NodeClassificationLayerDependentImportanceSamplingTrainer", |
|
|
|
"NodeClassificationNeighborSamplingTrainer", |
|
|
|
"LinkPredictionTrainer", |
|
|
|
"Acc", |
|
|
|
"Auc", |
|
|
|
@@ -53,3 +55,10 @@ __all__ = [ |
|
|
|
"MicroF1", |
|
|
|
"get_feval", |
|
|
|
] |
|
|
|
|
|
|
|
if DependentBackend.is_pyg(): |
|
|
|
__all__.extend([ |
|
|
|
"NodeClassificationGraphSAINTTrainer", |
|
|
|
"NodeClassificationLayerDependentImportanceSamplingTrainer", |
|
|
|
"NodeClassificationNeighborSamplingTrainer", |
|
|
|
]) |