You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gcn.py 3.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import time
  16. import pytest
  17. import numpy as np
  18. from mindspore import context
  19. from model_zoo.gcn.src.gcn import GCN
  20. from model_zoo.gcn.src.metrics import LossAccuracyWrapper, TrainNetWrapper
  21. from model_zoo.gcn.src.config import ConfigGCN
  22. from model_zoo.gcn.src.dataset import get_adj_features_labels, get_mask
  23. DATA_DIR = '/home/workspace/mindspore_dataset/cora/cora_mr/cora_mr'
  24. TRAIN_NODE_NUM = 140
  25. EVAL_NODE_NUM = 500
  26. TEST_NODE_NUM = 1000
  27. SEED = 20
  28. @pytest.mark.level0
  29. @pytest.mark.platform_arm_ascend_training
  30. @pytest.mark.platform_x86_ascend_training
  31. @pytest.mark.env_onecard
  32. def test_gcn():
  33. print("test_gcn begin")
  34. np.random.seed(SEED)
  35. context.set_context(mode=context.GRAPH_MODE,
  36. device_target="Ascend", save_graphs=False)
  37. config = ConfigGCN()
  38. config.dropout = 0.0
  39. adj, feature, label_onehot, _ = get_adj_features_labels(DATA_DIR)
  40. nodes_num = label_onehot.shape[0]
  41. train_mask = get_mask(nodes_num, 0, TRAIN_NODE_NUM)
  42. eval_mask = get_mask(nodes_num, TRAIN_NODE_NUM, TRAIN_NODE_NUM + EVAL_NODE_NUM)
  43. test_mask = get_mask(nodes_num, nodes_num - TEST_NODE_NUM, nodes_num)
  44. class_num = label_onehot.shape[1]
  45. gcn_net = GCN(config, adj, feature, class_num)
  46. gcn_net.add_flags_recursive(fp16=True)
  47. eval_net = LossAccuracyWrapper(gcn_net, label_onehot, eval_mask, config.weight_decay)
  48. test_net = LossAccuracyWrapper(gcn_net, label_onehot, test_mask, config.weight_decay)
  49. train_net = TrainNetWrapper(gcn_net, label_onehot, train_mask, config)
  50. loss_list = []
  51. for epoch in range(config.epochs):
  52. t = time.time()
  53. train_net.set_train()
  54. train_result = train_net()
  55. train_loss = train_result[0].asnumpy()
  56. train_accuracy = train_result[1].asnumpy()
  57. eval_net.set_train(False)
  58. eval_result = eval_net()
  59. eval_loss = eval_result[0].asnumpy()
  60. eval_accuracy = eval_result[1].asnumpy()
  61. loss_list.append(eval_loss)
  62. print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
  63. "train_acc=", "{:.5f}".format(train_accuracy), "val_loss=", "{:.5f}".format(eval_loss),
  64. "val_acc=", "{:.5f}".format(eval_accuracy), "time=", "{:.5f}".format(time.time() - t))
  65. if epoch > config.early_stopping and loss_list[-1] > np.mean(loss_list[-(config.early_stopping+1):-1]):
  66. print("Early stopping...")
  67. break
  68. test_net.set_train(False)
  69. test_result = test_net()
  70. test_loss = test_result[0].asnumpy()
  71. test_accuracy = test_result[1].asnumpy()
  72. print("Test set results:", "loss=", "{:.5f}".format(test_loss),
  73. "accuracy=", "{:.5f}".format(test_accuracy))
  74. assert test_accuracy > 0.812