You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 9.3 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. """
  2. Utilities used by the solver
  3. * LeaderBoard: The LeaderBoard that maintains the performance of models.
  4. """
  5. import random
  6. import typing as _typing
  7. import torch
  8. import torch.backends.cudnn
  9. import numpy as np
  10. import pandas as pd
  11. from ..backend import DependentBackend
  12. from ..data import Dataset
  13. from ..data.graph import GeneralStaticGraph
  14. from ..utils import get_logger
  15. LOGGER = get_logger("LeaderBoard")
  16. BACKEND = DependentBackend.get_backend_name()
  17. if BACKEND == 'dgl':
  18. from autogl.datasets.utils.conversion import to_dgl_dataset as _convert_dataset
  19. else:
  20. from autogl.datasets.utils.conversion import to_pyg_dataset as _convert_dataset
  21. class LeaderBoard:
  22. """
  23. The leaderBoard that can be used to store / sort the model performance automatically.
  24. Parameters
  25. ----------
  26. fields: list of `str`
  27. A list of field name that shows the model performance. The first field is used as
  28. the major field for sorting the model performances.
  29. is_higher_better: `dict` of *field* -> `bool`
  30. A mapping of indicator that whether each field is higher better.
  31. """
  32. def __init__(self, fields, is_higher_better):
  33. assert isinstance(fields, list)
  34. self.keys = ["name"] + fields
  35. self.perform_dict = pd.DataFrame(columns=self.keys)
  36. self.is_higher_better = is_higher_better
  37. self.major_field = fields[0]
  38. def set_major_field(self, field) -> None:
  39. """
  40. Set the major field of current LeaderBoard.
  41. Parameters
  42. ----------
  43. field: `str`
  44. The major field, should be one of the fields when initialized.
  45. Returns
  46. -------
  47. None
  48. """
  49. if field in self.keys and not field == "name":
  50. self.major_field = field
  51. else:
  52. LOGGER.warning(
  53. f"Field [{field}] NOT found in the current LeaderBoard, will ignore."
  54. )
  55. def insert_model_performance(self, name, performance) -> None:
  56. """
  57. Add/Override a record of model performance. If name given is already in the leaderboard,
  58. will overrride the slot.
  59. Parameters
  60. ----------
  61. name: `str`
  62. The model name/identifier that identifies the model.
  63. performance: `dict`
  64. The performance dict. The key inside the dict should be the fields when initialized.
  65. The value of the dict should be the corresponding scores.
  66. Returns
  67. -------
  68. None
  69. """
  70. if name not in self.perform_dict["name"]:
  71. # we just add a new row
  72. performance["name"] = name
  73. new = pd.DataFrame(performance, index=[0])
  74. self.perform_dict = self.perform_dict.append(new, ignore_index=True)
  75. else:
  76. LOGGER.warning(
  77. "model already in the leaderboard, will override current result."
  78. )
  79. self.remove_model_performance(name)
  80. self.insert_model_performance(name, performance)
  81. def remove_model_performance(self, name) -> None:
  82. """
  83. Remove the record of given models.
  84. Parameters
  85. ----------
  86. name: `str`
  87. The model name/identifier that needed to be removed.
  88. Returns
  89. -------
  90. None
  91. """
  92. if name not in self.perform_dict["name"]:
  93. LOGGER.warning(
  94. "no model detected in current leaderboard, will ignore removing action."
  95. )
  96. return
  97. index = self.perform_dict["name"][self.perform_dict["name"] == name].index
  98. self.perform_dict.drop(self.perform_dict.index[index], inplace=True)
  99. return
  100. def get_best_model(self, index=0) -> str:
  101. """
  102. Get the best model according to the performance of the major field.
  103. Parameters
  104. ----------
  105. index: `int`
  106. The index of the model (from good to bad). Default `0`.
  107. Returns
  108. -------
  109. name: `str`
  110. The name/identifier of the required model.
  111. """
  112. sorted_df = self.perform_dict.sort_values(
  113. by=self.major_field, ascending=not self.is_higher_better[self.major_field]
  114. )
  115. name_list = sorted_df["name"].tolist()
  116. if "ensemble" in name_list:
  117. name_list.remove("ensemble")
  118. return name_list[index]
  119. def show(self, top_k=0) -> None:
  120. """
  121. Show current LeaderBoard (from best model to worst).
  122. Parameters
  123. ----------
  124. top_k: `int`
  125. Controls the number model shown.
  126. If less than or equal to `0`, will show all the models. Default to `0`.
  127. Returns
  128. -------
  129. None
  130. """
  131. top_k: int = top_k if top_k > 0 else len(self.perform_dict)
  132. """
  133. reindex self.__performance_data_frame
  134. to ensure the columns of name and representation are in left-side of the data frame
  135. """
  136. _columns = self.perform_dict.columns.tolist()
  137. maxcolwidths: _typing.List[_typing.Optional[int]] = []
  138. if "name" in _columns:
  139. _columns.remove("name")
  140. _columns.insert(0, "name")
  141. maxcolwidths.append(40)
  142. self.perform_dict = self.perform_dict[_columns]
  143. sorted_performance_df: pd.DataFrame = self.perform_dict.sort_values(
  144. self.major_field, ascending=not self.is_higher_better[self.major_field]
  145. )
  146. sorted_performance_df = sorted_performance_df.head(top_k)
  147. from tabulate import tabulate
  148. _columns = sorted_performance_df.columns.tolist()
  149. maxcolwidths.extend([None for _ in range(len(_columns) - len(maxcolwidths))])
  150. print(
  151. tabulate(
  152. list(zip(*[sorted_performance_df[column] for column in _columns])),
  153. headers=_columns,
  154. tablefmt="grid",
  155. )
  156. )
  157. def get_graph_from_dataset(dataset, graph_id=0):
  158. if isinstance(dataset, Dataset):
  159. return dataset[graph_id]
  160. if BACKEND == 'pyg': return dataset[graph_id]
  161. if BACKEND == 'dgl':
  162. from dgl import DGLGraph
  163. data = dataset[graph_id]
  164. if isinstance(data, DGLGraph): return data
  165. return data[0]
  166. def get_graph_node_number(graph):
  167. # FIXME: if the feature is None, this will throw an error
  168. if isinstance(graph, GeneralStaticGraph):
  169. if BACKEND == 'pyg':
  170. return graph.nodes.data['x'].size(0)
  171. return graph.nodes.data['feat'].size(0)
  172. if BACKEND == 'pyg':
  173. size = graph.x.shape[0]
  174. else:
  175. size = graph.num_nodes()
  176. return size
  177. def get_graph_node_features(graph):
  178. if isinstance(graph, GeneralStaticGraph):
  179. if BACKEND == 'dgl' and 'feat' in graph.nodes.data:
  180. return graph.nodes.data['feat']
  181. if BACKEND == 'pyg' and 'x' in graph.nodes.data:
  182. return graph.nodes.data['x']
  183. return None
  184. if BACKEND == 'pyg' and hasattr(graph, 'x'):
  185. return graph.x
  186. elif BACKEND == 'dgl' and 'feat' in graph.ndata:
  187. return graph.ndata['feat']
  188. return None
  189. def get_graph_masks(graph, mask='train'):
  190. if isinstance(graph, GeneralStaticGraph):
  191. if f'{mask}_mask' in graph.nodes.data:
  192. return graph.nodes.data[f'{mask}_mask']
  193. return None
  194. if BACKEND == 'pyg' and hasattr(graph, f'{mask}_mask'):
  195. return getattr(graph, f'{mask}_mask')
  196. if BACKEND == 'dgl' and f'{mask}_mask' in graph.ndata:
  197. return graph.ndata[f'{mask}_mask']
  198. return None
  199. def get_graph_labels(graph):
  200. if isinstance(graph, GeneralStaticGraph):
  201. if 'label' in graph.nodes.data and BACKEND == 'dgl':
  202. return graph.nodes.data['label']
  203. if 'y' in graph.nodes.data and BACKEND == 'pyg':
  204. return graph.nodes.data['y']
  205. return None
  206. if BACKEND == 'pyg' and hasattr(graph, 'y'): return graph.y
  207. if BACKEND == 'dgl' and 'label' in graph.ndata: return graph.ndata['label']
  208. return None
  209. def get_dataset_labels(dataset):
  210. if isinstance(dataset[0], GeneralStaticGraph):
  211. return torch.LongTensor([d.data['label' if BACKEND == 'dgl' else 'y'] for d in dataset])
  212. if BACKEND == 'pyg':
  213. return dataset.data.y
  214. else:
  215. return torch.LongTensor([d[1] for d in dataset])
  216. def convert_dataset(dataset):
  217. # todo: replace the trick by re-implementing the convert_dataset in utils
  218. if hasattr(dataset[0], "edges"): return _convert_dataset(dataset)
  219. # if isinstance(dataset, Dataset): return _convert_dataset(dataset)
  220. return dataset
  221. def set_seed(seed=None):
  222. """
  223. Set seed of whole process
  224. """
  225. if seed is None:
  226. seed = random.randint(0, 5000)
  227. random.seed(seed)
  228. np.random.seed(seed)
  229. torch.manual_seed(seed)
  230. if torch.cuda.is_available():
  231. torch.cuda.manual_seed_all(seed)
  232. torch.backends.cudnn.deterministic = True
  233. torch.backends.cudnn.benchmark = False
  234. def get_graph_labels_hetero(graph, target_node_type):
  235. if isinstance(graph, GeneralStaticGraph):
  236. if 'label' in graph.nodes[target_node_type].data and BACKEND == 'dgl':
  237. return graph.nodes[target_node_type].data['label']
  238. return None
  239. if BACKEND == 'dgl' and 'label' in graph.ndata[target_node_type]: return graph.ndata[target_node_type]['label']