From a28ed9743a5d80de582dd3be4778dd5685269270 Mon Sep 17 00:00:00 2001 From: zhangxinfeng3 Date: Thu, 13 Aug 2020 20:18:10 +0800 Subject: [PATCH] update variational inference and toolbox --- mindspore/nn/probability/dpn/vae/cvae.py | 10 +- mindspore/nn/probability/dpn/vae/vae.py | 10 +- .../nn/probability/infer/variational/elbo.py | 8 +- .../toolbox/uncertainty_evaluation.py | 104 +++++++++++------- tests/st/probability/test_gpu_svi_cvae.py | 2 +- tests/st/probability/test_gpu_svi_vae.py | 2 +- tests/st/probability/test_uncertainty.py | 10 +- 7 files changed, 86 insertions(+), 60 deletions(-) diff --git a/mindspore/nn/probability/dpn/vae/cvae.py b/mindspore/nn/probability/dpn/vae/cvae.py index 87839c914f..81ed36c610 100644 --- a/mindspore/nn/probability/dpn/vae/cvae.py +++ b/mindspore/nn/probability/dpn/vae/cvae.py @@ -16,7 +16,6 @@ from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore._checkparam import check_int_positive -from ...distribution.normal import Normal from ....cell import Cell from ....layer.basic import Dense, OneHot @@ -46,7 +45,7 @@ class ConditionalVAE(Cell): - **input_y** (Tensor) - the tensor of the target data, the shape is math:`(N, 1)`. Outputs: - - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)). + - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). """ def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes): @@ -59,11 +58,10 @@ class ConditionalVAE(Cell): self.normal = C.normal self.exp = P.Exp() self.reshape = P.Reshape() + self.shape = P.Shape() self.concat = P.Concat(axis=1) self.to_tensor = P.ScalarToArray() - self.normal_dis = Normal() self.one_hot = OneHot(depth=num_classes) - self.standard_normal_dis = Normal([0] * self.latent_size, [1] * self.latent_size) self.dense1 = Dense(self.hidden_size, self.latent_size) self.dense2 = Dense(self.hidden_size, self.latent_size) self.dense3 = Dense(self.latent_size + self.num_classes, self.hidden_size) @@ -82,11 +80,11 @@ class ConditionalVAE(Cell): def construct(self, x, y): mu, log_var = self._encode(x, y) std = self.exp(0.5 * log_var) - z = self.normal_dis('sample', mean=mu, sd=std) + z = self.normal(self.shape(mu), mu, std, seed=0) y = self.one_hot(y) z_c = self.concat((z, y)) recon_x = self._decode(z_c) - return recon_x, x, mu, std, z, self.standard_normal_dis + return recon_x, x, mu, std def generate_sample(self, sample_y, generate_nums=None, shape=None): """ diff --git a/mindspore/nn/probability/dpn/vae/vae.py b/mindspore/nn/probability/dpn/vae/vae.py index 25ef01f1fb..9e47b5d14d 100644 --- a/mindspore/nn/probability/dpn/vae/vae.py +++ b/mindspore/nn/probability/dpn/vae/vae.py @@ -16,7 +16,6 @@ from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore._checkparam import check_int_positive -from ...distribution.normal import Normal from ....cell import Cell from ....layer.basic import Dense @@ -43,7 +42,7 @@ class VAE(Cell): - **input** (Tensor) - the same shape as the input of encoder. Outputs: - - **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)). + - **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). """ def __init__(self, encoder, decoder, hidden_size, latent_size): @@ -55,9 +54,8 @@ class VAE(Cell): self.normal = C.normal self.exp = P.Exp() self.reshape = P.Reshape() + self.shape = P.Shape() self.to_tensor = P.ScalarToArray() - self.normal_dis = Normal() - self.standard_normal_dis = Normal([0]*self.latent_size, [1]*self.latent_size) self.dense1 = Dense(self.hidden_size, self.latent_size) self.dense2 = Dense(self.hidden_size, self.latent_size) self.dense3 = Dense(self.latent_size, self.hidden_size) @@ -76,9 +74,9 @@ class VAE(Cell): def construct(self, x): mu, log_var = self._encode(x) std = self.exp(0.5 * log_var) - z = self.normal_dis('sample', mean=mu, sd=std) + z = self.normal(self.shape(mu), mu, std, seed=0) recon_x = self._decode(z) - return recon_x, x, mu, std, z, self.standard_normal_dis + return recon_x, x, mu, std def generate_sample(self, generate_nums, shape): """ diff --git a/mindspore/nn/probability/infer/variational/elbo.py b/mindspore/nn/probability/infer/variational/elbo.py index df2a7c369b..9eb573ddea 100644 --- a/mindspore/nn/probability/infer/variational/elbo.py +++ b/mindspore/nn/probability/infer/variational/elbo.py @@ -36,7 +36,7 @@ class ELBO(Cell): - Normal: If the distribution of output data is Normal, the reconstruct loss is MSELoss. Inputs: - - **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor), z(Tensor), prior(Cell)). + - **input_data** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). - **target_data** (Tensor) - the target tensor. Outputs: @@ -46,6 +46,7 @@ class ELBO(Cell): def __init__(self, latent_prior='Normal', output_prior='Normal'): super(ELBO, self).__init__() self.sum = P.ReduceSum() + self.zeros = P.ZerosLike() if latent_prior == 'Normal': self.posterior = Normal() else: @@ -56,9 +57,8 @@ class ELBO(Cell): raise ValueError('The values of output_dis now only support Normal') def construct(self, data, label): - recon_x, x, mu, std, z, prior = data + recon_x, x, mu, std = data reconstruct_loss = self.recon_loss(x, recon_x) - kl_loss = -(prior('log_prob', z) - self.posterior('log_prob', z, mu, std)) \ - * self.posterior('prob', z, mu, std) + kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu)+1, mu, std) elbo = reconstruct_loss + self.sum(kl_loss) return elbo diff --git a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py index 173f641a05..a61b19cd13 100644 --- a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py +++ b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py @@ -19,7 +19,7 @@ from mindspore._checkparam import check_int_positive from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.train import Model -from mindspore.train.callback import LossMonitor +from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import load_checkpoint, load_param_into_net from ...cell import Cell from ...layer.basic import Dense, Flatten, Dropout @@ -43,8 +43,13 @@ class UncertaintyEvaluation: num_classes (int): The number of labels of classification. If the task type is classification, it must be set; if not classification, it need not to be set. Default: None. - epochs (int): Total number of iterations on the data. Default: None. - uncertainty_model_path (str): The save or read path of the uncertainty model. + epochs (int): Total number of iterations on the data. Default: 1. + epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. + ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. + save_model (bool): Save the uncertainty model or not, if True, the epi_uncer_model_path + and ale_uncer_model_path should not be None. If False, give the path of + the uncertainty model, it will load the model to evaluate, if not given + the path, it will not save or load the uncertainty model. Examples: >>> network = LeNet() @@ -55,21 +60,26 @@ class UncertaintyEvaluation: >>> train_dataset=ds_train, >>> task_type='classification', >>> num_classes=10, - >>> epochs=5, - >>> uncertainty_model_path=None) - >>> epistemic_uncertainty = evaluation.eval_epistemic(eval_data) - >>> aleatoric_uncertainty = evaluation.eval_aleatoric(eval_data) + >>> epochs=1, + >>> epi_uncer_model_path=None, + >>> ale_uncer_model_path=None, + >>> save_model=False) + >>> epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data) + >>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data) """ - def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=None, - uncertainty_model_path=None): + def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1, + epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False): self.model = model self.train_dataset = train_dataset self.task_type = task_type self.num_classes = check_int_positive(num_classes) self.epochs = epochs - self.uncer_model_path = uncertainty_model_path - self.uncer_model = None + self.epi_uncer_model_path = epi_uncer_model_path + self.ale_uncer_model_path = ale_uncer_model_path + self.save_model = save_model + self.epi_uncer_model = None + self.ale_uncer_model = None self.concat = P.Concat(axis=0) self.sum = P.ReduceSum() self.pow = P.Pow() @@ -78,6 +88,10 @@ class UncertaintyEvaluation: if self.task_type == 'classification': if self.num_classes is None: raise ValueError("Classification task needs to input labels.") + if self.save_model: + if self.epi_uncer_model_path is None or self.ale_uncer_model_path is None: + raise ValueError("If save_model is True, the epi_uncer_model_path and " + "ale_uncer_model_path should not be None.") def _uncertainty_normalize(self, data): area = np.max(data) - np.min(data) @@ -87,31 +101,38 @@ class UncertaintyEvaluation: """ Get the model which can obtain the epistemic uncertainty. """ - if self.uncer_model and self.uncer_model_path is None: - self.uncer_model = EpistemicUncertaintyModel(self.model) - if self.uncer_model.drop_count == 0: + if self.epi_uncer_model is None: + self.epi_uncer_model = EpistemicUncertaintyModel(self.model) + if self.epi_uncer_model.drop_count == 0: if self.task_type == 'classification': net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") - net_opt = Adam(self.uncer_model.trainable_params()) - model = Model(self.uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + net_opt = Adam(self.epi_uncer_model.trainable_params()) + model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) else: net_loss = MSELoss() - net_opt = Adam(self.uncer_model.trainable_params()) - model = Model(self.uncer_model, net_loss, net_opt, metrics={"MSE": MSE()}) - model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()]) - elif self.uncer_model is None: - uncer_param_dict = load_checkpoint(self.uncer_model_path) - load_param_into_net(self.uncer_model, uncer_param_dict) + net_opt = Adam(self.epi_uncer_model.trainable_params()) + model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()}) + if self.save_model: + config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs) + ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model', + directory=self.epi_uncer_model_path, + config=config_ck) + model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()]) + elif self.epi_uncer_model_path is None: + model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()]) + else: + uncer_param_dict = load_checkpoint(self.epi_uncer_model_path) + load_param_into_net(self.epi_uncer_model, uncer_param_dict) def _eval_epistemic_uncertainty(self, eval_data, mc=10): """ Evaluate the epistemic uncertainty of classification and regression models using MC dropout. """ self._get_epistemic_uncertainty_model() - self.uncer_model.set_train(True) + self.epi_uncer_model.set_train(True) outputs = [None] * mc for i in range(mc): - pred = self.uncer_model(eval_data) + pred = self.epi_uncer_model(eval_data) outputs[i] = pred.asnumpy() if self.task_type == 'classification': outputs = np.stack(outputs, axis=2) @@ -126,30 +147,37 @@ class UncertaintyEvaluation: """ Get the model which can obtain the aleatoric uncertainty. """ - if self.uncer_model and self.uncer_model_path is None: - self.uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type) + if self.ale_uncer_model is None: + self.ale_uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type) net_loss = AleatoricLoss(self.task_type) - net_opt = Adam(self.uncer_model.trainable_params()) + net_opt = Adam(self.ale_uncer_model.trainable_params()) if self.task_type == 'classification': - model = Model(self.uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + else: + model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()}) + if self.save_model: + config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs) + ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model', + directory=self.ale_uncer_model_path, + config=config_ck) + model.train(self.epochs, self.train_dataset, callbacks=[ckpoint_cb, LossMonitor()]) + elif self.ale_uncer_model_path is None: + model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()]) else: - model = Model(self.uncer_model, net_loss, net_opt, metrics={"MSE": MSE()}) - model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()]) - elif self.uncer_model is None: - uncer_param_dict = load_checkpoint(self.uncer_model_path) - load_param_into_net(self.uncer_model, uncer_param_dict) + uncer_param_dict = load_checkpoint(self.ale_uncer_model_path) + load_param_into_net(self.ale_uncer_model, uncer_param_dict) def _eval_aleatoric_uncertainty(self, eval_data): """ Evaluate the aleatoric uncertainty of classification and regression models. """ self._get_aleatoric_uncertainty_model() - _, var = self.uncer_model(eval_data) + _, var = self.ale_uncer_model(eval_data) ale_uncertainty = self.sum(self.pow(var, 2), 1) ale_uncertainty = self._uncertainty_normalize(ale_uncertainty.asnumpy()) return ale_uncertainty - def eval_epistemic(self, eval_data): + def eval_epistemic_uncertainty(self, eval_data): """ Evaluate the epistemic uncertainty of inference results, which also called model uncertainty. @@ -159,10 +187,10 @@ class UncertaintyEvaluation: Returns: numpy.dtype, the epistemic uncertainty of inference results of data samples. """ - uncertainty = self._eval_aleatoric_uncertainty(eval_data) + uncertainty = self._eval_epistemic_uncertainty(eval_data) return uncertainty - def eval_aleatoric(self, eval_data): + def eval_aleatoric_uncertainty(self, eval_data): """ Evaluate the aleatoric uncertainty of inference results, which also called data uncertainty. @@ -172,7 +200,7 @@ class UncertaintyEvaluation: Returns: numpy.dtype, the aleatoric uncertainty of inference results of data samples. """ - uncertainty = self._eval_epistemic_uncertainty(eval_data) + uncertainty = self._eval_aleatoric_uncertainty(eval_data) return uncertainty diff --git a/tests/st/probability/test_gpu_svi_cvae.py b/tests/st/probability/test_gpu_svi_cvae.py index 08c645179e..44f6c040fa 100644 --- a/tests/st/probability/test_gpu_svi_cvae.py +++ b/tests/st/probability/test_gpu_svi_cvae.py @@ -107,7 +107,7 @@ if __name__ == "__main__": # define the cvae model cvae = ConditionalVAE(encoder, decoder, hidden_size=400, latent_size=20, num_classes=10) # define the loss function - net_loss = ELBO(latent_prior='Normal', output_dis='Normal') + net_loss = ELBO(latent_prior='Normal', output_prior='Normal') # define the optimizer optimizer = nn.Adam(params=cvae.trainable_params(), learning_rate=0.001) # define the training dataset diff --git a/tests/st/probability/test_gpu_svi_vae.py b/tests/st/probability/test_gpu_svi_vae.py index c18a290e51..a175a4ae4c 100644 --- a/tests/st/probability/test_gpu_svi_vae.py +++ b/tests/st/probability/test_gpu_svi_vae.py @@ -95,7 +95,7 @@ if __name__ == "__main__": # define the vae model vae = VAE(encoder, decoder, hidden_size=400, latent_size=20) # define the loss function - net_loss = ELBO(latent_prior='Normal', output_dis='Normal') + net_loss = ELBO(latent_prior='Normal', output_prior='Normal') # define the optimizer optimizer = nn.Adam(params=vae.trainable_params(), learning_rate=0.001) # define the training dataset diff --git a/tests/st/probability/test_uncertainty.py b/tests/st/probability/test_uncertainty.py index 1037669f92..92850141eb 100644 --- a/tests/st/probability/test_uncertainty.py +++ b/tests/st/probability/test_uncertainty.py @@ -125,9 +125,11 @@ if __name__ == '__main__': train_dataset=ds_train, task_type='classification', num_classes=10, - epochs=5, - uncertainty_model_path=None) + epochs=1, + epi_uncer_model_path=None, + ale_uncer_model_path=None, + save_model=False) for eval_data in ds_eval.create_dict_iterator(): eval_data = Tensor(eval_data['image'], mstype.float32) - epistemic_uncertainty = evaluation.eval_epistemic(eval_data) - aleatoric_uncertainty = evaluation.eval_aleatoric(eval_data) + epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data) + aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)