You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vae.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Variational auto-encoder (VAE)"""
  16. from mindspore.ops import composite as C
  17. from mindspore.ops import operations as P
  18. from mindspore._checkparam import Validator
  19. from ....cell import Cell
  20. from ....layer.basic import Dense
  21. class VAE(Cell):
  22. r"""
  23. Variational Auto-Encoder (VAE).
  24. The VAE defines a generative model, `Z` is sampled from the prior, then used to reconstruct `X` by a decoder.
  25. For more details, refer to `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
  26. Note:
  27. When the encoder and decoder are defined, the shape of the encoder's output tensor and decoder's input tensor
  28. must be :math:`(N, hidden\_size)`.
  29. The latent_size must be less than or equal to the hidden_size.
  30. Args:
  31. encoder(Cell): The Deep Neural Network (DNN) model defined as encoder.
  32. decoder(Cell): The DNN model defined as decoder.
  33. hidden_size(int): The size of encoder's output tensor.
  34. latent_size(int): The size of the latent space.
  35. Inputs:
  36. - **input** (Tensor) - The shape of input tensor is :math:`(N, C, H, W)`, which is the same as the input of
  37. encoder.
  38. Outputs:
  39. - **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
  40. Supported Platforms:
  41. ``Ascend`` ``GPU``
  42. """
  43. def __init__(self, encoder, decoder, hidden_size, latent_size):
  44. super(VAE, self).__init__()
  45. self.encoder = encoder
  46. self.decoder = decoder
  47. if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
  48. raise TypeError('The encoder and decoder should be Cell type.')
  49. self.hidden_size = Validator.check_positive_int(hidden_size)
  50. self.latent_size = Validator.check_positive_int(latent_size)
  51. if hidden_size < latent_size:
  52. raise ValueError('The latent_size should be less than or equal to the hidden_size.')
  53. self.normal = C.normal
  54. self.exp = P.Exp()
  55. self.reshape = P.Reshape()
  56. self.shape = P.Shape()
  57. self.to_tensor = P.ScalarToArray()
  58. self.dense1 = Dense(self.hidden_size, self.latent_size)
  59. self.dense2 = Dense(self.hidden_size, self.latent_size)
  60. self.dense3 = Dense(self.latent_size, self.hidden_size)
  61. def _encode(self, x):
  62. en_x = self.encoder(x)
  63. mu = self.dense1(en_x)
  64. log_var = self.dense2(en_x)
  65. return mu, log_var
  66. def _decode(self, z):
  67. z = self.dense3(z)
  68. recon_x = self.decoder(z)
  69. return recon_x
  70. def construct(self, x):
  71. mu, log_var = self._encode(x)
  72. std = self.exp(0.5 * log_var)
  73. z = self.normal(self.shape(mu), mu, std, seed=0)
  74. recon_x = self._decode(z)
  75. return recon_x, x, mu, std
  76. def generate_sample(self, generate_nums, shape):
  77. """
  78. Randomly sample from latent space to generate samples.
  79. Args:
  80. generate_nums (int): The number of samples to generate.
  81. shape(tuple): The shape of sample, it must be (generate_nums, C, H, W) or (-1, C, H, W).
  82. Returns:
  83. Tensor, the generated samples.
  84. """
  85. generate_nums = Validator.check_positive_int(generate_nums)
  86. if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
  87. raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
  88. sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
  89. sample = self._decode(sample_z)
  90. sample = self.reshape(sample, shape)
  91. return sample
  92. def reconstruct_sample(self, x):
  93. """
  94. Reconstruct samples from original data.
  95. Args:
  96. x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W).
  97. Returns:
  98. Tensor, the reconstructed sample.
  99. """
  100. mu, log_var = self._encode(x)
  101. std = self.exp(0.5 * log_var)
  102. z = self.normal(mu.shape, mu, std, seed=0)
  103. recon_x = self._decode(z)
  104. return recon_x