""" auto graph model a list of models with their hyper parameters NOTE: neural architecture search (NAS) maybe included here """ import copy import logging import typing as _typing import torch import torch.nn.functional as F from copy import deepcopy from autogl.module.hpo.auto_module import AutoModule base_approach_logger: logging.Logger = logging.getLogger("BaseModel") def activate_func(x, func): if func == "tanh": return torch.tanh(x) elif hasattr(F, func): return getattr(F, func)(x) elif func == "": pass else: raise TypeError("PyTorch does not support activation function {}".format(func)) return x class BaseAutoModel(AutoModule): def __init__(self, input_dimension, output_dimension, device, **kwargs): self._input_dimension = input_dimension self._output_dimension = output_dimension self._model = None self._kwargs = kwargs super(BaseAutoModel, self).__init__(device) def to(self, device): # compatible with v0.2 self.to_device(device) def to_device(self, device): self.device = device if self.model is not None: self.model.to(self.device) @property def model(self): return self._model def from_hyper_parameter(self, hp, **kwargs): kw = deepcopy(self._kwargs) kw.update(kwargs) ret_self = self.__class__( self.input_dimension, self.output_dimension, self.device, **kw ) hp_now = dict(self.hyper_parameters) hp_now.update(hp) ret_self.hyper_parameters = hp_now ret_self.initialize() return ret_self @property def input_dimension(self): return self._input_dimension @input_dimension.setter def input_dimension(self, input_dimension): self._input_dimension = input_dimension @property def output_dimension(self): return self._output_dimension @output_dimension.setter def output_dimension(self, output_dimension): self._output_dimension = output_dimension class _BaseBaseModel: # todo: after renaming the experimental base class _BaseModel to BaseModel, # rename this class to _BaseModel """ The base class for class BaseModel, designed to implement some basic functionality of BaseModel. -- Designed by ZiXin Sun """ @classmethod def __formulate_device( cls, device: _typing.Union[str, torch.device] = ... ) -> torch.device: if type(device) == torch.device or ( type(device) == str and device.strip().lower() != "auto" ): return torch.device(device) elif torch.cuda.is_available() and torch.cuda.device_count() > 0: return torch.device("cuda") else: return torch.device("cpu") @property def device(self) -> torch.device: return self.__device @device.setter def device(self, __device: _typing.Union[str, torch.device, None]): self.__device: torch.device = self.__formulate_device(__device) @property def model(self) -> _typing.Optional[torch.nn.Module]: if self._model is None: base_approach_logger.debug( "property of model NOT initialized before accessing" ) return self._model @model.setter def model(self, _model: torch.nn.Module) -> None: if not isinstance(_model, torch.nn.Module): raise TypeError( "the property of model MUST be an instance of " "torch.nn.Module" ) self._model = _model def _initialize(self): raise NotImplementedError def initialize(self) -> bool: """ Initialize the model in case that the model has NOT been initialized :return: whether self._initialize() method called """ if not self.__is_initialized: self._initialize() self.__is_initialized = True return True return False # def to(self, *args, **kwargs): # """ # Due to the signature of to() method in class BaseApproach # is inconsistent with the signature of the method # in the base class torch.nn.Module, # this intermediate overridden method is necessary to # walk around (bypass) the inspection for # signature of overriding method. # :param args: positional arguments list # :param kwargs: keyword arguments dict # :return: self # """ # return super(_BaseBaseModel, self).to(*args, **kwargs) def forward(self, *args, **kwargs): if self.model is not None and isinstance(self.model, torch.nn.Module): return self.model(*args, **kwargs) else: raise NotImplementedError def __init__( self, model: _typing.Optional[torch.nn.Module] = None, initialize: bool = False, device: _typing.Union[str, torch.device] = ..., ): if type(initialize) != bool: raise TypeError super(_BaseBaseModel, self).__init__() self.__device: torch.device = self.__formulate_device(device) self._model: _typing.Optional[torch.nn.Module] = model self.__is_initialized: bool = False if initialize: self.initialize() class _BaseModel(_BaseBaseModel, BaseAutoModel): """ The upcoming root base class for Model, i.e. BaseModel -- Designed by ZiXin Sun """ # todo: Deprecate and remove the legacy class "BaseModel", # then rename this class to "BaseModel", # correspondingly, this class will no longer extend # the legacy class "BaseModel" after the removal. def _initialize(self): raise NotImplementedError def to(self, device: torch.device): self.device = device if self.model is not None and isinstance(self.model, torch.nn.Module): self.model.to(self.device) return super().to(device) @property def space(self) -> _typing.Sequence[_typing.Dict[str, _typing.Any]]: # todo: deprecate and remove in future major version return self.__hyper_parameter_space @property def hyper_parameter_space(self): return self.__hyper_parameter_space @hyper_parameter_space.setter def hyper_parameter_space( self, space: _typing.Sequence[_typing.Dict[str, _typing.Any]] ): self.__hyper_parameter_space = space @property def hyper_parameter(self) -> _typing.Dict[str, _typing.Any]: return self.__hyper_parameter @hyper_parameter.setter def hyper_parameter(self, _hyper_parameter: _typing.Dict[str, _typing.Any]): if not isinstance(_hyper_parameter, dict): raise TypeError self.__hyper_parameter = _hyper_parameter def get_hyper_parameter(self) -> _typing.Dict[str, _typing.Any]: """ todo: consider deprecating this trivial getter method in the future :return: copied hyper parameter """ return copy.deepcopy(self.__hyper_parameter) def __init__( self, model: _typing.Optional[torch.nn.Module] = None, initialize: bool = False, hyper_parameter_space: _typing.Sequence[_typing.Any] = ..., hyper_parameter: _typing.Dict[str, _typing.Any] = ..., device: _typing.Union[str, torch.device] = ..., ): if type(initialize) != bool: raise TypeError super(_BaseModel, self).__init__(model, initialize, device) if hyper_parameter_space != Ellipsis and isinstance( hyper_parameter_space, _typing.Sequence ): self.__hyper_parameter_space: _typing.Sequence[ _typing.Dict[str, _typing.Any] ] = hyper_parameter_space else: self.__hyper_parameter_space: _typing.Sequence[ _typing.Dict[str, _typing.Any] ] = [] if hyper_parameter != Ellipsis and isinstance(hyper_parameter, dict): self.__hyper_parameter: _typing.Dict[str, _typing.Any] = hyper_parameter else: self.__hyper_parameter: _typing.Dict[str, _typing.Any] = {} def from_hyper_parameter(self, hyper_parameter: _typing.Dict[str, _typing.Any]): raise NotImplementedError class ClassificationModel(_BaseModel): def _initialize(self): raise NotImplementedError def from_hyper_parameter( self, hyper_parameter: _typing.Dict[str, _typing.Any] ) -> "ClassificationModel": new_model: ClassificationModel = self.__class__( num_features=self.num_features, num_classes=self.num_classes, device=self.device, init=False, ) _hyper_parameter = self.hyper_parameter _hyper_parameter.update(hyper_parameter) new_model.hyper_parameter = _hyper_parameter new_model.initialize() return new_model def __init__( self, num_features: int = ..., num_classes: int = ..., num_graph_features: int = ..., device: _typing.Union[str, torch.device] = ..., hyper_parameter_space: _typing.Sequence[_typing.Any] = ..., hyper_parameter: _typing.Dict[str, _typing.Any] = ..., init: bool = False, **kwargs ): if "initialize" in kwargs: del kwargs["initialize"] super(ClassificationModel, self).__init__( initialize=init, hyper_parameter_space=hyper_parameter_space, hyper_parameter=hyper_parameter, device=device, **kwargs ) if num_classes != Ellipsis and type(num_classes) == int: self.__num_classes: int = num_classes if num_classes > 0 else 0 else: self.__num_classes: int = 0 if num_features != Ellipsis and type(num_features) == int: self.__num_features: int = num_features if num_features > 0 else 0 else: self.__num_features: int = 0 if num_graph_features != Ellipsis and type(num_graph_features) == int: if num_graph_features > 0: self.__num_graph_features: int = num_graph_features else: self.__num_graph_features: int = 0 else: self.__num_graph_features: int = 0 def __repr__(self) -> str: import yaml return yaml.dump(self.hyper_parameter) @property def num_classes(self) -> int: return self.__num_classes @num_classes.setter def num_classes(self, __num_classes: int): if type(__num_classes) != int: raise TypeError if not __num_classes > 0: raise ValueError self.__num_classes = __num_classes if __num_classes > 0 else 0 @property def num_features(self) -> int: return self.__num_features @num_features.setter def num_features(self, __num_features: int): if type(__num_features) != int: raise TypeError if not __num_features > 0: raise ValueError self.__num_features = __num_features if __num_features > 0 else 0 def get_num_classes(self) -> int: # todo: consider replacing with property with getter and setter return self.__num_classes def set_num_classes(self, num_classes: int) -> None: # todo: consider replacing with property with getter and setter if type(num_classes) != int: raise TypeError self.__num_classes = num_classes if num_classes > 0 else 0 def get_num_features(self) -> int: # todo: consider replacing with property with getter and setter return self.__num_features def set_num_features(self, num_features: int): # todo: consider replacing with property with getter and setter if type(num_features) != int: raise TypeError self.__num_features = num_features if num_features > 0 else 0 def set_num_graph_features(self, num_graph_features: int): # todo: consider replacing with property with getter and setter if type(num_graph_features) != int: raise TypeError else: if num_graph_features > 0: self.__num_graph_features = num_graph_features else: self.__num_graph_features = 0 class _ClassificationModel(torch.nn.Module): def __init__(self): super(_ClassificationModel, self).__init__() def cls_encode(self, data) -> torch.Tensor: raise NotImplementedError def cls_decode(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError def cls_forward(self, data) -> torch.Tensor: return self.cls_decode(self.cls_encode(data)) class ClassificationSupportedSequentialModel(_ClassificationModel): def __init__(self): super(ClassificationSupportedSequentialModel, self).__init__() @property def sequential_encoding_layers(self) -> torch.nn.ModuleList: raise NotImplementedError def cls_encode(self, data) -> torch.Tensor: raise NotImplementedError def cls_decode(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError