From 0979b03cc3f29b6bdcc830b1936cb6873fba61ba Mon Sep 17 00:00:00 2001 From: Beini Date: Sat, 11 Dec 2021 03:54:56 +0000 Subject: [PATCH] hetero_data --- autogl/module/model/dgl/hetero/hetero_data.py | 171 ++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 autogl/module/model/dgl/hetero/hetero_data.py diff --git a/autogl/module/model/dgl/hetero/hetero_data.py b/autogl/module/model/dgl/hetero/hetero_data.py new file mode 100644 index 0000000..d47b287 --- /dev/null +++ b/autogl/module/model/dgl/hetero/hetero_data.py @@ -0,0 +1,171 @@ +import dgl +from dgl.data.utils import download, get_download_dir, _get_dgl_url +from scipy import io as sio +import os.path as osp +import sys +import torch +import numpy as np +import urllib.request +import torch.nn as nn + +def get_binary_mask(total_size, indices): + mask = torch.zeros(total_size) + mask[indices] = 1 + return mask.byte() + +class BaseHeteroDataset(): + r""" + Description + ----------- + + + Attributes + ------------- + g : dgl.DGLHeteroGraph + The dgl heterogeneous graph. + num_classes : int + Number of classes for target nodes. + metapaths : List[[List[str]]] + """ + def __init__(self,): + self.num_classes = None + self.metapaths = None + self.num_features = None + self.g = None + + +class HeteroData(BaseHeteroDataset): + + def __init__(self, name, **kwargs): + + super.__init__() + self.name = name + + if name=='acm_raw': + self.g, self.num_classes = self.load_acm_raw() + elif name=='acm': + self.g, self.num_classes = self.load_hgt_acm(random_init_fea=True) + + def load_acm_raw(self): + self.metapaths = [['pa', 'ap'], ['pf', 'fp']] + filename = 'ACM.mat' + url = 'dataset/' + filename + data_path = get_download_dir() + '/' + filename + if osp.exists(data_path): + print(f'Using existing file {filename}', file=sys.stderr) + else: + download(_get_dgl_url(url), path=data_path) + + data = sio.loadmat(data_path) + p_vs_l = data['PvsL'] # paper-field? + p_vs_a = data['PvsA'] # paper-author + p_vs_t = data['PvsT'] # paper-term, bag of words + p_vs_c = data['PvsC'] # paper-conference, labels come from that + + # We assign + # (1) KDD papers as class 0 (data mining), + # (2) SIGMOD and VLDB papers as class 1 (database), + # (3) SIGCOMM and MOBICOMM papers as class 2 (communication) + conf_ids = [0, 1, 9, 10, 13] + label_ids = [0, 1, 2, 2, 1] + + p_vs_c_filter = p_vs_c[:, conf_ids] + p_selected = (p_vs_c_filter.sum(1) != 0).A1.nonzero()[0] + p_vs_l = p_vs_l[p_selected] + p_vs_a = p_vs_a[p_selected] + p_vs_t = p_vs_t[p_selected] + p_vs_c = p_vs_c[p_selected] + + hg = dgl.heterograph({ + ('paper', 'pa', 'author'): p_vs_a.nonzero(), + ('author', 'ap', 'paper'): p_vs_a.transpose().nonzero(), + ('paper', 'pf', 'field'): p_vs_l.nonzero(), + ('field', 'fp', 'paper'): p_vs_l.transpose().nonzero() + }) + + hg.nodes['paper'].data['feat'] = torch.FloatTensor(p_vs_t.toarray()) + #features = torch.FloatTensor(p_vs_t.toarray()) + + pc_p, pc_c = p_vs_c.nonzero() + labels = np.zeros(len(p_selected), dtype=np.int64) + for conf_id, label_id in zip(conf_ids, label_ids): + labels[pc_p[pc_c == conf_id]] = label_id + hg.nodes['paper'].data['label'] = torch.LongTensor(labels) + + num_classes = 3 + + float_mask = np.zeros(len(pc_p)) + for conf_id in conf_ids: + pc_c_mask = (pc_c == conf_id) + float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum())) + train_idx = np.where(float_mask <= 0.2)[0] + val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0] + test_idx = np.where(float_mask > 0.3)[0] + + num_nodes = hg.number_of_nodes('paper') + hg.nodes['paper'].data['train_mask'] = get_binary_mask(num_nodes, train_idx) + hg.nodes['paper'].data['val_mask'] = get_binary_mask(num_nodes, val_idx) + hg.nodes['paper'].data['test_mask'] = get_binary_mask(num_nodes, test_idx) + + num_features = hg.nodes['paper'].data['feat'].size(1) + + return hg, num_classes, num_features + + def load_hgt_acm(self, random_init_fea=True): + data_url = 'https://data.dgl.ai/dataset/ACM.mat' + data_file_path = '/tmp/ACM.mat' + + urllib.request.urlretrieve(data_url, data_file_path) + data = sio.loadmat(data_file_path) + + hg = dgl.heterograph({ + ('paper', 'written-by', 'author') : data['PvsA'].nonzero(), + ('author', 'writing', 'paper') : data['PvsA'].transpose().nonzero(), + ('paper', 'citing', 'paper') : data['PvsP'].nonzero(), + ('paper', 'cited', 'paper') : data['PvsP'].transpose().nonzero(), + ('paper', 'is-about', 'subject') : data['PvsL'].nonzero(), + ('subject', 'has', 'paper') : data['PvsL'].transpose().nonzero(), + }) + + pvc = data['PvsC'].tocsr() + p_selected = pvc.tocoo() + # generate labels + labels = pvc.indices + hg.nodes['paper'].data['label'] = torch.tensor(labels).long() + + # generate train/val/test split + pid = p_selected.row + shuffle = np.random.permutation(pid) + train_idx = torch.tensor(shuffle[0:800]).long() + val_idx = torch.tensor(shuffle[800:900]).long() + test_idx = torch.tensor(shuffle[900:]).long() + num_nodes = hg.number_of_nodes('paper') + hg.nodes['paper'].data['train_mask'] = get_binary_mask(num_nodes, train_idx) + hg.nodes['paper'].data['val_mask'] = get_binary_mask(num_nodes, val_idx) + hg.nodes['paper'].data['test_mask'] = get_binary_mask(num_nodes, test_idx) + + hg.node_dict = {} + hg.edge_dict = {} + for ntype in hg.ntypes: + hg.node_dict[ntype] = len(hg.node_dict) + for etype in hg.etypes: + hg.edge_dict[etype] = len(hg.edge_dict) + + for etype in hg.etypes: + hg.edges[etype].data['id'] = torch.ones(hg.number_of_edges(etype), dtype=torch.long) * len(edge_dict) + + # Random initialize input feature + if random_init_fea: + for ntype in hg.ntypes: + emb = nn.Parameter(torch.Tensor(hg.number_of_nodes(ntype), 256), requires_grad = False) + nn.init.xavier_uniform_(emb) + hg.nodes[ntype].data['feat'] = emb + + num_features = 256 + num_classes = labels.max().item()+1 + + return hg, num_classes, num_features + + + +