import torch import typing as _typing from ...hpo import AutoModule class BaseEncoderMaintainer(AutoModule): def __init__( self, device: _typing.Union[torch.device, str, int, None] = ..., *args, **kwargs ): super(BaseEncoderMaintainer, self).__init__( device, *args, **kwargs ) self._encoder: _typing.Optional[torch.nn.Module] = None @property def encoder(self) -> _typing.Optional[torch.nn.Module]: return self._encoder def to_device(self, device: _typing.Union[torch.device, str, int, None]): self.device = device if ( self._encoder not in (Ellipsis, None) and isinstance(self._encoder, torch.nn.Module) ): self._encoder.to(self.device) def from_hyper_parameter( self, hyper_parameter: _typing.Mapping[str, _typing.Any], **kwargs ): raise NotImplementedError def _initialize(self) -> _typing.Optional[bool]: raise NotImplementedError class AutoHomogeneousEncoderMaintainer(BaseEncoderMaintainer): def _initialize(self) -> _typing.Optional[bool]: raise NotImplementedError def __init__( self, input_dimension: _typing.Optional[int] = ..., final_dimension: _typing.Optional[int] = ..., device: _typing.Union[torch.device, str, int, None] = ..., *args, **kwargs ): self._input_dimension: _typing.Optional[int] = input_dimension self._final_dimension: _typing.Optional[int] = final_dimension super(AutoHomogeneousEncoderMaintainer, self).__init__( device, *args, **kwargs ) self.__args: _typing.Tuple[_typing.Any, ...] = args self.__kwargs: _typing.Mapping[str, _typing.Any] = kwargs @property def input_dimension(self) -> _typing.Optional[int]: return self._input_dimension @input_dimension.setter def input_dimension(self, input_dimension): self._input_dimension = input_dimension @property def final_dimension(self): return self._final_dimension @final_dimension.setter def final_dimension(self, final_dimension): # TODO: may mutate search space according to the final dimension self._final_dimension = final_dimension def from_hyper_parameter( self, hyper_parameter: _typing.Mapping[str, _typing.Any], **kwargs ): new_kwargs = dict(self.__kwargs) new_kwargs.update(kwargs) duplicate: AutoHomogeneousEncoderMaintainer = self.__class__( self.input_dimension, self.final_dimension, self.device, **new_kwargs ) hp = dict(self.hyper_parameters) hp.update(hyper_parameter) duplicate.hyper_parameters = hp duplicate.initialize() return duplicate def get_output_dimensions(self) -> _typing.Iterable[int]: """""" ''' Note that this is a default implicit assumption ''' _output_dimensions = [self._input_dimension] _output_dimensions.extend(self.hyper_parameters["hidden"]) if ( isinstance(self.final_dimension, int) and self.final_dimension > 0 ): _output_dimensions.append(self.final_dimension) return _output_dimensions