from ...backend import DependentBackend TRAINER_DICT = {} from .base import ( BaseTrainer, Evaluation, BaseNodeClassificationTrainer, BaseGraphClassificationTrainer, BaseLinkPredictionTrainer, BaseNodeClassificationHetTrainer ) def register_trainer(name): def register_trainer_cls(cls): if name in TRAINER_DICT: raise ValueError("Cannot register duplicate trainer ({})".format(name)) if not issubclass(cls, BaseTrainer): raise ValueError( "Trainer ({}: {}) must extend BaseTrainer".format(name, cls.__name__) ) TRAINER_DICT[name] = cls return cls return register_trainer_cls from .graph_classification_full import GraphClassificationFullTrainer from .node_classification_full import NodeClassificationFullTrainer from .link_prediction_full import LinkPredictionTrainer from .node_classification_het import NodeClassificationHetTrainer if DependentBackend.is_pyg(): from .node_classification_trainer import ( NodeClassificationGraphSAINTTrainer, NodeClassificationLayerDependentImportanceSamplingTrainer, NodeClassificationNeighborSamplingTrainer ) from .ssl import GraphCLSemisupervisedTrainer, GraphCLUnsupervisedTrainer from .evaluation import get_feval, Acc, Auc, Logloss, Mrr, MicroF1 __all__ = [ "BaseTrainer", "Evaluation", "BaseGraphClassificationTrainer", "BaseNodeClassificationTrainer", "BaseNodeClassificationHetTrainer", "BaseLinkPredictionTrainer", "GraphClassificationFullTrainer", "NodeClassificationFullTrainer", "NodeClassificationHetTrainer", "LinkPredictionTrainer", "Acc", "Auc", "Logloss", "Mrr", "MicroF1", "get_feval", ] if DependentBackend.is_pyg(): __all__.extend([ "NodeClassificationGraphSAINTTrainer", "NodeClassificationLayerDependentImportanceSamplingTrainer", "NodeClassificationNeighborSamplingTrainer", "GraphCLSemisupervisedTrainer", "GraphCLUnsupervisedTrainer" ])