Merge pull request !6551 from mcgrady00h/mindspore-zhusuantags/v1.1.0
| @@ -0,0 +1,18 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ Zhusuan package: a probalistic programming library """ | |||||
| from .framework import * | |||||
| from .variational import * | |||||
| @@ -0,0 +1,18 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ Core functionality for Zhusuan """ | |||||
| from .bn import * | |||||
| @@ -0,0 +1,92 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ Bayesian Network """ | |||||
| import mindspore.nn as nn | |||||
| import mindspore.nn.probability.distribution as msd | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| class BayesianNet(nn.Cell): | |||||
| """ | |||||
| We currently support 3 types of variables: x = observation, z = latent, y = condition. | |||||
| A Bayeisian Network models a generative process for certain varaiables: p(x,z|y) or p(z|x,y) or p(x|z,y) | |||||
| """ | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.normal_dist = msd.Normal(dtype=mstype.float32) | |||||
| self.bernoulli_dist = msd.Bernoulli(dtype=mstype.float32) | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=True) | |||||
| def Normal(self, | |||||
| name, | |||||
| observation=None, | |||||
| mean=None, | |||||
| std=None, | |||||
| seed=0, | |||||
| dtype=mstype.float32, | |||||
| shape=(), | |||||
| reparameterize=True): | |||||
| """ Normal distribution wrapper """ | |||||
| assert not name is None | |||||
| assert not seed is None | |||||
| assert not dtype is None | |||||
| if observation is None: | |||||
| if reparameterize: | |||||
| epsilon = self.normal_dist('sample', shape, self.zeros(mean.shape), self.ones(std.shape)) | |||||
| sample = mean + std * epsilon | |||||
| else: | |||||
| sample = self.normal_dist('sample', shape, mean, std) | |||||
| else: | |||||
| sample = observation | |||||
| log_prob = self.reduce_sum(self.normal_dist('log_prob', sample, mean, std), 1) | |||||
| return sample, log_prob | |||||
| def Bernoulli(self, | |||||
| name, | |||||
| observation=None, | |||||
| probs=None, | |||||
| seed=0, | |||||
| dtype=mstype.float32, | |||||
| shape=()): | |||||
| """ Bernoulli distribution wrapper """ | |||||
| assert not name is None | |||||
| assert not seed is None | |||||
| assert not dtype is None | |||||
| if observation is None: | |||||
| sample = self.bernoulli_dist('sample', shape, probs) | |||||
| else: | |||||
| sample = observation | |||||
| log_prob = self.reduce_sum(self.bernoulli_dist('log_prob', sample, probs), 1) | |||||
| return sample, log_prob | |||||
| def construct(self, *inputs, **kwargs): | |||||
| """ | |||||
| We currently fix the parameters of the construct function. | |||||
| Args: | |||||
| the inputs must consist of 3 variables in order. | |||||
| x: data sample, observation | |||||
| z: latent variable | |||||
| y: conditional information | |||||
| """ | |||||
| raise NotImplementedError | |||||
| @@ -0,0 +1,18 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ Variational inference related codes """ | |||||
| from .elbo import * | |||||
| @@ -0,0 +1,43 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ ELBO """ | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| class ELBO(nn.Cell): | |||||
| """ ELBO class """ | |||||
| def __init__(self, generator, variational): | |||||
| super().__init__() | |||||
| self.generator = generator | |||||
| self.variational = variational | |||||
| self.reshape_op = P.Reshape() | |||||
| self.reduce_mean = P.ReduceMean(keep_dims=False) | |||||
| self.square = P.Square() | |||||
| def construct(self, *inputs, **kwargs): | |||||
| if len(inputs) >= 2: | |||||
| x, y = inputs[0], inputs[1] | |||||
| else: | |||||
| x = inputs[0] | |||||
| y = None | |||||
| z, log_prob_z = self.variational(x, None, y) | |||||
| _, log_prob_x_, _, log_prob_z_ = self.generator(x, z, y) | |||||
| elbo = self.reduce_mean(log_prob_x_) + self.reduce_mean(log_prob_z_) - self.reduce_mean(log_prob_z) | |||||
| return -elbo | |||||
| @@ -0,0 +1,16 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ Zhusuan examples """ | |||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ VAE examples """ | |||||
| @@ -0,0 +1,87 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ Utils """ | |||||
| from PIL import Image | |||||
| import numpy as np | |||||
| from mindspore.common import dtype as mstype | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||||
| from mindspore.dataset.transforms.vision import Inter | |||||
| def create_dataset(data_path, batch_size=32, repeat_size=1, | |||||
| num_parallel_workers=1): | |||||
| """ create dataset for train or test | |||||
| Args: | |||||
| data_path: Data path | |||||
| batch_size: The number of data records in each group | |||||
| repeat_size: The number of replicated data records | |||||
| num_parallel_workers: The number of parallel workers | |||||
| """ | |||||
| # define dataset | |||||
| mnist_ds = ds.MnistDataset(data_path) | |||||
| #mnist_ds = ds.MnistDataset(data_path,num_samples=32) | |||||
| # define operation parameters | |||||
| resize_height, resize_width = 32, 32 | |||||
| rescale = 1.0 / 255.0 | |||||
| shift = 0.0 | |||||
| # define map operations | |||||
| resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # resize images to (32, 32) | |||||
| rescale_op = CV.Rescale(rescale, shift) # rescale images | |||||
| hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network. | |||||
| type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network | |||||
| # apply map operations on images | |||||
| mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) | |||||
| mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) | |||||
| # apply DatasetOps | |||||
| buffer_size = 10000 | |||||
| mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script | |||||
| mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) | |||||
| mnist_ds = mnist_ds.repeat(repeat_size) | |||||
| return mnist_ds | |||||
| def save_img(data, name, size=32, num=32): | |||||
| """ | |||||
| Visualize data and save to target files | |||||
| Args: | |||||
| data: nparray of size (num, size, size) | |||||
| name: ouput file name | |||||
| size: image size | |||||
| num: number of images | |||||
| """ | |||||
| col = int(num / 8) | |||||
| row = 8 | |||||
| imgs = Image.new('L', (size*col, size*row)) | |||||
| for i in range(num): | |||||
| j = i/8 | |||||
| img_data = data[i] | |||||
| img_data = np.resize(img_data, (size, size)) | |||||
| img_data = img_data * 255 | |||||
| img_data = img_data.astype(np.uint8) | |||||
| im = Image.fromarray(img_data, 'L') | |||||
| imgs.paste(im, (int(j) * size, (i % 8) * size)) | |||||
| imgs.save(name) | |||||
| @@ -0,0 +1,165 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ VAE """ | |||||
| import os | |||||
| import numpy as np | |||||
| from utils import create_dataset, save_img | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore import Tensor | |||||
| from mindspore.train import Model | |||||
| from mindspore.train.callback import LossMonitor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common import dtype as mstype | |||||
| import zhusuan as zs | |||||
| class ReduceMeanLoss(nn.L1Loss): | |||||
| def construct(self, base, target): | |||||
| # return self.get_loss(x) | |||||
| return base | |||||
| class Generator(zs.BayesianNet): | |||||
| """ Generator """ | |||||
| def __init__(self, x_dim, z_dim, batch_size): | |||||
| super().__init__() | |||||
| self.x_dim = x_dim | |||||
| self.z_dim = z_dim | |||||
| self.batch_size = batch_size | |||||
| self.fc1 = nn.Dense(z_dim, 500) | |||||
| self.act1 = nn.ReLU() | |||||
| self.fc2 = nn.Dense(500, 500) | |||||
| self.act2 = nn.ReLU() | |||||
| self.fc3 = nn.Dense(500, x_dim) | |||||
| self.fill = P.Fill() | |||||
| self.sigmoid = P.Sigmoid() | |||||
| self.reshape_op = P.Reshape() | |||||
| def ones(self, shape): | |||||
| return self.fill(mstype.float32, shape, 1.) | |||||
| def zeros(self, shape): | |||||
| return self.fill(mstype.float32, shape, 0.) | |||||
| def construct(self, x, z, y): | |||||
| """ construct """ | |||||
| assert y is None ## we have no conditional information | |||||
| if not x is None: | |||||
| x = self.reshape_op(x, (32, 32*32)) | |||||
| z_mean = self.zeros((self.batch_size, self.z_dim)) | |||||
| z_std = self.ones((self.batch_size, self.z_dim)) | |||||
| z, log_prob_z = self.Normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=False) | |||||
| x_mean = self.sigmoid(self.fc3(self.act2(self.fc2(self.act1(self.fc1(z)))))) | |||||
| if x is None: | |||||
| #x = self.bernoulli_dist('sample', (), x_mean) | |||||
| x = x_mean | |||||
| x, log_prob_x = self.Bernoulli('data', observation=x, shape=(), probs=x_mean) | |||||
| return x, log_prob_x, z, log_prob_z | |||||
| class Variational(zs.BayesianNet): | |||||
| """ Variational """ | |||||
| def __init__(self, x_dim, z_dim, batch_size): | |||||
| super().__init__() | |||||
| self.x_dim = x_dim | |||||
| self.z_dim = z_dim | |||||
| self.batch_size = batch_size | |||||
| self.reshape_op = P.Reshape() | |||||
| self.fc1 = nn.Dense(x_dim, 500) | |||||
| self.act1 = nn.ReLU() | |||||
| self.fc2 = nn.Dense(500, 500) | |||||
| self.act2 = nn.ReLU() | |||||
| self.fc3 = nn.Dense(500, z_dim) | |||||
| self.fc4 = nn.Dense(500, z_dim) | |||||
| self.fill = P.Fill() | |||||
| self.exp = P.Exp() | |||||
| def ones(self, shape): | |||||
| return self.fill(mstype.float32, shape, 1.) | |||||
| def zeros(self, shape): | |||||
| return self.fill(mstype.float32, shape, 0.) | |||||
| def construct(self, x, z, y): | |||||
| """ construct """ | |||||
| assert y is None ## we have no conditional information | |||||
| x = self.reshape_op(x, (32, 32*32)) | |||||
| z_logit = self.act2(self.fc2(self.act1(self.fc1(x)))) | |||||
| z_mean = self.fc3(z_logit) | |||||
| z_std = self.exp(self.fc4(z_logit)) | |||||
| #z, log_prob_z = self.reparameterization(z_mean, z_std) | |||||
| z, log_prob_z = self.Normal('latent', observation=z, mean=z_mean, std=z_std, shape=(), reparameterize=True) | |||||
| return z, log_prob_z | |||||
| def main(): | |||||
| # We currently support pynative mode with device GPU | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||||
| epoch_size = 1 | |||||
| batch_size = 32 | |||||
| mnist_path = "/data/chengzi/zhusuan-mindspore/data/MNIST" | |||||
| repeat_size = 1 | |||||
| # Define model parameters | |||||
| z_dim = 40 | |||||
| x_dim = 32*32 | |||||
| # create the network | |||||
| generator = Generator(x_dim, z_dim, batch_size) | |||||
| variational = Variational(x_dim, z_dim, batch_size) | |||||
| network = zs.variational.ELBO(generator, variational) | |||||
| # define loss | |||||
| # learning rate setting | |||||
| lr = 0.001 | |||||
| net_loss = ReduceMeanLoss() | |||||
| # define the optimizer | |||||
| print(network.trainable_params()[0]) | |||||
| net_opt = nn.Adam(network.trainable_params(), lr) | |||||
| model = Model(network, net_loss, net_opt) | |||||
| ds_train = create_dataset(os.path.join(mnist_path, "train"), batch_size, repeat_size) | |||||
| model.train(epoch_size, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False) | |||||
| print(network.trainable_params()[0]) | |||||
| iterator = ds_train.create_tuple_iterator() | |||||
| for item in iterator: | |||||
| batch_x = item[0].reshape(32, 32*32) | |||||
| break | |||||
| z, _ = network.variational(Tensor(batch_x), None, None) | |||||
| sample, _, _, _ = network.generator(None, z, None) | |||||
| sample = sample.asnumpy() | |||||
| save_img(batch_x, 'result/origin_x.png') | |||||
| save_img(sample, 'result/reconstruct_x.png') | |||||
| for i in range(4): | |||||
| sample, _, _, _ = network.generator(None, None, None) | |||||
| sample = sample.asnumpy() | |||||
| samples = sample if i == 0 else np.concatenate([samples, sample], axis=0) | |||||
| save_img(samples, 'result/sample_x.png', num=4*batch_size) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||