From b16b71b9e84db48e1a85aceecd2f6c510a576a1b Mon Sep 17 00:00:00 2001 From: zhangxinfeng3 Date: Mon, 10 Aug 2020 14:27:12 +0800 Subject: [PATCH] add infer and dpn --- mindspore/nn/probability/__init__.py | 2 + mindspore/nn/probability/dpn/__init__.py | 24 +++ mindspore/nn/probability/dpn/vae/__init__.py | 25 +++ mindspore/nn/probability/dpn/vae/cvae.py | 127 ++++++++++++++ mindspore/nn/probability/dpn/vae/vae.py | 113 ++++++++++++ mindspore/nn/probability/infer/__init__.py | 22 +++ .../probability/infer/variational/__init__.py | 26 +++ .../nn/probability/infer/variational/elbo.py | 64 +++++++ .../nn/probability/infer/variational/svi.py | 72 ++++++++ tests/st/probability/test_gpu_svi_cvae.py | 130 ++++++++++++++ tests/st/probability/test_gpu_svi_vae.py | 115 ++++++++++++ tests/st/probability/test_gpu_vae_gan.py | 164 ++++++++++++++++++ tests/ut/python/nn/probability/test_vae.py | 57 ++++++ 13 files changed, 941 insertions(+) create mode 100644 mindspore/nn/probability/dpn/__init__.py create mode 100644 mindspore/nn/probability/dpn/vae/__init__.py create mode 100644 mindspore/nn/probability/dpn/vae/cvae.py create mode 100644 mindspore/nn/probability/dpn/vae/vae.py create mode 100644 mindspore/nn/probability/infer/__init__.py create mode 100644 mindspore/nn/probability/infer/variational/__init__.py create mode 100644 mindspore/nn/probability/infer/variational/elbo.py create mode 100644 mindspore/nn/probability/infer/variational/svi.py create mode 100644 tests/st/probability/test_gpu_svi_cvae.py create mode 100644 tests/st/probability/test_gpu_svi_vae.py create mode 100644 tests/st/probability/test_gpu_vae_gan.py create mode 100644 tests/ut/python/nn/probability/test_vae.py diff --git a/mindspore/nn/probability/__init__.py b/mindspore/nn/probability/__init__.py index 5bc8a54c40..bc5966ee4e 100644 --- a/mindspore/nn/probability/__init__.py +++ b/mindspore/nn/probability/__init__.py @@ -20,3 +20,5 @@ The high-level components used to construct the probabilistic network. from . import bijector from . import distribution +from . import infer +from . import dpn diff --git a/mindspore/nn/probability/dpn/__init__.py b/mindspore/nn/probability/dpn/__init__.py new file mode 100644 index 0000000000..d5636469a9 --- /dev/null +++ b/mindspore/nn/probability/dpn/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Deep Probability Network(dpn). + +Deep probability network such as BNN and VAE network. +""" + +from .vae import * + +__all__ = [] +__all__.extend(vae.__all__) diff --git a/mindspore/nn/probability/dpn/vae/__init__.py b/mindspore/nn/probability/dpn/vae/__init__.py new file mode 100644 index 0000000000..350839fec9 --- /dev/null +++ b/mindspore/nn/probability/dpn/vae/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Variational auto-encoder (VAE). + +The interface of VAE, which allows to construct probablity model like DNN model. +""" + +from .vae import VAE +from .cvae import ConditionalVAE + +__all__ = ['VAE', + 'ConditionalVAE'] diff --git a/mindspore/nn/probability/dpn/vae/cvae.py b/mindspore/nn/probability/dpn/vae/cvae.py new file mode 100644 index 0000000000..87839c914f --- /dev/null +++ b/mindspore/nn/probability/dpn/vae/cvae.py @@ -0,0 +1,127 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Conditional Variational auto-encoder (CVAE).""" +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore._checkparam import check_int_positive +from ...distribution.normal import Normal +from ....cell import Cell +from ....layer.basic import Dense, OneHot + + +class ConditionalVAE(Cell): + r""" + Conditional Variational auto-encoder (CVAE). + + The difference with VAE is that CVAE uses labels information. + see more details in ``. + + Note: + When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor + should be math:`(N, hidden_size)`. + The latent_size should be less than or equal to the hidden_size. + + Args: + encoder(Cell): The DNN model defined as encoder. + decoder(Cell): The DNN model defined as decoder. + hidden_size(int): The size of encoder's output tensor. + latent_size(int): The size of the latent space. + num_classes(int): The number of classes. + + Inputs: + - **input_x** (Tensor) - the same shape as the input of encoder. + - **input_y** (Tensor) - the tensor of the target data, the shape is math:`(N, 1)`. + + Outputs: + - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)). + """ + + def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes): + super(ConditionalVAE, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.hidden_size = check_int_positive(hidden_size) + self.latent_size = check_int_positive(latent_size) + self.num_classes = check_int_positive(num_classes) + self.normal = C.normal + self.exp = P.Exp() + self.reshape = P.Reshape() + self.concat = P.Concat(axis=1) + self.to_tensor = P.ScalarToArray() + self.normal_dis = Normal() + self.one_hot = OneHot(depth=num_classes) + self.standard_normal_dis = Normal([0] * self.latent_size, [1] * self.latent_size) + self.dense1 = Dense(self.hidden_size, self.latent_size) + self.dense2 = Dense(self.hidden_size, self.latent_size) + self.dense3 = Dense(self.latent_size + self.num_classes, self.hidden_size) + + def _encode(self, x, y): + en_x = self.encoder(x, y) + mu = self.dense1(en_x) + log_var = self.dense2(en_x) + return mu, log_var + + def _decode(self, z): + z = self.dense3(z) + recon_x = self.decoder(z) + return recon_x + + def construct(self, x, y): + mu, log_var = self._encode(x, y) + std = self.exp(0.5 * log_var) + z = self.normal_dis('sample', mean=mu, sd=std) + y = self.one_hot(y) + z_c = self.concat((z, y)) + recon_x = self._decode(z_c) + return recon_x, x, mu, std, z, self.standard_normal_dis + + def generate_sample(self, sample_y, generate_nums=None, shape=None): + """ + Randomly sample from latent space to generate sample. + + Args: + sample_y (Tensor): Define the label of sample, int tensor. + generate_nums (int): The number of samples to generate. + shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`. + + Returns: + Tensor, the generated sample. + """ + sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) + sample_y = self.one_hot(sample_y) + sample_c = self.concat((sample_z, sample_y)) + sample = self._decode(sample_c) + sample = self.reshape(sample, shape) + return sample + + def reconstruct_sample(self, x, y): + """ + Reconstruct sample from original data. + + Args: + x (Tensor): The input tensor to be reconstructed. + y (Tensor): The label of the input tensor. + + Returns: + Tensor, the reconstructed sample. + """ + mu, log_var = self._encode(x, y) + std = self.exp(0.5 * log_var) + z = self.normal(mu.shape, mu, std, seed=0) + y = self.one_hot(y) + z_c = self.concat((z, y)) + recon_x = self._decode(z_c) + return recon_x diff --git a/mindspore/nn/probability/dpn/vae/vae.py b/mindspore/nn/probability/dpn/vae/vae.py new file mode 100644 index 0000000000..25ef01f1fb --- /dev/null +++ b/mindspore/nn/probability/dpn/vae/vae.py @@ -0,0 +1,113 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Variational auto-encoder (VAE)""" +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore._checkparam import check_int_positive +from ...distribution.normal import Normal +from ....cell import Cell +from ....layer.basic import Dense + + +class VAE(Cell): + r""" + Variational auto-encoder (VAE). + + The VAE defines a generative model, `Z` is sampled from the prior, then used to reconstruct `X` by a decoder. + see more details in `Auto-Encoding Variational Bayes`_. + + Note: + When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor + should be math:`(N, hidden_size)`. + The latent_size should be less than or equal to the hidden_size. + + Args: + encoder(Cell): The DNN model defined as encoder. + decoder(Cell): The DNN model defined as decoder. + hidden_size(int): The size of encoder's output tensor. + latent_size(int): The size of the latent space. + + Inputs: + - **input** (Tensor) - the same shape as the input of encoder. + + Outputs: + - **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)). + """ + + def __init__(self, encoder, decoder, hidden_size, latent_size): + super(VAE, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.hidden_size = check_int_positive(hidden_size) + self.latent_size = check_int_positive(latent_size) + self.normal = C.normal + self.exp = P.Exp() + self.reshape = P.Reshape() + self.to_tensor = P.ScalarToArray() + self.normal_dis = Normal() + self.standard_normal_dis = Normal([0]*self.latent_size, [1]*self.latent_size) + self.dense1 = Dense(self.hidden_size, self.latent_size) + self.dense2 = Dense(self.hidden_size, self.latent_size) + self.dense3 = Dense(self.latent_size, self.hidden_size) + + def _encode(self, x): + en_x = self.encoder(x) + mu = self.dense1(en_x) + log_var = self.dense2(en_x) + return mu, log_var + + def _decode(self, z): + z = self.dense3(z) + recon_x = self.decoder(z) + return recon_x + + def construct(self, x): + mu, log_var = self._encode(x) + std = self.exp(0.5 * log_var) + z = self.normal_dis('sample', mean=mu, sd=std) + recon_x = self._decode(z) + return recon_x, x, mu, std, z, self.standard_normal_dis + + def generate_sample(self, generate_nums, shape): + """ + Randomly sample from latent space to generate sample. + + Args: + generate_nums (int): The number of samples to generate. + shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`. + + Returns: + Tensor, the generated sample. + """ + sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) + sample = self._decode(sample_z) + sample = self.reshape(sample, shape) + return sample + + def reconstruct_sample(self, x): + """ + Reconstruct sample from original data. + + Args: + x (Tensor): The input tensor to be reconstructed. + + Returns: + Tensor, the reconstructed sample. + """ + mu, log_var = self._encode(x) + std = self.exp(0.5 * log_var) + z = self.normal(mu.shape, mu, std, seed=0) + recon_x = self._decode(z) + return recon_x diff --git a/mindspore/nn/probability/infer/__init__.py b/mindspore/nn/probability/infer/__init__.py new file mode 100644 index 0000000000..52968a8e0d --- /dev/null +++ b/mindspore/nn/probability/infer/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Infer algorithms in Probabilistic Programming. +""" + +from .variational import * + +__all__ = [] +__all__.extend(variational.__all__) diff --git a/mindspore/nn/probability/infer/variational/__init__.py b/mindspore/nn/probability/infer/variational/__init__.py new file mode 100644 index 0000000000..f05dffbfd0 --- /dev/null +++ b/mindspore/nn/probability/infer/variational/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +SVI and ELBO. + +The SVI interface is for variational inference. +The ELBO interface is called as loss while model training. +""" + +from .svi import SVI +from .elbo import ELBO + +__all__ = ['SVI', + 'ELBO'] diff --git a/mindspore/nn/probability/infer/variational/elbo.py b/mindspore/nn/probability/infer/variational/elbo.py new file mode 100644 index 0000000000..df2a7c369b --- /dev/null +++ b/mindspore/nn/probability/infer/variational/elbo.py @@ -0,0 +1,64 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The Evidence Lower Bound (ELBO).""" +from mindspore.ops import operations as P +from ...distribution.normal import Normal +from ....cell import Cell +from ....loss.loss import MSELoss + + +class ELBO(Cell): + r""" + The Evidence Lower Bound (ELBO). + + Variational inference minimizes the Kullback-Leibler (KL) divergence from the variational distribution to + the posterior distribution. It maximizes the evidence lower bound (ELBO), a lower bound on the logarithm of + the marginal probability of the observations log p(x). The ELBO is equal to the negative KL divergence up to + an additive constant. + see more details in `Variational Inference: A Review for Statisticians`_. + + Args: + latent_prior(str): The prior distribution of latent space. Default: Normal. + - Normal: The prior distribution of latent space is Normal. + output_prior(str): The distribution of output data. Default: Normal. + - Normal: If the distribution of output data is Normal, the reconstruct loss is MSELoss. + + Inputs: + - **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)). + - **target_data** (Tensor) - the target tensor. + + Outputs: + Tensor, loss float tensor. + """ + + def __init__(self, latent_prior='Normal', output_prior='Normal'): + super(ELBO, self).__init__() + self.sum = P.ReduceSum() + if latent_prior == 'Normal': + self.posterior = Normal() + else: + raise ValueError('The values of latent_prior now only support Normal') + if output_prior == 'Normal': + self.recon_loss = MSELoss(reduction='sum') + else: + raise ValueError('The values of output_dis now only support Normal') + + def construct(self, data, label): + recon_x, x, mu, std, z, prior = data + reconstruct_loss = self.recon_loss(x, recon_x) + kl_loss = -(prior('log_prob', z) - self.posterior('log_prob', z, mu, std)) \ + * self.posterior('prob', z, mu, std) + elbo = reconstruct_loss + self.sum(kl_loss) + return elbo diff --git a/mindspore/nn/probability/infer/variational/svi.py b/mindspore/nn/probability/infer/variational/svi.py new file mode 100644 index 0000000000..8aca1221ac --- /dev/null +++ b/mindspore/nn/probability/infer/variational/svi.py @@ -0,0 +1,72 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Stochastic Variational Inference(SVI).""" +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +from ....wrap.cell_wrapper import TrainOneStepCell + + +class SVI: + r""" + Stochastic Variational Inference(SVI). + + Variational inference casts the inference problem as an optimization. Some distributions over the hidden + variables that is indexed by a set of free parameters, and then optimize the parameters to make it closest to + the posterior of interest. + see more details in `Variational Inference: A Review for Statisticians`_. + + Args: + net_with_loss(Cell): Cell with loss function. + optimizer (Cell): Optimizer for updating the weights. + + """ + + def __init__(self, net_with_loss, optimizer): + self.net_with_loss = net_with_loss + self.optimizer = optimizer + self._loss = 0.0 + + def run(self, train_dataset, epochs=10): + """ + Optimize the parameters by training the probability network, and return the trained network. + + Args: + epochs (int): Total number of iterations on the data. Default: 10. + train_dataset (Dataset): A training dataset iterator. + + Outputs: + Cell, the trained probability network. + """ + train_net = TrainOneStepCell(self.net_with_loss, self.optimizer) + train_net.set_train() + for _ in range(1, epochs+1): + train_loss = 0 + dataset_size = 0 + for data in train_dataset.create_dict_iterator(): + x = Tensor(data['image'], dtype=mstype.float32) + y = Tensor(data['label'], dtype=mstype.int32) + dataset_size += len(x) + loss = train_net(x, y).asnumpy() + train_loss += loss + self._loss = train_loss / dataset_size + model = self.net_with_loss.backbone_network + return model + + def get_train_loss(self): + """ + Returns: + numpy.dtype, the loss after training. + """ + return self._loss diff --git a/tests/st/probability/test_gpu_svi_cvae.py b/tests/st/probability/test_gpu_svi_cvae.py new file mode 100644 index 0000000000..08c645179e --- /dev/null +++ b/tests/st/probability/test_gpu_svi_cvae.py @@ -0,0 +1,130 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as CV +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.ops import operations as P +from mindspore.nn.probability.dpn import ConditionalVAE +from mindspore.nn.probability.infer import ELBO, SVI + +context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU") +IMAGE_SHAPE = (-1, 1, 32, 32) +image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train") + + +class Encoder(nn.Cell): + def __init__(self, num_classes): + super(Encoder, self).__init__() + self.fc1 = nn.Dense(1024 + num_classes, 400) + self.relu = nn.ReLU() + self.flatten = nn.Flatten() + self.concat = P.Concat(axis=1) + self.one_hot = nn.OneHot(depth=num_classes) + + def construct(self, x, y): + x = self.flatten(x) + y = self.one_hot(y) + input_x = self.concat((x, y)) + input_x = self.fc1(input_x) + input_x = self.relu(input_x) + return input_x + + +class Decoder(nn.Cell): + def __init__(self): + super(Decoder, self).__init__() + self.fc2 = nn.Dense(400, 1024) + self.sigmoid = nn.Sigmoid() + self.reshape = P.Reshape() + + def construct(self, z): + z = self.fc2(z) + z = self.reshape(z, IMAGE_SHAPE) + z = self.sigmoid(z) + return z + + +class WithLossCell(nn.Cell): + def __init__(self, backbone, loss_fn): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self._loss_fn = loss_fn + + def construct(self, data, label): + out = self._backbone(data, label) + return self._loss_fn(out, label) + + +def create_dataset(data_path, batch_size=32, repeat_size=1, + num_parallel_workers=1): + """ + create dataset for train or test + """ + # define dataset + mnist_ds = ds.MnistDataset(data_path) + + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + + # apply map operations on images + mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + mnist_ds = mnist_ds.batch(batch_size) + mnist_ds = mnist_ds.repeat(repeat_size) + + return mnist_ds + + +if __name__ == "__main__": + # define the encoder and decoder + encoder = Encoder(num_classes=10) + decoder = Decoder() + # define the cvae model + cvae = ConditionalVAE(encoder, decoder, hidden_size=400, latent_size=20, num_classes=10) + # define the loss function + net_loss = ELBO(latent_prior='Normal', output_dis='Normal') + # define the optimizer + optimizer = nn.Adam(params=cvae.trainable_params(), learning_rate=0.001) + # define the training dataset + ds_train = create_dataset(image_path, 128, 1) + # define the WithLossCell modified + net_with_loss = WithLossCell(cvae, net_loss) + # define the variational inference + vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer) + # run the vi to return the trained network. + cvae = vi.run(train_dataset=ds_train, epochs=10) + # get the trained loss + trained_loss = vi.get_train_loss() + # test function: generate_sample + sample_label = Tensor([i for i in range(0, 8)] * 8, dtype=mstype.int32) + generated_sample = cvae.generate_sample(sample_label, 64, IMAGE_SHAPE) + # test function: reconstruct_sample + for sample in ds_train.create_dict_iterator(): + sample_x = Tensor(sample['image'], dtype=mstype.float32) + sample_y = Tensor(sample['label'], dtype=mstype.int32) + reconstructed_sample = cvae.reconstruct_sample(sample_x, sample_y) diff --git a/tests/st/probability/test_gpu_svi_vae.py b/tests/st/probability/test_gpu_svi_vae.py new file mode 100644 index 0000000000..c18a290e51 --- /dev/null +++ b/tests/st/probability/test_gpu_svi_vae.py @@ -0,0 +1,115 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as CV +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.ops import operations as P +from mindspore.nn.probability.dpn import VAE +from mindspore.nn.probability.infer import ELBO, SVI + +context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU") +IMAGE_SHAPE = (-1, 1, 32, 32) +image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train") + + +class Encoder(nn.Cell): + def __init__(self): + super(Encoder, self).__init__() + self.fc1 = nn.Dense(1024, 800) + self.fc2 = nn.Dense(800, 400) + self.relu = nn.ReLU() + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + return x + + +class Decoder(nn.Cell): + def __init__(self): + super(Decoder, self).__init__() + self.fc1 = nn.Dense(400, 1024) + self.sigmoid = nn.Sigmoid() + self.reshape = P.Reshape() + + def construct(self, z): + z = self.fc1(z) + z = self.reshape(z, IMAGE_SHAPE) + z = self.sigmoid(z) + return z + + +def create_dataset(data_path, batch_size=32, repeat_size=1, + num_parallel_workers=1): + """ + create dataset for train or test + """ + # define dataset + mnist_ds = ds.MnistDataset(data_path) + + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + + # apply map operations on images + mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + mnist_ds = mnist_ds.batch(batch_size) + mnist_ds = mnist_ds.repeat(repeat_size) + + return mnist_ds + + +if __name__ == "__main__": + # define the encoder and decoder + encoder = Encoder() + decoder = Decoder() + # define the vae model + vae = VAE(encoder, decoder, hidden_size=400, latent_size=20) + # define the loss function + net_loss = ELBO(latent_prior='Normal', output_dis='Normal') + # define the optimizer + optimizer = nn.Adam(params=vae.trainable_params(), learning_rate=0.001) + # define the training dataset + ds_train = create_dataset(image_path, 128, 1) + net_with_loss = nn.WithLossCell(vae, net_loss) + # define the variational inference + vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer) + # run the vi to return the trained network. + vae = vi.run(train_dataset=ds_train, epochs=10) + # get the trained loss + trained_loss = vi.get_train_loss() + # test function: generate_sample + generated_sample = vae.generate_sample(64, IMAGE_SHAPE) + # test function: reconstruct_sample + for sample in ds_train.create_dict_iterator(): + sample_x = Tensor(sample['image'], dtype=mstype.float32) + reconstructed_sample = vae.reconstruct_sample(sample_x) diff --git a/tests/st/probability/test_gpu_vae_gan.py b/tests/st/probability/test_gpu_vae_gan.py new file mode 100644 index 0000000000..b4f62d10e8 --- /dev/null +++ b/tests/st/probability/test_gpu_vae_gan.py @@ -0,0 +1,164 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +The VAE interface can be called to construct VAE-GAN network. +""" +import os + +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as CV +import mindspore.nn as nn +from mindspore import context +from mindspore.ops import operations as P +from mindspore.nn.probability.dpn import VAE +from mindspore.nn.probability.infer import ELBO, SVI + +context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU") +IMAGE_SHAPE = (-1, 1, 32, 32) +image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train") + + +class Encoder(nn.Cell): + def __init__(self): + super(Encoder, self).__init__() + self.fc1 = nn.Dense(1024, 400) + self.relu = nn.ReLU() + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + return x + + +class Decoder(nn.Cell): + def __init__(self): + super(Decoder, self).__init__() + self.fc1 = nn.Dense(400, 1024) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + self.reshape = P.Reshape() + + def construct(self, z): + z = self.fc1(z) + z = self.reshape(z, IMAGE_SHAPE) + z = self.sigmoid(z) + return z + + +class Discriminator(nn.Cell): + """ + The Discriminator of the GAN network. + """ + + def __init__(self): + super(Discriminator, self).__init__() + self.fc1 = nn.Dense(1024, 400) + self.fc2 = nn.Dense(400, 720) + self.fc3 = nn.Dense(720, 1024) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + x = self.sigmoid(x) + return x + + +class VaeGan(nn.Cell): + def __init__(self): + super(VaeGan, self).__init__() + self.E = Encoder() + self.G = Decoder() + self.D = Discriminator() + self.dense = nn.Dense(20, 400) + self.vae = VAE(self.E, self.G, 400, 20) + self.shape = P.Shape() + self.to_tensor = P.ScalarToArray() + + def construct(self, x): + recon_x, x, mu, std, z, prior = self.vae(x) + z_p = prior('sample', self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0)) + z_p = self.dense(z_p) + x_p = self.G(z_p) + ld_real = self.D(x) + ld_fake = self.D(recon_x) + ld_p = self.D(x_p) + return ld_real, ld_fake, ld_p, recon_x, x, mu, std, z, prior + + +class VaeGanLoss(nn.Cell): + def __init__(self): + super(VaeGanLoss, self).__init__() + self.zeros = P.ZerosLike() + self.mse = nn.MSELoss(reduction='sum') + self.elbo = ELBO(latent_prior='Normal', output_dis='Normal') + + def construct(self, data, label): + ld_real, ld_fake, ld_p, recon_x, x, mean, std, z, prior = data + y_real = self.zeros(ld_real) + 1 + y_fake = self.zeros(ld_fake) + elbo_data = (recon_x, x, mean, std, z, prior) + loss_D = self.mse(ld_real, y_real) + loss_GD = self.mse(ld_p, y_fake) + loss_G = self.mse(ld_fake, y_real) + elbo_loss = self.elbo(elbo_data, label) + return loss_D + loss_G + loss_GD + elbo_loss + + +def create_dataset(data_path, batch_size=32, repeat_size=1, + num_parallel_workers=1): + """ + create dataset for train or test + """ + # define dataset + mnist_ds = ds.MnistDataset(data_path) + + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode + rescale_op = CV.Rescale(rescale, shift) + hwc2chw_op = CV.HWC2CHW() + + # apply map operations on images + mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + mnist_ds = mnist_ds.batch(batch_size) + mnist_ds = mnist_ds.repeat(repeat_size) + + return mnist_ds + + +if __name__ == "__main__": + vae_gan = VaeGan() + net_loss = VaeGanLoss() + optimizer = nn.Adam(params=vae_gan.trainable_params(), learning_rate=0.001) + ds_train = create_dataset(image_path, 128, 1) + net_with_loss = nn.WithLossCell(vae_gan, net_loss) + vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer) + vae_gan = vi.run(train_dataset=ds_train, epochs=10) diff --git a/tests/ut/python/nn/probability/test_vae.py b/tests/ut/python/nn/probability/test_vae.py new file mode 100644 index 0000000000..aff129574e --- /dev/null +++ b/tests/ut/python/nn/probability/test_vae.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test VAE interface """ +import numpy as np + +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import _executor +from mindspore.nn.probability.dpn import VAE + + +class Encoder(nn.Cell): + def __init__(self): + super(Encoder, self).__init__() + self.fc1 = nn.Dense(6, 3) + self.relu = nn.ReLU() + + def construct(self, x): + x = self.fc1(x) + x = self.relu(x) + return x + + +class Decoder(nn.Cell): + def __init__(self): + super(Decoder, self).__init__() + self.fc1 = nn.Dense(3, 6) + self.sigmoid = nn.Sigmoid() + + def construct(self, z): + z = self.fc1(z) + z = self.sigmoid(z) + return z + + +def test_vae(): + """ + Test the vae interface with the DNN model. + """ + encoder = Encoder() + decoder = Decoder() + net = VAE(encoder, decoder, hidden_size=3, latent_size=2) + input_data = Tensor(np.random.rand(32, 6), dtype=mstype.float32) + _executor.compile(net, input_data)