You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gcn.py 3.6 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import time
  16. import pytest
  17. import numpy as np
  18. from mindspore import context
  19. from mindspore import Tensor
  20. from model_zoo.official.gnn.gcn.src.gcn import GCN
  21. from model_zoo.official.gnn.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
  22. from model_zoo.official.gnn.gcn.src.config import ConfigGCN
  23. from model_zoo.official.gnn.gcn.src.dataset import get_adj_features_labels, get_mask
  24. DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr'
  25. TRAIN_NODE_NUM = 140
  26. EVAL_NODE_NUM = 500
  27. TEST_NODE_NUM = 1000
  28. SEED = 20
  29. @pytest.mark.level0
  30. @pytest.mark.platform_arm_ascend_training
  31. @pytest.mark.platform_x86_ascend_training
  32. @pytest.mark.env_onecard
  33. def test_gcn():
  34. print("test_gcn begin")
  35. np.random.seed(SEED)
  36. context.set_context(mode=context.GRAPH_MODE,
  37. device_target="Ascend", save_graphs=False)
  38. config = ConfigGCN()
  39. config.dropout = 0.0
  40. adj, feature, label_onehot, _ = get_adj_features_labels(DATA_DIR)
  41. nodes_num = label_onehot.shape[0]
  42. train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM)
  43. eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM, TRAIN_NODE_NUM + EVAL_NODE_NUM)
  44. test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
  45. class_num = label_onehot.shape[1]
  46. input_dim = feature.shape[1]
  47. gcn_net = GCN(config, input_dim, class_num)
  48. gcn_net.add_flags_recursive(fp16=True)
  49. adj = Tensor(adj)
  50. feature = Tensor(feature)
  51. eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
  52. test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
  53. train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
  54. loss_list = []
  55. for epoch in range(config.epochs):
  56. t = time.time()
  57. train_net.set_train()
  58. train_result = train_net(adj, feature)
  59. train_loss = train_result[0].asnumpy()
  60. train_accuracy = train_result[1].asnumpy()
  61. eval_net.set_train(False)
  62. eval_result = eval_net(adj, feature)
  63. eval_loss = eval_result[0].asnumpy()
  64. eval_accuracy = eval_result[1].asnumpy()
  65. loss_list.append(eval_loss)
  66. print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
  67. "train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
  68. "val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
  69. if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
  70. print("Early stopping...")
  71. break
  72. test_net.set_train(False)
  73. test_result = test_net(adj, feature)
  74. test_loss = test_result[0].asnumpy()
  75. test_accuracy = test_result[1].asnumpy()
  76. print("Test set results:", "loss=", "{:.5f}".format(test_loss),
  77. "accuracy=", "{:.5f}".format(test_accuracy))
  78. assert test_accuracy > 0.812