| @@ -15,7 +15,7 @@ import os.path as osp | |||||
| import urllib | import urllib | ||||
| import tarfile | import tarfile | ||||
| from zipfile import ZipFile | from zipfile import ZipFile | ||||
| from gklearn.utils.graphfiles import loadDataset | |||||
| # from gklearn.utils.graphfiles import loadDataset | |||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| import networkx as nx | import networkx as nx | ||||
| import torch | import torch | ||||
| @@ -152,21 +152,27 @@ class DataFetcher(): | |||||
| with tarfile.open(filename_archive, 'r:gz') as tar: | with tarfile.open(filename_archive, 'r:gz') as tar: | ||||
| if self._reload and self._verbose: | if self._reload and self._verbose: | ||||
| print(filename + ' Downloaded.') | print(filename + ' Downloaded.') | ||||
| subpath = os.path.join(path, tar.getnames()[0]) | |||||
| if not osp.exists(subpath) or self._reload: | |||||
| tar.extractall(path = path) | tar.extractall(path = path) | ||||
| return os.path.join(path, tar.getnames()[0]) | |||||
| return subpath | |||||
| elif filename.endswith('.tar'): | elif filename.endswith('.tar'): | ||||
| if tarfile.is_tarfile(filename_archive): | if tarfile.is_tarfile(filename_archive): | ||||
| with tarfile.open(filename_archive, 'r:') as tar: | with tarfile.open(filename_archive, 'r:') as tar: | ||||
| if self._reload and self._verbose: | if self._reload and self._verbose: | ||||
| print(filename + ' Downloaded.') | print(filename + ' Downloaded.') | ||||
| subpath = os.path.join(path, tar.getnames()[0]) | |||||
| if not osp.exists(subpath) or self._reload: | |||||
| tar.extractall(path = path) | tar.extractall(path = path) | ||||
| return os.path.join(path, tar.getnames()[0]) | |||||
| return subpath | |||||
| elif filename.endswith('.zip'): | elif filename.endswith('.zip'): | ||||
| with ZipFile(filename_archive, 'r') as zip_ref: | with ZipFile(filename_archive, 'r') as zip_ref: | ||||
| if self._reload and self._verbose: | if self._reload and self._verbose: | ||||
| print(filename + ' Downloaded.') | print(filename + ' Downloaded.') | ||||
| subpath = os.path.join(path, zip_ref.namelist()[0]) | |||||
| if not osp.exists(subpath) or self._reload: | |||||
| zip_ref.extractall(path) | zip_ref.extractall(path) | ||||
| return os.path.join(path, zip_ref.namelist()[0]) | |||||
| return subpath | |||||
| else: | else: | ||||
| raise ValueError(filename + ' Unsupported file.') | raise ValueError(filename + ' Unsupported file.') | ||||
| @@ -261,6 +267,11 @@ class DataFetcher(): | |||||
| else: | else: | ||||
| geometry = geo_txt | geometry = geo_txt | ||||
| # url. | |||||
| url = td_node[11].xpath('a')[0].attrib['href'].strip() | |||||
| pos_zip = url.rfind('.zip') | |||||
| url = url[:pos_zip + 4] | |||||
| infos[td_node[0].xpath('strong')[0].text.strip()] = { | infos[td_node[0].xpath('strong')[0].text.strip()] = { | ||||
| 'database': 'tudataset', | 'database': 'tudataset', | ||||
| 'reference': td_node[1].text.strip(), | 'reference': td_node[1].text.strip(), | ||||
| @@ -274,7 +285,7 @@ class DataFetcher(): | |||||
| 'node_attr_dim': node_attr_dim, | 'node_attr_dim': node_attr_dim, | ||||
| 'geometry': geometry, | 'geometry': geometry, | ||||
| 'edge_attr_dim': edge_attr_dim, | 'edge_attr_dim': edge_attr_dim, | ||||
| 'url': td_node[11].xpath('a')[0].attrib['href'].strip(), | |||||
| 'url': url, | |||||
| 'domain': domain | 'domain': domain | ||||
| } | } | ||||