You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vae_mnist.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ VAE """
  16. import os
  17. import numpy as np
  18. from utils import create_dataset, save_img
  19. import mindspore.nn as nn
  20. from mindspore import context
  21. from mindspore import Tensor
  22. from mindspore.train import Model
  23. from mindspore.train.callback import LossMonitor
  24. from mindspore.ops import operations as P
  25. from mindspore.common import dtype as mstype
  26. import zhusuan as zs
  27. class ReduceMeanLoss(nn.L1Loss):
  28. def construct(self, base, target):
  29. # return self.get_loss(x)
  30. return base
  31. class Generator(zs.BayesianNet):
  32. """ Generator """
  33. def __init__(self, x_dim, z_dim, batch_size):
  34. super().__init__()
  35. self.x_dim = x_dim
  36. self.z_dim = z_dim
  37. self.batch_size = batch_size
  38. self.fc1 = nn.Dense(z_dim, 500)
  39. self.act1 = nn.ReLU()
  40. self.fc2 = nn.Dense(500, 500)
  41. self.act2 = nn.ReLU()
  42. self.fc3 = nn.Dense(500, x_dim)
  43. self.fill = P.Fill()
  44. self.sigmoid = P.Sigmoid()
  45. self.reshape_op = P.Reshape()
  46. def ones(self, shape):
  47. return self.fill(mstype.float32, shape, 1.)
  48. def zeros(self, shape):
  49. return self.fill(mstype.float32, shape, 0.)
  50. def construct(self, x, z, y):
  51. """ construct """
  52. assert y is None ## we have no conditional information
  53. if not x is None:
  54. x = self.reshape_op(x, (32, 32*32))
  55. z_mean = self.zeros((self.batch_size, self.z_dim))
  56. z_std = self.ones((self.batch_size, self.z_dim))
  57. z, log_prob_z = self.Normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=False)
  58. x_mean = self.sigmoid(self.fc3(self.act2(self.fc2(self.act1(self.fc1(z))))))
  59. if x is None:
  60. #x = self.bernoulli_dist('sample', (), x_mean)
  61. x = x_mean
  62. x, log_prob_x = self.Bernoulli('data', observation=x, shape=(), probs=x_mean)
  63. return x, log_prob_x, z, log_prob_z
  64. class Variational(zs.BayesianNet):
  65. """ Variational """
  66. def __init__(self, x_dim, z_dim, batch_size):
  67. super().__init__()
  68. self.x_dim = x_dim
  69. self.z_dim = z_dim
  70. self.batch_size = batch_size
  71. self.reshape_op = P.Reshape()
  72. self.fc1 = nn.Dense(x_dim, 500)
  73. self.act1 = nn.ReLU()
  74. self.fc2 = nn.Dense(500, 500)
  75. self.act2 = nn.ReLU()
  76. self.fc3 = nn.Dense(500, z_dim)
  77. self.fc4 = nn.Dense(500, z_dim)
  78. self.fill = P.Fill()
  79. self.exp = P.Exp()
  80. def ones(self, shape):
  81. return self.fill(mstype.float32, shape, 1.)
  82. def zeros(self, shape):
  83. return self.fill(mstype.float32, shape, 0.)
  84. def construct(self, x, z, y):
  85. """ construct """
  86. assert y is None ## we have no conditional information
  87. x = self.reshape_op(x, (32, 32*32))
  88. z_logit = self.act2(self.fc2(self.act1(self.fc1(x))))
  89. z_mean = self.fc3(z_logit)
  90. z_std = self.exp(self.fc4(z_logit))
  91. #z, log_prob_z = self.reparameterization(z_mean, z_std)
  92. z, log_prob_z = self.Normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=True)
  93. return z, log_prob_z
  94. def main():
  95. # We currently support pynative mode with device GPU
  96. context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
  97. epoch_size = 1
  98. batch_size = 32
  99. mnist_path = "/data/chengzi/zhusuan-mindspore/data/MNIST"
  100. repeat_size = 1
  101. # Define model parameters
  102. z_dim = 40
  103. x_dim = 32*32
  104. # create the network
  105. generator = Generator(x_dim, z_dim, batch_size)
  106. variational = Variational(x_dim, z_dim, batch_size)
  107. network = zs.variational.ELBO(generator, variational)
  108. # define loss
  109. # learning rate setting
  110. lr = 0.001
  111. net_loss = ReduceMeanLoss()
  112. # define the optimizer
  113. print(network.trainable_params()[0])
  114. net_opt = nn.Adam(network.trainable_params(), lr)
  115. model = Model(network, net_loss, net_opt)
  116. ds_train = create_dataset(os.path.join(mnist_path, "train"), batch_size, repeat_size)
  117. model.train(epoch_size, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
  118. print(network.trainable_params()[0])
  119. iterator = ds_train.create_tuple_iterator()
  120. for item in iterator:
  121. batch_x = item[0].reshape(32, 32*32)
  122. break
  123. z, _ = network.variational(Tensor(batch_x), None, None)
  124. sample, _, _, _ = network.generator(None, z, None)
  125. sample = sample.asnumpy()
  126. save_img(batch_x, 'result/origin_x.png')
  127. save_img(sample, 'result/reconstruct_x.png')
  128. for i in range(4):
  129. sample, _, _, _ = network.generator(None, None, None)
  130. sample = sample.asnumpy()
  131. samples = sample if i == 0 else np.concatenate([samples, sample], axis=0)
  132. save_img(samples, 'result/sample_x.png', num=4*batch_size)
  133. if __name__ == '__main__':
  134. main()