Browse Source

hetero_data

tags/v0.3.1
Beini Frozenmad 4 years ago
parent
commit
0979b03cc3
1 changed files with 171 additions and 0 deletions
  1. +171
    -0
      autogl/module/model/dgl/hetero/hetero_data.py

+ 171
- 0
autogl/module/model/dgl/hetero/hetero_data.py View File

@@ -0,0 +1,171 @@
import dgl
from dgl.data.utils import download, get_download_dir, _get_dgl_url
from scipy import io as sio
import os.path as osp
import sys
import torch
import numpy as np
import urllib.request
import torch.nn as nn

def get_binary_mask(total_size, indices):
mask = torch.zeros(total_size)
mask[indices] = 1
return mask.byte()

class BaseHeteroDataset():
r"""
Description
-----------
Attributes
-------------
g : dgl.DGLHeteroGraph
The dgl heterogeneous graph.
num_classes : int
Number of classes for target nodes.
metapaths : List[[List[str]]]
"""
def __init__(self,):
self.num_classes = None
self.metapaths = None
self.num_features = None
self.g = None


class HeteroData(BaseHeteroDataset):
def __init__(self, name, **kwargs):

super.__init__()
self.name = name

if name=='acm_raw':
self.g, self.num_classes = self.load_acm_raw()
elif name=='acm':
self.g, self.num_classes = self.load_hgt_acm(random_init_fea=True)

def load_acm_raw(self):
self.metapaths = [['pa', 'ap'], ['pf', 'fp']]
filename = 'ACM.mat'
url = 'dataset/' + filename
data_path = get_download_dir() + '/' + filename
if osp.exists(data_path):
print(f'Using existing file {filename}', file=sys.stderr)
else:
download(_get_dgl_url(url), path=data_path)

data = sio.loadmat(data_path)
p_vs_l = data['PvsL'] # paper-field?
p_vs_a = data['PvsA'] # paper-author
p_vs_t = data['PvsT'] # paper-term, bag of words
p_vs_c = data['PvsC'] # paper-conference, labels come from that

# We assign
# (1) KDD papers as class 0 (data mining),
# (2) SIGMOD and VLDB papers as class 1 (database),
# (3) SIGCOMM and MOBICOMM papers as class 2 (communication)
conf_ids = [0, 1, 9, 10, 13]
label_ids = [0, 1, 2, 2, 1]

p_vs_c_filter = p_vs_c[:, conf_ids]
p_selected = (p_vs_c_filter.sum(1) != 0).A1.nonzero()[0]
p_vs_l = p_vs_l[p_selected]
p_vs_a = p_vs_a[p_selected]
p_vs_t = p_vs_t[p_selected]
p_vs_c = p_vs_c[p_selected]

hg = dgl.heterograph({
('paper', 'pa', 'author'): p_vs_a.nonzero(),
('author', 'ap', 'paper'): p_vs_a.transpose().nonzero(),
('paper', 'pf', 'field'): p_vs_l.nonzero(),
('field', 'fp', 'paper'): p_vs_l.transpose().nonzero()
})

hg.nodes['paper'].data['feat'] = torch.FloatTensor(p_vs_t.toarray())
#features = torch.FloatTensor(p_vs_t.toarray())

pc_p, pc_c = p_vs_c.nonzero()
labels = np.zeros(len(p_selected), dtype=np.int64)
for conf_id, label_id in zip(conf_ids, label_ids):
labels[pc_p[pc_c == conf_id]] = label_id
hg.nodes['paper'].data['label'] = torch.LongTensor(labels)

num_classes = 3

float_mask = np.zeros(len(pc_p))
for conf_id in conf_ids:
pc_c_mask = (pc_c == conf_id)
float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum()))
train_idx = np.where(float_mask <= 0.2)[0]
val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0]
test_idx = np.where(float_mask > 0.3)[0]

num_nodes = hg.number_of_nodes('paper')
hg.nodes['paper'].data['train_mask'] = get_binary_mask(num_nodes, train_idx)
hg.nodes['paper'].data['val_mask'] = get_binary_mask(num_nodes, val_idx)
hg.nodes['paper'].data['test_mask'] = get_binary_mask(num_nodes, test_idx)

num_features = hg.nodes['paper'].data['feat'].size(1)

return hg, num_classes, num_features

def load_hgt_acm(self, random_init_fea=True):
data_url = 'https://data.dgl.ai/dataset/ACM.mat'
data_file_path = '/tmp/ACM.mat'

urllib.request.urlretrieve(data_url, data_file_path)
data = sio.loadmat(data_file_path)

hg = dgl.heterograph({
('paper', 'written-by', 'author') : data['PvsA'].nonzero(),
('author', 'writing', 'paper') : data['PvsA'].transpose().nonzero(),
('paper', 'citing', 'paper') : data['PvsP'].nonzero(),
('paper', 'cited', 'paper') : data['PvsP'].transpose().nonzero(),
('paper', 'is-about', 'subject') : data['PvsL'].nonzero(),
('subject', 'has', 'paper') : data['PvsL'].transpose().nonzero(),
})

pvc = data['PvsC'].tocsr()
p_selected = pvc.tocoo()
# generate labels
labels = pvc.indices
hg.nodes['paper'].data['label'] = torch.tensor(labels).long()

# generate train/val/test split
pid = p_selected.row
shuffle = np.random.permutation(pid)
train_idx = torch.tensor(shuffle[0:800]).long()
val_idx = torch.tensor(shuffle[800:900]).long()
test_idx = torch.tensor(shuffle[900:]).long()
num_nodes = hg.number_of_nodes('paper')
hg.nodes['paper'].data['train_mask'] = get_binary_mask(num_nodes, train_idx)
hg.nodes['paper'].data['val_mask'] = get_binary_mask(num_nodes, val_idx)
hg.nodes['paper'].data['test_mask'] = get_binary_mask(num_nodes, test_idx)

hg.node_dict = {}
hg.edge_dict = {}
for ntype in hg.ntypes:
hg.node_dict[ntype] = len(hg.node_dict)
for etype in hg.etypes:
hg.edge_dict[etype] = len(hg.edge_dict)

for etype in hg.etypes:
hg.edges[etype].data['id'] = torch.ones(hg.number_of_edges(etype), dtype=torch.long) * len(edge_dict)

# Random initialize input feature
if random_init_fea:
for ntype in hg.ntypes:
emb = nn.Parameter(torch.Tensor(hg.number_of_nodes(ntype), 256), requires_grad = False)
nn.init.xavier_uniform_(emb)
hg.nodes[ntype].data['feat'] = emb

num_features = 256
num_classes = labels.max().item()+1

return hg, num_classes, num_features



Loading…
Cancel
Save