Browse Source

add gcn training scripts

tags/v0.5.0-beta
chentingting 5 years ago
parent
commit
e801d48906
7 changed files with 396 additions and 0 deletions
  1. +0
    -0
      tests/st/gnn/gcn/__init__.py
  2. +0
    -0
      tests/st/gnn/gcn/src/__init__.py
  3. +22
    -0
      tests/st/gnn/gcn/src/config.py
  4. +60
    -0
      tests/st/gnn/gcn/src/dataset.py
  5. +163
    -0
      tests/st/gnn/gcn/src/gcn.py
  6. +68
    -0
      tests/st/gnn/gcn/src/metrics.py
  7. +83
    -0
      tests/st/gnn/gcn/test_gcn.py

+ 0
- 0
tests/st/gnn/gcn/__init__.py View File


+ 0
- 0
tests/st/gnn/gcn/src/__init__.py View File


+ 22
- 0
tests/st/gnn/gcn/src/config.py View File

@@ -0,0 +1,22 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

class ConfigGCN():
learning_rate = 0.01
epochs = 200
hidden1 = 16
dropout = 0.0
weight_decay = 5e-4
early_stopping = 10

+ 60
- 0
tests/st/gnn/gcn/src/dataset.py View File

@@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import scipy.sparse as sp
import mindspore.dataset as ds


def normalize_adj(adj):
rowsum = np.array(adj.sum(1))
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()


def get_adj_features_labels(data_dir):
g = ds.GraphData(data_dir)
nodes = g.get_all_nodes(0)
nodes_list = nodes.tolist()
row_tensor = g.get_node_feature(nodes_list, [1, 2])
features = row_tensor[0]
labels = row_tensor[1]

nodes_num = labels.shape[0]
class_num = labels.max() + 1
labels_onehot = np.eye(nodes_num, class_num)[labels].astype(np.float32)

neighbor = g.get_all_neighbors(nodes_list, 0)
node_map = {node_id: index for index, node_id in enumerate(nodes_list)}
adj = np.zeros([nodes_num, nodes_num], dtype=np.float32)
for index, value in np.ndenumerate(neighbor):
# The first column of neighbor is node_id, second column to last column are neighbors of the first column.
# So we only care index[1] > 1.
# If the node does not have that many neighbors, -1 is padded. So if value < 0, we will not deal with it.
if value >= 0 and index[1] > 0:
adj[node_map[neighbor[index[0], 0]], node_map[value]] = 1
adj = sp.coo_matrix(adj)
adj = adj + adj.T.multiply(adj.T > adj) + sp.eye(nodes_num)
nor_adj = normalize_adj(adj)
nor_adj = np.array(nor_adj.todense())
return nor_adj, features, labels_onehot


def get_mask(total, begin, end):
mask = np.zeros([total]).astype(np.float32)
mask[begin:end] = 1
return mask

+ 163
- 0
tests/st/gnn/gcn/src/gcn.py View File

@@ -0,0 +1,163 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
from mindspore import nn
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.nn.layer.activation import get_activation
from src.metrics import Loss, Accuracy


def glorot(shape):
init_range = np.sqrt(6.0/(shape[0]+shape[1]))
initial = np.random.uniform(-init_range, init_range, shape).astype(np.float32)
return Tensor(initial)


class GraphConvolution(nn.Cell):
def __init__(self,
feature_in_dim,
feature_out_dim,
dropout_ratio=None,
activation=None):
super(GraphConvolution, self).__init__()
self.in_dim = feature_in_dim
self.out_dim = feature_out_dim
self.weight_init = glorot([self.out_dim, self.in_dim])
self.fc = nn.Dense(self.in_dim,
self.out_dim,
weight_init=self.weight_init,
has_bias=False)
self.dropout_ratio = dropout_ratio
if self.dropout_ratio is not None:
self.dropout = nn.Dropout(keep_prob=1-self.dropout_ratio)
self.dropout_flag = self.dropout_ratio is not None
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
self.matmul = P.MatMul()

def construct(self, adj, input_feature):
dropout = input_feature
if self.dropout_flag:
dropout = self.dropout(dropout)

fc = self.fc(dropout)
output_feature = self.matmul(adj, fc)

if self.activation_flag:
output_feature = self.activation(output_feature)
return output_feature


class GCN(nn.Cell):
def __init__(self, config, adj, feature, output_dim):
super(GCN, self).__init__()
self.adj = Tensor(adj)
self.feature = Tensor(feature)
input_dim = feature.shape[1]
self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout)
self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None)

def construct(self):
output0 = self.layer0(self.adj, self.feature)
output1 = self.layer1(self.adj, output0)
return output1


class LossAccuracyWrapper(nn.Cell):
def __init__(self, network, label, mask, weight_decay):
super(LossAccuracyWrapper, self).__init__()
self.network = network
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])
self.accuracy = Accuracy(label, mask)

def construct(self):
preds = self.network()
loss = self.loss(preds)
accuracy = self.accuracy(preds)
return loss, accuracy


class LossWrapper(nn.Cell):
def __init__(self, network, label, mask, weight_decay):
super(LossWrapper, self).__init__()
self.network = network
self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0])

def construct(self):
preds = self.network()
loss = self.loss(preds)
return loss


class TrainOneStepCell(nn.Cell):
r"""
Network training package class.

Wraps the network with an optimizer. The resulting Cell be trained without inputs.
Backward graph will be created in the construct function to do parameter updating. Different
parallel modes are available to run the training.

Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.

Outputs:
Tensor, a scalar Tensor with shape :math:`()`.

Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
"""

def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens

def construct(self):
weights = self.weights
loss = self.network()
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(sens)
return F.depend(loss, self.optimizer(grads))


class TrainNetWrapper(nn.Cell):
def __init__(self, network, label, mask, config):
super(TrainNetWrapper, self).__init__(auto_prefix=True)
self.network = network
loss_net = LossWrapper(network, label, mask, config.weight_decay)
optimizer = nn.Adam(loss_net.trainable_params(),
learning_rate=config.learning_rate)
self.loss_train_net = TrainOneStepCell(loss_net, optimizer)
self.accuracy = Accuracy(label, mask)

def construct(self):
loss = self.loss_train_net()
accuracy = self.accuracy(self.network())
return loss, accuracy

+ 68
- 0
tests/st/gnn/gcn/src/metrics.py View File

@@ -0,0 +1,68 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore import nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P


class Loss(nn.Cell):
def __init__(self, label, mask, weight_decay, param):
super(Loss, self).__init__()
self.label = Tensor(label)
self.mask = Tensor(mask)
self.loss = P.SoftmaxCrossEntropyWithLogits()
self.one = Tensor(1.0, mstype.float32)
self.zero = Tensor(0.0, mstype.float32)
self.mean = P.ReduceMean()
self.cast = P.Cast()
self.l2_loss = P.L2Loss()
self.reduce_sum = P.ReduceSum()
self.weight_decay = weight_decay
self.param = param

def construct(self, preds):
param = self.l2_loss(self.param)
loss = self.weight_decay * param
preds = self.cast(preds, mstype.float32)
loss = loss + self.loss(preds, self.label)[0]
mask = self.cast(self.mask, mstype.float32)
mask_reduce = self.mean(mask)
mask = mask / mask_reduce
loss = loss * mask
loss = self.mean(loss)
return loss


class Accuracy(nn.Cell):
def __init__(self, label, mask):
super(Accuracy, self).__init__()
self.label = Tensor(label)
self.mask = Tensor(mask)
self.equal = P.Equal()
self.argmax = P.Argmax()
self.cast = P.Cast()
self.mean = P.ReduceMean()

def construct(self, preds):
preds = self.cast(preds, mstype.float32)
correct_prediction = self.equal(self.argmax(preds), self.argmax(self.label))
accuracy_all = self.cast(correct_prediction, mstype.float32)
mask = self.cast(self.mask, mstype.float32)
mask_reduce = self.mean(mask)
mask = mask / mask_reduce
accuracy_all *= mask
return self.mean(accuracy_all)

+ 83
- 0
tests/st/gnn/gcn/test_gcn.py View File

@@ -0,0 +1,83 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import time
import pytest
import numpy as np
from mindspore import context
from src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper
from src.config import ConfigGCN
from src.dataset import get_adj_features_labels, get_mask


DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr'
TRAIN_NODE_NUM = 140
EVAL_NODE_NUM = 500
TEST_NODE_NUM = 1000
SEED = 20


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_gcn():
print("test_gcn begin")
np.random.seed(SEED)
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=True)
config = ConfigGCN()
adj, feature, label = get_adj_features_labels(DATA_DIR)

nodes_num = label.shape[0]
train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM)
eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM, TRAIN_NODE_NUM + EVAL_NODE_NUM)
test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)

class_num = label.shape[1]
gcn_net = GCN(config, adj, feature, class_num)
gcn_net.add_flags_recursive(fp16=True)

eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask, config.weight_decay)
test_net = LossAccuracyWrapper(gcn_net, label, test_mask, config.weight_decay)
train_net = TrainNetWrapper(gcn_net, label, train_mask, config)

loss_list = []
for epoch in range(config.epochs):
t = time.time()

train_result = train_net()
train_loss = train_result[0].asnumpy()
train_accuracy = train_result[1].asnumpy()

eval_result = eval_net()
eval_loss = eval_result[0].asnumpy()
eval_accuracy = eval_result[1].asnumpy()

loss_list.append(eval_loss)
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
"train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
"val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))

if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
print("Early stopping...")
break

test_result = test_net()
test_loss = test_result[0].asnumpy()
test_accuracy = test_result[1].asnumpy()
print("Test set results:", "loss=", "{:.5f}".format(test_loss),
"accuracy=", "{:.5f}".format(test_accuracy))
assert test_accuracy > 0.812

Loading…
Cancel
Save