You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cvae.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Conditional Variational auto-encoder (CVAE)."""
  16. from mindspore.ops import composite as C
  17. from mindspore.ops import operations as P
  18. from mindspore._checkparam import Validator
  19. from ....cell import Cell
  20. from ....layer.basic import Dense, OneHot
  21. class ConditionalVAE(Cell):
  22. r"""
  23. Conditional Variational Auto-Encoder (CVAE).
  24. The difference with VAE is that CVAE uses labels information.
  25. For more details, refer to `Learning Structured Output Representation using Deep Conditional Generative Models
  26. <http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-
  27. generative-models>`_.
  28. Note:
  29. When encoder and decoder ard defined, the shape of the encoder's output tensor and decoder's input tensor
  30. must be :math:`(N, hidden\_size)`.
  31. The latent_size must be less than or equal to the hidden_size.
  32. Args:
  33. encoder(Cell): The Deep Neural Network (DNN) model defined as encoder.
  34. decoder(Cell): The DNN model defined as decoder.
  35. hidden_size(int): The size of encoder's output tensor.
  36. latent_size(int): The size of the latent space.
  37. num_classes(int): The number of classes.
  38. Inputs:
  39. - **input_x** (Tensor) - The shape of input tensor is :math:`(N, C, H, W)`, which is the same as the input of
  40. encoder.
  41. - **input_y** (Tensor) - The tensor of the target data, the shape is :math:`(N,)`.
  42. Outputs:
  43. - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)).
  44. """
  45. def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes):
  46. super(ConditionalVAE, self).__init__()
  47. self.encoder = encoder
  48. self.decoder = decoder
  49. if (not isinstance(encoder, Cell)) or (not isinstance(decoder, Cell)):
  50. raise TypeError('The encoder and decoder should be Cell type.')
  51. self.hidden_size = Validator.check_positive_int(hidden_size)
  52. self.latent_size = Validator.check_positive_int(latent_size)
  53. if hidden_size < latent_size:
  54. raise ValueError('The latent_size should be less than or equal to the hidden_size.')
  55. self.num_classes = Validator.check_positive_int(num_classes)
  56. self.normal = C.normal
  57. self.exp = P.Exp()
  58. self.reshape = P.Reshape()
  59. self.shape = P.Shape()
  60. self.concat = P.Concat(axis=1)
  61. self.to_tensor = P.ScalarToArray()
  62. self.one_hot = OneHot(depth=num_classes)
  63. self.dense1 = Dense(self.hidden_size, self.latent_size)
  64. self.dense2 = Dense(self.hidden_size, self.latent_size)
  65. self.dense3 = Dense(self.latent_size + self.num_classes, self.hidden_size)
  66. def _encode(self, x, y):
  67. en_x = self.encoder(x, y)
  68. mu = self.dense1(en_x)
  69. log_var = self.dense2(en_x)
  70. return mu, log_var
  71. def _decode(self, z):
  72. z = self.dense3(z)
  73. recon_x = self.decoder(z)
  74. return recon_x
  75. def construct(self, x, y):
  76. """
  77. The input are x and y, so the WithLossCell method needs to be rewritten when using cvae interface.
  78. """
  79. mu, log_var = self._encode(x, y)
  80. std = self.exp(0.5 * log_var)
  81. z = self.normal(self.shape(mu), mu, std, seed=0)
  82. y = self.one_hot(y)
  83. z_c = self.concat((z, y))
  84. recon_x = self._decode(z_c)
  85. return recon_x, x, mu, std
  86. def generate_sample(self, sample_y, generate_nums, shape):
  87. """
  88. Randomly sample from the latent space to generate samples.
  89. Args:
  90. sample_y (Tensor): Define the label of samples. Tensor of shape (generate_nums, ) and type mindspore.int32.
  91. generate_nums (int): The number of samples to generate.
  92. shape(tuple): The shape of sample, which must be the format of (generate_nums, C, H, W) or (-1, C, H, W).
  93. Returns:
  94. Tensor, the generated samples.
  95. """
  96. generate_nums = Validator.check_positive_int(generate_nums)
  97. if not isinstance(shape, tuple) or len(shape) != 4 or (shape[0] != -1 and shape[0] != generate_nums):
  98. raise ValueError('The shape should be (generate_nums, C, H, W) or (-1, C, H, W).')
  99. sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0)
  100. sample_y = self.one_hot(sample_y)
  101. sample_c = self.concat((sample_z, sample_y))
  102. sample = self._decode(sample_c)
  103. sample = self.reshape(sample, shape)
  104. return sample
  105. def reconstruct_sample(self, x, y):
  106. """
  107. Reconstruct samples from original data.
  108. Args:
  109. x (Tensor): The input tensor to be reconstructed, the shape is (N, C, H, W).
  110. y (Tensor): The label of the input tensor, the shape is (N,).
  111. Returns:
  112. Tensor, the reconstructed sample.
  113. """
  114. mu, log_var = self._encode(x, y)
  115. std = self.exp(0.5 * log_var)
  116. z = self.normal(mu.shape, mu, std, seed=0)
  117. y = self.one_hot(y)
  118. z_c = self.concat((z, y))
  119. recon_x = self._decode(z_c)
  120. return recon_x