You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_decoder.py 1.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import torch
  2. import typing as _typing
  3. from ...hpo import AutoModule
  4. from ..encoders import base_encoder
  5. class BaseDecoderMaintainer(AutoModule):
  6. def _initialize(
  7. self, encoder: base_encoder.AutoHomogeneousEncoderMaintainer, *args, **kwargs
  8. ) -> _typing.Optional[bool]:
  9. """ Abstract initialization method to override """
  10. raise NotImplementedError
  11. def from_hyper_parameter_and_encoder(
  12. self, hp: _typing.Mapping[str, _typing.Any],
  13. encoder: base_encoder.BaseEncoderMaintainer
  14. ) -> 'BaseDecoderMaintainer':
  15. duplicate = self.__class__(
  16. self.output_dimension, self.device
  17. )
  18. new_hp = dict(self.hyper_parameters)
  19. new_hp.update(hp)
  20. duplicate.hyper_parameters = new_hp
  21. duplicate.initialize(encoder)
  22. return duplicate
  23. def __init__(
  24. self, output_dimension: _typing.Optional[int] = ...,
  25. device: _typing.Union[torch.device, str, int, None] = ...,
  26. *args, **kwargs
  27. ):
  28. super(BaseDecoderMaintainer, self).__init__(
  29. device, *args, **kwargs
  30. )
  31. self.output_dimension = output_dimension
  32. self._decoder: _typing.Optional[torch.nn.Module] = None
  33. @property
  34. def output_dimension(self):
  35. return self.__output_dimension
  36. @output_dimension.setter
  37. def output_dimension(self, output_dimension):
  38. self.__output_dimension = output_dimension
  39. @property
  40. def decoder(self) -> _typing.Optional[torch.nn.Module]:
  41. return self._decoder
  42. def to_device(self, device: _typing.Union[torch.device, str, int, None]):
  43. self.device = device
  44. if (
  45. self._decoder not in (Ellipsis, None) and
  46. isinstance(self._decoder, torch.nn.Module)
  47. ):
  48. self._decoder.to(self.device)