You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_eval.py 5.1 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # coding: UTF-8
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from sklearn import metrics
  7. import time
  8. from utils import get_time_dif
  9. from tensorboardX import SummaryWriter
  10. # 权重初始化,默认xavier
  11. def init_network(model, method='xavier', exclude='embedding', seed=123):
  12. for name, w in model.named_parameters():
  13. if exclude not in name:
  14. if 'weight' in name:
  15. if method == 'xavier':
  16. nn.init.xavier_normal_(w)
  17. elif method == 'kaiming':
  18. nn.init.kaiming_normal_(w)
  19. else:
  20. nn.init.normal_(w)
  21. elif 'bias' in name:
  22. nn.init.constant_(w, 0)
  23. else:
  24. pass
  25. def train(config, model, train_iter, dev_iter, test_iter):
  26. start_time = time.time()
  27. model.train()
  28. optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
  29. # optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
  30. # 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
  31. # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
  32. total_batch = 0 # 记录进行到多少batch
  33. dev_best_loss = float('inf')
  34. last_improve = 0 # 记录上次验证集loss下降的batch数
  35. flag = False # 记录是否很久没有效果提升
  36. writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
  37. for epoch in range(config.num_epochs):
  38. print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
  39. # scheduler.step() # 学习率衰减
  40. for i, (trains, labels) in enumerate(train_iter):
  41. outputs = model(trains)
  42. model.zero_grad()
  43. loss = F.cross_entropy(outputs, labels)
  44. loss.backward()
  45. optimizer.step()
  46. if total_batch % 100 == 0:
  47. # 每多少轮输出在训练集和验证集上的效果
  48. true = labels.data.cpu()
  49. predic = torch.max(outputs.data, 1)[1].cpu()
  50. train_acc = metrics.accuracy_score(true, predic)
  51. dev_acc, dev_loss = evaluate(config, model, dev_iter)
  52. if dev_loss < dev_best_loss:
  53. dev_best_loss = dev_loss
  54. best_model = model
  55. torch.save(model.state_dict(), config.save_path)
  56. improve = '*'
  57. last_improve = total_batch
  58. else:
  59. improve = ''
  60. time_dif = get_time_dif(start_time)
  61. msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
  62. print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
  63. writer.add_scalar("loss/train", loss.item(), total_batch)
  64. writer.add_scalar("loss/dev", dev_loss, total_batch)
  65. writer.add_scalar("acc/train", train_acc, total_batch)
  66. writer.add_scalar("acc/dev", dev_acc, total_batch)
  67. model.train()
  68. total_batch += 1
  69. if total_batch - last_improve > config.require_improvement:
  70. # 验证集loss超过1000batch没下降,结束训练
  71. print("No optimization for a long time, auto-stopping...")
  72. flag = True
  73. break
  74. if flag:
  75. break
  76. writer.close()
  77. test(config, model, test_iter)
  78. def test(config, model, test_iter):
  79. # test
  80. model.load_state_dict(torch.load(config.save_path))
  81. model.eval()
  82. start_time = time.time()
  83. test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
  84. msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
  85. print(msg.format(test_loss, test_acc))
  86. print("Precision, Recall and F1-Score...")
  87. print(test_report)
  88. print("Confusion Matrix...")
  89. print(test_confusion)
  90. time_dif = get_time_dif(start_time)
  91. print("Time usage:", time_dif)
  92. def evaluate(config, model, data_iter, test=False):
  93. model.eval()
  94. loss_total = 0
  95. predict_all = np.array([], dtype=int)
  96. labels_all = np.array([], dtype=int)
  97. with torch.no_grad():
  98. for texts, labels in data_iter:
  99. outputs = model(texts)
  100. loss = F.cross_entropy(outputs, labels)
  101. loss_total += loss
  102. labels = labels.data.cpu().numpy()
  103. predic = torch.max(outputs.data, 1)[1].cpu().numpy()
  104. labels_all = np.append(labels_all, labels)
  105. predict_all = np.append(predict_all, predic)
  106. acc = metrics.accuracy_score(labels_all, predict_all)
  107. if test:
  108. report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
  109. confusion = metrics.confusion_matrix(labels_all, predict_all)
  110. return acc, loss_total / len(data_iter), report, confusion
  111. return acc, loss_total / len(data_iter)

No Description