| @@ -19,7 +19,6 @@ from mindspore.ops import functional as F | |||
| from mindspore._extends import cell_attr_register | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore._checkparam import Validator | |||
| from mindspore.nn.layer.activation import get_activation | |||
| @@ -72,9 +71,9 @@ class GNNFeatureTransform(nn.Cell): | |||
| bias_init='zeros', | |||
| has_bias=True): | |||
| super(GNNFeatureTransform, self).__init__() | |||
| self.in_channels = Validator.check_positive_int(in_channels) | |||
| self.out_channels = Validator.check_positive_int(out_channels) | |||
| self.has_bias = Validator.check_bool(has_bias) | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.has_bias = has_bias | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| @@ -259,8 +258,8 @@ class AttentionHead(nn.Cell): | |||
| coef_activation=nn.LeakyReLU(), | |||
| activation=nn.ELU()): | |||
| super(AttentionHead, self).__init__() | |||
| self.in_channel = Validator.check_positive_int(in_channel) | |||
| self.out_channel = Validator.check_positive_int(out_channel) | |||
| self.in_channel = in_channel | |||
| self.out_channel = out_channel | |||
| self.in_drop_ratio = in_drop_ratio | |||
| self.in_drop = nn.Dropout(keep_prob=1 - in_drop_ratio) | |||
| self.in_drop_2 = nn.Dropout(keep_prob=1 - in_drop_ratio) | |||
| @@ -284,7 +283,7 @@ class AttentionHead(nn.Cell): | |||
| self.matmul = P.MatMul() | |||
| self.bias_add = P.BiasAdd() | |||
| self.bias = Parameter(initializer('zeros', self.out_channel), name='bias') | |||
| self.residual = Validator.check_bool(residual) | |||
| self.residual = residual | |||
| if self.residual: | |||
| if in_channel != out_channel: | |||
| self.residual_transform_flag = True | |||
| @@ -436,8 +435,6 @@ class GAT(nn.Cell): | |||
| """ | |||
| def __init__(self, | |||
| features, | |||
| biases, | |||
| ftr_dims, | |||
| num_class, | |||
| num_nodes, | |||
| @@ -448,17 +445,15 @@ class GAT(nn.Cell): | |||
| activation=nn.ELU(), | |||
| residual=False): | |||
| super(GAT, self).__init__() | |||
| self.features = Tensor(features) | |||
| self.biases = Tensor(biases) | |||
| self.ftr_dims = Validator.check_positive_int(ftr_dims) | |||
| self.num_class = Validator.check_positive_int(num_class) | |||
| self.num_nodes = Validator.check_positive_int(num_nodes) | |||
| self.ftr_dims = ftr_dims | |||
| self.num_class = num_class | |||
| self.num_nodes = num_nodes | |||
| self.hidden_units = hidden_units | |||
| self.num_heads = num_heads | |||
| self.attn_drop = attn_drop | |||
| self.ftr_drop = ftr_drop | |||
| self.activation = activation | |||
| self.residual = Validator.check_bool(residual) | |||
| self.residual = residual | |||
| self.layers = [] | |||
| # first layer | |||
| self.layers.append(AttentionAggregator( | |||
| @@ -491,9 +486,9 @@ class GAT(nn.Cell): | |||
| output_transform='sum')) | |||
| self.layers = nn.layer.CellList(self.layers) | |||
| def construct(self, training=True): | |||
| input_data = self.features | |||
| bias_mat = self.biases | |||
| def construct(self, feature, biases, training=True): | |||
| input_data = feature | |||
| bias_mat = biases | |||
| for cell in self.layers: | |||
| input_data = cell(input_data, bias_mat, training) | |||
| return input_data/self.num_heads[-1] | |||
| @@ -103,8 +103,8 @@ class LossAccuracyWrapper(nn.Cell): | |||
| self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, self.network.trainable_params()) | |||
| self.acc_func = MaskedAccuracy(num_class, label, mask) | |||
| def construct(self): | |||
| logits = self.network(training=False) | |||
| def construct(self, feature, biases): | |||
| logits = self.network(feature, biases, training=False) | |||
| loss = self.loss_func(logits) | |||
| accuracy = self.acc_func(logits) | |||
| return loss, accuracy | |||
| @@ -120,8 +120,8 @@ class LossNetWrapper(nn.Cell): | |||
| params = list(param for param in self.network.trainable_params() if param.name[-4:] != 'bias') | |||
| self.loss_func = MaskedSoftMaxLoss(num_class, label, mask, l2_coeff, params) | |||
| def construct(self): | |||
| logits = self.network() | |||
| def construct(self, feature, biases): | |||
| logits = self.network(feature, biases) | |||
| loss = self.loss_func(logits) | |||
| return loss | |||
| @@ -145,11 +145,11 @@ class TrainOneStepCell(nn.Cell): | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| def construct(self): | |||
| def construct(self, feature, biases): | |||
| weights = self.weights | |||
| loss = self.network() | |||
| loss = self.network(feature, biases) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(sens) | |||
| grads = self.grad(self.network, weights)(feature, biases, sens) | |||
| return F.depend(loss, self.optimizer(grads)) | |||
| @@ -174,7 +174,7 @@ class TrainGAT(nn.Cell): | |||
| self.loss_train_net = TrainOneStepCell(loss_net, optimizer) | |||
| self.accuracy_func = MaskedAccuracy(num_class, label, mask) | |||
| def construct(self): | |||
| loss = self.loss_train_net() | |||
| accuracy = self.accuracy_func(self.network()) | |||
| def construct(self, feature, biases): | |||
| loss = self.loss_train_net(feature, biases) | |||
| accuracy = self.accuracy_func(self.network(feature, biases)) | |||
| return loss, accuracy | |||
| @@ -20,6 +20,7 @@ import numpy as np | |||
| import mindspore.context as context | |||
| from mindspore.train.serialization import save_checkpoint, load_checkpoint | |||
| from mindspore.common import set_seed | |||
| from mindspore import Tensor | |||
| from src.config import GatConfig | |||
| from src.dataset import load_and_process | |||
| @@ -56,9 +57,7 @@ def train(): | |||
| num_nodes = feature.shape[1] | |||
| num_class = y_train.shape[2] | |||
| gat_net = GAT(feature, | |||
| biases, | |||
| feature_size, | |||
| gat_net = GAT(feature_size, | |||
| num_class, | |||
| num_nodes, | |||
| hid_units, | |||
| @@ -67,6 +66,9 @@ def train(): | |||
| ftr_drop=GatConfig.feature_dropout) | |||
| gat_net.add_flags_recursive(fp16=True) | |||
| feature = Tensor(feature) | |||
| biases = Tensor(biases) | |||
| eval_net = LossAccuracyWrapper(gat_net, | |||
| num_class, | |||
| y_val, | |||
| @@ -84,11 +86,11 @@ def train(): | |||
| val_acc_max = 0.0 | |||
| val_loss_min = np.inf | |||
| for _epoch in range(num_epochs): | |||
| train_result = train_net() | |||
| train_result = train_net(feature, biases) | |||
| train_loss = train_result[0].asnumpy() | |||
| train_acc = train_result[1].asnumpy() | |||
| eval_result = eval_net() | |||
| eval_result = eval_net(feature, biases) | |||
| eval_loss = eval_result[0].asnumpy() | |||
| eval_acc = eval_result[1].asnumpy() | |||
| @@ -110,9 +112,7 @@ def train(): | |||
| print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max)) | |||
| print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model)) | |||
| break | |||
| gat_net_test = GAT(feature, | |||
| biases, | |||
| feature_size, | |||
| gat_net_test = GAT(feature_size, | |||
| num_class, | |||
| num_nodes, | |||
| hid_units, | |||
| @@ -127,7 +127,7 @@ def train(): | |||
| y_test, | |||
| test_mask, | |||
| l2_coeff) | |||
| test_result = test_net() | |||
| test_result = test_net(feature, biases) | |||
| print("Test loss={}, test acc={}".format(test_result[0], test_result[1])) | |||
| @@ -92,15 +92,12 @@ class GCN(nn.Cell): | |||
| output_dim (int): The number of output channels, equal to classes num. | |||
| """ | |||
| def __init__(self, config, adj, feature, output_dim): | |||
| def __init__(self, config, input_dim, output_dim): | |||
| super(GCN, self).__init__() | |||
| self.adj = Tensor(adj) | |||
| self.feature = Tensor(feature) | |||
| input_dim = feature.shape[1] | |||
| self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout) | |||
| self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None) | |||
| def construct(self): | |||
| output0 = self.layer0(self.adj, self.feature) | |||
| output1 = self.layer1(self.adj, output0) | |||
| def construct(self, adj, feature): | |||
| output0 = self.layer0(adj, feature) | |||
| output1 = self.layer1(adj, output0) | |||
| return output1 | |||
| @@ -91,8 +91,8 @@ class LossAccuracyWrapper(nn.Cell): | |||
| self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | |||
| self.accuracy = Accuracy(label, mask) | |||
| def construct(self): | |||
| preds = self.network() | |||
| def construct(self, adj, feature): | |||
| preds = self.network(adj, feature) | |||
| loss = self.loss(preds) | |||
| accuracy = self.accuracy(preds) | |||
| return loss, accuracy | |||
| @@ -114,8 +114,8 @@ class LossWrapper(nn.Cell): | |||
| self.network = network | |||
| self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | |||
| def construct(self): | |||
| preds = self.network() | |||
| def construct(self, adj, feature): | |||
| preds = self.network(adj, feature) | |||
| loss = self.loss(preds) | |||
| return loss | |||
| @@ -154,11 +154,11 @@ class TrainOneStepCell(nn.Cell): | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| def construct(self): | |||
| def construct(self, adj, feature): | |||
| weights = self.weights | |||
| loss = self.network() | |||
| loss = self.network(adj, feature) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(sens) | |||
| grads = self.grad(self.network, weights)(adj, feature, sens) | |||
| return F.depend(loss, self.optimizer(grads)) | |||
| @@ -182,7 +182,7 @@ class TrainNetWrapper(nn.Cell): | |||
| self.loss_train_net = TrainOneStepCell(loss_net, optimizer) | |||
| self.accuracy = Accuracy(label, mask) | |||
| def construct(self): | |||
| loss = self.loss_train_net() | |||
| accuracy = self.accuracy(self.network()) | |||
| def construct(self, adj, feature): | |||
| loss = self.loss_train_net(adj, feature) | |||
| accuracy = self.accuracy(self.network(adj, feature)) | |||
| return loss, accuracy | |||
| @@ -26,6 +26,7 @@ from matplotlib import pyplot as plt | |||
| from matplotlib import animation | |||
| from sklearn import manifold | |||
| from mindspore import context | |||
| from mindspore import Tensor | |||
| from mindspore.common import set_seed | |||
| from mindspore.train.serialization import save_checkpoint, load_checkpoint | |||
| @@ -71,9 +72,13 @@ def train(): | |||
| test_mask = get_mask(nodes_num, nodes_num - args_opt.test_nodes_num, nodes_num) | |||
| class_num = label_onehot.shape[1] | |||
| gcn_net = GCN(config, adj, feature, class_num) | |||
| input_dim = feature.shape[1] | |||
| gcn_net = GCN(config, input_dim, class_num) | |||
| gcn_net.add_flags_recursive(fp16=True) | |||
| adj = Tensor(adj) | |||
| feature = Tensor(feature) | |||
| eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay) | |||
| train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config) | |||
| @@ -92,12 +97,12 @@ def train(): | |||
| t = time.time() | |||
| train_net.set_train() | |||
| train_result = train_net() | |||
| train_result = train_net(adj, feature) | |||
| train_loss = train_result[0].asnumpy() | |||
| train_accuracy = train_result[1].asnumpy() | |||
| eval_net.set_train(False) | |||
| eval_result = eval_net() | |||
| eval_result = eval_net(adj, feature) | |||
| eval_loss = eval_result[0].asnumpy() | |||
| eval_accuracy = eval_result[1].asnumpy() | |||
| @@ -115,14 +120,14 @@ def train(): | |||
| print("Early stopping...") | |||
| break | |||
| save_checkpoint(gcn_net, "ckpts/gcn.ckpt") | |||
| gcn_net_test = GCN(config, adj, feature, class_num) | |||
| gcn_net_test = GCN(config, input_dim, class_num) | |||
| load_checkpoint("ckpts/gcn.ckpt", net=gcn_net_test) | |||
| gcn_net_test.add_flags_recursive(fp16=True) | |||
| test_net = LossAccuracyWrapper(gcn_net_test, label_onehot, test_mask, config.weight_decay) | |||
| t_test = time.time() | |||
| test_net.set_train(False) | |||
| test_result = test_net() | |||
| test_result = test_net(adj, feature) | |||
| test_loss = test_result[0].asnumpy() | |||
| test_accuracy = test_result[1].asnumpy() | |||
| print("Test set results:", "loss=", "{:.5f}".format(test_loss), | |||
| @@ -17,6 +17,7 @@ import time | |||
| import pytest | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore import Tensor | |||
| from model_zoo.official.gnn.gcn.src.gcn import GCN | |||
| from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper | |||
| from model_zoo.official.gnn.gcn.src.config import ConfigGCN | |||
| @@ -49,9 +50,13 @@ def test_gcn(): | |||
| test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num) | |||
| class_num = label_onehot.shape[1] | |||
| gcn_net = GCN(config, adj, feature, class_num) | |||
| input_dim = feature.shape[1] | |||
| gcn_net = GCN(config, input_dim, class_num) | |||
| gcn_net.add_flags_recursive(fp16=True) | |||
| adj = Tensor(adj) | |||
| feature = Tensor(feature) | |||
| eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay) | |||
| test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay) | |||
| train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config) | |||
| @@ -61,12 +66,12 @@ def test_gcn(): | |||
| t = time.time() | |||
| train_net.set_train() | |||
| train_result = train_net() | |||
| train_result = train_net(adj, feature) | |||
| train_loss = train_result[0].asnumpy() | |||
| train_accuracy = train_result[1].asnumpy() | |||
| eval_net.set_train(False) | |||
| eval_result = eval_net() | |||
| eval_result = eval_net(adj, feature) | |||
| eval_loss = eval_result[0].asnumpy() | |||
| eval_accuracy = eval_result[1].asnumpy() | |||
| @@ -80,7 +85,7 @@ def test_gcn(): | |||
| break | |||
| test_net.set_train(False) | |||
| test_result = test_net() | |||
| test_result = test_net(adj, feature) | |||
| test_loss = test_result[0].asnumpy() | |||
| test_accuracy = test_result[1].asnumpy() | |||
| print("Test set results:", "loss=", "{:.5f}".format(test_loss), | |||