| @@ -0,0 +1,171 @@ | |||
| import dgl | |||
| from dgl.data.utils import download, get_download_dir, _get_dgl_url | |||
| from scipy import io as sio | |||
| import os.path as osp | |||
| import sys | |||
| import torch | |||
| import numpy as np | |||
| import urllib.request | |||
| import torch.nn as nn | |||
| def get_binary_mask(total_size, indices): | |||
| mask = torch.zeros(total_size) | |||
| mask[indices] = 1 | |||
| return mask.byte() | |||
| class BaseHeteroDataset(): | |||
| r""" | |||
| Description | |||
| ----------- | |||
| Attributes | |||
| ------------- | |||
| g : dgl.DGLHeteroGraph | |||
| The dgl heterogeneous graph. | |||
| num_classes : int | |||
| Number of classes for target nodes. | |||
| metapaths : List[[List[str]]] | |||
| """ | |||
| def __init__(self,): | |||
| self.num_classes = None | |||
| self.metapaths = None | |||
| self.num_features = None | |||
| self.g = None | |||
| class HeteroData(BaseHeteroDataset): | |||
| def __init__(self, name, **kwargs): | |||
| super.__init__() | |||
| self.name = name | |||
| if name=='acm_raw': | |||
| self.g, self.num_classes = self.load_acm_raw() | |||
| elif name=='acm': | |||
| self.g, self.num_classes = self.load_hgt_acm(random_init_fea=True) | |||
| def load_acm_raw(self): | |||
| self.metapaths = [['pa', 'ap'], ['pf', 'fp']] | |||
| filename = 'ACM.mat' | |||
| url = 'dataset/' + filename | |||
| data_path = get_download_dir() + '/' + filename | |||
| if osp.exists(data_path): | |||
| print(f'Using existing file {filename}', file=sys.stderr) | |||
| else: | |||
| download(_get_dgl_url(url), path=data_path) | |||
| data = sio.loadmat(data_path) | |||
| p_vs_l = data['PvsL'] # paper-field? | |||
| p_vs_a = data['PvsA'] # paper-author | |||
| p_vs_t = data['PvsT'] # paper-term, bag of words | |||
| p_vs_c = data['PvsC'] # paper-conference, labels come from that | |||
| # We assign | |||
| # (1) KDD papers as class 0 (data mining), | |||
| # (2) SIGMOD and VLDB papers as class 1 (database), | |||
| # (3) SIGCOMM and MOBICOMM papers as class 2 (communication) | |||
| conf_ids = [0, 1, 9, 10, 13] | |||
| label_ids = [0, 1, 2, 2, 1] | |||
| p_vs_c_filter = p_vs_c[:, conf_ids] | |||
| p_selected = (p_vs_c_filter.sum(1) != 0).A1.nonzero()[0] | |||
| p_vs_l = p_vs_l[p_selected] | |||
| p_vs_a = p_vs_a[p_selected] | |||
| p_vs_t = p_vs_t[p_selected] | |||
| p_vs_c = p_vs_c[p_selected] | |||
| hg = dgl.heterograph({ | |||
| ('paper', 'pa', 'author'): p_vs_a.nonzero(), | |||
| ('author', 'ap', 'paper'): p_vs_a.transpose().nonzero(), | |||
| ('paper', 'pf', 'field'): p_vs_l.nonzero(), | |||
| ('field', 'fp', 'paper'): p_vs_l.transpose().nonzero() | |||
| }) | |||
| hg.nodes['paper'].data['feat'] = torch.FloatTensor(p_vs_t.toarray()) | |||
| #features = torch.FloatTensor(p_vs_t.toarray()) | |||
| pc_p, pc_c = p_vs_c.nonzero() | |||
| labels = np.zeros(len(p_selected), dtype=np.int64) | |||
| for conf_id, label_id in zip(conf_ids, label_ids): | |||
| labels[pc_p[pc_c == conf_id]] = label_id | |||
| hg.nodes['paper'].data['label'] = torch.LongTensor(labels) | |||
| num_classes = 3 | |||
| float_mask = np.zeros(len(pc_p)) | |||
| for conf_id in conf_ids: | |||
| pc_c_mask = (pc_c == conf_id) | |||
| float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum())) | |||
| train_idx = np.where(float_mask <= 0.2)[0] | |||
| val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0] | |||
| test_idx = np.where(float_mask > 0.3)[0] | |||
| num_nodes = hg.number_of_nodes('paper') | |||
| hg.nodes['paper'].data['train_mask'] = get_binary_mask(num_nodes, train_idx) | |||
| hg.nodes['paper'].data['val_mask'] = get_binary_mask(num_nodes, val_idx) | |||
| hg.nodes['paper'].data['test_mask'] = get_binary_mask(num_nodes, test_idx) | |||
| num_features = hg.nodes['paper'].data['feat'].size(1) | |||
| return hg, num_classes, num_features | |||
| def load_hgt_acm(self, random_init_fea=True): | |||
| data_url = 'https://data.dgl.ai/dataset/ACM.mat' | |||
| data_file_path = '/tmp/ACM.mat' | |||
| urllib.request.urlretrieve(data_url, data_file_path) | |||
| data = sio.loadmat(data_file_path) | |||
| hg = dgl.heterograph({ | |||
| ('paper', 'written-by', 'author') : data['PvsA'].nonzero(), | |||
| ('author', 'writing', 'paper') : data['PvsA'].transpose().nonzero(), | |||
| ('paper', 'citing', 'paper') : data['PvsP'].nonzero(), | |||
| ('paper', 'cited', 'paper') : data['PvsP'].transpose().nonzero(), | |||
| ('paper', 'is-about', 'subject') : data['PvsL'].nonzero(), | |||
| ('subject', 'has', 'paper') : data['PvsL'].transpose().nonzero(), | |||
| }) | |||
| pvc = data['PvsC'].tocsr() | |||
| p_selected = pvc.tocoo() | |||
| # generate labels | |||
| labels = pvc.indices | |||
| hg.nodes['paper'].data['label'] = torch.tensor(labels).long() | |||
| # generate train/val/test split | |||
| pid = p_selected.row | |||
| shuffle = np.random.permutation(pid) | |||
| train_idx = torch.tensor(shuffle[0:800]).long() | |||
| val_idx = torch.tensor(shuffle[800:900]).long() | |||
| test_idx = torch.tensor(shuffle[900:]).long() | |||
| num_nodes = hg.number_of_nodes('paper') | |||
| hg.nodes['paper'].data['train_mask'] = get_binary_mask(num_nodes, train_idx) | |||
| hg.nodes['paper'].data['val_mask'] = get_binary_mask(num_nodes, val_idx) | |||
| hg.nodes['paper'].data['test_mask'] = get_binary_mask(num_nodes, test_idx) | |||
| hg.node_dict = {} | |||
| hg.edge_dict = {} | |||
| for ntype in hg.ntypes: | |||
| hg.node_dict[ntype] = len(hg.node_dict) | |||
| for etype in hg.etypes: | |||
| hg.edge_dict[etype] = len(hg.edge_dict) | |||
| for etype in hg.etypes: | |||
| hg.edges[etype].data['id'] = torch.ones(hg.number_of_edges(etype), dtype=torch.long) * len(edge_dict) | |||
| # Random initialize input feature | |||
| if random_init_fea: | |||
| for ntype in hg.ntypes: | |||
| emb = nn.Parameter(torch.Tensor(hg.number_of_nodes(ntype), 256), requires_grad = False) | |||
| nn.init.xavier_uniform_(emb) | |||
| hg.nodes[ntype].data['feat'] = emb | |||
| num_features = 256 | |||
| num_classes = labels.max().item()+1 | |||
| return hg, num_classes, num_features | |||