| @@ -0,0 +1,22 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| class ConfigGCN(): | |||
| learning_rate = 0.01 | |||
| epochs = 200 | |||
| hidden1 = 16 | |||
| dropout = 0.0 | |||
| weight_decay = 5e-4 | |||
| early_stopping = 10 | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import scipy.sparse as sp | |||
| import mindspore.dataset as ds | |||
| def normalize_adj(adj): | |||
| rowsum = np.array(adj.sum(1)) | |||
| d_inv_sqrt = np.power(rowsum, -0.5).flatten() | |||
| d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. | |||
| d_mat_inv_sqrt = sp.diags(d_inv_sqrt) | |||
| return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() | |||
| def get_adj_features_labels(data_dir): | |||
| g = ds.GraphData(data_dir) | |||
| nodes = g.get_all_nodes(0) | |||
| nodes_list = nodes.tolist() | |||
| row_tensor = g.get_node_feature(nodes_list, [1, 2]) | |||
| features = row_tensor[0] | |||
| labels = row_tensor[1] | |||
| nodes_num = labels.shape[0] | |||
| class_num = labels.max() + 1 | |||
| labels_onehot = np.eye(nodes_num, class_num)[labels].astype(np.float32) | |||
| neighbor = g.get_all_neighbors(nodes_list, 0) | |||
| node_map = {node_id: index for index, node_id in enumerate(nodes_list)} | |||
| adj = np.zeros([nodes_num, nodes_num], dtype=np.float32) | |||
| for index, value in np.ndenumerate(neighbor): | |||
| # The first column of neighbor is node_id, second column to last column are neighbors of the first column. | |||
| # So we only care index[1] > 1. | |||
| # If the node does not have that many neighbors, -1 is padded. So if value < 0, we will not deal with it. | |||
| if value >= 0 and index[1] > 0: | |||
| adj[node_map[neighbor[index[0], 0]], node_map[value]] = 1 | |||
| adj = sp.coo_matrix(adj) | |||
| adj = adj + adj.T.multiply(adj.T > adj) + sp.eye(nodes_num) | |||
| nor_adj = normalize_adj(adj) | |||
| nor_adj = np.array(nor_adj.todense()) | |||
| return nor_adj, features, labels_onehot | |||
| def get_mask(total, begin, end): | |||
| mask = np.zeros([total]).astype(np.float32) | |||
| mask[begin:end] = 1 | |||
| return mask | |||
| @@ -0,0 +1,163 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| from mindspore import nn | |||
| from mindspore.common.parameter import ParameterTuple | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore import Tensor | |||
| from mindspore.nn.layer.activation import get_activation | |||
| from src.metrics import Loss, Accuracy | |||
| def glorot(shape): | |||
| init_range = np.sqrt(6.0/(shape[0]+shape[1])) | |||
| initial = np.random.uniform(-init_range, init_range, shape).astype(np.float32) | |||
| return Tensor(initial) | |||
| class GraphConvolution(nn.Cell): | |||
| def __init__(self, | |||
| feature_in_dim, | |||
| feature_out_dim, | |||
| dropout_ratio=None, | |||
| activation=None): | |||
| super(GraphConvolution, self).__init__() | |||
| self.in_dim = feature_in_dim | |||
| self.out_dim = feature_out_dim | |||
| self.weight_init = glorot([self.out_dim, self.in_dim]) | |||
| self.fc = nn.Dense(self.in_dim, | |||
| self.out_dim, | |||
| weight_init=self.weight_init, | |||
| has_bias=False) | |||
| self.dropout_ratio = dropout_ratio | |||
| if self.dropout_ratio is not None: | |||
| self.dropout = nn.Dropout(keep_prob=1-self.dropout_ratio) | |||
| self.dropout_flag = self.dropout_ratio is not None | |||
| self.activation = get_activation(activation) | |||
| self.activation_flag = self.activation is not None | |||
| self.matmul = P.MatMul() | |||
| def construct(self, adj, input_feature): | |||
| dropout = input_feature | |||
| if self.dropout_flag: | |||
| dropout = self.dropout(dropout) | |||
| fc = self.fc(dropout) | |||
| output_feature = self.matmul(adj, fc) | |||
| if self.activation_flag: | |||
| output_feature = self.activation(output_feature) | |||
| return output_feature | |||
| class GCN(nn.Cell): | |||
| def __init__(self, config, adj, feature, output_dim): | |||
| super(GCN, self).__init__() | |||
| self.adj = Tensor(adj) | |||
| self.feature = Tensor(feature) | |||
| input_dim = feature.shape[1] | |||
| self.layer0 = GraphConvolution(input_dim, config.hidden1, activation="relu", dropout_ratio=config.dropout) | |||
| self.layer1 = GraphConvolution(config.hidden1, output_dim, dropout_ratio=None) | |||
| def construct(self): | |||
| output0 = self.layer0(self.adj, self.feature) | |||
| output1 = self.layer1(self.adj, output0) | |||
| return output1 | |||
| class LossAccuracyWrapper(nn.Cell): | |||
| def __init__(self, network, label, mask, weight_decay): | |||
| super(LossAccuracyWrapper, self).__init__() | |||
| self.network = network | |||
| self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | |||
| self.accuracy = Accuracy(label, mask) | |||
| def construct(self): | |||
| preds = self.network() | |||
| loss = self.loss(preds) | |||
| accuracy = self.accuracy(preds) | |||
| return loss, accuracy | |||
| class LossWrapper(nn.Cell): | |||
| def __init__(self, network, label, mask, weight_decay): | |||
| super(LossWrapper, self).__init__() | |||
| self.network = network | |||
| self.loss = Loss(label, mask, weight_decay, network.trainable_params()[0]) | |||
| def construct(self): | |||
| preds = self.network() | |||
| loss = self.loss(preds) | |||
| return loss | |||
| class TrainOneStepCell(nn.Cell): | |||
| r""" | |||
| Network training package class. | |||
| Wraps the network with an optimizer. The resulting Cell be trained without inputs. | |||
| Backward graph will be created in the construct function to do parameter updating. Different | |||
| parallel modes are available to run the training. | |||
| Args: | |||
| network (Cell): The training network. | |||
| optimizer (Cell): Optimizer for updating the weights. | |||
| sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||
| Outputs: | |||
| Tensor, a scalar Tensor with shape :math:`()`. | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> loss_net = nn.WithLossCell(net, loss_fn) | |||
| >>> train_net = nn.TrainOneStepCell(loss_net, optim) | |||
| """ | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.add_flags(defer_inline=True) | |||
| self.weights = ParameterTuple(network.trainable_params()) | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| def construct(self): | |||
| weights = self.weights | |||
| loss = self.network() | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(sens) | |||
| return F.depend(loss, self.optimizer(grads)) | |||
| class TrainNetWrapper(nn.Cell): | |||
| def __init__(self, network, label, mask, config): | |||
| super(TrainNetWrapper, self).__init__(auto_prefix=True) | |||
| self.network = network | |||
| loss_net = LossWrapper(network, label, mask, config.weight_decay) | |||
| optimizer = nn.Adam(loss_net.trainable_params(), | |||
| learning_rate=config.learning_rate) | |||
| self.loss_train_net = TrainOneStepCell(loss_net, optimizer) | |||
| self.accuracy = Accuracy(label, mask) | |||
| def construct(self): | |||
| loss = self.loss_train_net() | |||
| accuracy = self.accuracy(self.network()) | |||
| return loss, accuracy | |||
| @@ -0,0 +1,68 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore import nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| class Loss(nn.Cell): | |||
| def __init__(self, label, mask, weight_decay, param): | |||
| super(Loss, self).__init__() | |||
| self.label = Tensor(label) | |||
| self.mask = Tensor(mask) | |||
| self.loss = P.SoftmaxCrossEntropyWithLogits() | |||
| self.one = Tensor(1.0, mstype.float32) | |||
| self.zero = Tensor(0.0, mstype.float32) | |||
| self.mean = P.ReduceMean() | |||
| self.cast = P.Cast() | |||
| self.l2_loss = P.L2Loss() | |||
| self.reduce_sum = P.ReduceSum() | |||
| self.weight_decay = weight_decay | |||
| self.param = param | |||
| def construct(self, preds): | |||
| param = self.l2_loss(self.param) | |||
| loss = self.weight_decay * param | |||
| preds = self.cast(preds, mstype.float32) | |||
| loss = loss + self.loss(preds, self.label)[0] | |||
| mask = self.cast(self.mask, mstype.float32) | |||
| mask_reduce = self.mean(mask) | |||
| mask = mask / mask_reduce | |||
| loss = loss * mask | |||
| loss = self.mean(loss) | |||
| return loss | |||
| class Accuracy(nn.Cell): | |||
| def __init__(self, label, mask): | |||
| super(Accuracy, self).__init__() | |||
| self.label = Tensor(label) | |||
| self.mask = Tensor(mask) | |||
| self.equal = P.Equal() | |||
| self.argmax = P.Argmax() | |||
| self.cast = P.Cast() | |||
| self.mean = P.ReduceMean() | |||
| def construct(self, preds): | |||
| preds = self.cast(preds, mstype.float32) | |||
| correct_prediction = self.equal(self.argmax(preds), self.argmax(self.label)) | |||
| accuracy_all = self.cast(correct_prediction, mstype.float32) | |||
| mask = self.cast(self.mask, mstype.float32) | |||
| mask_reduce = self.mean(mask) | |||
| mask = mask / mask_reduce | |||
| accuracy_all *= mask | |||
| return self.mean(accuracy_all) | |||
| @@ -0,0 +1,83 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import time | |||
| import pytest | |||
| import numpy as np | |||
| from mindspore import context | |||
| from src.gcn import GCN, LossAccuracyWrapper, TrainNetWrapper | |||
| from src.config import ConfigGCN | |||
| from src.dataset import get_adj_features_labels, get_mask | |||
| DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr' | |||
| TRAIN_NODE_NUM = 140 | |||
| EVAL_NODE_NUM = 500 | |||
| TEST_NODE_NUM = 1000 | |||
| SEED = 20 | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_gcn(): | |||
| print("test_gcn begin") | |||
| np.random.seed(SEED) | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="Ascend", save_graphs=True) | |||
| config = ConfigGCN() | |||
| adj, feature, label = get_adj_features_labels(DATA_DIR) | |||
| nodes_num = label.shape[0] | |||
| train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM) | |||
| eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM, TRAIN_NODE_NUM + EVAL_NODE_NUM) | |||
| test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num) | |||
| class_num = label.shape[1] | |||
| gcn_net = GCN(config, adj, feature, class_num) | |||
| gcn_net.add_flags_recursive(fp16=True) | |||
| eval_net = LossAccuracyWrapper(gcn_net, label, eval_mask, config.weight_decay) | |||
| test_net = LossAccuracyWrapper(gcn_net, label, test_mask, config.weight_decay) | |||
| train_net = TrainNetWrapper(gcn_net, label, train_mask, config) | |||
| loss_list = [] | |||
| for epoch in range(config.epochs): | |||
| t = time.time() | |||
| train_result = train_net() | |||
| train_loss = train_result[0].asnumpy() | |||
| train_accuracy = train_result[1].asnumpy() | |||
| eval_result = eval_net() | |||
| eval_loss = eval_result[0].asnumpy() | |||
| eval_accuracy = eval_result[1].asnumpy() | |||
| loss_list.append(eval_loss) | |||
| print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss), | |||
| "train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss), | |||
| "val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t)) | |||
| if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]): | |||
| print("Early stopping...") | |||
| break | |||
| test_result = test_net() | |||
| test_loss = test_result[0].asnumpy() | |||
| test_accuracy = test_result[1].asnumpy() | |||
| print("Test set results:", "loss=", "{:.5f}".format(test_loss), | |||
| "accuracy=", "{:.5f}".format(test_accuracy)) | |||
| assert test_accuracy > 0.812 | |||