You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.5 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test train gat"""
  16. import argparse
  17. import os
  18. import numpy as np
  19. import mindspore.context as context
  20. from mindspore.train.serialization import save_checkpoint, load_checkpoint
  21. from mindspore.common import set_seed
  22. from src.config import GatConfig
  23. from src.dataset import load_and_process
  24. from src.gat import GAT
  25. from src.utils import LossAccuracyWrapper, TrainGAT
  26. set_seed(0)
  27. def train():
  28. """Train GAT model."""
  29. parser = argparse.ArgumentParser()
  30. parser.add_argument('--data_dir', type=str, default='./data/cora/cora_mr', help='Data dir')
  31. parser.add_argument('--train_nodes_num', type=int, default=140, help='Nodes numbers for training')
  32. parser.add_argument('--eval_nodes_num', type=int, default=500, help='Nodes numbers for evaluation')
  33. parser.add_argument('--test_nodes_num', type=int, default=1000, help='Nodes numbers for test')
  34. args = parser.parse_args()
  35. if not os.path.exists("ckpts"):
  36. os.mkdir("ckpts")
  37. context.set_context(mode=context.GRAPH_MODE,
  38. device_target="Ascend",
  39. save_graphs=False)
  40. # train parameters
  41. hid_units = GatConfig.hid_units
  42. n_heads = GatConfig.n_heads
  43. early_stopping = GatConfig.early_stopping
  44. lr = GatConfig.lr
  45. l2_coeff = GatConfig.l2_coeff
  46. num_epochs = GatConfig.num_epochs
  47. feature, biases, y_train, train_mask, y_val, eval_mask, y_test, test_mask = load_and_process(args.data_dir,
  48. args.train_nodes_num,
  49. args.eval_nodes_num,
  50. args.test_nodes_num)
  51. feature_size = feature.shape[2]
  52. num_nodes = feature.shape[1]
  53. num_class = y_train.shape[2]
  54. gat_net = GAT(feature,
  55. biases,
  56. feature_size,
  57. num_class,
  58. num_nodes,
  59. hid_units,
  60. n_heads,
  61. attn_drop=GatConfig.attn_dropout,
  62. ftr_drop=GatConfig.feature_dropout)
  63. gat_net.add_flags_recursive(fp16=True)
  64. eval_net = LossAccuracyWrapper(gat_net,
  65. num_class,
  66. y_val,
  67. eval_mask,
  68. l2_coeff)
  69. train_net = TrainGAT(gat_net,
  70. num_class,
  71. y_train,
  72. train_mask,
  73. lr,
  74. l2_coeff)
  75. train_net.set_train(True)
  76. val_acc_max = 0.0
  77. val_loss_min = np.inf
  78. for _epoch in range(num_epochs):
  79. train_result = train_net()
  80. train_loss = train_result[0].asnumpy()
  81. train_acc = train_result[1].asnumpy()
  82. eval_result = eval_net()
  83. eval_loss = eval_result[0].asnumpy()
  84. eval_acc = eval_result[1].asnumpy()
  85. print("Epoch:{}, train loss={:.5f}, train acc={:.5f} | val loss={:.5f}, val acc={:.5f}".format(
  86. _epoch, train_loss, train_acc, eval_loss, eval_acc))
  87. if eval_acc >= val_acc_max or eval_loss < val_loss_min:
  88. if eval_acc >= val_acc_max and eval_loss < val_loss_min:
  89. val_acc_model = eval_acc
  90. val_loss_model = eval_loss
  91. if os.path.exists("ckpts/gat.ckpt"):
  92. os.remove("ckpts/gat.ckpt")
  93. save_checkpoint(train_net.network, "ckpts/gat.ckpt")
  94. val_acc_max = np.max((val_acc_max, eval_acc))
  95. val_loss_min = np.min((val_loss_min, eval_loss))
  96. curr_step = 0
  97. else:
  98. curr_step += 1
  99. if curr_step == early_stopping:
  100. print("Early Stop Triggered!, Min loss: {}, Max accuracy: {}".format(val_loss_min, val_acc_max))
  101. print("Early stop model validation loss: {}, accuracy{}".format(val_loss_model, val_acc_model))
  102. break
  103. gat_net_test = GAT(feature,
  104. biases,
  105. feature_size,
  106. num_class,
  107. num_nodes,
  108. hid_units,
  109. n_heads,
  110. attn_drop=0.0,
  111. ftr_drop=0.0)
  112. load_checkpoint("ckpts/gat.ckpt", net=gat_net_test)
  113. gat_net_test.add_flags_recursive(fp16=True)
  114. test_net = LossAccuracyWrapper(gat_net_test,
  115. num_class,
  116. y_test,
  117. test_mask,
  118. l2_coeff)
  119. test_result = test_net()
  120. print("Test loss={}, test acc={}".format(test_result[0], test_result[1]))
  121. if __name__ == "__main__":
  122. train()