| @@ -179,7 +179,9 @@ def random_splits_mask_class( | |||
| return dataset | |||
| def graph_cross_validation(dataset, n_splits=10, shuffle=True, random_seed=42, stratify=False): | |||
| def graph_cross_validation( | |||
| dataset, n_splits=10, shuffle=True, random_seed=42, stratify=False | |||
| ): | |||
| r"""Cross validation for graph classification data, returning one fold with specific idx in autogl.datasets or pyg.Dataloader(default) | |||
| Parameters | |||
| @@ -197,7 +199,9 @@ def graph_cross_validation(dataset, n_splits=10, shuffle=True, random_seed=42, s | |||
| random_state for sklearn.model_selection.StratifiedKFold | |||
| """ | |||
| if stratify: | |||
| skf = StratifiedKFold(n_splits=n_splits, shuffle=shuffle, random_state=random_seed) | |||
| skf = StratifiedKFold( | |||
| n_splits=n_splits, shuffle=shuffle, random_state=random_seed | |||
| ) | |||
| else: | |||
| skf = KFold(n_splits=n_splits, shuffle=shuffle, random_state=random_seed) | |||
| idx_list = [] | |||
| @@ -318,7 +322,9 @@ def graph_random_splits(dataset, train_ratio=0.2, val_ratio=0.4, seed=None): | |||
| return dataset | |||
| def graph_get_split(dataset, mask="train", is_loader=True, batch_size=128, num_workers = 0): | |||
| def graph_get_split( | |||
| dataset, mask="train", is_loader=True, batch_size=128, num_workers=0 | |||
| ): | |||
| r"""Get train/test dataset/dataloader after cross validation. | |||
| Parameters | |||
| @@ -340,7 +346,11 @@ def graph_get_split(dataset, mask="train", is_loader=True, batch_size=128, num_w | |||
| dataset, "%s_split" % (mask) | |||
| ), "Given dataset do not have %s split" % (mask) | |||
| if is_loader: | |||
| return DataLoader(getattr(dataset, "%s_split" % (mask)), batch_size=batch_size, num_workers = num_workers) | |||
| return DataLoader( | |||
| getattr(dataset, "%s_split" % (mask)), | |||
| batch_size=batch_size, | |||
| num_workers=num_workers, | |||
| ) | |||
| else: | |||
| return getattr(dataset, "%s_split" % (mask)) | |||
| @@ -14,7 +14,7 @@ def register_model(name): | |||
| ) | |||
| MODEL_DICT[name] = cls | |||
| return cls | |||
| return register_model_cls | |||
| @@ -9,6 +9,7 @@ import typing as _typing | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from copy import deepcopy | |||
| base_approach_logger: logging.Logger = logging.getLogger("BaseModel") | |||
| @@ -49,7 +50,11 @@ class BaseModel: | |||
| def to(self, device): | |||
| if isinstance(device, (str, torch.device)): | |||
| self.device = device | |||
| if hasattr(self, "model") and self.model is not None and isinstance(self.model, torch.nn.Module): | |||
| if ( | |||
| hasattr(self, "model") | |||
| and self.model is not None | |||
| and isinstance(self.model, torch.nn.Module) | |||
| ): | |||
| self.model.to(self.device) | |||
| return self | |||
| @@ -95,28 +100,28 @@ class _BaseBaseModel: | |||
| designed to implement some basic functionality of BaseModel. | |||
| -- Designed by ZiXin Sun | |||
| """ | |||
| @classmethod | |||
| def __formulate_device( | |||
| cls, device: _typing.Union[str, torch.device] = ... | |||
| cls, device: _typing.Union[str, torch.device] = ... | |||
| ) -> torch.device: | |||
| if ( | |||
| type(device) == torch.device or | |||
| (type(device) == str and device.strip().lower() != "auto") | |||
| if type(device) == torch.device or ( | |||
| type(device) == str and device.strip().lower() != "auto" | |||
| ): | |||
| return torch.device(device) | |||
| elif torch.cuda.is_available() and torch.cuda.device_count() > 0: | |||
| return torch.device("cuda") | |||
| else: | |||
| return torch.device("cpu") | |||
| @property | |||
| def device(self) -> torch.device: | |||
| return self.__device | |||
| @device.setter | |||
| def device(self, __device: _typing.Union[str, torch.device, None]): | |||
| self.__device: torch.device = self.__formulate_device(__device) | |||
| @property | |||
| def model(self) -> _typing.Optional[torch.nn.Module]: | |||
| if self._model is None: | |||
| @@ -124,19 +129,18 @@ class _BaseBaseModel: | |||
| "property of model NOT initialized before accessing" | |||
| ) | |||
| return self._model | |||
| @model.setter | |||
| def model(self, _model: torch.nn.Module) -> None: | |||
| if not isinstance(_model, torch.nn.Module): | |||
| raise TypeError( | |||
| "the property of model MUST be an instance of " | |||
| "torch.nn.Module" | |||
| "the property of model MUST be an instance of " "torch.nn.Module" | |||
| ) | |||
| self._model = _model | |||
| def _initialize(self): | |||
| raise NotImplementedError | |||
| def initialize(self) -> bool: | |||
| """ | |||
| Initialize the model in case that the model has NOT been initialized | |||
| @@ -147,7 +151,7 @@ class _BaseBaseModel: | |||
| self.__is_initialized = True | |||
| return True | |||
| return False | |||
| # def to(self, *args, **kwargs): | |||
| # """ | |||
| # Due to the signature of to() method in class BaseApproach | |||
| @@ -161,17 +165,18 @@ class _BaseBaseModel: | |||
| # :return: self | |||
| # """ | |||
| # return super(_BaseBaseModel, self).to(*args, **kwargs) | |||
| def forward(self, *args, **kwargs): | |||
| if self.model is not None and isinstance(self.model, torch.nn.Module): | |||
| return self.model(*args, **kwargs) | |||
| else: | |||
| raise NotImplementedError | |||
| def __init__( | |||
| self, model: _typing.Optional[torch.nn.Module] = None, | |||
| initialize: bool = False, | |||
| device: _typing.Union[str, torch.device] = ... | |||
| self, | |||
| model: _typing.Optional[torch.nn.Module] = None, | |||
| initialize: bool = False, | |||
| device: _typing.Union[str, torch.device] = ..., | |||
| ): | |||
| if type(initialize) != bool: | |||
| raise TypeError | |||
| @@ -188,64 +193,65 @@ class _BaseModel(_BaseBaseModel, BaseModel): | |||
| The upcoming root base class for Model, i.e. BaseModel | |||
| -- Designed by ZiXin Sun | |||
| """ | |||
| # todo: Deprecate and remove the legacy class "BaseModel", | |||
| # then rename this class to "BaseModel", | |||
| # correspondingly, this class will no longer extend | |||
| # the legacy class "BaseModel" after the removal. | |||
| def _initialize(self): | |||
| raise NotImplementedError | |||
| def to(self, device: torch.device): | |||
| self.device = device | |||
| if self.model is not None and isinstance(self.model, torch.nn.Module): | |||
| self.model.to(self.device) | |||
| return super().to(device) | |||
| @property | |||
| def space(self) -> _typing.Sequence[_typing.Dict[str, _typing.Any]]: | |||
| # todo: deprecate and remove in future major version | |||
| return self.__hyper_parameter_space | |||
| @property | |||
| def hyper_parameter_space(self): | |||
| return self.__hyper_parameter_space | |||
| @hyper_parameter_space.setter | |||
| def hyper_parameter_space( | |||
| self, space: _typing.Sequence[_typing.Dict[str, _typing.Any]] | |||
| self, space: _typing.Sequence[_typing.Dict[str, _typing.Any]] | |||
| ): | |||
| self.__hyper_parameter_space = space | |||
| @property | |||
| def hyper_parameter(self) -> _typing.Dict[str, _typing.Any]: | |||
| return self.__hyper_parameter | |||
| @hyper_parameter.setter | |||
| def hyper_parameter(self, _hyper_parameter: _typing.Dict[str, _typing.Any]): | |||
| if not isinstance(_hyper_parameter, dict): | |||
| raise TypeError | |||
| self.__hyper_parameter = _hyper_parameter | |||
| def get_hyper_parameter(self) -> _typing.Dict[str, _typing.Any]: | |||
| """ | |||
| todo: consider deprecating this trivial getter method in the future | |||
| :return: copied hyper parameter | |||
| """ | |||
| return copy.deepcopy(self.__hyper_parameter) | |||
| def __init__( | |||
| self, model: _typing.Optional[torch.nn.Module] = None, | |||
| initialize: bool = False, | |||
| hyper_parameter_space: _typing.Sequence[_typing.Any] = ..., | |||
| hyper_parameter: _typing.Dict[str, _typing.Any] = ..., | |||
| device: _typing.Union[str, torch.device] = ... | |||
| self, | |||
| model: _typing.Optional[torch.nn.Module] = None, | |||
| initialize: bool = False, | |||
| hyper_parameter_space: _typing.Sequence[_typing.Any] = ..., | |||
| hyper_parameter: _typing.Dict[str, _typing.Any] = ..., | |||
| device: _typing.Union[str, torch.device] = ..., | |||
| ): | |||
| if type(initialize) != bool: | |||
| raise TypeError | |||
| super(_BaseModel, self).__init__(model, initialize, device) | |||
| if ( | |||
| hyper_parameter_space != Ellipsis and | |||
| isinstance(hyper_parameter_space, _typing.Sequence) | |||
| if hyper_parameter_space != Ellipsis and isinstance( | |||
| hyper_parameter_space, _typing.Sequence | |||
| ): | |||
| self.__hyper_parameter_space: _typing.Sequence[ | |||
| _typing.Dict[str, _typing.Any] | |||
| @@ -266,27 +272,30 @@ class _BaseModel(_BaseBaseModel, BaseModel): | |||
| class ClassificationModel(_BaseModel): | |||
| def _initialize(self): | |||
| raise NotImplementedError | |||
| def from_hyper_parameter( | |||
| self, hyper_parameter: _typing.Dict[str, _typing.Any] | |||
| self, hyper_parameter: _typing.Dict[str, _typing.Any] | |||
| ) -> "ClassificationModel": | |||
| new_model: ClassificationModel = self.__class__( | |||
| num_features=self.num_features, | |||
| num_classes=self.num_classes, | |||
| device=self.device, | |||
| init=False | |||
| init=False, | |||
| ) | |||
| _hyper_parameter = self.hyper_parameter | |||
| _hyper_parameter.update(hyper_parameter) | |||
| new_model.hyper_parameter = _hyper_parameter | |||
| new_model.initialize() | |||
| return new_model | |||
| def __init__( | |||
| self, num_features: int = ..., num_classes: int = ..., | |||
| num_graph_features: int = ..., | |||
| device: _typing.Union[str, torch.device] = ..., | |||
| init: bool = False, **kwargs | |||
| self, | |||
| num_features: int = ..., | |||
| num_classes: int = ..., | |||
| num_graph_features: int = ..., | |||
| device: _typing.Union[str, torch.device] = ..., | |||
| init: bool = False, | |||
| **kwargs | |||
| ): | |||
| if "initialize" in kwargs: | |||
| del kwargs["initialize"] | |||
| @@ -308,11 +317,11 @@ class ClassificationModel(_BaseModel): | |||
| self.__num_graph_features: int = 0 | |||
| else: | |||
| self.__num_graph_features: int = 0 | |||
| @property | |||
| def num_classes(self) -> int: | |||
| return self.__num_classes | |||
| @num_classes.setter | |||
| def num_classes(self, __num_classes: int): | |||
| if type(__num_classes) != int: | |||
| @@ -320,11 +329,11 @@ class ClassificationModel(_BaseModel): | |||
| if not __num_classes > 0: | |||
| raise ValueError | |||
| self.__num_classes = __num_classes if __num_classes > 0 else 0 | |||
| @property | |||
| def num_features(self) -> int: | |||
| return self.__num_features | |||
| @num_features.setter | |||
| def num_features(self, __num_features: int): | |||
| if type(__num_features) != int: | |||
| @@ -332,27 +341,27 @@ class ClassificationModel(_BaseModel): | |||
| if not __num_features > 0: | |||
| raise ValueError | |||
| self.__num_features = __num_features if __num_features > 0 else 0 | |||
| def get_num_classes(self) -> int: | |||
| # todo: consider replacing with property with getter and setter | |||
| return self.__num_classes | |||
| def set_num_classes(self, num_classes: int) -> None: | |||
| # todo: consider replacing with property with getter and setter | |||
| if type(num_classes) != int: | |||
| raise TypeError | |||
| self.__num_classes = num_classes if num_classes > 0 else 0 | |||
| def get_num_features(self) -> int: | |||
| # todo: consider replacing with property with getter and setter | |||
| return self.__num_features | |||
| def set_num_features(self, num_features: int): | |||
| # todo: consider replacing with property with getter and setter | |||
| if type(num_features) != int: | |||
| raise TypeError | |||
| self.__num_features = num_features if num_features > 0 else 0 | |||
| def set_num_graph_features(self, num_graph_features: int): | |||
| # todo: consider replacing with property with getter and setter | |||
| if type(num_graph_features) != int: | |||
| @@ -11,9 +11,12 @@ LOGGER = get_logger("GCNModel") | |||
| class GCN(torch.nn.Module): | |||
| def __init__( | |||
| self, num_features: int, num_classes: int, | |||
| hidden_features: _typing.Sequence[int], | |||
| dropout: float, activation_name: str | |||
| self, | |||
| num_features: int, | |||
| num_classes: int, | |||
| hidden_features: _typing.Sequence[int], | |||
| dropout: float, | |||
| activation_name: str, | |||
| ): | |||
| super().__init__() | |||
| self.__convolution_layers: torch.nn.ModuleList = torch.nn.ModuleList() | |||
| @@ -25,31 +28,33 @@ class GCN(torch.nn.Module): | |||
| ) | |||
| ) | |||
| else: | |||
| self.__convolution_layers.append(torch_geometric.nn.GCNConv( | |||
| num_features, hidden_features[0], add_self_loops=False | |||
| )) | |||
| self.__convolution_layers.append( | |||
| torch_geometric.nn.GCNConv( | |||
| num_features, hidden_features[0], add_self_loops=False | |||
| ) | |||
| ) | |||
| for i in range(len(hidden_features)): | |||
| self.__convolution_layers.append( | |||
| torch_geometric.nn.GCNConv( | |||
| hidden_features[i], hidden_features[i + 1] | |||
| ) if i + 1 < len(hidden_features) | |||
| else torch_geometric.nn.GCNConv( | |||
| hidden_features[i], num_classes | |||
| ) | |||
| if i + 1 < len(hidden_features) | |||
| else torch_geometric.nn.GCNConv(hidden_features[i], num_classes) | |||
| ) | |||
| self.__dropout: float = dropout | |||
| self.__activation_name: str = activation_name | |||
| def __layer_wise_forward(self, data): | |||
| # todo: Implement this forward method | |||
| # in case that data.edge_indexes property is provided | |||
| # for Layer-wise and Node-wise sampled training | |||
| raise NotImplementedError | |||
| def __basic_forward( | |||
| self, x: torch.Tensor, | |||
| edge_index: torch.Tensor, | |||
| edge_weight: _typing.Optional[torch.Tensor] = None | |||
| self, | |||
| x: torch.Tensor, | |||
| edge_index: torch.Tensor, | |||
| edge_weight: _typing.Optional[torch.Tensor] = None, | |||
| ) -> torch.Tensor: | |||
| for layer_index in range(len(self.__convolution_layers)): | |||
| x: torch.Tensor = self.__convolution_layers[layer_index]( | |||
| @@ -57,31 +62,32 @@ class GCN(torch.nn.Module): | |||
| ) | |||
| if layer_index + 1 < len(self.__convolution_layers): | |||
| x = activate_func(x, self.__activation_name) | |||
| x = torch.nn.functional.dropout(x, p=self.__dropout, training=self.training) | |||
| x = torch.nn.functional.dropout( | |||
| x, p=self.__dropout, training=self.training | |||
| ) | |||
| return torch.nn.functional.log_softmax(x, dim=1) | |||
| def forward(self, data) -> torch.Tensor: | |||
| if ( | |||
| hasattr(data, "edge_indexes") and | |||
| getattr(data, "edge_indexes") is not None | |||
| ): | |||
| if hasattr(data, "edge_indexes") and getattr(data, "edge_indexes") is not None: | |||
| return self.__layer_wise_forward(data) | |||
| else: | |||
| if not (hasattr(data, "x") and hasattr(data, "edge_index")): | |||
| raise AttributeError | |||
| if not ( | |||
| type(getattr(data, "x")) == torch.Tensor and | |||
| type(getattr(data, "edge_index")) == torch.Tensor | |||
| type(getattr(data, "x")) == torch.Tensor | |||
| and type(getattr(data, "edge_index")) == torch.Tensor | |||
| ): | |||
| raise TypeError | |||
| x: torch.Tensor = getattr(data, "x") | |||
| edge_index: torch.LongTensor = getattr(data, "edge_index") | |||
| if ( | |||
| hasattr(data, "edge_weight") and | |||
| type(getattr(data, "edge_weight")) == torch.Tensor and | |||
| getattr(data, "edge_weight").size() == (edge_index.size(1),) | |||
| hasattr(data, "edge_weight") | |||
| and type(getattr(data, "edge_weight")) == torch.Tensor | |||
| and getattr(data, "edge_weight").size() == (edge_index.size(1),) | |||
| ): | |||
| edge_weight: _typing.Optional[torch.Tensor] = getattr(data, "edge_weight") | |||
| edge_weight: _typing.Optional[torch.Tensor] = getattr( | |||
| data, "edge_weight" | |||
| ) | |||
| else: | |||
| edge_weight: _typing.Optional[torch.Tensor] = None | |||
| return self.__basic_forward(x, edge_index, edge_weight) | |||
| @@ -120,18 +126,22 @@ class AutoGCN(ClassificationModel): | |||
| """ | |||
| def __init__( | |||
| self, num_features: int = ..., num_classes: int = ..., | |||
| device: _typing.Union[str, torch.device] = ..., | |||
| init: bool = False, **kwargs | |||
| self, | |||
| num_features: int = ..., | |||
| num_classes: int = ..., | |||
| device: _typing.Union[str, torch.device] = ..., | |||
| init: bool = False, | |||
| **kwargs | |||
| ) -> None: | |||
| super(AutoGCN, self).__init__( | |||
| num_features, num_classes, device=device, init=init, **kwargs | |||
| ) | |||
| def _initialize(self): | |||
| self.model = GCN( | |||
| self.num_features, self.num_classes, | |||
| self.num_features, | |||
| self.num_classes, | |||
| self.hyper_parameter.get("hidden"), | |||
| self.hyper_parameter.get("dropout"), | |||
| self.hyper_parameter.get("act") | |||
| self.hyper_parameter.get("act"), | |||
| ).to(self.device) | |||
| @@ -9,19 +9,23 @@ from .base import BaseModel, activate_func | |||
| class GraphSAGE(torch.nn.Module): | |||
| def __init__( | |||
| self, num_features: int, num_classes: int, | |||
| hidden_features: _typing.Sequence[int], | |||
| dropout: float, activation_name: str, | |||
| aggr: str = "mean", **kwargs | |||
| self, | |||
| num_features: int, | |||
| num_classes: int, | |||
| hidden_features: _typing.Sequence[int], | |||
| dropout: float, | |||
| activation_name: str, | |||
| aggr: str = "mean", | |||
| **kwargs | |||
| ): | |||
| super(GraphSAGE, self).__init__() | |||
| if type(aggr) != str: | |||
| raise TypeError | |||
| if aggr not in ("add", "max", "mean"): | |||
| aggr = "mean" | |||
| self.__convolution_layers: torch.nn.ModuleList = torch.nn.ModuleList() | |||
| num_layers: int = len(hidden_features) + 1 | |||
| if num_layers == 1: | |||
| self.__convolution_layers.append( | |||
| @@ -42,7 +46,7 @@ class GraphSAGE(torch.nn.Module): | |||
| ) | |||
| self.__dropout: float = dropout | |||
| self.__activation_name: str = activation_name | |||
| def __full_forward(self, data): | |||
| x: torch.Tensor = getattr(data, "x") | |||
| edge_index: torch.Tensor = getattr(data, "edge_index") | |||
| @@ -52,24 +56,26 @@ class GraphSAGE(torch.nn.Module): | |||
| x = activate_func(x, self.__activation_name) | |||
| x = F.dropout(x, p=self.__dropout, training=self.training) | |||
| return F.log_softmax(x, dim=1) | |||
| def __distributed_forward(self, data): | |||
| x: torch.Tensor = getattr(data, "x") | |||
| edge_indexes: _typing.Sequence[torch.Tensor] = getattr(data, "edge_indexes") | |||
| if len(edge_indexes) != len(self.__convolution_layers): | |||
| raise AttributeError | |||
| for layer_index in range(len(self.__convolution_layers)): | |||
| x: torch.Tensor = self.__convolution_layers[layer_index](x, edge_indexes[layer_index]) | |||
| x: torch.Tensor = self.__convolution_layers[layer_index]( | |||
| x, edge_indexes[layer_index] | |||
| ) | |||
| if layer_index + 1 < len(self.__convolution_layers): | |||
| x = activate_func(x, self.__activation_name) | |||
| x = F.dropout(x, p=self.__dropout, training=self.training) | |||
| return F.log_softmax(x, dim=1) | |||
| def forward(self, data): | |||
| if ( | |||
| hasattr(data, "edge_indexes") and | |||
| isinstance(getattr(data, "edge_indexes"), _typing.Sequence) and | |||
| len(getattr(data, "edge_indexes")) == len(self.__convolution_layers) | |||
| hasattr(data, "edge_indexes") | |||
| and isinstance(getattr(data, "edge_indexes"), _typing.Sequence) | |||
| and len(getattr(data, "edge_indexes")) == len(self.__convolution_layers) | |||
| ): | |||
| return self.__distributed_forward(data) | |||
| else: | |||
| @@ -79,15 +85,20 @@ class GraphSAGE(torch.nn.Module): | |||
| @register_model("sage") | |||
| class AutoSAGE(BaseModel): | |||
| def __init__( | |||
| self, num_features: int = 1, num_classes: int = 1, | |||
| device: _typing.Optional[torch.device] = torch.device("cpu"), | |||
| init: bool = False, **kwargs | |||
| self, | |||
| num_features: int = 1, | |||
| num_classes: int = 1, | |||
| device: _typing.Optional[torch.device] = torch.device("cpu"), | |||
| init: bool = False, | |||
| **kwargs | |||
| ): | |||
| super(AutoSAGE, self).__init__(init) | |||
| self.__num_features: int = num_features | |||
| self.__num_classes: int = num_classes | |||
| self.__device: torch.device = device if device is not None else torch.device("cpu") | |||
| self.__device: torch.device = ( | |||
| device if device is not None else torch.device("cpu") | |||
| ) | |||
| self.hyperparams = { | |||
| "num_layers": 3, | |||
| "hidden": [64, 32], | |||
| @@ -97,26 +108,27 @@ class AutoSAGE(BaseModel): | |||
| } | |||
| self.params = { | |||
| "num_features": self.__num_features, | |||
| "num_classes": self.__num_classes | |||
| "num_classes": self.__num_classes, | |||
| } | |||
| self._model: GraphSAGE = GraphSAGE( | |||
| self.__num_features, self.__num_classes, [64, 32], 0.5, "relu" | |||
| ) | |||
| self._initialized: bool = False | |||
| if init: | |||
| self.initialize() | |||
| @property | |||
| def model(self) -> GraphSAGE: | |||
| return self._model | |||
| def initialize(self): | |||
| """ Initialize model """ | |||
| if not self._initialized: | |||
| self._model: GraphSAGE = GraphSAGE( | |||
| self.__num_features, self.__num_classes, | |||
| self.__num_features, | |||
| self.__num_classes, | |||
| hidden_features=self.hyperparams["hidden"], | |||
| activation_name=self.hyperparams["act"], | |||
| **self.hyperparams | |||
| @@ -83,15 +83,14 @@ class EarlyStopping: | |||
| class BaseTrainer: | |||
| def __init__( | |||
| self, | |||
| model: BaseModel, | |||
| device: _typing.Union[torch.device, str], | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], | |||
| _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Acc,), | |||
| loss: str = "nll_loss", | |||
| self, | |||
| model: BaseModel, | |||
| device: _typing.Union[torch.device, str], | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Acc,), | |||
| loss: str = "nll_loss", | |||
| ): | |||
| """ | |||
| The basic trainer. | |||
| @@ -108,47 +107,50 @@ class BaseTrainer: | |||
| """ | |||
| super().__init__() | |||
| self.model: BaseModel = model | |||
| if ( | |||
| type(device) == torch.device or | |||
| (type(device) == str and device.lower() != "auto") | |||
| if type(device) == torch.device or ( | |||
| type(device) == str and device.lower() != "auto" | |||
| ): | |||
| self.__device: torch.device = torch.device(device) | |||
| else: | |||
| self.__device: torch.device = torch.device( | |||
| "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu" | |||
| "cuda" | |||
| if torch.cuda.is_available() and torch.cuda.device_count() > 0 | |||
| else "cpu" | |||
| ) | |||
| self.init: bool = init | |||
| self.__feval: _typing.Sequence[_typing.Type[Evaluation]] = get_feval(feval) | |||
| self.loss: str = loss | |||
| @property | |||
| def device(self) -> torch.device: | |||
| return self.__device | |||
| @device.setter | |||
| def device(self, __device: _typing.Union[torch.device, str]): | |||
| if ( | |||
| type(__device) == torch.device or | |||
| (type(__device) == str and __device.lower() != "auto") | |||
| if type(__device) == torch.device or ( | |||
| type(__device) == str and __device.lower() != "auto" | |||
| ): | |||
| self.__device: torch.device = torch.device(__device) | |||
| else: | |||
| self.__device: torch.device = torch.device( | |||
| "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu" | |||
| "cuda" | |||
| if torch.cuda.is_available() and torch.cuda.device_count() > 0 | |||
| else "cpu" | |||
| ) | |||
| @property | |||
| def feval(self) -> _typing.Sequence[_typing.Type[Evaluation]]: | |||
| return self.__feval | |||
| @feval.setter | |||
| def feval( | |||
| self, _feval: _typing.Union[ | |||
| _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] | |||
| self, | |||
| _feval: _typing.Union[ | |||
| _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ], | |||
| ): | |||
| self.__feval: _typing.Sequence[_typing.Type[Evaluation]] = get_feval(_feval) | |||
| def to(self, device: torch.device): | |||
| """ | |||
| Transfer the trainer to another device | |||
| @@ -168,7 +170,9 @@ class BaseTrainer: | |||
| """Get auto model used in trainer.""" | |||
| raise NotImplementedError() | |||
| def get_feval(self, return_major: bool = False) -> _typing.Union[ | |||
| def get_feval( | |||
| self, return_major: bool = False | |||
| ) -> _typing.Union[ | |||
| _typing.Type[Evaluation], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ]: | |||
| """ | |||
| @@ -212,7 +216,7 @@ class BaseTrainer: | |||
| pass | |||
| def duplicate_from_hyper_parameter( | |||
| self, hp, model: _typing.Optional[BaseModel] = ... | |||
| self, hp, model: _typing.Optional[BaseModel] = ... | |||
| ) -> "BaseTrainer": | |||
| """Create a new trainer with the given hyper parameter.""" | |||
| raise NotImplementedError() | |||
| @@ -322,30 +326,30 @@ class BaseTrainer: | |||
| class _BaseClassificationTrainer(BaseTrainer): | |||
| """ Base class of trainer for classification tasks """ | |||
| def __init__( | |||
| self, | |||
| model: _typing.Union[BaseModel, str], | |||
| num_features: int, | |||
| num_classes: int, | |||
| device: _typing.Union[torch.device, str, None] = "auto", | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], | |||
| _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Acc,), | |||
| loss: str = "nll_loss", | |||
| self, | |||
| model: _typing.Union[BaseModel, str], | |||
| num_features: int, | |||
| num_classes: int, | |||
| device: _typing.Union[torch.device, str, None] = "auto", | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Acc,), | |||
| loss: str = "nll_loss", | |||
| ): | |||
| self.num_features: int = num_features | |||
| self.num_classes: int = num_classes | |||
| if ( | |||
| type(device) == torch.device or | |||
| (type(device) == str and device.lower() != "auto") | |||
| if type(device) == torch.device or ( | |||
| type(device) == str and device.lower() != "auto" | |||
| ): | |||
| __device: torch.device = torch.device(device) | |||
| else: | |||
| __device: torch.device = torch.device( | |||
| "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu" | |||
| "cuda" | |||
| if torch.cuda.is_available() and torch.cuda.device_count() > 0 | |||
| else "cpu" | |||
| ) | |||
| if type(model) == str: | |||
| _model: BaseModel = ModelUniversalRegistry.get_model(model)( | |||
| @@ -357,22 +361,23 @@ class _BaseClassificationTrainer(BaseTrainer): | |||
| raise TypeError( | |||
| f"Model argument only support str or BaseModel, got ${model}." | |||
| ) | |||
| super(_BaseClassificationTrainer, self).__init__(_model, __device, init, feval, loss) | |||
| super(_BaseClassificationTrainer, self).__init__( | |||
| _model, __device, init, feval, loss | |||
| ) | |||
| class BaseNodeClassificationTrainer(_BaseClassificationTrainer): | |||
| def __init__( | |||
| self, | |||
| model: _typing.Union[BaseModel, str], | |||
| num_features: int, | |||
| num_classes: int, | |||
| device: _typing.Union[torch.device, str, None] = None, | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], | |||
| _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Acc,), | |||
| loss: str = "nll_loss", | |||
| self, | |||
| model: _typing.Union[BaseModel, str], | |||
| num_features: int, | |||
| num_classes: int, | |||
| device: _typing.Union[torch.device, str, None] = None, | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Acc,), | |||
| loss: str = "nll_loss", | |||
| ): | |||
| super(BaseNodeClassificationTrainer, self).__init__( | |||
| model, num_features, num_classes, device, init, feval, loss | |||
| @@ -381,18 +386,17 @@ class BaseNodeClassificationTrainer(_BaseClassificationTrainer): | |||
| class BaseGraphClassificationTrainer(_BaseClassificationTrainer): | |||
| def __init__( | |||
| self, | |||
| model: _typing.Union[BaseModel, str], | |||
| num_features: int, | |||
| num_classes: int, | |||
| num_graph_features: int = 0, | |||
| device: _typing.Union[torch.device, str, None] = None, | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], | |||
| _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Acc,), | |||
| loss: str = "nll_loss", | |||
| self, | |||
| model: _typing.Union[BaseModel, str], | |||
| num_features: int, | |||
| num_classes: int, | |||
| num_graph_features: int = 0, | |||
| device: _typing.Union[torch.device, str, None] = None, | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Acc,), | |||
| loss: str = "nll_loss", | |||
| ): | |||
| self.num_graph_features: int = num_graph_features | |||
| super(BaseGraphClassificationTrainer, self).__init__( | |||
| @@ -13,12 +13,12 @@ class Evaluation: | |||
| def get_eval_name() -> str: | |||
| """ Expected to return the name of this evaluation method """ | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def is_higher_better() -> bool: | |||
| """ Expected to return whether this evaluation method is higher better (bool) """ | |||
| return True | |||
| @staticmethod | |||
| def evaluate(predict, label) -> float: | |||
| """ Expected to return the evaluation result (float) """ | |||
| @@ -39,6 +39,7 @@ def register_evaluate(*name): | |||
| ) | |||
| EVALUATE_DICT[n] = cls | |||
| return cls | |||
| return register_evaluate_cls | |||
| @@ -54,22 +55,26 @@ def get_feval(feval): | |||
| class EvaluationUniversalRegistry: | |||
| @classmethod | |||
| def register_evaluation(cls, *names) -> _typing.Callable[ | |||
| [_typing.Type[Evaluation]], _typing.Type[Evaluation] | |||
| ]: | |||
| def register_evaluation( | |||
| cls, *names | |||
| ) -> _typing.Callable[[_typing.Type[Evaluation]], _typing.Type[Evaluation]]: | |||
| def _register_evaluation( | |||
| _class: _typing.Type[Evaluation] | |||
| _class: _typing.Type[Evaluation], | |||
| ) -> _typing.Type[Evaluation]: | |||
| for n in names: | |||
| if n in EVALUATE_DICT: | |||
| raise ValueError("Cannot register duplicate evaluator ({})".format(n)) | |||
| raise ValueError( | |||
| "Cannot register duplicate evaluator ({})".format(n) | |||
| ) | |||
| if not issubclass(_class, Evaluation): | |||
| raise ValueError( | |||
| "Evaluator ({}: {}) must extend Evaluation".format(n, cls.__name__) | |||
| "Evaluator ({}: {}) must extend Evaluation".format( | |||
| n, cls.__name__ | |||
| ) | |||
| ) | |||
| EVALUATE_DICT[n] = _class | |||
| return _class | |||
| return _register_evaluation | |||
| @@ -100,7 +100,7 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| self.batch_size = batch_size if batch_size is not None else 64 | |||
| self.num_workers = num_workers if num_workers is not None else 4 | |||
| if self.num_workers > 0: | |||
| mp.set_start_method('fork', force=True) | |||
| mp.set_start_method("fork", force=True) | |||
| self.early_stopping_round = ( | |||
| early_stopping_round if early_stopping_round is not None else 100 | |||
| ) | |||
| @@ -305,10 +305,10 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| """ | |||
| train_loader = utils.graph_get_split( | |||
| dataset, "train", batch_size=self.batch_size, num_workers = self.num_workers | |||
| dataset, "train", batch_size=self.batch_size, num_workers=self.num_workers | |||
| ) # DataLoader(dataset['train'], batch_size=self.batch_size) | |||
| valid_loader = utils.graph_get_split( | |||
| dataset, "val", batch_size=self.batch_size, num_workers = self.num_workers | |||
| dataset, "val", batch_size=self.batch_size, num_workers=self.num_workers | |||
| ) # DataLoader(dataset['val'], batch_size=self.batch_size) | |||
| self.train_only(train_loader, valid_loader) | |||
| if keep_valid_result and valid_loader: | |||
| @@ -332,7 +332,9 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| ------- | |||
| The prediction result of ``predict_proba``. | |||
| """ | |||
| loader = utils.graph_get_split(dataset, mask, batch_size=self.batch_size, num_workers = self.num_workers) | |||
| loader = utils.graph_get_split( | |||
| dataset, mask, batch_size=self.batch_size, num_workers=self.num_workers | |||
| ) | |||
| return self._predict_proba(loader, in_log_format=True).max(1)[1] | |||
| def predict_proba(self, dataset, mask="test", in_log_format=False): | |||
| @@ -353,7 +355,9 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| ------- | |||
| The prediction result. | |||
| """ | |||
| loader = utils.graph_get_split(dataset, mask, batch_size=self.batch_size, num_workers = self.num_workers) | |||
| loader = utils.graph_get_split( | |||
| dataset, mask, batch_size=self.batch_size, num_workers=self.num_workers | |||
| ) | |||
| return self._predict_proba(loader, in_log_format) | |||
| def _predict_proba(self, loader, in_log_format=False): | |||
| @@ -436,7 +440,9 @@ class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| res: The evaluation result on the given dataset. | |||
| """ | |||
| loader = utils.graph_get_split(dataset, mask, batch_size=self.batch_size, num_workers = self.num_workers) | |||
| loader = utils.graph_get_split( | |||
| dataset, mask, batch_size=self.batch_size, num_workers=self.num_workers | |||
| ) | |||
| return self._evaluate(loader, feval) | |||
| def _evaluate(self, loader, feval=None): | |||
| @@ -21,79 +21,90 @@ class NodeClassificationNeighborSamplingTrainer(BaseNodeClassificationTrainer): | |||
| for automatically training the node classification tasks | |||
| with neighbour sampling | |||
| """ | |||
| def __init__( | |||
| self, | |||
| model: _typing.Union[BaseModel, str], | |||
| num_features: int, | |||
| num_classes: int, | |||
| optimizer: _typing.Union[ | |||
| _typing.Type[torch.optim.Optimizer], str, None | |||
| ] = None, | |||
| lr: float = 1e-4, | |||
| max_epoch: int = 100, | |||
| early_stopping_round: int = 100, | |||
| weight_decay: float = 1e-4, | |||
| device: _typing.Optional[torch.device] = None, | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], | |||
| _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Logloss,), | |||
| loss: str = "nll_loss", | |||
| lr_scheduler_type: _typing.Optional[str] = None, | |||
| **kwargs | |||
| self, | |||
| model: _typing.Union[BaseModel, str], | |||
| num_features: int, | |||
| num_classes: int, | |||
| optimizer: _typing.Union[_typing.Type[torch.optim.Optimizer], str, None] = None, | |||
| lr: float = 1e-4, | |||
| max_epoch: int = 100, | |||
| early_stopping_round: int = 100, | |||
| weight_decay: float = 1e-4, | |||
| device: _typing.Optional[torch.device] = None, | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Logloss,), | |||
| loss: str = "nll_loss", | |||
| lr_scheduler_type: _typing.Optional[str] = None, | |||
| **kwargs, | |||
| ) -> None: | |||
| if isinstance(optimizer, type) and issubclass(optimizer, torch.optim.Optimizer): | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = optimizer | |||
| elif type(optimizer) == str: | |||
| if optimizer.lower() == "adam": | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = torch.optim.Adam | |||
| self._optimizer_class: _typing.Type[ | |||
| torch.optim.Optimizer | |||
| ] = torch.optim.Adam | |||
| elif optimizer.lower() == "adam" + "w": | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = torch.optim.AdamW | |||
| self._optimizer_class: _typing.Type[ | |||
| torch.optim.Optimizer | |||
| ] = torch.optim.AdamW | |||
| elif optimizer.lower() == "sgd": | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = torch.optim.SGD | |||
| self._optimizer_class: _typing.Type[ | |||
| torch.optim.Optimizer | |||
| ] = torch.optim.SGD | |||
| else: | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = torch.optim.Adam | |||
| self._optimizer_class: _typing.Type[ | |||
| torch.optim.Optimizer | |||
| ] = torch.optim.Adam | |||
| else: | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = torch.optim.Adam | |||
| self._optimizer_class: _typing.Type[ | |||
| torch.optim.Optimizer | |||
| ] = torch.optim.Adam | |||
| self._learning_rate: float = lr if lr > 0 else 1e-4 | |||
| self._lr_scheduler_type: _typing.Optional[str] = lr_scheduler_type | |||
| self._max_epoch: int = max_epoch if max_epoch > 0 else 1e2 | |||
| self.__sampling_sizes: _typing.Sequence[int] = kwargs.get("sampling_sizes") | |||
| self._weight_decay: float = weight_decay if weight_decay > 0 else 1e-4 | |||
| early_stopping_round: int = early_stopping_round if early_stopping_round > 0 else 1e2 | |||
| self._early_stopping = EarlyStopping(patience=early_stopping_round, verbose=False) | |||
| early_stopping_round: int = ( | |||
| early_stopping_round if early_stopping_round > 0 else 1e2 | |||
| ) | |||
| self._early_stopping = EarlyStopping( | |||
| patience=early_stopping_round, verbose=False | |||
| ) | |||
| super(NodeClassificationNeighborSamplingTrainer, self).__init__( | |||
| model, num_features, num_classes, device, init, feval, loss | |||
| ) | |||
| self._valid_result: torch.Tensor = torch.zeros(0) | |||
| self._valid_result_prob: torch.Tensor = torch.zeros(0) | |||
| self._valid_score: _typing.Sequence[float] = [] | |||
| self._hyper_parameter_space: _typing.Sequence[_typing.Dict[str, _typing.Any]] = [] | |||
| self._hyper_parameter_space: _typing.Sequence[ | |||
| _typing.Dict[str, _typing.Any] | |||
| ] = [] | |||
| self.__initialized: bool = False | |||
| if init: | |||
| self.initialize() | |||
| def initialize(self) -> "NodeClassificationNeighborSamplingTrainer": | |||
| if self.__initialized: | |||
| return self | |||
| self.model.initialize() | |||
| self.__initialized = True | |||
| return self | |||
| def get_model(self) -> BaseModel: | |||
| return self.model | |||
| def __train_only( | |||
| self, data | |||
| ) -> "NodeClassificationNeighborSamplingTrainer": | |||
| def __train_only(self, data) -> "NodeClassificationNeighborSamplingTrainer": | |||
| """ | |||
| The function of training on the given dataset and mask. | |||
| :param data: data of a specific graph | |||
| @@ -102,38 +113,41 @@ class NodeClassificationNeighborSamplingTrainer(BaseNodeClassificationTrainer): | |||
| data = data.to(self.device) | |||
| optimizer: torch.optim.Optimizer = self._optimizer_class( | |||
| self.model.model.parameters(), | |||
| lr=self._learning_rate, weight_decay=self._weight_decay | |||
| lr=self._learning_rate, | |||
| weight_decay=self._weight_decay, | |||
| ) | |||
| if type(self._lr_scheduler_type) == str: | |||
| if self._lr_scheduler_type.lower() == "step" + "lr": | |||
| lr_scheduler: torch.optim.lr_scheduler.StepLR = \ | |||
| torch.optim.lr_scheduler.StepLR( | |||
| optimizer, step_size=100, gamma=0.1 | |||
| ) | |||
| lr_scheduler: torch.optim.lr_scheduler.StepLR = ( | |||
| torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) | |||
| ) | |||
| elif self._lr_scheduler_type.lower() == "multi" + "step" + "lr": | |||
| lr_scheduler: torch.optim.lr_scheduler.MultiStepLR = \ | |||
| lr_scheduler: torch.optim.lr_scheduler.MultiStepLR = ( | |||
| torch.optim.lr_scheduler.MultiStepLR( | |||
| optimizer, milestones=[30, 80], gamma=0.1 | |||
| ) | |||
| ) | |||
| elif self._lr_scheduler_type.lower() == "exponential" + "lr": | |||
| lr_scheduler: torch.optim.lr_scheduler.ExponentialLR = \ | |||
| torch.optim.lr_scheduler.ExponentialLR( | |||
| optimizer, gamma=0.1 | |||
| ) | |||
| lr_scheduler: torch.optim.lr_scheduler.ExponentialLR = ( | |||
| torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1) | |||
| ) | |||
| elif self._lr_scheduler_type.lower() == "ReduceLROnPlateau".lower(): | |||
| lr_scheduler: torch.optim.lr_scheduler.ReduceLROnPlateau = \ | |||
| lr_scheduler: torch.optim.lr_scheduler.ReduceLROnPlateau = ( | |||
| torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") | |||
| ) | |||
| else: | |||
| lr_scheduler: torch.optim.lr_scheduler.LambdaLR = \ | |||
| lr_scheduler: torch.optim.lr_scheduler.LambdaLR = ( | |||
| torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0) | |||
| ) | |||
| else: | |||
| lr_scheduler: torch.optim.lr_scheduler.LambdaLR = \ | |||
| lr_scheduler: torch.optim.lr_scheduler.LambdaLR = ( | |||
| torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0) | |||
| ) | |||
| train_sampler: NeighborSampler = NeighborSampler( | |||
| data, self.__sampling_sizes, batch_size=20 | |||
| ) | |||
| for current_epoch in range(self._max_epoch): | |||
| self.model.model.train() | |||
| """ epoch start """ | |||
| @@ -147,20 +161,20 @@ class NodeClassificationNeighborSamplingTrainer(BaseNodeClassificationTrainer): | |||
| ) | |||
| loss_function = getattr(torch.nn.functional, self.loss) | |||
| loss: torch.Tensor = loss_function( | |||
| prediction[target_node_indexes], | |||
| data.y[target_node_indexes] | |||
| prediction[target_node_indexes], data.y[target_node_indexes] | |||
| ) | |||
| loss.backward() | |||
| optimizer.step() | |||
| if lr_scheduler is not None: | |||
| lr_scheduler.step() | |||
| """ Validate performance """ | |||
| if hasattr(data, "val_mask") and getattr(data, "val_mask") is not None: | |||
| validation_results: _typing.Sequence[float] = \ | |||
| self.evaluate((data,), "val", [self.feval[0]]) | |||
| validation_results: _typing.Sequence[float] = self.evaluate( | |||
| (data,), "val", [self.feval[0]] | |||
| ) | |||
| if self.feval[0].is_higher_better(): | |||
| validation_loss: float = -validation_results[0] | |||
| else: | |||
| @@ -172,7 +186,7 @@ class NodeClassificationNeighborSamplingTrainer(BaseNodeClassificationTrainer): | |||
| if hasattr(data, "val_mask") and data.val_mask is not None: | |||
| self._early_stopping.load_checkpoint(self.model.model) | |||
| return self | |||
| def __predict_only(self, data): | |||
| """ | |||
| The function of predicting on the given data. | |||
| @@ -184,7 +198,7 @@ class NodeClassificationNeighborSamplingTrainer(BaseNodeClassificationTrainer): | |||
| with torch.no_grad(): | |||
| prediction = self.model.model(data) | |||
| return prediction | |||
| def train(self, dataset, keep_valid_result: bool = True): | |||
| """ | |||
| The function of training on the given dataset and keeping valid result. | |||
| @@ -198,10 +212,9 @@ class NodeClassificationNeighborSamplingTrainer(BaseNodeClassificationTrainer): | |||
| self._valid_result: torch.Tensor = prediction[data.val_mask].max(1)[1] | |||
| self._valid_result_prob: torch.Tensor = prediction[data.val_mask] | |||
| self._valid_score = self.evaluate(dataset, "val") | |||
| def predict_proba( | |||
| self, dataset, mask: _typing.Optional[str] = None, | |||
| in_log_format: bool = False | |||
| self, dataset, mask: _typing.Optional[str] = None, in_log_format: bool = False | |||
| ) -> torch.Tensor: | |||
| """ | |||
| The function of predicting the probability on the given dataset. | |||
| @@ -224,29 +237,22 @@ class NodeClassificationNeighborSamplingTrainer(BaseNodeClassificationTrainer): | |||
| _mask = data.test_mask | |||
| result = self.__predict_only(data)[_mask] | |||
| return result if in_log_format else torch.exp(result) | |||
| def predict(self, dataset, mask: _typing.Optional[str] = None) -> torch.Tensor: | |||
| return self.predict_proba( | |||
| dataset, mask, in_log_format=True | |||
| ).max(1)[1] | |||
| return self.predict_proba(dataset, mask, in_log_format=True).max(1)[1] | |||
| def get_valid_predict(self) -> torch.Tensor: | |||
| return self._valid_result | |||
| def get_valid_predict_proba(self) -> torch.Tensor: | |||
| return self._valid_result_prob | |||
| def get_valid_score(self, return_major: bool = True): | |||
| if return_major: | |||
| return ( | |||
| self._valid_score[0], | |||
| self.feval[0].is_higher_better() | |||
| ) | |||
| return (self._valid_score[0], self.feval[0].is_higher_better()) | |||
| else: | |||
| return ( | |||
| self._valid_score, [f.is_higher_better() for f in self.feval] | |||
| ) | |||
| return (self._valid_score, [f.is_higher_better() for f in self.feval]) | |||
| def get_name_with_hp(self) -> str: | |||
| name = "-".join( | |||
| [ | |||
| @@ -259,25 +265,24 @@ class NodeClassificationNeighborSamplingTrainer(BaseNodeClassificationTrainer): | |||
| ] | |||
| ) | |||
| name = ( | |||
| name | |||
| + "|" | |||
| + "-".join( | |||
| [ | |||
| str(x[0]) + "-" + str(x[1]) | |||
| for x in self.model.get_hyper_parameter().items() | |||
| ] | |||
| ) | |||
| name | |||
| + "|" | |||
| + "-".join( | |||
| [ | |||
| str(x[0]) + "-" + str(x[1]) | |||
| for x in self.model.get_hyper_parameter().items() | |||
| ] | |||
| ) | |||
| ) | |||
| return name | |||
| def evaluate( | |||
| self, | |||
| dataset, | |||
| mask: _typing.Optional[str] = None, | |||
| feval: _typing.Union[ | |||
| None, _typing.Sequence[str], | |||
| _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = None | |||
| self, | |||
| dataset, | |||
| mask: _typing.Optional[str] = None, | |||
| feval: _typing.Union[ | |||
| None, _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = None, | |||
| ) -> _typing.Sequence[float]: | |||
| data = dataset[0] | |||
| data = data.to(self.device) | |||
| @@ -295,53 +300,60 @@ class NodeClassificationNeighborSamplingTrainer(BaseNodeClassificationTrainer): | |||
| _mask = data.test_mask | |||
| prediction_probability: torch.Tensor = self.predict_proba(dataset, mask) | |||
| y_ground_truth = data.y[_mask] | |||
| results = [] | |||
| for f in _feval: | |||
| try: | |||
| results.append( | |||
| f.evaluate(prediction_probability, y_ground_truth) | |||
| ) | |||
| results.append(f.evaluate(prediction_probability, y_ground_truth)) | |||
| except: | |||
| results.append( | |||
| f.evaluate(prediction_probability.cpu().numpy(), y_ground_truth.cpu().numpy()) | |||
| f.evaluate( | |||
| prediction_probability.cpu().numpy(), | |||
| y_ground_truth.cpu().numpy(), | |||
| ) | |||
| ) | |||
| return results | |||
| def to(self, device: torch.device): | |||
| self.device = device | |||
| if self.model is not None: | |||
| self.model.to(self.device) | |||
| def duplicate_from_hyper_parameter( | |||
| self, hp: _typing.Dict[str, _typing.Any], | |||
| model: _typing.Union[BaseModel, str, None] = None | |||
| self, | |||
| hp: _typing.Dict[str, _typing.Any], | |||
| model: _typing.Union[BaseModel, str, None] = None, | |||
| ) -> "NodeClassificationNeighborSamplingTrainer": | |||
| if model is None or not isinstance(model, BaseModel): | |||
| model = self.model | |||
| model = model.from_hyper_parameter( | |||
| dict( | |||
| [ | |||
| x for x in hp.items() | |||
| x | |||
| for x in hp.items() | |||
| if x[0] in [y["parameterName"] for y in model.hyper_parameter_space] | |||
| ] | |||
| ) | |||
| ) | |||
| return NodeClassificationNeighborSamplingTrainer( | |||
| model, self.num_features, self.num_classes, | |||
| model, | |||
| self.num_features, | |||
| self.num_classes, | |||
| self._optimizer_class, | |||
| device=self.device, init=True, | |||
| feval=self.feval, loss=self.loss, | |||
| device=self.device, | |||
| init=True, | |||
| feval=self.feval, | |||
| loss=self.loss, | |||
| lr_scheduler_type=self._lr_scheduler_type, | |||
| **hp | |||
| **hp, | |||
| ) | |||
| @property | |||
| def hyper_parameter_space(self): | |||
| return self._hyper_parameter_space | |||
| @hyper_parameter_space.setter | |||
| def hyper_parameter_space(self, hp_space): | |||
| self._hyper_parameter_space = hp_space | |||
| @@ -350,50 +362,63 @@ class NodeClassificationNeighborSamplingTrainer(BaseNodeClassificationTrainer): | |||
| @register_trainer("NodeClassificationGraphSAINTTrainer") | |||
| class NodeClassificationGraphSAINTTrainer(BaseNodeClassificationTrainer): | |||
| def __init__( | |||
| self, | |||
| model: _typing.Union[BaseModel], | |||
| num_features: int, | |||
| num_classes: int, | |||
| optimizer: _typing.Union[ | |||
| _typing.Type[torch.optim.Optimizer], str, None | |||
| ], | |||
| lr: float = 1e-4, | |||
| max_epoch: int = 100, | |||
| early_stopping_round: int = 100, | |||
| weight_decay: float = 1e-4, | |||
| device: _typing.Optional[torch.device] = None, | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], | |||
| _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Logloss,), | |||
| loss: str = "nll_loss", | |||
| lr_scheduler_type: _typing.Optional[str] = None, | |||
| **kwargs | |||
| self, | |||
| model: _typing.Union[BaseModel], | |||
| num_features: int, | |||
| num_classes: int, | |||
| optimizer: _typing.Union[_typing.Type[torch.optim.Optimizer], str, None], | |||
| lr: float = 1e-4, | |||
| max_epoch: int = 100, | |||
| early_stopping_round: int = 100, | |||
| weight_decay: float = 1e-4, | |||
| device: _typing.Optional[torch.device] = None, | |||
| init: bool = True, | |||
| feval: _typing.Union[ | |||
| _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = (Logloss,), | |||
| loss: str = "nll_loss", | |||
| lr_scheduler_type: _typing.Optional[str] = None, | |||
| **kwargs, | |||
| ) -> None: | |||
| if isinstance(optimizer, type) and issubclass(optimizer, torch.optim.Optimizer): | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = optimizer | |||
| elif type(optimizer) == str: | |||
| if optimizer.lower() == "adam": | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = torch.optim.Adam | |||
| self._optimizer_class: _typing.Type[ | |||
| torch.optim.Optimizer | |||
| ] = torch.optim.Adam | |||
| elif optimizer.lower() == "adam" + "w": | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = torch.optim.AdamW | |||
| self._optimizer_class: _typing.Type[ | |||
| torch.optim.Optimizer | |||
| ] = torch.optim.AdamW | |||
| elif optimizer.lower() == "sgd": | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = torch.optim.SGD | |||
| self._optimizer_class: _typing.Type[ | |||
| torch.optim.Optimizer | |||
| ] = torch.optim.SGD | |||
| else: | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = torch.optim.Adam | |||
| self._optimizer_class: _typing.Type[ | |||
| torch.optim.Optimizer | |||
| ] = torch.optim.Adam | |||
| else: | |||
| self._optimizer_class: _typing.Type[torch.optim.Optimizer] = torch.optim.Adam | |||
| self._optimizer_class: _typing.Type[ | |||
| torch.optim.Optimizer | |||
| ] = torch.optim.Adam | |||
| self._learning_rate: float = lr if lr > 0 else 1e-4 | |||
| self._lr_scheduler_type: _typing.Optional[str] = lr_scheduler_type | |||
| self._max_epoch: int = max_epoch if max_epoch > 0 else 1e2 | |||
| self._weight_decay: float = weight_decay if weight_decay > 0 else 1e-4 | |||
| early_stopping_round: int = early_stopping_round if early_stopping_round > 0 else 1e2 | |||
| self._early_stopping = EarlyStopping(patience=early_stopping_round, verbose=False) | |||
| early_stopping_round: int = ( | |||
| early_stopping_round if early_stopping_round > 0 else 1e2 | |||
| ) | |||
| self._early_stopping = EarlyStopping( | |||
| patience=early_stopping_round, verbose=False | |||
| ) | |||
| # Assign an empty initial hyper parameter space | |||
| self._hyper_parameter_space: _typing.Sequence[_typing.Dict[str, _typing.Any]] = [] | |||
| self._hyper_parameter_space: _typing.Sequence[ | |||
| _typing.Dict[str, _typing.Any] | |||
| ] = [] | |||
| self._valid_result: torch.Tensor = torch.zeros(0) | |||
| self._valid_result_prob: torch.Tensor = torch.zeros(0) | |||
| self._valid_score: _typing.Sequence[float] = () | |||
| @@ -401,7 +426,7 @@ class NodeClassificationGraphSAINTTrainer(BaseNodeClassificationTrainer): | |||
| super(NodeClassificationGraphSAINTTrainer, self).__init__( | |||
| model, num_features, num_classes, device, init, feval, loss | |||
| ) | |||
| """ Set hyper parameters """ | |||
| if "num_subgraphs" not in kwargs: | |||
| raise KeyError | |||
| @@ -427,23 +452,23 @@ class NodeClassificationGraphSAINTTrainer(BaseNodeClassificationTrainer): | |||
| self.__sampling_method_identifier: str = kwargs.get("sampling_method") | |||
| if self.__sampling_method_identifier.lower() not in ("node", "edge"): | |||
| self.__sampling_method_identifier: str = "node" | |||
| self.__is_initialized: bool = False | |||
| if init: | |||
| self.initialize() | |||
| def initialize(self): | |||
| if self.__is_initialized: | |||
| return self | |||
| self.model.initialize() | |||
| self.__is_initialized = True | |||
| return self | |||
| def to(self, device: torch.device): | |||
| self.device = device | |||
| if self.model is not None: | |||
| self.model.to(self.device) | |||
| def get_model(self): | |||
| return self.model | |||
| @@ -456,34 +481,37 @@ class NodeClassificationGraphSAINTTrainer(BaseNodeClassificationTrainer): | |||
| data = data.to(self.device) | |||
| optimizer: torch.optim.Optimizer = self._optimizer_class( | |||
| self.model.parameters(), | |||
| lr=self._learning_rate, weight_decay=self._weight_decay | |||
| lr=self._learning_rate, | |||
| weight_decay=self._weight_decay, | |||
| ) | |||
| if type(self._lr_scheduler_type) == str: | |||
| if self._lr_scheduler_type.lower() == "step" + "lr": | |||
| lr_scheduler: torch.optim.lr_scheduler.StepLR = \ | |||
| torch.optim.lr_scheduler.StepLR( | |||
| optimizer, step_size=100, gamma=0.1 | |||
| ) | |||
| lr_scheduler: torch.optim.lr_scheduler.StepLR = ( | |||
| torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1) | |||
| ) | |||
| elif self._lr_scheduler_type.lower() == "multi" + "step" + "lr": | |||
| lr_scheduler: torch.optim.lr_scheduler.MultiStepLR = \ | |||
| lr_scheduler: torch.optim.lr_scheduler.MultiStepLR = ( | |||
| torch.optim.lr_scheduler.MultiStepLR( | |||
| optimizer, milestones=[30, 80], gamma=0.1 | |||
| ) | |||
| ) | |||
| elif self._lr_scheduler_type.lower() == "exponential" + "lr": | |||
| lr_scheduler: torch.optim.lr_scheduler.ExponentialLR = \ | |||
| torch.optim.lr_scheduler.ExponentialLR( | |||
| optimizer, gamma=0.1 | |||
| ) | |||
| lr_scheduler: torch.optim.lr_scheduler.ExponentialLR = ( | |||
| torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1) | |||
| ) | |||
| elif self._lr_scheduler_type.lower() == "ReduceLROnPlateau".lower(): | |||
| lr_scheduler: torch.optim.lr_scheduler.ReduceLROnPlateau = \ | |||
| lr_scheduler: torch.optim.lr_scheduler.ReduceLROnPlateau = ( | |||
| torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") | |||
| ) | |||
| else: | |||
| lr_scheduler: torch.optim.lr_scheduler.LambdaLR = \ | |||
| lr_scheduler: torch.optim.lr_scheduler.LambdaLR = ( | |||
| torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0) | |||
| ) | |||
| else: | |||
| lr_scheduler: torch.optim.lr_scheduler.LambdaLR = \ | |||
| lr_scheduler: torch.optim.lr_scheduler.LambdaLR = ( | |||
| torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0) | |||
| ) | |||
| if self.__sampling_method_identifier.lower() == "edge": | |||
| sub_graph_sampler = GraphSAINTRandomEdgeSampler( | |||
| self.__sampling_budget, self.__num_subgraphs | |||
| @@ -492,60 +520,58 @@ class NodeClassificationGraphSAINTTrainer(BaseNodeClassificationTrainer): | |||
| sub_graph_sampler = GraphSAINTRandomNodeSampler( | |||
| self.__sampling_budget, self.__num_subgraphs | |||
| ) | |||
| for current_epoch in range(self._max_epoch): | |||
| self.model.model.train() | |||
| """ epoch start """ | |||
| """ Sample sub-graphs """ | |||
| sub_graph_set = sub_graph_sampler.sample(data) | |||
| sub_graphs_loader: torch.utils.data.DataLoader = \ | |||
| sub_graphs_loader: torch.utils.data.DataLoader = ( | |||
| torch.utils.data.DataLoader(sub_graph_set) | |||
| ) | |||
| integral_alpha: torch.Tensor = getattr(sub_graph_set, "alpha") | |||
| integral_lambda: torch.Tensor = getattr(sub_graph_set, "lambda") | |||
| """ iterate sub-graphs """ | |||
| for sub_graph_data in sub_graphs_loader: | |||
| optimizer.zero_grad() | |||
| sampled_edge_indexes: torch.Tensor = \ | |||
| sub_graph_data.sampled_edge_indexes | |||
| sampled_node_indexes: torch.Tensor = \ | |||
| sub_graph_data.sampled_node_indexes | |||
| sampled_train_mask: torch.Tensor = \ | |||
| sub_graph_data.train_mask | |||
| sampled_edge_indexes: torch.Tensor = sub_graph_data.sampled_edge_indexes | |||
| sampled_node_indexes: torch.Tensor = sub_graph_data.sampled_node_indexes | |||
| sampled_train_mask: torch.Tensor = sub_graph_data.train_mask | |||
| sampled_alpha = integral_alpha[sampled_edge_indexes] | |||
| sub_graph_data.edge_weight = 1 / sampled_alpha | |||
| prediction: torch.Tensor = self.model.model(sub_graph_data) | |||
| if not hasattr(torch.nn.functional, self.loss): | |||
| raise TypeError( | |||
| f"PyTorch does not support loss type {self.loss}" | |||
| ) | |||
| raise TypeError(f"PyTorch does not support loss type {self.loss}") | |||
| loss_func = getattr(torch.nn.functional, self.loss) | |||
| unreduced_loss: torch.Tensor = loss_func( | |||
| prediction[sampled_train_mask], | |||
| data.y[sampled_train_mask], | |||
| reduction="none" | |||
| reduction="none", | |||
| ) | |||
| sampled_lambda: torch.Tensor = integral_lambda[sampled_node_indexes] | |||
| sampled_train_lambda: torch.Tensor = sampled_lambda[sampled_train_mask] | |||
| assert unreduced_loss.size() == sampled_train_lambda.size() | |||
| loss_weighted_sum: torch.Tensor = \ | |||
| torch.sum(unreduced_loss / sampled_train_lambda) | |||
| loss_weighted_sum: torch.Tensor = torch.sum( | |||
| unreduced_loss / sampled_train_lambda | |||
| ) | |||
| loss_weighted_sum.backward() | |||
| optimizer.step() | |||
| if lr_scheduler is not None: | |||
| lr_scheduler.step() | |||
| """ Validate performance """ | |||
| if ( | |||
| hasattr(data, "val_mask") and | |||
| type(getattr(data, "val_mask")) == torch.Tensor | |||
| hasattr(data, "val_mask") | |||
| and type(getattr(data, "val_mask")) == torch.Tensor | |||
| ): | |||
| validation_results: _typing.Sequence[float] = \ | |||
| self.evaluate((data,), "val", [self.feval[0]]) | |||
| validation_results: _typing.Sequence[float] = self.evaluate( | |||
| (data,), "val", [self.feval[0]] | |||
| ) | |||
| if self.feval[0].is_higher_better(): | |||
| validation_loss: float = -validation_results[0] | |||
| else: | |||
| @@ -557,7 +583,7 @@ class NodeClassificationGraphSAINTTrainer(BaseNodeClassificationTrainer): | |||
| if hasattr(data, "val_mask") and data.val_mask is not None: | |||
| self._early_stopping.load_checkpoint(self.model.model) | |||
| return self | |||
| def __predict_only(self, data): | |||
| """ | |||
| The function of predicting on the given data. | |||
| @@ -569,10 +595,9 @@ class NodeClassificationGraphSAINTTrainer(BaseNodeClassificationTrainer): | |||
| with torch.no_grad(): | |||
| predicted_x: torch.Tensor = self.model.model(data) | |||
| return predicted_x | |||
| def predict_proba( | |||
| self, dataset, mask: _typing.Optional[str] = None, | |||
| in_log_format=False | |||
| self, dataset, mask: _typing.Optional[str] = None, in_log_format=False | |||
| ): | |||
| """ | |||
| The function of predicting the probability on the given dataset. | |||
| @@ -595,17 +620,17 @@ class NodeClassificationGraphSAINTTrainer(BaseNodeClassificationTrainer): | |||
| _mask: torch.Tensor = data.test_mask | |||
| result = self.__predict_only(data)[_mask] | |||
| return result if in_log_format else torch.exp(result) | |||
| def predict(self, dataset, mask: _typing.Optional[str] = None) -> torch.Tensor: | |||
| return self.predict_proba(dataset, mask, in_log_format=True).max(1)[1] | |||
| def evaluate( | |||
| self, dataset, | |||
| mask: _typing.Optional[str] = None, | |||
| feval: _typing.Union[ | |||
| None, _typing.Sequence[str], | |||
| _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = None | |||
| self, | |||
| dataset, | |||
| mask: _typing.Optional[str] = None, | |||
| feval: _typing.Union[ | |||
| None, _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | |||
| ] = None, | |||
| ) -> _typing.Sequence[float]: | |||
| data = dataset[0] | |||
| data = data.to(self.device) | |||
| @@ -624,24 +649,22 @@ class NodeClassificationGraphSAINTTrainer(BaseNodeClassificationTrainer): | |||
| _mask: torch.Tensor = data.test_mask | |||
| else: | |||
| _mask: torch.Tensor = data.test_mask | |||
| prediction_probability: torch.Tensor = \ | |||
| self.predict_proba(dataset, mask) | |||
| prediction_probability: torch.Tensor = self.predict_proba(dataset, mask) | |||
| y_ground_truth: torch.Tensor = data.y[_mask] | |||
| eval_results = [] | |||
| for f in _feval: | |||
| try: | |||
| eval_results.append( | |||
| f.evaluate(prediction_probability, y_ground_truth) | |||
| ) | |||
| eval_results.append(f.evaluate(prediction_probability, y_ground_truth)) | |||
| except: | |||
| eval_results.append( | |||
| f.evaluate( | |||
| prediction_probability.cpu().numpy(), y_ground_truth.cpu().numpy() | |||
| prediction_probability.cpu().numpy(), | |||
| y_ground_truth.cpu().numpy(), | |||
| ) | |||
| ) | |||
| return eval_results | |||
| def train(self, dataset, keep_valid_result: bool = True): | |||
| """ | |||
| The function of training on the given dataset and keeping valid result. | |||
| @@ -655,36 +678,36 @@ class NodeClassificationGraphSAINTTrainer(BaseNodeClassificationTrainer): | |||
| self._valid_result: torch.Tensor = prediction[data.val_mask].max(1)[1] | |||
| self._valid_result_prob: torch.Tensor = prediction[data.val_mask] | |||
| self._valid_score: _typing.Sequence[float] = self.evaluate(dataset, "val") | |||
| def get_valid_predict(self) -> torch.Tensor: | |||
| return self._valid_result | |||
| def get_valid_predict_proba(self) -> torch.Tensor: | |||
| return self._valid_result_prob | |||
| def get_valid_score(self, return_major: bool = True) -> _typing.Tuple[ | |||
| def get_valid_score( | |||
| self, return_major: bool = True | |||
| ) -> _typing.Tuple[ | |||
| _typing.Union[float, _typing.Sequence[float]], | |||
| _typing.Union[bool, _typing.Sequence[bool]] | |||
| _typing.Union[bool, _typing.Sequence[bool]], | |||
| ]: | |||
| if return_major: | |||
| return self._valid_score[0], self.feval[0].is_higher_better() | |||
| else: | |||
| return ( | |||
| self._valid_score, [f.is_higher_better() for f in self.feval] | |||
| ) | |||
| return (self._valid_score, [f.is_higher_better() for f in self.feval]) | |||
| @property | |||
| def hyper_parameter_space(self) -> _typing.Sequence[_typing.Dict[str, _typing.Any]]: | |||
| return self._hyper_parameter_space | |||
| @hyper_parameter_space.setter | |||
| def hyper_parameter_space( | |||
| self, hp_space: _typing.Sequence[_typing.Dict[str, _typing.Any]] | |||
| self, hp_space: _typing.Sequence[_typing.Dict[str, _typing.Any]] | |||
| ) -> None: | |||
| if not isinstance(hp_space, _typing.Sequence): | |||
| raise TypeError | |||
| self._hyper_parameter_space = hp_space | |||
| def get_name_with_hp(self) -> str: | |||
| name = "-".join( | |||
| [ | |||
| @@ -697,36 +720,42 @@ class NodeClassificationGraphSAINTTrainer(BaseNodeClassificationTrainer): | |||
| ] | |||
| ) | |||
| name = ( | |||
| name | |||
| + "|" | |||
| + "-".join( | |||
| [ | |||
| str(x[0]) + "-" + str(x[1]) | |||
| for x in self.model.get_hyper_parameter().items() | |||
| ] | |||
| ) | |||
| name | |||
| + "|" | |||
| + "-".join( | |||
| [ | |||
| str(x[0]) + "-" + str(x[1]) | |||
| for x in self.model.get_hyper_parameter().items() | |||
| ] | |||
| ) | |||
| ) | |||
| return name | |||
| def duplicate_from_hyper_parameter( | |||
| self, hp: _typing.Dict[str, _typing.Any], | |||
| model: _typing.Optional[BaseModel] = None | |||
| self, | |||
| hp: _typing.Dict[str, _typing.Any], | |||
| model: _typing.Optional[BaseModel] = None, | |||
| ) -> "NodeClassificationGraphSAINTTrainer": | |||
| if model is None or not isinstance(model, BaseModel): | |||
| model: BaseModel = self.model | |||
| model = model.from_hyper_parameter( | |||
| dict( | |||
| [ | |||
| x for x in hp.items() | |||
| x | |||
| for x in hp.items() | |||
| if x[0] in [y["parameterName"] for y in model.hyper_parameter_space] | |||
| ] | |||
| ) | |||
| ) | |||
| return NodeClassificationGraphSAINTTrainer( | |||
| model, self.num_features, self.num_classes, | |||
| model, | |||
| self.num_features, | |||
| self.num_classes, | |||
| self._optimizer_class, | |||
| device=self.device, init=True, | |||
| feval=self.feval, loss=self.loss, | |||
| device=self.device, | |||
| init=True, | |||
| feval=self.feval, | |||
| loss=self.loss, | |||
| lr_scheduler_type=self._lr_scheduler_type, | |||
| **hp | |||
| **hp, | |||
| ) | |||
| @@ -10,10 +10,10 @@ class _SubGraphSet(torch.utils.data.Dataset): | |||
| self.__remaining_args: _typing.Sequence[_typing.Any] = args | |||
| for key, value in kwargs.items(): | |||
| setattr(self, key, value) | |||
| def __len__(self) -> int: | |||
| return len(self.__graphs) | |||
| def __getitem__(self, index: int) -> _typing.Any: | |||
| if not 0 <= index < len(self.__graphs): | |||
| raise IndexError | |||
| @@ -22,8 +22,12 @@ class _SubGraphSet(torch.utils.data.Dataset): | |||
| class _GraphSAINTSubGraphSampler: | |||
| def __init__( | |||
| self, sampler_class: _typing.Type[torch_geometric.data.GraphSAINTSampler], | |||
| budget: int, num_graphs: int = 1, walk_length: int = 1, num_workers: int = 0 | |||
| self, | |||
| sampler_class: _typing.Type[torch_geometric.data.GraphSAINTSampler], | |||
| budget: int, | |||
| num_graphs: int = 1, | |||
| walk_length: int = 1, | |||
| num_workers: int = 0, | |||
| ): | |||
| """ | |||
| :param sampler_class: class of torch_geometric.data.GraphSAINTSampler | |||
| @@ -40,7 +44,7 @@ class _GraphSAINTSubGraphSampler: | |||
| self.__num_graphs: int = num_graphs | |||
| self.__walk_length: int = walk_length | |||
| self.__num_workers: int = num_workers if num_workers > 0 else 0 | |||
| def sample(self, _integral_data) -> _SubGraphSet: | |||
| """ | |||
| :param _integral_data: conventional data for an integral graph | |||
| @@ -49,18 +53,23 @@ class _GraphSAINTSubGraphSampler: | |||
| data = copy.copy(_integral_data) | |||
| data.sampled_node_indexes = torch.arange(data.num_nodes, dtype=torch.int64) | |||
| data.sampled_edge_indexes = torch.arange(data.num_edges, dtype=torch.int64) | |||
| if type(self.__sampler_class) == torch_geometric.data.GraphSAINTRandomWalkSampler: | |||
| _sampler: torch_geometric.data.GraphSAINTRandomWalkSampler = \ | |||
| if ( | |||
| type(self.__sampler_class) | |||
| == torch_geometric.data.GraphSAINTRandomWalkSampler | |||
| ): | |||
| _sampler: torch_geometric.data.GraphSAINTRandomWalkSampler = ( | |||
| torch_geometric.data.GraphSAINTRandomWalkSampler( | |||
| data, self.__budget, self.__walk_length, self.__num_graphs, | |||
| num_workers=self.__num_workers | |||
| data, | |||
| self.__budget, | |||
| self.__walk_length, | |||
| self.__num_graphs, | |||
| num_workers=self.__num_workers, | |||
| ) | |||
| ) | |||
| else: | |||
| _sampler: torch_geometric.data.GraphSAINTSampler = \ | |||
| self.__sampler_class( | |||
| data, self.__budget, self.__num_graphs, | |||
| num_workers=self.__num_workers | |||
| ) | |||
| _sampler: torch_geometric.data.GraphSAINTSampler = self.__sampler_class( | |||
| data, self.__budget, self.__num_graphs, num_workers=self.__num_workers | |||
| ) | |||
| """ Sample sub-graphs """ | |||
| datalist: list = [d for d in _sampler] | |||
| """ Compute the normalization """ | |||
| @@ -73,12 +82,16 @@ class _GraphSAINTSubGraphSampler: | |||
| [sub_graph.sampled_edge_indexes for sub_graph in datalist] | |||
| ) | |||
| for current_sampled_node_index in concatenated_sampled_nodes.unique(): | |||
| node_sampled_count[current_sampled_node_index] = \ | |||
| torch.where(concatenated_sampled_nodes == current_sampled_node_index)[0].size(0) | |||
| node_sampled_count[current_sampled_node_index] = torch.where( | |||
| concatenated_sampled_nodes == current_sampled_node_index | |||
| )[0].size(0) | |||
| for current_sampled_edge_index in concatenated_sampled_edges.unique(): | |||
| edge_sampled_count[current_sampled_edge_index] = \ | |||
| torch.where(concatenated_sampled_edges == current_sampled_edge_index)[0].size(0) | |||
| _alpha: torch.Tensor = edge_sampled_count / node_sampled_count[data.edge_index[1]] | |||
| edge_sampled_count[current_sampled_edge_index] = torch.where( | |||
| concatenated_sampled_edges == current_sampled_edge_index | |||
| )[0].size(0) | |||
| _alpha: torch.Tensor = ( | |||
| edge_sampled_count / node_sampled_count[data.edge_index[1]] | |||
| ) | |||
| _alpha[torch.isnan(_alpha) | torch.isinf(_alpha)] = 0 | |||
| _lambda: torch.Tensor = node_sampled_count / self.__num_graphs | |||
| return _SubGraphSet(datalist, **{"alpha": _alpha, "lambda": _lambda}) | |||
| @@ -101,5 +114,8 @@ class GraphSAINTRandomEdgeSampler(_GraphSAINTSubGraphSampler): | |||
| class GraphSAINTRandomWalkSampler(_GraphSAINTSubGraphSampler): | |||
| def __init__(self, edge_budget: int, num_graphs: int = 1, walk_length: int = 4): | |||
| super(GraphSAINTRandomWalkSampler, self).__init__( | |||
| torch_geometric.data.GraphSAINTRandomWalkSampler, edge_budget, num_graphs, walk_length | |||
| torch_geometric.data.GraphSAINTRandomWalkSampler, | |||
| edge_budget, | |||
| num_graphs, | |||
| walk_length, | |||
| ) | |||
| @@ -9,37 +9,41 @@ class NeighborSampler(torch.utils.data.DataLoader, collections.Iterable): | |||
| class _NodeIndexesDataset(torch.utils.data.Dataset): | |||
| def __init__(self, node_indexes): | |||
| self.__node_indexes: _typing.Sequence[int] = node_indexes | |||
| def __getitem__(self, index) -> int: | |||
| if not 0 <= index < len(self.__node_indexes): | |||
| raise IndexError("Index out of range") | |||
| else: | |||
| return self.__node_indexes[index] | |||
| def __len__(self) -> int: | |||
| return len(self.__node_indexes) | |||
| def __init__( | |||
| self, data, | |||
| sampling_sizes: _typing.Sequence[int], | |||
| target_node_indexes: _typing.Optional[_typing.Sequence[int]] = None, | |||
| batch_size: _typing.Optional[int] = 1, | |||
| *args, **kwargs | |||
| self, | |||
| data, | |||
| sampling_sizes: _typing.Sequence[int], | |||
| target_node_indexes: _typing.Optional[_typing.Sequence[int]] = None, | |||
| batch_size: _typing.Optional[int] = 1, | |||
| *args, | |||
| **kwargs | |||
| ): | |||
| self._data = data | |||
| self.__sampling_sizes: _typing.Sequence[int] = sampling_sizes | |||
| if not ( | |||
| target_node_indexes is not None and | |||
| isinstance(target_node_indexes, _typing.Sequence) | |||
| target_node_indexes is not None | |||
| and isinstance(target_node_indexes, _typing.Sequence) | |||
| ): | |||
| if hasattr(data, "train_mask"): | |||
| target_node_indexes: _typing.Sequence[int] = \ | |||
| torch.where(getattr(data, "train_mask"))[0] | |||
| target_node_indexes: _typing.Sequence[int] = torch.where( | |||
| getattr(data, "train_mask") | |||
| )[0] | |||
| else: | |||
| target_node_indexes: _typing.Sequence[int] = \ | |||
| list(np.arange(0, data.x.shape[0])) | |||
| target_node_indexes: _typing.Sequence[int] = list( | |||
| np.arange(0, data.x.shape[0]) | |||
| ) | |||
| self.__edge_index_map: _typing.Dict[ | |||
| int, _typing.Union[torch.Tensor, _typing.Sequence[int]] | |||
| ] = {} | |||
| @@ -47,9 +51,11 @@ class NeighborSampler(torch.utils.data.DataLoader, collections.Iterable): | |||
| super(NeighborSampler, self).__init__( | |||
| self._NodeIndexesDataset(target_node_indexes), | |||
| batch_size=batch_size if batch_size > 0 else 1, | |||
| collate_fn=self.__sample, *args, **kwargs | |||
| collate_fn=self.__sample, | |||
| *args, | |||
| **kwargs | |||
| ) | |||
| def __init_edge_index_map(self): | |||
| self.__edge_index_map.clear() | |||
| all_edge_index: torch.Tensor = getattr(self._data, "edge_index") | |||
| @@ -58,12 +64,12 @@ class NeighborSampler(torch.utils.data.DataLoader, collections.Iterable): | |||
| self.__edge_index_map[target_node_index] = torch.where( | |||
| all_edge_index[1] == target_node_index | |||
| )[0] | |||
| def __iter__(self): | |||
| return super(NeighborSampler, self).__iter__() | |||
| def __sample( | |||
| self, target_nodes_indexes: _typing.List[int] | |||
| self, target_nodes_indexes: _typing.List[int] | |||
| ) -> _typing.Tuple[torch.Tensor, _typing.List[torch.Tensor]]: | |||
| """ | |||
| Sample a sub-graph with neighborhood sampling | |||
| @@ -71,14 +77,15 @@ class NeighborSampler(torch.utils.data.DataLoader, collections.Iterable): | |||
| """ | |||
| original_edge_index: torch.Tensor = self._data.edge_index | |||
| edges_indexes: _typing.List[torch.Tensor] = [] | |||
| current_target_nodes_indexes: _typing.List[int] = target_nodes_indexes | |||
| for current_sampling_size in self.__sampling_sizes: | |||
| current_edge_index: _typing.Optional[torch.Tensor] = None | |||
| for current_target_node_index in current_target_nodes_indexes: | |||
| if current_target_node_index in self.__edge_index_map: | |||
| all_indexes: torch.Tensor = \ | |||
| self.__edge_index_map.get(current_target_node_index) | |||
| all_indexes: torch.Tensor = self.__edge_index_map.get( | |||
| current_target_node_index | |||
| ) | |||
| else: | |||
| all_indexes: torch.Tensor = torch.where( | |||
| original_edge_index[1] == current_target_node_index | |||
| @@ -89,25 +96,38 @@ class NeighborSampler(torch.utils.data.DataLoader, collections.Iterable): | |||
| ) | |||
| if current_edge_index is not None: | |||
| current_edge_index: torch.Tensor = torch.cat( | |||
| [current_edge_index, original_edge_index[:, sampled_indexes]], dim=1 | |||
| [ | |||
| current_edge_index, | |||
| original_edge_index[:, sampled_indexes], | |||
| ], | |||
| dim=1, | |||
| ) | |||
| else: | |||
| current_edge_index: torch.Tensor = original_edge_index[:, sampled_indexes] | |||
| current_edge_index: torch.Tensor = original_edge_index[ | |||
| :, sampled_indexes | |||
| ] | |||
| else: | |||
| all_indexes_list = all_indexes.tolist() | |||
| random.shuffle(all_indexes_list) | |||
| shuffled_indexes_list: _typing.List[int] = \ | |||
| all_indexes_list[0: current_sampling_size] | |||
| shuffled_indexes_list: _typing.List[int] = all_indexes_list[ | |||
| 0:current_sampling_size | |||
| ] | |||
| if current_edge_index is not None: | |||
| current_edge_index: torch.Tensor = torch.cat( | |||
| [current_edge_index, original_edge_index[:, shuffled_indexes_list]], dim=1 | |||
| [ | |||
| current_edge_index, | |||
| original_edge_index[:, shuffled_indexes_list], | |||
| ], | |||
| dim=1, | |||
| ) | |||
| else: | |||
| current_edge_index: torch.Tensor = original_edge_index[:, shuffled_indexes_list] | |||
| current_edge_index: torch.Tensor = original_edge_index[ | |||
| :, shuffled_indexes_list | |||
| ] | |||
| edges_indexes.append(current_edge_index) | |||
| if len(edges_indexes) < len(self.__sampling_sizes): | |||
| next_target_nodes_indexes: torch.Tensor = current_edge_index[0].unique() | |||
| current_target_nodes_indexes = next_target_nodes_indexes.tolist() | |||
| return torch.tensor(target_nodes_indexes), edges_indexes[::-1] | |||
| @@ -70,12 +70,10 @@ class AutoNodeClassifier(BaseClassifier): | |||
| Default ``auto``. | |||
| """ | |||
| # pylint: disable=W0102 | |||
| def __init__( | |||
| self, | |||
| feature_module=None, | |||
| graph_models=["gat", "gcn"], | |||
| graph_models=("gat", "gcn"), | |||
| hpo_module="anneal", | |||
| ensemble_module="voting", | |||
| max_evals=50, | |||
| @@ -27,7 +27,7 @@ if __name__ == "__main__": | |||
| choices=["mutag", "imdb-b", "imdb-m", "proteins", "collab"], | |||
| ) | |||
| parser.add_argument( | |||
| "--configs", default="../configs/graph_classification.yaml", help="config files" | |||
| "--configs", default="../configs/graphclf_full.yml", help="config files" | |||
| ) | |||
| parser.add_argument("--device", type=int, default=0, help="device to run on") | |||
| parser.add_argument("--seed", type=int, default=0, help="random seed") | |||