| @@ -20,3 +20,5 @@ The high-level components used to construct the probabilistic network. | |||
| from . import bijector | |||
| from . import distribution | |||
| from . import infer | |||
| from . import dpn | |||
| @@ -0,0 +1,24 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Deep Probability Network(dpn). | |||
| Deep probability network such as BNN and VAE network. | |||
| """ | |||
| from .vae import * | |||
| __all__ = [] | |||
| __all__.extend(vae.__all__) | |||
| @@ -0,0 +1,25 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Variational auto-encoder (VAE). | |||
| The interface of VAE, which allows to construct probablity model like DNN model. | |||
| """ | |||
| from .vae import VAE | |||
| from .cvae import ConditionalVAE | |||
| __all__ = ['VAE', | |||
| 'ConditionalVAE'] | |||
| @@ -0,0 +1,127 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Conditional Variational auto-encoder (CVAE).""" | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import check_int_positive | |||
| from ...distribution.normal import Normal | |||
| from ....cell import Cell | |||
| from ....layer.basic import Dense, OneHot | |||
| class ConditionalVAE(Cell): | |||
| r""" | |||
| Conditional Variational auto-encoder (CVAE). | |||
| The difference with VAE is that CVAE uses labels information. | |||
| see more details in `<http://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep- | |||
| conditional-generative-models>`. | |||
| Note: | |||
| When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor | |||
| should be math:`(N, hidden_size)`. | |||
| The latent_size should be less than or equal to the hidden_size. | |||
| Args: | |||
| encoder(Cell): The DNN model defined as encoder. | |||
| decoder(Cell): The DNN model defined as decoder. | |||
| hidden_size(int): The size of encoder's output tensor. | |||
| latent_size(int): The size of the latent space. | |||
| num_classes(int): The number of classes. | |||
| Inputs: | |||
| - **input_x** (Tensor) - the same shape as the input of encoder. | |||
| - **input_y** (Tensor) - the tensor of the target data, the shape is math:`(N, 1)`. | |||
| Outputs: | |||
| - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)). | |||
| """ | |||
| def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes): | |||
| super(ConditionalVAE, self).__init__() | |||
| self.encoder = encoder | |||
| self.decoder = decoder | |||
| self.hidden_size = check_int_positive(hidden_size) | |||
| self.latent_size = check_int_positive(latent_size) | |||
| self.num_classes = check_int_positive(num_classes) | |||
| self.normal = C.normal | |||
| self.exp = P.Exp() | |||
| self.reshape = P.Reshape() | |||
| self.concat = P.Concat(axis=1) | |||
| self.to_tensor = P.ScalarToArray() | |||
| self.normal_dis = Normal() | |||
| self.one_hot = OneHot(depth=num_classes) | |||
| self.standard_normal_dis = Normal([0] * self.latent_size, [1] * self.latent_size) | |||
| self.dense1 = Dense(self.hidden_size, self.latent_size) | |||
| self.dense2 = Dense(self.hidden_size, self.latent_size) | |||
| self.dense3 = Dense(self.latent_size + self.num_classes, self.hidden_size) | |||
| def _encode(self, x, y): | |||
| en_x = self.encoder(x, y) | |||
| mu = self.dense1(en_x) | |||
| log_var = self.dense2(en_x) | |||
| return mu, log_var | |||
| def _decode(self, z): | |||
| z = self.dense3(z) | |||
| recon_x = self.decoder(z) | |||
| return recon_x | |||
| def construct(self, x, y): | |||
| mu, log_var = self._encode(x, y) | |||
| std = self.exp(0.5 * log_var) | |||
| z = self.normal_dis('sample', mean=mu, sd=std) | |||
| y = self.one_hot(y) | |||
| z_c = self.concat((z, y)) | |||
| recon_x = self._decode(z_c) | |||
| return recon_x, x, mu, std, z, self.standard_normal_dis | |||
| def generate_sample(self, sample_y, generate_nums=None, shape=None): | |||
| """ | |||
| Randomly sample from latent space to generate sample. | |||
| Args: | |||
| sample_y (Tensor): Define the label of sample, int tensor. | |||
| generate_nums (int): The number of samples to generate. | |||
| shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`. | |||
| Returns: | |||
| Tensor, the generated sample. | |||
| """ | |||
| sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) | |||
| sample_y = self.one_hot(sample_y) | |||
| sample_c = self.concat((sample_z, sample_y)) | |||
| sample = self._decode(sample_c) | |||
| sample = self.reshape(sample, shape) | |||
| return sample | |||
| def reconstruct_sample(self, x, y): | |||
| """ | |||
| Reconstruct sample from original data. | |||
| Args: | |||
| x (Tensor): The input tensor to be reconstructed. | |||
| y (Tensor): The label of the input tensor. | |||
| Returns: | |||
| Tensor, the reconstructed sample. | |||
| """ | |||
| mu, log_var = self._encode(x, y) | |||
| std = self.exp(0.5 * log_var) | |||
| z = self.normal(mu.shape, mu, std, seed=0) | |||
| y = self.one_hot(y) | |||
| z_c = self.concat((z, y)) | |||
| recon_x = self._decode(z_c) | |||
| return recon_x | |||
| @@ -0,0 +1,113 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Variational auto-encoder (VAE)""" | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import check_int_positive | |||
| from ...distribution.normal import Normal | |||
| from ....cell import Cell | |||
| from ....layer.basic import Dense | |||
| class VAE(Cell): | |||
| r""" | |||
| Variational auto-encoder (VAE). | |||
| The VAE defines a generative model, `Z` is sampled from the prior, then used to reconstruct `X` by a decoder. | |||
| see more details in `Auto-Encoding Variational Bayes<https://arxiv.org/abs/1312.6114>`_. | |||
| Note: | |||
| When define the encoder and decoder, the shape of the encoder's output tensor and decoder's input tensor | |||
| should be math:`(N, hidden_size)`. | |||
| The latent_size should be less than or equal to the hidden_size. | |||
| Args: | |||
| encoder(Cell): The DNN model defined as encoder. | |||
| decoder(Cell): The DNN model defined as decoder. | |||
| hidden_size(int): The size of encoder's output tensor. | |||
| latent_size(int): The size of the latent space. | |||
| Inputs: | |||
| - **input** (Tensor) - the same shape as the input of encoder. | |||
| Outputs: | |||
| - **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)). | |||
| """ | |||
| def __init__(self, encoder, decoder, hidden_size, latent_size): | |||
| super(VAE, self).__init__() | |||
| self.encoder = encoder | |||
| self.decoder = decoder | |||
| self.hidden_size = check_int_positive(hidden_size) | |||
| self.latent_size = check_int_positive(latent_size) | |||
| self.normal = C.normal | |||
| self.exp = P.Exp() | |||
| self.reshape = P.Reshape() | |||
| self.to_tensor = P.ScalarToArray() | |||
| self.normal_dis = Normal() | |||
| self.standard_normal_dis = Normal([0]*self.latent_size, [1]*self.latent_size) | |||
| self.dense1 = Dense(self.hidden_size, self.latent_size) | |||
| self.dense2 = Dense(self.hidden_size, self.latent_size) | |||
| self.dense3 = Dense(self.latent_size, self.hidden_size) | |||
| def _encode(self, x): | |||
| en_x = self.encoder(x) | |||
| mu = self.dense1(en_x) | |||
| log_var = self.dense2(en_x) | |||
| return mu, log_var | |||
| def _decode(self, z): | |||
| z = self.dense3(z) | |||
| recon_x = self.decoder(z) | |||
| return recon_x | |||
| def construct(self, x): | |||
| mu, log_var = self._encode(x) | |||
| std = self.exp(0.5 * log_var) | |||
| z = self.normal_dis('sample', mean=mu, sd=std) | |||
| recon_x = self._decode(z) | |||
| return recon_x, x, mu, std, z, self.standard_normal_dis | |||
| def generate_sample(self, generate_nums, shape): | |||
| """ | |||
| Randomly sample from latent space to generate sample. | |||
| Args: | |||
| generate_nums (int): The number of samples to generate. | |||
| shape(tuple): The shape of sample, it should be math:`(generate_nums, C, H, W)`. | |||
| Returns: | |||
| Tensor, the generated sample. | |||
| """ | |||
| sample_z = self.normal((generate_nums, self.latent_size), self.to_tensor(0.0), self.to_tensor(1.0), seed=0) | |||
| sample = self._decode(sample_z) | |||
| sample = self.reshape(sample, shape) | |||
| return sample | |||
| def reconstruct_sample(self, x): | |||
| """ | |||
| Reconstruct sample from original data. | |||
| Args: | |||
| x (Tensor): The input tensor to be reconstructed. | |||
| Returns: | |||
| Tensor, the reconstructed sample. | |||
| """ | |||
| mu, log_var = self._encode(x) | |||
| std = self.exp(0.5 * log_var) | |||
| z = self.normal(mu.shape, mu, std, seed=0) | |||
| recon_x = self._decode(z) | |||
| return recon_x | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Infer algorithms in Probabilistic Programming. | |||
| """ | |||
| from .variational import * | |||
| __all__ = [] | |||
| __all__.extend(variational.__all__) | |||
| @@ -0,0 +1,26 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| SVI and ELBO. | |||
| The SVI interface is for variational inference. | |||
| The ELBO interface is called as loss while model training. | |||
| """ | |||
| from .svi import SVI | |||
| from .elbo import ELBO | |||
| __all__ = ['SVI', | |||
| 'ELBO'] | |||
| @@ -0,0 +1,64 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """The Evidence Lower Bound (ELBO).""" | |||
| from mindspore.ops import operations as P | |||
| from ...distribution.normal import Normal | |||
| from ....cell import Cell | |||
| from ....loss.loss import MSELoss | |||
| class ELBO(Cell): | |||
| r""" | |||
| The Evidence Lower Bound (ELBO). | |||
| Variational inference minimizes the Kullback-Leibler (KL) divergence from the variational distribution to | |||
| the posterior distribution. It maximizes the evidence lower bound (ELBO), a lower bound on the logarithm of | |||
| the marginal probability of the observations log p(x). The ELBO is equal to the negative KL divergence up to | |||
| an additive constant. | |||
| see more details in `Variational Inference: A Review for Statisticians<https://arxiv.org/abs/1601.00670>`_. | |||
| Args: | |||
| latent_prior(str): The prior distribution of latent space. Default: Normal. | |||
| - Normal: The prior distribution of latent space is Normal. | |||
| output_prior(str): The distribution of output data. Default: Normal. | |||
| - Normal: If the distribution of output data is Normal, the reconstruct loss is MSELoss. | |||
| Inputs: | |||
| - **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)). | |||
| - **target_data** (Tensor) - the target tensor. | |||
| Outputs: | |||
| Tensor, loss float tensor. | |||
| """ | |||
| def __init__(self, latent_prior='Normal', output_prior='Normal'): | |||
| super(ELBO, self).__init__() | |||
| self.sum = P.ReduceSum() | |||
| if latent_prior == 'Normal': | |||
| self.posterior = Normal() | |||
| else: | |||
| raise ValueError('The values of latent_prior now only support Normal') | |||
| if output_prior == 'Normal': | |||
| self.recon_loss = MSELoss(reduction='sum') | |||
| else: | |||
| raise ValueError('The values of output_dis now only support Normal') | |||
| def construct(self, data, label): | |||
| recon_x, x, mu, std, z, prior = data | |||
| reconstruct_loss = self.recon_loss(x, recon_x) | |||
| kl_loss = -(prior('log_prob', z) - self.posterior('log_prob', z, mu, std)) \ | |||
| * self.posterior('prob', z, mu, std) | |||
| elbo = reconstruct_loss + self.sum(kl_loss) | |||
| return elbo | |||
| @@ -0,0 +1,72 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Stochastic Variational Inference(SVI).""" | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| from ....wrap.cell_wrapper import TrainOneStepCell | |||
| class SVI: | |||
| r""" | |||
| Stochastic Variational Inference(SVI). | |||
| Variational inference casts the inference problem as an optimization. Some distributions over the hidden | |||
| variables that is indexed by a set of free parameters, and then optimize the parameters to make it closest to | |||
| the posterior of interest. | |||
| see more details in `Variational Inference: A Review for Statisticians<https://arxiv.org/abs/1601.00670>`_. | |||
| Args: | |||
| net_with_loss(Cell): Cell with loss function. | |||
| optimizer (Cell): Optimizer for updating the weights. | |||
| """ | |||
| def __init__(self, net_with_loss, optimizer): | |||
| self.net_with_loss = net_with_loss | |||
| self.optimizer = optimizer | |||
| self._loss = 0.0 | |||
| def run(self, train_dataset, epochs=10): | |||
| """ | |||
| Optimize the parameters by training the probability network, and return the trained network. | |||
| Args: | |||
| epochs (int): Total number of iterations on the data. Default: 10. | |||
| train_dataset (Dataset): A training dataset iterator. | |||
| Outputs: | |||
| Cell, the trained probability network. | |||
| """ | |||
| train_net = TrainOneStepCell(self.net_with_loss, self.optimizer) | |||
| train_net.set_train() | |||
| for _ in range(1, epochs+1): | |||
| train_loss = 0 | |||
| dataset_size = 0 | |||
| for data in train_dataset.create_dict_iterator(): | |||
| x = Tensor(data['image'], dtype=mstype.float32) | |||
| y = Tensor(data['label'], dtype=mstype.int32) | |||
| dataset_size += len(x) | |||
| loss = train_net(x, y).asnumpy() | |||
| train_loss += loss | |||
| self._loss = train_loss / dataset_size | |||
| model = self.net_with_loss.backbone_network | |||
| return model | |||
| def get_train_loss(self): | |||
| """ | |||
| Returns: | |||
| numpy.dtype, the loss after training. | |||
| """ | |||
| return self._loss | |||
| @@ -0,0 +1,130 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||
| import mindspore.nn as nn | |||
| from mindspore import context, Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.probability.dpn import ConditionalVAE | |||
| from mindspore.nn.probability.infer import ELBO, SVI | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU") | |||
| IMAGE_SHAPE = (-1, 1, 32, 32) | |||
| image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train") | |||
| class Encoder(nn.Cell): | |||
| def __init__(self, num_classes): | |||
| super(Encoder, self).__init__() | |||
| self.fc1 = nn.Dense(1024 + num_classes, 400) | |||
| self.relu = nn.ReLU() | |||
| self.flatten = nn.Flatten() | |||
| self.concat = P.Concat(axis=1) | |||
| self.one_hot = nn.OneHot(depth=num_classes) | |||
| def construct(self, x, y): | |||
| x = self.flatten(x) | |||
| y = self.one_hot(y) | |||
| input_x = self.concat((x, y)) | |||
| input_x = self.fc1(input_x) | |||
| input_x = self.relu(input_x) | |||
| return input_x | |||
| class Decoder(nn.Cell): | |||
| def __init__(self): | |||
| super(Decoder, self).__init__() | |||
| self.fc2 = nn.Dense(400, 1024) | |||
| self.sigmoid = nn.Sigmoid() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, z): | |||
| z = self.fc2(z) | |||
| z = self.reshape(z, IMAGE_SHAPE) | |||
| z = self.sigmoid(z) | |||
| return z | |||
| class WithLossCell(nn.Cell): | |||
| def __init__(self, backbone, loss_fn): | |||
| super(WithLossCell, self).__init__(auto_prefix=False) | |||
| self._backbone = backbone | |||
| self._loss_fn = loss_fn | |||
| def construct(self, data, label): | |||
| out = self._backbone(data, label) | |||
| return self._loss_fn(out, label) | |||
| def create_dataset(data_path, batch_size=32, repeat_size=1, | |||
| num_parallel_workers=1): | |||
| """ | |||
| create dataset for train or test | |||
| """ | |||
| # define dataset | |||
| mnist_ds = ds.MnistDataset(data_path) | |||
| resize_height, resize_width = 32, 32 | |||
| rescale = 1.0 / 255.0 | |||
| shift = 0.0 | |||
| # define map operations | |||
| resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode | |||
| rescale_op = CV.Rescale(rescale, shift) | |||
| hwc2chw_op = CV.HWC2CHW() | |||
| # apply map operations on images | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) | |||
| # apply DatasetOps | |||
| mnist_ds = mnist_ds.batch(batch_size) | |||
| mnist_ds = mnist_ds.repeat(repeat_size) | |||
| return mnist_ds | |||
| if __name__ == "__main__": | |||
| # define the encoder and decoder | |||
| encoder = Encoder(num_classes=10) | |||
| decoder = Decoder() | |||
| # define the cvae model | |||
| cvae = ConditionalVAE(encoder, decoder, hidden_size=400, latent_size=20, num_classes=10) | |||
| # define the loss function | |||
| net_loss = ELBO(latent_prior='Normal', output_dis='Normal') | |||
| # define the optimizer | |||
| optimizer = nn.Adam(params=cvae.trainable_params(), learning_rate=0.001) | |||
| # define the training dataset | |||
| ds_train = create_dataset(image_path, 128, 1) | |||
| # define the WithLossCell modified | |||
| net_with_loss = WithLossCell(cvae, net_loss) | |||
| # define the variational inference | |||
| vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer) | |||
| # run the vi to return the trained network. | |||
| cvae = vi.run(train_dataset=ds_train, epochs=10) | |||
| # get the trained loss | |||
| trained_loss = vi.get_train_loss() | |||
| # test function: generate_sample | |||
| sample_label = Tensor([i for i in range(0, 8)] * 8, dtype=mstype.int32) | |||
| generated_sample = cvae.generate_sample(sample_label, 64, IMAGE_SHAPE) | |||
| # test function: reconstruct_sample | |||
| for sample in ds_train.create_dict_iterator(): | |||
| sample_x = Tensor(sample['image'], dtype=mstype.float32) | |||
| sample_y = Tensor(sample['label'], dtype=mstype.int32) | |||
| reconstructed_sample = cvae.reconstruct_sample(sample_x, sample_y) | |||
| @@ -0,0 +1,115 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||
| import mindspore.nn as nn | |||
| from mindspore import context, Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.probability.dpn import VAE | |||
| from mindspore.nn.probability.infer import ELBO, SVI | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU") | |||
| IMAGE_SHAPE = (-1, 1, 32, 32) | |||
| image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train") | |||
| class Encoder(nn.Cell): | |||
| def __init__(self): | |||
| super(Encoder, self).__init__() | |||
| self.fc1 = nn.Dense(1024, 800) | |||
| self.fc2 = nn.Dense(800, 400) | |||
| self.relu = nn.ReLU() | |||
| self.flatten = nn.Flatten() | |||
| def construct(self, x): | |||
| x = self.flatten(x) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.relu(x) | |||
| return x | |||
| class Decoder(nn.Cell): | |||
| def __init__(self): | |||
| super(Decoder, self).__init__() | |||
| self.fc1 = nn.Dense(400, 1024) | |||
| self.sigmoid = nn.Sigmoid() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, z): | |||
| z = self.fc1(z) | |||
| z = self.reshape(z, IMAGE_SHAPE) | |||
| z = self.sigmoid(z) | |||
| return z | |||
| def create_dataset(data_path, batch_size=32, repeat_size=1, | |||
| num_parallel_workers=1): | |||
| """ | |||
| create dataset for train or test | |||
| """ | |||
| # define dataset | |||
| mnist_ds = ds.MnistDataset(data_path) | |||
| resize_height, resize_width = 32, 32 | |||
| rescale = 1.0 / 255.0 | |||
| shift = 0.0 | |||
| # define map operations | |||
| resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode | |||
| rescale_op = CV.Rescale(rescale, shift) | |||
| hwc2chw_op = CV.HWC2CHW() | |||
| # apply map operations on images | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) | |||
| # apply DatasetOps | |||
| mnist_ds = mnist_ds.batch(batch_size) | |||
| mnist_ds = mnist_ds.repeat(repeat_size) | |||
| return mnist_ds | |||
| if __name__ == "__main__": | |||
| # define the encoder and decoder | |||
| encoder = Encoder() | |||
| decoder = Decoder() | |||
| # define the vae model | |||
| vae = VAE(encoder, decoder, hidden_size=400, latent_size=20) | |||
| # define the loss function | |||
| net_loss = ELBO(latent_prior='Normal', output_dis='Normal') | |||
| # define the optimizer | |||
| optimizer = nn.Adam(params=vae.trainable_params(), learning_rate=0.001) | |||
| # define the training dataset | |||
| ds_train = create_dataset(image_path, 128, 1) | |||
| net_with_loss = nn.WithLossCell(vae, net_loss) | |||
| # define the variational inference | |||
| vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer) | |||
| # run the vi to return the trained network. | |||
| vae = vi.run(train_dataset=ds_train, epochs=10) | |||
| # get the trained loss | |||
| trained_loss = vi.get_train_loss() | |||
| # test function: generate_sample | |||
| generated_sample = vae.generate_sample(64, IMAGE_SHAPE) | |||
| # test function: reconstruct_sample | |||
| for sample in ds_train.create_dict_iterator(): | |||
| sample_x = Tensor(sample['image'], dtype=mstype.float32) | |||
| reconstructed_sample = vae.reconstruct_sample(sample_x) | |||
| @@ -0,0 +1,164 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| The VAE interface can be called to construct VAE-GAN network. | |||
| """ | |||
| import os | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.probability.dpn import VAE | |||
| from mindspore.nn.probability.infer import ELBO, SVI | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU") | |||
| IMAGE_SHAPE = (-1, 1, 32, 32) | |||
| image_path = os.path.join('/home/workspace/mindspore_dataset/mnist', "train") | |||
| class Encoder(nn.Cell): | |||
| def __init__(self): | |||
| super(Encoder, self).__init__() | |||
| self.fc1 = nn.Dense(1024, 400) | |||
| self.relu = nn.ReLU() | |||
| self.flatten = nn.Flatten() | |||
| def construct(self, x): | |||
| x = self.flatten(x) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| return x | |||
| class Decoder(nn.Cell): | |||
| def __init__(self): | |||
| super(Decoder, self).__init__() | |||
| self.fc1 = nn.Dense(400, 1024) | |||
| self.relu = nn.ReLU() | |||
| self.sigmoid = nn.Sigmoid() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, z): | |||
| z = self.fc1(z) | |||
| z = self.reshape(z, IMAGE_SHAPE) | |||
| z = self.sigmoid(z) | |||
| return z | |||
| class Discriminator(nn.Cell): | |||
| """ | |||
| The Discriminator of the GAN network. | |||
| """ | |||
| def __init__(self): | |||
| super(Discriminator, self).__init__() | |||
| self.fc1 = nn.Dense(1024, 400) | |||
| self.fc2 = nn.Dense(400, 720) | |||
| self.fc3 = nn.Dense(720, 1024) | |||
| self.relu = nn.ReLU() | |||
| self.sigmoid = nn.Sigmoid() | |||
| self.flatten = nn.Flatten() | |||
| def construct(self, x): | |||
| x = self.flatten(x) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.relu(x) | |||
| x = self.fc3(x) | |||
| x = self.sigmoid(x) | |||
| return x | |||
| class VaeGan(nn.Cell): | |||
| def __init__(self): | |||
| super(VaeGan, self).__init__() | |||
| self.E = Encoder() | |||
| self.G = Decoder() | |||
| self.D = Discriminator() | |||
| self.dense = nn.Dense(20, 400) | |||
| self.vae = VAE(self.E, self.G, 400, 20) | |||
| self.shape = P.Shape() | |||
| self.to_tensor = P.ScalarToArray() | |||
| def construct(self, x): | |||
| recon_x, x, mu, std, z, prior = self.vae(x) | |||
| z_p = prior('sample', self.shape(mu), self.to_tensor(0.0), self.to_tensor(1.0)) | |||
| z_p = self.dense(z_p) | |||
| x_p = self.G(z_p) | |||
| ld_real = self.D(x) | |||
| ld_fake = self.D(recon_x) | |||
| ld_p = self.D(x_p) | |||
| return ld_real, ld_fake, ld_p, recon_x, x, mu, std, z, prior | |||
| class VaeGanLoss(nn.Cell): | |||
| def __init__(self): | |||
| super(VaeGanLoss, self).__init__() | |||
| self.zeros = P.ZerosLike() | |||
| self.mse = nn.MSELoss(reduction='sum') | |||
| self.elbo = ELBO(latent_prior='Normal', output_dis='Normal') | |||
| def construct(self, data, label): | |||
| ld_real, ld_fake, ld_p, recon_x, x, mean, std, z, prior = data | |||
| y_real = self.zeros(ld_real) + 1 | |||
| y_fake = self.zeros(ld_fake) | |||
| elbo_data = (recon_x, x, mean, std, z, prior) | |||
| loss_D = self.mse(ld_real, y_real) | |||
| loss_GD = self.mse(ld_p, y_fake) | |||
| loss_G = self.mse(ld_fake, y_real) | |||
| elbo_loss = self.elbo(elbo_data, label) | |||
| return loss_D + loss_G + loss_GD + elbo_loss | |||
| def create_dataset(data_path, batch_size=32, repeat_size=1, | |||
| num_parallel_workers=1): | |||
| """ | |||
| create dataset for train or test | |||
| """ | |||
| # define dataset | |||
| mnist_ds = ds.MnistDataset(data_path) | |||
| resize_height, resize_width = 32, 32 | |||
| rescale = 1.0 / 255.0 | |||
| shift = 0.0 | |||
| # define map operations | |||
| resize_op = CV.Resize((resize_height, resize_width)) # Bilinear mode | |||
| rescale_op = CV.Rescale(rescale, shift) | |||
| hwc2chw_op = CV.HWC2CHW() | |||
| # apply map operations on images | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) | |||
| # apply DatasetOps | |||
| mnist_ds = mnist_ds.batch(batch_size) | |||
| mnist_ds = mnist_ds.repeat(repeat_size) | |||
| return mnist_ds | |||
| if __name__ == "__main__": | |||
| vae_gan = VaeGan() | |||
| net_loss = VaeGanLoss() | |||
| optimizer = nn.Adam(params=vae_gan.trainable_params(), learning_rate=0.001) | |||
| ds_train = create_dataset(image_path, 128, 1) | |||
| net_with_loss = nn.WithLossCell(vae_gan, net_loss) | |||
| vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer) | |||
| vae_gan = vi.run(train_dataset=ds_train, epochs=10) | |||
| @@ -0,0 +1,57 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test VAE interface """ | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn.probability.dpn import VAE | |||
| class Encoder(nn.Cell): | |||
| def __init__(self): | |||
| super(Encoder, self).__init__() | |||
| self.fc1 = nn.Dense(6, 3) | |||
| self.relu = nn.ReLU() | |||
| def construct(self, x): | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| return x | |||
| class Decoder(nn.Cell): | |||
| def __init__(self): | |||
| super(Decoder, self).__init__() | |||
| self.fc1 = nn.Dense(3, 6) | |||
| self.sigmoid = nn.Sigmoid() | |||
| def construct(self, z): | |||
| z = self.fc1(z) | |||
| z = self.sigmoid(z) | |||
| return z | |||
| def test_vae(): | |||
| """ | |||
| Test the vae interface with the DNN model. | |||
| """ | |||
| encoder = Encoder() | |||
| decoder = Decoder() | |||
| net = VAE(encoder, decoder, hidden_size=3, latent_size=2) | |||
| input_data = Tensor(np.random.rand(32, 6), dtype=mstype.float32) | |||
| _executor.compile(net, input_data) | |||