You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

uncertainty_evaluation.py 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Toolbox for Uncertainty Evaluation."""
  16. from copy import deepcopy
  17. import numpy as np
  18. from mindspore._checkparam import Validator
  19. from mindspore.ops import composite as C
  20. from mindspore.ops import operations as P
  21. from mindspore.train import Model
  22. from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
  23. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  24. from ...cell import Cell
  25. from ...layer.basic import Dense, Flatten, Dropout
  26. from ...layer.container import SequentialCell
  27. from ...layer.conv import Conv2d
  28. from ...loss import SoftmaxCrossEntropyWithLogits, MSELoss
  29. from ...metrics import Accuracy, MSE
  30. from ...optim import Adam
  31. class UncertaintyEvaluation:
  32. r"""
  33. Toolbox for Uncertainty Evaluation.
  34. Args:
  35. model (Cell): The model for uncertainty evaluation.
  36. train_dataset (Dataset): A dataset iterator to train model.
  37. task_type (str): Option for the task types of model
  38. - regression: A regression model.
  39. - classification: A classification model.
  40. num_classes (int): The number of labels of classification.
  41. If the task type is classification, it must be set; otherwise, it is not needed.
  42. Default: None.
  43. epochs (int): Total number of iterations on the data. Default: 1.
  44. epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. Default: None.
  45. ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. Default: None.
  46. save_model (bool): Whether to save the uncertainty model or not, if true, the epi_uncer_model_path
  47. and ale_uncer_model_path must not be None. If false, the model to evaluate will be loaded from
  48. the the path of the uncertainty model; if the path is not given , it will not save or load the
  49. uncertainty model. Default: False.
  50. Supported Platforms:
  51. ``Ascend`` ``GPU``
  52. Examples:
  53. >>> network = LeNet()
  54. >>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
  55. >>> load_param_into_net(network, param_dict)
  56. >>> ds_train = create_dataset('workspace/mnist/train')
  57. >>> evaluation = UncertaintyEvaluation(model=network,
  58. ... train_dataset=ds_train,
  59. ... task_type='classification',
  60. ... num_classes=10,
  61. ... epochs=1,
  62. ... epi_uncer_model_path=None,
  63. ... ale_uncer_model_path=None,
  64. ... save_model=False)
  65. >>> epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data)
  66. >>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
  67. >>> output = epistemic_uncertainty.shape
  68. >>> print(output)
  69. (32, 10)
  70. >>> output = aleatoric_uncertainty.shape
  71. >>> print(output)
  72. (32,)
  73. """
  74. def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1,
  75. epi_uncer_model_path=None, ale_uncer_model_path=None, save_model=False):
  76. self.epi_model = model
  77. self.ale_model = deepcopy(model)
  78. self.epi_train_dataset = train_dataset
  79. self.ale_train_dataset = train_dataset
  80. self.task_type = task_type
  81. self.epochs = Validator.check_positive_int(epochs)
  82. self.epi_uncer_model_path = epi_uncer_model_path
  83. self.ale_uncer_model_path = ale_uncer_model_path
  84. self.save_model = Validator.check_bool(save_model)
  85. self.epi_uncer_model = None
  86. self.ale_uncer_model = None
  87. self.concat = P.Concat(axis=0)
  88. self.sum = P.ReduceSum()
  89. self.pow = P.Pow()
  90. if not isinstance(model, Cell):
  91. raise TypeError('The model should be Cell type.')
  92. if task_type not in ('regression', 'classification'):
  93. raise ValueError('The task should be regression or classification.')
  94. if task_type == 'classification':
  95. self.num_classes = Validator.check_positive_int(num_classes)
  96. else:
  97. self.num_classes = num_classes
  98. if save_model:
  99. if epi_uncer_model_path is None or ale_uncer_model_path is None:
  100. raise ValueError("If save_model is True, the epi_uncer_model_path and "
  101. "ale_uncer_model_path should not be None.")
  102. def _get_epistemic_uncertainty_model(self):
  103. """
  104. Get the model which can obtain the epistemic uncertainty.
  105. """
  106. if self.epi_uncer_model is None:
  107. self.epi_uncer_model = EpistemicUncertaintyModel(self.epi_model)
  108. if self.epi_uncer_model.drop_count == 0 and self.epi_train_dataset is not None:
  109. if self.task_type == 'classification':
  110. net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  111. net_opt = Adam(self.epi_uncer_model.trainable_params())
  112. model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
  113. else:
  114. net_loss = MSELoss()
  115. net_opt = Adam(self.epi_uncer_model.trainable_params())
  116. model = Model(self.epi_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
  117. if self.save_model:
  118. config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs)
  119. ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model',
  120. directory=self.epi_uncer_model_path,
  121. config=config_ck)
  122. model.train(self.epochs, self.epi_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
  123. elif self.epi_uncer_model_path is None:
  124. model.train(self.epochs, self.epi_train_dataset, callbacks=[LossMonitor()])
  125. else:
  126. uncer_param_dict = load_checkpoint(self.epi_uncer_model_path)
  127. load_param_into_net(self.epi_uncer_model, uncer_param_dict)
  128. def _eval_epistemic_uncertainty(self, eval_data, mc=10):
  129. """
  130. Evaluate the epistemic uncertainty of classification and regression models using MC dropout.
  131. """
  132. self._get_epistemic_uncertainty_model()
  133. self.epi_uncer_model.set_train(True)
  134. outputs = [None] * mc
  135. for i in range(mc):
  136. pred = self.epi_uncer_model(eval_data)
  137. outputs[i] = pred.asnumpy()
  138. if self.task_type == 'classification':
  139. outputs = np.stack(outputs, axis=2)
  140. epi_uncertainty = outputs.var(axis=2)
  141. else:
  142. outputs = np.stack(outputs, axis=1)
  143. epi_uncertainty = outputs.var(axis=1)
  144. epi_uncertainty = np.array(epi_uncertainty)
  145. return epi_uncertainty
  146. def _get_aleatoric_uncertainty_model(self):
  147. """
  148. Get the model which can obtain the aleatoric uncertainty.
  149. """
  150. if self.ale_train_dataset is None:
  151. raise ValueError('The train dataset should not be None when evaluating aleatoric uncertainty.')
  152. if self.ale_uncer_model is None:
  153. self.ale_uncer_model = AleatoricUncertaintyModel(self.ale_model, self.num_classes, self.task_type)
  154. net_loss = AleatoricLoss(self.task_type)
  155. net_opt = Adam(self.ale_uncer_model.trainable_params())
  156. if self.task_type == 'classification':
  157. model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
  158. else:
  159. model = Model(self.ale_uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
  160. if self.save_model:
  161. config_ck = CheckpointConfig(keep_checkpoint_max=self.epochs)
  162. ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model',
  163. directory=self.ale_uncer_model_path,
  164. config=config_ck)
  165. model.train(self.epochs, self.ale_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
  166. elif self.ale_uncer_model_path is None:
  167. model.train(self.epochs, self.ale_train_dataset, callbacks=[LossMonitor()])
  168. else:
  169. uncer_param_dict = load_checkpoint(self.ale_uncer_model_path)
  170. load_param_into_net(self.ale_uncer_model, uncer_param_dict)
  171. def _eval_aleatoric_uncertainty(self, eval_data):
  172. """
  173. Evaluate the aleatoric uncertainty of classification and regression models.
  174. """
  175. self._get_aleatoric_uncertainty_model()
  176. _, var = self.ale_uncer_model(eval_data)
  177. ale_uncertainty = self.sum(self.pow(var, 2), 1)
  178. ale_uncertainty = ale_uncertainty.asnumpy()
  179. return ale_uncertainty
  180. def eval_epistemic_uncertainty(self, eval_data):
  181. """
  182. Evaluate the epistemic uncertainty of inference results, which also called model uncertainty.
  183. Args:
  184. eval_data (Tensor): The data samples to be evaluated, the shape must be (N,C,H,W).
  185. Returns:
  186. numpy.dtype, the epistemic uncertainty of inference results of data samples.
  187. """
  188. uncertainty = self._eval_epistemic_uncertainty(eval_data)
  189. return uncertainty
  190. def eval_aleatoric_uncertainty(self, eval_data):
  191. """
  192. Evaluate the aleatoric uncertainty of inference results, which also called data uncertainty.
  193. Args:
  194. eval_data (Tensor): The data samples to be evaluated, the shape must be (N,C,H,W).
  195. Returns:
  196. numpy.dtype, the aleatoric uncertainty of inference results of data samples.
  197. """
  198. uncertainty = self._eval_aleatoric_uncertainty(eval_data)
  199. return uncertainty
  200. class EpistemicUncertaintyModel(Cell):
  201. """
  202. Using dropout during training and eval time which is approximate bayesian inference. In this way,
  203. we can obtain the epistemic uncertainty (also called model uncertainty).
  204. If the original model has Dropout layer, just use dropout when eval time, if not, add dropout layer
  205. after Dense layer or Conv layer, then use dropout during train and eval time.
  206. See more details in `Dropout as a Bayesian Approximation: Representing Model uncertainty in Deep Learning
  207. <https://arxiv.org/abs/1506.02142>`_.
  208. """
  209. def __init__(self, epi_model):
  210. super(EpistemicUncertaintyModel, self).__init__()
  211. self.drop_count = 0
  212. self.epi_model = self._make_epistemic(epi_model)
  213. def construct(self, x):
  214. x = self.epi_model(x)
  215. return x
  216. def _make_epistemic(self, epi_model, dropout_rate=0.5):
  217. """
  218. The dropout rate is set to 0.5 by default.
  219. """
  220. for (name, layer) in epi_model.name_cells().items():
  221. if isinstance(layer, (Conv2d, Dense, Dropout)):
  222. if isinstance(layer, Dropout):
  223. self.drop_count += 1
  224. return epi_model
  225. uncertainty_layer = layer
  226. uncertainty_name = name
  227. drop = Dropout(keep_prob=dropout_rate)
  228. bnn_drop = SequentialCell([uncertainty_layer, drop])
  229. setattr(epi_model, uncertainty_name, bnn_drop)
  230. return epi_model
  231. self._make_epistemic(layer)
  232. raise ValueError("The model has not Dense Layer or Convolution Layer, "
  233. "it can not evaluate epistemic uncertainty so far.")
  234. class AleatoricUncertaintyModel(Cell):
  235. """
  236. The aleatoric uncertainty (also called data uncertainty) is caused by input data, to obtain this
  237. uncertainty, the loss function must be modified in order to add variance into loss.
  238. See more details in `What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?
  239. <https://arxiv.org/abs/1703.04977>`_.
  240. """
  241. def __init__(self, ale_model, num_classes, task):
  242. super(AleatoricUncertaintyModel, self).__init__()
  243. self.task = task
  244. if task == 'classification':
  245. self.ale_model = ale_model
  246. self.var_layer = Dense(num_classes, num_classes)
  247. else:
  248. self.ale_model, self.var_layer, self.pred_layer = self._make_aleatoric(ale_model)
  249. def construct(self, x):
  250. if self.task == 'classification':
  251. pred = self.ale_model(x)
  252. var = self.var_layer(pred)
  253. else:
  254. x = self.ale_model(x)
  255. pred = self.pred_layer(x)
  256. var = self.var_layer(x)
  257. return pred, var
  258. def _make_aleatoric(self, ale_model):
  259. """
  260. In order to add variance into original loss, add var Layer after the original network.
  261. """
  262. dense_layer = dense_name = None
  263. for (name, layer) in ale_model.name_cells().items():
  264. if isinstance(layer, Dense):
  265. dense_layer = layer
  266. dense_name = name
  267. if dense_layer is None:
  268. raise ValueError("The model has not Dense Layer, "
  269. "it can not evaluate aleatoric uncertainty so far.")
  270. setattr(ale_model, dense_name, Flatten())
  271. var_layer = Dense(dense_layer.in_channels, dense_layer.out_channels)
  272. return ale_model, var_layer, dense_layer
  273. class AleatoricLoss(Cell):
  274. """
  275. The loss function of aleatoric model, different modification methods are adopted for
  276. classification and regression.
  277. """
  278. def __init__(self, task):
  279. super(AleatoricLoss, self).__init__()
  280. self.task = task
  281. if self.task == 'classification':
  282. self.sum = P.ReduceSum()
  283. self.exp = P.Exp()
  284. self.normal = C.normal
  285. self.to_tensor = P.ScalarToArray()
  286. self.entropy = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
  287. else:
  288. self.mean = P.ReduceMean()
  289. self.exp = P.Exp()
  290. self.pow = P.Pow()
  291. def construct(self, data_pred, y):
  292. y_pred, var = data_pred
  293. if self.task == 'classification':
  294. sample_times = 10
  295. epsilon = self.normal((1, sample_times), self.to_tensor(0.0), self.to_tensor(1.0), 0)
  296. total_loss = 0
  297. for i in range(sample_times):
  298. y_pred_i = y_pred + epsilon[0][i] * var
  299. loss = self.entropy(y_pred_i, y)
  300. total_loss += loss
  301. avg_loss = total_loss / sample_times
  302. return avg_loss
  303. loss = self.mean(0.5 * self.exp(-var) * self.pow(y - y_pred, 2) + 0.5 * var)
  304. return loss