You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vae.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Variational auto-encoder (VAE)"""
  16. from mindspore.ops import composite as C
  17. from mindspore.ops import operations as P
  18. from mindspore._checkparam import Validator
  19. from ....cell import Cell
  20. from ....layer.basic import Dense
  21. class VAE(Cell):
  22. r"""
  23. Variational Auto-Encoder (VAE).
  24. The VAE defines a generative model, `Z` is sampled from the prior, then used to reconstruct `X` by a decoder.
  25. For more details, refer to `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
  26. Note:
  27. When the encoder and decoder are defined, the shape of the encoder's output tensor and decoder's input tensor
  28. must be :math:`(N, hidden\_size)`.
  29. The latent_size must be less than or equal to the hidden_size.
  30. Args:
  31. encoder(Cell): The Deep Neural Network (DNN) model defined as encoder.
  32. decoder(Cell): The DNN model defined as decoder.
  33. hidden_size(int): The size of encoder's output tensor.
  34. latent_size(int): The size of the latent space.
  35. Inputs:
  36. - **input** (Tensor) - The shape of input tensor is :math:`(N, C, H, W)`, which is the same as the input of
  37. encoder.
  38. Outputs:
  39. - **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
  40. """
  41. def __init__(self, encoder, decoder, hidden_size, latent_size):
  42. super(VAE, self).__init__()
  43. self.encoder = encoder
  44. self.decoder = decoder
  45. if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
  46. raise TypeError('The encoder and decoder should be Cell type.')
  47. self.hidden_size = Validator.check_positive_int(hidden_size)
  48. self.latent_size = Validator.check_positive_int(latent_size)
  49. if hidden_size < latent_size:
  50. raise ValueError('The latent_size should be less than or equal to the hidden_size.')
  51. self.normal = C.normal
  52. self.exp = P.Exp()
  53. self.reshape = P.Reshape()
  54. self.shape = P.Shape()
  55. self.to_tensor = P.ScalarToArray()
  56. self.dense1 = Dense(self.hidden_size, self.latent_size)
  57. self.dense2 = Dense(self.hidden_size, self.latent_size)
  58. self.dense3 = Dense(self.latent_size, self.hidden_size)
  59. def _encode(self, x):
  60. en_x = self.encoder(x)
  61. mu = self.dense1(en_x)
  62. log_var = self.dense2(en_x)
  63. return mu, log_var
  64. def _decode(self, z):
  65. z = self.dense3(z)
  66. recon_x = self.decoder(z)
  67. return recon_x
  68. def construct(self, x):
  69. mu, log_var = self._encode(x)
  70. std = self.exp(0.5 * log_var)
  71. z = self.normal(self.shape(mu), mu, std, seed=0)
  72. recon_x = self._decode(z)
  73. return recon_x, x, mu, std
  74. def generate_sample(self, generate_nums, shape):
  75. """
  76. Randomly sample from latent space to generate samples.
  77. Args:
  78. generate_nums (int): The number of samples to generate.
  79. shape(tuple): The shape of sample, it must be (generate_nums, C, H, W) or (-1, C, H, W).
  80. Returns:
  81. Tensor, the generated samples.
  82. """
  83. generate_nums = Validator.check_positive_int(generate_nums)
  84. if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
  85. raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
  86. sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
  87. sample = self._decode(sample_z)
  88. sample = self.reshape(sample, shape)
  89. return sample
  90. def reconstruct_sample(self, x):
  91. """
  92. Reconstruct samples from original data.
  93. Args:
  94. x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W).
  95. Returns:
  96. Tensor, the reconstructed sample.
  97. """
  98. mu, log_var = self._encode(x)
  99. std = self.exp(0.5 * log_var)
  100. z = self.normal(mu.shape, mu, std, seed=0)
  101. recon_x = self._decode(z)
  102. return recon_x