You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test train gat"""
  16. import argparse
  17. import os
  18. import numpy as np
  19. import mindspore.context as context
  20. from mindspore.train.serialization import _exec_save_checkpoint, load_checkpoint
  21. from src.config import GatConfig
  22. from src.dataset import load_and_process
  23. from src.gat import GAT
  24. from src.utils import LossAccuracyWrapper, TrainGAT
  25. def train():
  26. """Train GAT model."""
  27. parser = argparse.ArgumentParser()
  28. parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Data dir')
  29. parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
  30. parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
  31. parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
  32. args = parser.parse_args()
  33. if not os.path.exists("ckpts"):
  34. os.mkdir("ckpts")
  35. context.set_context(mode=context.GRAPH_MODE,
  36. device_target="Ascend",
  37. save_graphs=False)
  38. # train parameters
  39. hid_units = GatConfig.hid_units
  40. n_heads = GatConfig.n_heads
  41. early_stopping = GatConfig.early_stopping
  42. lr = GatConfig.lr
  43. l2_coeff = GatConfig.l2_coeff
  44. num_epochs = GatConfig.num_epochs
  45. feature, biases, y_train, train_mask, y_val, eval_mask, y_test, test_mask = load_and_process(args.data_dir,
  46. args.train_nodes_num,
  47. args.eval_nodes_num,
  48. args.test_nodes_num)
  49. feature_size = feature.shape[2]
  50. num_nodes = feature.shape[1]
  51. num_class = y_train.shape[2]
  52. gat_net = GAT(feature,
  53. biases,
  54. feature_size,
  55. num_class,
  56. num_nodes,
  57. hid_units,
  58. n_heads,
  59. attn_drop=GatConfig.attn_dropout,
  60. ftr_drop=GatConfig.feature_dropout)
  61. gat_net.add_flags_recursive(fp16=True)
  62. eval_net = LossAccuracyWrapper(gat_net,
  63. num_class,
  64. y_val,
  65. eval_mask,
  66. l2_coeff)
  67. train_net = TrainGAT(gat_net,
  68. num_class,
  69. y_train,
  70. train_mask,
  71. lr,
  72. l2_coeff)
  73. train_net.set_train(True)
  74. val_acc_max = 0.0
  75. val_loss_min = np.inf
  76. for _epoch in range(num_epochs):
  77. train_result = train_net()
  78. train_loss = train_result[0].asnumpy()
  79. train_acc = train_result[1].asnumpy()
  80. eval_result = eval_net()
  81. eval_loss = eval_result[0].asnumpy()
  82. eval_acc = eval_result[1].asnumpy()
  83. print("Epoch:{}, train loss={:.5f}, train acc={:.5f} | val loss={:.5f}, val acc={:.5f}".format(
  84. _epoch, train_loss, train_acc, eval_loss, eval_acc))
  85. if eval_acc >= val_acc_max or eval_loss < val_loss_min:
  86. if eval_acc >= val_acc_max and eval_loss < val_loss_min:
  87. val_acc_model = eval_acc
  88. val_loss_model = eval_loss
  89. _exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt")
  90. val_acc_max = np.max((val_acc_max, eval_acc))
  91. val_loss_min = np.min((val_loss_min, eval_loss))
  92. curr_step = 0
  93. else:
  94. curr_step += 1
  95. if curr_step == early_stopping:
  96. print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max))
  97. print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model))
  98. break
  99. gat_net_test = GAT(feature,
  100. biases,
  101. feature_size,
  102. num_class,
  103. num_nodes,
  104. hid_units,
  105. n_heads,
  106. attn_drop=0.0,
  107. ftr_drop=0.0)
  108. load_checkpoint("ckpts/gat.ckpt", net=gat_net_test)
  109. gat_net_test.add_flags_recursive(fp16=True)
  110. test_net = LossAccuracyWrapper(gat_net_test,
  111. num_class,
  112. y_test,
  113. test_mask,
  114. l2_coeff)
  115. test_result = test_net()
  116. print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))
  117. if __name__ == "__main__":
  118. train()