You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

simple_bridge.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. """
  2. This module contains a simple implementation of the Bridge part.
  3. Copyright (c) 2024 LAMDA. All rights reserved.
  4. """
  5. import os.path as osp
  6. from typing import Any, List, Optional, Tuple, Union
  7. from numpy import ndarray
  8. import wandb
  9. from ..data.evaluation import BaseMetric
  10. from ..data.structures import ListData
  11. from ..learning import ABLModel
  12. from ..reasoning import Reasoner
  13. from ..utils import print_log
  14. from .base_bridge import BaseBridge, M, R
  15. class SimpleBridge(BaseBridge[M, R]):
  16. """
  17. A basic implementation for bridging machine learning and reasoning parts.
  18. This class implements the typical pipeline of Abductive Learning, which involves
  19. the following five steps:
  20. - Predict class probabilities and indices for the given data examples.
  21. - Map indices into pseudo-labels.
  22. - Revise pseudo-labels based on abdutive reasoning.
  23. - Map the revised pseudo-labels to indices.
  24. - Train the model.
  25. Parameters
  26. ----------
  27. model : M
  28. The machine learning model wrapped in ``ABLModel``, which is mainly used for
  29. prediction and model training.
  30. reasoner : R
  31. The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision.
  32. metric_list : List[BaseMetric]
  33. A list of metrics used for evaluating the model's performance.
  34. """
  35. def __init__(
  36. self,
  37. model: M,
  38. reasoner: R,
  39. metric_list: List[BaseMetric],
  40. ) -> None:
  41. super().__init__(model, reasoner)
  42. self.metric_list = metric_list
  43. self.use_wandb = self._check_wandb_available()
  44. if not hasattr(model.base_model, "predict_proba") and reasoner.dist_func in [
  45. "confidence",
  46. "avg_confidence",
  47. ]:
  48. raise ValueError(
  49. "If the base model does not implement the predict_proba method, "
  50. + "then the dist_func in the reasoner cannot be set to 'confidence'"
  51. + "or 'avg_confidence', which are related to predicted probability."
  52. )
  53. def _check_wandb_available(self):
  54. """
  55. Check if wandb is available and initialized.
  56. Returns
  57. -------
  58. bool
  59. True if wandb is available and initialized, False otherwise.
  60. """
  61. try:
  62. return wandb.run is not None
  63. except ImportError:
  64. return False
  65. def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
  66. """
  67. Predict class indices and probabilities (if ``predict_proba`` is implemented in
  68. ``self.model.base_model``) on the given data examples.
  69. Parameters
  70. ----------
  71. data_examples : ListData
  72. Data examples on which predictions are to be made.
  73. Returns
  74. -------
  75. Tuple[List[ndarray], List[ndarray]]
  76. A tuple containing lists of predicted indices and probabilities.
  77. """
  78. self.model.predict(data_examples)
  79. return data_examples.pred_idx, data_examples.pred_prob
  80. def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
  81. """
  82. Revise predicted pseudo-labels of the given data examples using abduction.
  83. Parameters
  84. ----------
  85. data_examples : ListData
  86. Data examples containing predicted pseudo-labels.
  87. Returns
  88. -------
  89. List[List[Any]]
  90. A list of abduced pseudo-labels for the given data examples.
  91. """
  92. self.reasoner.batch_abduce(data_examples)
  93. return data_examples.abduced_pseudo_label
  94. def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
  95. """
  96. Map indices of data examples into pseudo-labels.
  97. Parameters
  98. ----------
  99. data_examples : ListData
  100. Data examples containing the indices.
  101. Returns
  102. -------
  103. List[List[Any]]
  104. A list of pseudo-labels converted from indices.
  105. """
  106. pred_idx = data_examples.pred_idx
  107. data_examples.pred_pseudo_label = [
  108. [self.reasoner.idx_to_label[_idx] for _idx in sub_list] for sub_list in pred_idx
  109. ]
  110. return data_examples.pred_pseudo_label
  111. def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]:
  112. """
  113. Map pseudo-labels of data examples into indices.
  114. Parameters
  115. ----------
  116. data_examples : ListData
  117. Data examples containing pseudo-labels.
  118. Returns
  119. -------
  120. List[List[Any]]
  121. A list of indices converted from pseudo-labels.
  122. """
  123. abduced_idx = [
  124. [self.reasoner.label_to_idx[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
  125. for sub_list in data_examples.abduced_pseudo_label
  126. ]
  127. data_examples.abduced_idx = abduced_idx
  128. return data_examples.abduced_idx
  129. def data_preprocess(
  130. self,
  131. prefix: str,
  132. data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
  133. ) -> ListData:
  134. """
  135. Transform data in the form of (X, gt_pseudo_label, Y) into ListData.
  136. Parameters
  137. ----------
  138. prefix : str
  139. A prefix indicating the type of data processing (e.g., 'train', 'test').
  140. data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
  141. Data to be preprocessed. Can be ListData or a tuple of lists.
  142. Returns
  143. -------
  144. ListData
  145. The preprocessed ListData object.
  146. """
  147. if isinstance(data, ListData):
  148. data_examples = data
  149. if not (
  150. hasattr(data_examples, "X")
  151. and hasattr(data_examples, "gt_pseudo_label")
  152. and hasattr(data_examples, "Y")
  153. ):
  154. raise ValueError(
  155. f"{prefix}data should have X, gt_pseudo_label and Y attribute but "
  156. f"only {data_examples.all_keys()} are provided."
  157. )
  158. else:
  159. X, gt_pseudo_label, Y = data
  160. data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y)
  161. return data_examples
  162. def concat_data_examples(
  163. self, unlabel_data_examples: ListData, label_data_examples: Optional[ListData]
  164. ) -> ListData:
  165. """
  166. Concatenate unlabeled and labeled data examples. ``abduced_pseudo_label`` of unlabeled data
  167. examples and ``gt_pseudo_label`` of labeled data examples will be used to train the model.
  168. Parameters
  169. ----------
  170. unlabel_data_examples : ListData
  171. Unlabeled data examples to concatenate.
  172. label_data_examples : ListData, optional
  173. Labeled data examples to concatenate, if available.
  174. Returns
  175. -------
  176. ListData
  177. Concatenated data examples.
  178. """
  179. if label_data_examples is None:
  180. return unlabel_data_examples
  181. unlabel_data_examples.X = unlabel_data_examples.X + label_data_examples.X
  182. unlabel_data_examples.abduced_pseudo_label = (
  183. unlabel_data_examples.abduced_pseudo_label + label_data_examples.gt_pseudo_label
  184. )
  185. unlabel_data_examples.Y = unlabel_data_examples.Y + label_data_examples.Y
  186. return unlabel_data_examples
  187. def train(
  188. self,
  189. train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
  190. label_data: Optional[Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]]] = None,
  191. val_data: Optional[
  192. Union[
  193. ListData,
  194. Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
  195. ]
  196. ] = None,
  197. loops: int = 50,
  198. segment_size: Union[int, float] = 1.0,
  199. eval_interval: int = 1,
  200. save_interval: Optional[int] = None,
  201. save_dir: Optional[str] = None,
  202. ):
  203. """
  204. A typical training pipeline of Abuductive Learning.
  205. Parameters
  206. ----------
  207. train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
  208. Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData``
  209. object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes.
  210. - ``X`` is a list of sublists representing the input data.
  211. - ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but
  212. not to train. ``gt_pseudo_label`` can be ``None``.
  213. - ``Y`` is a list representing the ground truth reasoning result for each sublist
  214. in ``X``.
  215. label_data : Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]], optional
  216. Labeled data should be in the same format as ``train_data``. The only difference is
  217. that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be
  218. utilized to train the model. Defaults to None.
  219. val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 pylint: disable=line-too-long
  220. Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label``
  221. and ``Y`` can be either None or not, which depends on the evaluation metircs in
  222. ``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate
  223. the model during training time. Defaults to None.
  224. loops : int
  225. Learning part and Reasoning part will be iteratively optimized
  226. for ``loops`` times. Defaults to 50.
  227. segment_size : Union[int, float]
  228. Data will be split into segments of this size and data in each segment
  229. will be used together to train the model. Defaults to 1.0.
  230. eval_interval : int
  231. The model will be evaluated every ``eval_interval`` loop during training,
  232. Defaults to 1.
  233. save_interval : int, optional
  234. The model will be saved every ``eval_interval`` loop during training.
  235. Defaults to None.
  236. save_dir : str, optional
  237. Directory to save the model. Defaults to None.
  238. """
  239. data_examples = self.data_preprocess("train", train_data)
  240. if label_data is not None:
  241. label_data_examples = self.data_preprocess("label", label_data)
  242. else:
  243. label_data_examples = None
  244. if val_data is not None:
  245. val_data_examples = self.data_preprocess("val", val_data)
  246. else:
  247. val_data_examples = data_examples
  248. if isinstance(segment_size, int):
  249. if segment_size <= 0:
  250. raise ValueError("segment_size should be positive.")
  251. elif isinstance(segment_size, float):
  252. if 0 < segment_size <= 1:
  253. segment_size = int(segment_size * len(data_examples))
  254. else:
  255. raise ValueError("segment_size should be in (0, 1].")
  256. else:
  257. raise ValueError("segment_size should be int or float.")
  258. for loop in range(loops):
  259. for seg_idx in range((len(data_examples) - 1) // segment_size + 1):
  260. print_log(
  261. f"loop(train) [{loop + 1}/{loops}] segment(train) "
  262. f"[{(seg_idx + 1)}/{(len(data_examples) - 1) // segment_size + 1}] ",
  263. logger="current",
  264. )
  265. sub_data_examples = data_examples[seg_idx * segment_size : (seg_idx + 1) * segment_size]
  266. self.predict(sub_data_examples)
  267. self.idx_to_pseudo_label(sub_data_examples)
  268. self.abduce_pseudo_label(sub_data_examples)
  269. self.filter_pseudo_label(sub_data_examples)
  270. self.concat_data_examples(sub_data_examples, label_data_examples)
  271. self.pseudo_label_to_idx(sub_data_examples)
  272. if len(sub_data_examples) == 0:
  273. continue
  274. self.model.train(sub_data_examples)
  275. if (loop + 1) % eval_interval == 0 or loop == loops - 1:
  276. print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current")
  277. self._valid(val_data_examples, prefix="val")
  278. if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
  279. print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
  280. self.model.save(save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth"))
  281. def _valid(self, data_examples: ListData, prefix: str = "val") -> None:
  282. """
  283. Internal method for validating the model with given data examples.
  284. Parameters
  285. ----------
  286. data_examples : ListData
  287. Data examples to be used for validation.
  288. """
  289. self.predict(data_examples)
  290. self.idx_to_pseudo_label(data_examples)
  291. for metric in self.metric_list:
  292. metric.prefix = prefix
  293. for metric in self.metric_list:
  294. metric.process(data_examples)
  295. res = dict()
  296. for metric in self.metric_list:
  297. res.update(metric.evaluate())
  298. msg = "Evaluation ended, "
  299. for k, v in res.items():
  300. try:
  301. v = float(v)
  302. msg += k + f": {v:.3f} "
  303. except:
  304. pass
  305. if self.use_wandb:
  306. try:
  307. wandb_metrics = {}
  308. for k, v in res.items():
  309. wandb_metrics[f"{k}"] = v
  310. wandb.log(wandb_metrics)
  311. except Exception as e:
  312. print_log(f"Failed to log metrics to wandb: {e}", logger="current")
  313. print_log(msg, logger="current")
  314. def valid(
  315. self,
  316. val_data: Union[
  317. ListData,
  318. Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
  319. ],
  320. ) -> None:
  321. """
  322. Validate the model with the given validation data.
  323. Parameters
  324. ----------
  325. val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long
  326. Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData``
  327. object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label``
  328. and ``Y`` can be either None or not, which depends on the evaluation metircs in
  329. ``self.metric_list``.
  330. """
  331. val_data_examples = self.data_preprocess("val", val_data)
  332. self._valid(val_data_examples, prefix="val")
  333. def test(
  334. self,
  335. test_data: Union[
  336. ListData,
  337. Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
  338. ],
  339. ) -> None:
  340. """
  341. Test the model with the given test data.
  342. Parameters
  343. ----------
  344. test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501 pylint: disable=line-too-long
  345. Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object
  346. with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y``
  347. can be either None or not, which depends on the evaluation metircs in
  348. ``self.metric_list``.
  349. """
  350. print_log("Test start:", logger="current")
  351. test_data_examples = self.data_preprocess("test", test_data)
  352. self._valid(test_data_examples, prefix="test")

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.