You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gpu_svi_cvae.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import mindspore.common.dtype as mstype
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.transforms.vision.c_transforms as CV
  19. import mindspore.nn as nn
  20. from mindspore import context, Tensor
  21. from mindspore.ops import operations as P
  22. from mindspore.nn.probability.dpn import ConditionalVAE
  23. from mindspore.nn.probability.infer import ELBO, SVI
  24. context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
  25. IMAGE_SHAPE = (-1, 1, 32, 32)
  26. image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train")
  27. class Encoder(nn.Cell):
  28. def __init__(self, num_classes):
  29. super(Encoder, self).__init__()
  30. self.fc1 = nn.Dense(1024 + num_classes, 400)
  31. self.relu = nn.ReLU()
  32. self.flatten = nn.Flatten()
  33. self.concat = P.Concat(axis=1)
  34. self.one_hot = nn.OneHot(depth=num_classes)
  35. def construct(self, x, y):
  36. x = self.flatten(x)
  37. y = self.one_hot(y)
  38. input_x = self.concat((x, y))
  39. input_x = self.fc1(input_x)
  40. input_x = self.relu(input_x)
  41. return input_x
  42. class Decoder(nn.Cell):
  43. def __init__(self):
  44. super(Decoder, self).__init__()
  45. self.fc2 = nn.Dense(400, 1024)
  46. self.sigmoid = nn.Sigmoid()
  47. self.reshape = P.Reshape()
  48. def construct(self, z):
  49. z = self.fc2(z)
  50. z = self.reshape(z, IMAGE_SHAPE)
  51. z = self.sigmoid(z)
  52. return z
  53. class CVAEWithLossCell(nn.WithLossCell):
  54. """
  55. Rewrite WithLossCell for CVAE
  56. """
  57. def construct(self, data, label):
  58. out = self._backbone(data, label)
  59. return self._loss_fn(out, label)
  60. def create_dataset(data_path, batch_size=32, repeat_size=1,
  61. num_parallel_workers=1):
  62. """
  63. create dataset for train or test
  64. """
  65. # define dataset
  66. mnist_ds = ds.MnistDataset(data_path)
  67. resize_height, resize_width = 32, 32
  68. rescale = 1.0 / 255.0
  69. shift = 0.0
  70. # define map operations
  71. resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode
  72. rescale_op = CV.Rescale(rescale, shift)
  73. hwc2chw_op = CV.HWC2CHW()
  74. # apply map operations on images
  75. mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
  76. mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
  77. mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
  78. # apply DatasetOps
  79. mnist_ds = mnist_ds.batch(batch_size)
  80. mnist_ds = mnist_ds.repeat(repeat_size)
  81. return mnist_ds
  82. def test_svi_cvae():
  83. # define the encoder and decoder
  84. encoder = Encoder(num_classes=10)
  85. decoder = Decoder()
  86. # define the cvae model
  87. cvae = ConditionalVAE(encoder, decoder, hidden_size=400, latent_size=20, num_classes=10)
  88. # define the loss function
  89. net_loss = ELBO(latent_prior='Normal', output_prior='Normal')
  90. # define the optimizer
  91. optimizer = nn.Adam(params=cvae.trainable_params(), learning_rate=0.001)
  92. # define the training dataset
  93. ds_train = create_dataset(image_path, 128, 1)
  94. # define the WithLossCell modified
  95. net_with_loss = CVAEWithLossCell(cvae, net_loss)
  96. # define the variational inference
  97. vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)
  98. # run the vi to return the trained network.
  99. cvae = vi.run(train_dataset=ds_train, epochs=5)
  100. # get the trained loss
  101. trained_loss = vi.get_train_loss()
  102. # test function: generate_sample
  103. sample_label = Tensor([i for i in range(0, 8)] * 8, dtype=mstype.int32)
  104. generated_sample = cvae.generate_sample(sample_label, 64, IMAGE_SHAPE)
  105. # test function: reconstruct_sample
  106. for sample in ds_train.create_dict_iterator():
  107. sample_x = Tensor(sample['image'], dtype=mstype.float32)
  108. sample_y = Tensor(sample['label'], dtype=mstype.int32)
  109. reconstructed_sample = cvae.reconstruct_sample(sample_x, sample_y)
  110. print('The loss of the trained network is ', trained_loss)
  111. print('The shape of the generated sample is ', generated_sample.shape)
  112. print('The shape of the reconstructed sample is ', reconstructed_sample.shape)