Browse Source

change gnn input position

tags/v1.1.0
zhanke 5 years ago
parent
commit
4384ca214f
7 changed files with 65 additions and 63 deletions
  1. +13
    -18
      model_zoo/official/gnn/gat/src/gat.py
  2. +10
    -10
      model_zoo/official/gnn/gat/src/utils.py
  3. +9
    -9
      model_zoo/official/gnn/gat/train.py
  4. +4
    -7
      model_zoo/official/gnn/gcn/src/gcn.py
  5. +10
    -10
      model_zoo/official/gnn/gcn/src/metrics.py
  6. +10
    -5
      model_zoo/official/gnn/gcn/train.py
  7. +9
    -4
      tests/st/gnn/gcn/test_gcn.py

+ 13
- 18
model_zoo/official/gnn/gat/src/gat.py View File

@@ -19,7 +19,6 @@ from mindspore.ops import functional as F
from mindspore._extends import cell_attr_register
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore._checkparam import Validator
from mindspore.nn.layer.activation import get_activation


@@ -72,9 +71,9 @@ class GNNFeatureTransform(nn.Cell):
bias_init='zeros',
has_bias=True):
super(GNNFeatureTransform, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias)
self.in_channels = in_channels
self.out_channels = out_channels
self.has_bias = has_bias

if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
@@ -259,8 +258,8 @@ class AttentionHead(nn.Cell):
coef_activation=nn.LeakyReLU(),
activation=nn.ELU()):
super(AttentionHead, self).__init__()
self.in_channel = Validator.check_positive_int(in_channel)
self.out_channel = Validator.check_positive_int(out_channel)
self.in_channel = in_channel
self.out_channel = out_channel
self.in_drop_ratio = in_drop_ratio
self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio)
self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio)
@@ -284,7 +283,7 @@ class AttentionHead(nn.Cell):
self.matmul = P.MatMul()
self.bias_add = P.BiasAdd()
self.bias = Parameter(initializer('zeros', self.out_channel), name='bias')
self.residual = Validator.check_bool(residual)
self.residual = residual
if self.residual:
if in_channel != out_channel:
self.residual_transform_flag = True
@@ -436,8 +435,6 @@ class GAT(nn.Cell):
"""

def __init__(self,
features,
biases,
ftr_dims,
num_class,
num_nodes,
@@ -448,17 +445,15 @@ class GAT(nn.Cell):
activation=nn.ELU(),
residual=False):
super(GAT, self).__init__()
self.features = Tensor(features)
self.biases = Tensor(biases)
self.ftr_dims = Validator.check_positive_int(ftr_dims)
self.num_class = Validator.check_positive_int(num_class)
self.num_nodes = Validator.check_positive_int(num_nodes)
self.ftr_dims = ftr_dims
self.num_class = num_class
self.num_nodes = num_nodes
self.hidden_units = hidden_units
self.num_heads = num_heads
self.attn_drop = attn_drop
self.ftr_drop = ftr_drop
self.activation = activation
self.residual = Validator.check_bool(residual)
self.residual = residual
self.layers = []
# first layer
self.layers.append(AttentionAggregator(
@@ -491,9 +486,9 @@ class GAT(nn.Cell):
output_transform='sum'))
self.layers = nn.layer.CellList(self.layers)

def construct(self, training=True):
input_data = self.features
bias_mat = self.biases
def construct(self, feature, biases, training=True):
input_data = feature
bias_mat = biases
for cell in self.layers:
input_data = cell(input_data, bias_mat, training)
return input_data/self.num_heads[-1]

+ 10
- 10
model_zoo/official/gnn/gat/src/utils.py View File

@@ -103,8 +103,8 @@ class LossAccuracyWrapper(nn.Cell):
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, self.network.trainable_params())
self.acc_func = MaskedAccuracy(num_class, label, mask)

def construct(self):
logits = self.network(training=False)
def construct(self, feature, biases):
logits = self.network(feature, biases, training=False)
loss = self.loss_func(logits)
accuracy = self.acc_func(logits)
return loss, accuracy
@@ -120,8 +120,8 @@ class LossNetWrapper(nn.Cell):
params = list(param for param in self.network.trainable_params() if param.name[-4:] != 'bias')
self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, params)

def construct(self):
logits = self.network()
def construct(self, feature, biases):
logits = self.network(feature, biases)
loss = self.loss_func(logits)
return loss

@@ -145,11 +145,11 @@ class TrainOneStepCell(nn.Cell):
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens

def construct(self):
def construct(self, feature, biases):
weights = self.weights
loss = self.network()
loss = self.network(feature, biases)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(sens)
grads = self.grad(self.network, weights)(feature, biases, sens)
return F.depend(loss, self.optimizer(grads))


@@ -174,7 +174,7 @@ class TrainGAT(nn.Cell):
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
self.accuracy_func = MaskedAccuracy(num_class, label, mask)

def construct(self):
loss = self.loss_train_net()
accuracy = self.accuracy_func(self.network())
def construct(self, feature, biases):
loss = self.loss_train_net(feature, biases)
accuracy = self.accuracy_func(self.network(feature, biases))
return loss, accuracy

+ 9
- 9
model_zoo/official/gnn/gat/train.py View File

@@ -20,6 +20,7 @@ import numpy as np
import mindspore.context as context
from mindspore.train.serialization import save_checkpoint, load_checkpoint
from mindspore.common import set_seed
from mindspore import Tensor

from src.config import GatConfig
from src.dataset import load_and_process
@@ -56,9 +57,7 @@ def train():
num_nodes = feature.shape[1]
num_class = y_train.shape[2]

gat_net = GAT(feature,
biases,
feature_size,
gat_net = GAT(feature_size,
num_class,
num_nodes,
hid_units,
@@ -67,6 +66,9 @@ def train():
ftr_drop=GatConfig.feature_dropout)
gat_net.add_flags_recursive(fp16=True)

feature = Tensor(feature)
biases = Tensor(biases)

eval_net = LossAccuracyWrapper(gat_net,
num_class,
y_val,
@@ -84,11 +86,11 @@ def train():
val_acc_max = 0.0
val_loss_min = np.inf
for _epoch in range(num_epochs):
train_result = train_net()
train_result = train_net(feature, biases)
train_loss = train_result[0].asnumpy()
train_acc = train_result[1].asnumpy()

eval_result = eval_net()
eval_result = eval_net(feature, biases)
eval_loss = eval_result[0].asnumpy()
eval_acc = eval_result[1].asnumpy()

@@ -110,9 +112,7 @@ def train():
print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max))
print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model))
break
gat_net_test = GAT(feature,
biases,
feature_size,
gat_net_test = GAT(feature_size,
num_class,
num_nodes,
hid_units,
@@ -127,7 +127,7 @@ def train():
y_test,
test_mask,
l2_coeff)
test_result = test_net()
test_result = test_net(feature, biases)
print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))




+ 4
- 7
model_zoo/official/gnn/gcn/src/gcn.py View File

@@ -92,15 +92,12 @@ class GCN(nn.Cell):
output_dim (int): The number of output channels, equal to classes num.
"""

def __init__(self, config, adj, feature, output_dim):
def __init__(self, config, input_dim, output_dim):
super(GCN, self).__init__()
self.adj = Tensor(adj)
self.feature = Tensor(feature)
input_dim = feature.shape[1]
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)

def construct(self):
output0 = self.layer0(self.adj, self.feature)
output1 = self.layer1(self.adj, output0)
def construct(self, adj, feature):
output0 = self.layer0(adj, feature)
output1 = self.layer1(adj, output0)
return output1

+ 10
- 10
model_zoo/official/gnn/gcn/src/metrics.py View File

@@ -91,8 +91,8 @@ class LossAccuracyWrapper(nn.Cell):
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
self.accuracy = Accuracy(label, mask)

def construct(self):
preds = self.network()
def construct(self, adj, feature):
preds = self.network(adj, feature)
loss = self.loss(preds)
accuracy = self.accuracy(preds)
return loss, accuracy
@@ -114,8 +114,8 @@ class LossWrapper(nn.Cell):
self.network = network
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])

def construct(self):
preds = self.network()
def construct(self, adj, feature):
preds = self.network(adj, feature)
loss = self.loss(preds)
return loss

@@ -154,11 +154,11 @@ class TrainOneStepCell(nn.Cell):
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens

def construct(self):
def construct(self, adj, feature):
weights = self.weights
loss = self.network()
loss = self.network(adj, feature)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(sens)
grads = self.grad(self.network, weights)(adj, feature, sens)
return F.depend(loss, self.optimizer(grads))


@@ -182,7 +182,7 @@ class TrainNetWrapper(nn.Cell):
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
self.accuracy = Accuracy(label, mask)

def construct(self):
loss = self.loss_train_net()
accuracy = self.accuracy(self.network())
def construct(self, adj, feature):
loss = self.loss_train_net(adj, feature)
accuracy = self.accuracy(self.network(adj, feature))
return loss, accuracy

+ 10
- 5
model_zoo/official/gnn/gcn/train.py View File

@@ -26,6 +26,7 @@ from matplotlib import pyplot as plt
from matplotlib import animation
from sklearn import manifold
from mindspore import context
from mindspore import Tensor
from mindspore.common import set_seed
from mindspore.train.serialization import save_checkpoint, load_checkpoint

@@ -71,9 +72,13 @@ def train():
test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num)

class_num = label_onehot.shape[1]
gcn_net = GCN(config, adj, feature, class_num)
input_dim = feature.shape[1]
gcn_net = GCN(config, input_dim, class_num)
gcn_net.add_flags_recursive(fp16=True)

adj = Tensor(adj)
feature = Tensor(feature)

eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)

@@ -92,12 +97,12 @@ def train():
t = time.time()

train_net.set_train()
train_result = train_net()
train_result = train_net(adj, feature)
train_loss = train_result[0].asnumpy()
train_accuracy = train_result[1].asnumpy()

eval_net.set_train(False)
eval_result = eval_net()
eval_result = eval_net(adj, feature)
eval_loss = eval_result[0].asnumpy()
eval_accuracy = eval_result[1].asnumpy()

@@ -115,14 +120,14 @@ def train():
print("Early stopping...")
break
save_checkpoint(gcn_net, "ckpts/gcn.ckpt")
gcn_net_test = GCN(config, adj, feature, class_num)
gcn_net_test = GCN(config, input_dim, class_num)
load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test)
gcn_net_test.add_flags_recursive(fp16=True)

test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay)
t_test = time.time()
test_net.set_train(False)
test_result = test_net()
test_result = test_net(adj, feature)
test_loss = test_result[0].asnumpy()
test_accuracy = test_result[1].asnumpy()
print("Test set results:", "loss=", "{:.5f}".format(test_loss),


+ 9
- 4
tests/st/gnn/gcn/test_gcn.py View File

@@ -17,6 +17,7 @@ import time
import pytest
import numpy as np
from mindspore import context
from mindspore import Tensor
from model_zoo.official.gnn.gcn.src.gcn import GCN
from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
from model_zoo.official.gnn.gcn.src.config import ConfigGCN
@@ -49,9 +50,13 @@ def test_gcn():
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)

class_num = label_onehot.shape[1]
gcn_net = GCN(config, adj, feature, class_num)
input_dim = feature.shape[1]
gcn_net = GCN(config, input_dim, class_num)
gcn_net.add_flags_recursive(fp16=True)

adj = Tensor(adj)
feature = Tensor(feature)

eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
@@ -61,12 +66,12 @@ def test_gcn():
t = time.time()

train_net.set_train()
train_result = train_net()
train_result = train_net(adj, feature)
train_loss = train_result[0].asnumpy()
train_accuracy = train_result[1].asnumpy()

eval_net.set_train(False)
eval_result = eval_net()
eval_result = eval_net(adj, feature)
eval_loss = eval_result[0].asnumpy()
eval_accuracy = eval_result[1].asnumpy()

@@ -80,7 +85,7 @@ def test_gcn():
break

test_net.set_train(False)
test_result = test_net()
test_result = test_net(adj, feature)
test_loss = test_result[0].asnumpy()
test_accuracy = test_result[1].asnumpy()
print("Test set results:", "loss=", "{:.5f}".format(test_loss),


Loading…
Cancel
Save