import importlib import sys from ...backend import DependentBackend from . import _utils from .decoders import ( BaseDecoderMaintainer, DecoderUniversalRegistry, LogSoftmaxDecoderMaintainer, DotProductLinkPredictionDecoderMaintainer ) from .encoders import ( BaseEncoderMaintainer, AutoHomogeneousEncoderMaintainer, EncoderUniversalRegistry, GCNEncoderMaintainer, GATEncoderMaintainer, GINEncoderMaintainer, SAGEEncoderMaintainer ) if DependentBackend.is_dgl(): from .decoders import ( TopKDecoderMaintainer, JKSumPoolDecoderMaintainer ) else: from .decoders import ( DiffPoolDecoderMaintainer, SumPoolMLPDecoderMaintainer ) # load corresponding backend model of subclass def _load_subclass_backend(backend): sub_module = importlib.import_module(f'.{backend.get_backend_name()}', __name__) this = sys.modules[__name__] for api, obj in sub_module.__dict__.items(): setattr(this, api, obj) _load_subclass_backend(DependentBackend) __all__.extend([ "BaseDecoderMaintainer", "DecoderUniversalRegistry", "LogSoftmaxDecoderMaintainer", "DotProductLinkPredictionDecoderMaintainer", "BaseEncoderMaintainer", "AutoHomogeneousEncoderMaintainer", "EncoderUniversalRegistry", "GCNEncoderMaintainer", "GATEncoderMaintainer", "GINEncoderMaintainer", "SAGEEncoderMaintainer" ]) if DependentBackend.is_dgl(): __all__.extend([ "TopKDecoderMaintainer", "JKSumPoolDecoderMaintainer", ]) else: __all__.extend([ "DiffPoolDecoderMaintainer", "SumPoolMLPDecoderMaintainer" ])