You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import os
  2. import
  3. import
  4. import torch
  5. import torch.nn as nn
  6. .dataset as dst
  7. from .model import CNN_text
  8. from torch.autograd import Variable
  9. # Hyper Parameters
  10. batch_size = 50
  11. learning_rate = 0.0001
  12. num_epochs = 20
  13. cuda = True
  14. # split Dataset
  15. dataset = dst.MRDataset()
  16. length = len(dataset)
  17. train_dataset = dataset[:int(0.9 * length)]
  18. test_dataset = dataset[int(0.9 * length):]
  19. train_dataset = dst.train_set(train_dataset)
  20. test_dataset = dst.test_set(test_dataset)
  21. # Data Loader
  22. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  23. batch_size=batch_size,
  24. shuffle=True)
  25. test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
  26. batch_size=batch_size,
  27. shuffle=False)
  28. # cnn
  29. cnn = CNN_text(embed_num=len(dataset.word2id()), pretrained_embeddings=dataset.word_embeddings())
  30. if cuda:
  31. cnn.cuda()
  32. # Loss and Optimizer
  33. criterion = nn.CrossEntropyLoss()
  34. optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
  35. # train and test
  36. best_acc = None
  37. for epoch in range(num_epochs):
  38. # Train the Model
  39. cnn.train()
  40. for i, (sents, labels) in enumerate(train_loader):
  41. sents = Variable(sents)
  42. labels = Variable(labels)
  43. if cuda:
  44. sents = sents.cuda()
  45. labels = labels.cuda()
  46. optimizer.zero_grad()
  47. outputs = cnn(sents)
  48. loss = criterion(outputs, labels)
  49. loss.backward()
  50. optimizer.step()
  51. if (i + 1) % 100 == 0:
  52. print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
  53. % (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.data[0]))
  54. # Test the Model
  55. cnn.eval()
  56. correct = 0
  57. total = 0
  58. for sents, labels in test_loader:
  59. sents = Variable(sents)
  60. if cuda:
  61. sents = sents.cuda()
  62. labels = labels.cuda()
  63. outputs = cnn(sents)
  64. _, predicted = torch.max(outputs.data, 1)
  65. total += labels.size(0)
  66. correct += (predicted == labels).sum()
  67. acc = 100. * correct / total
  68. print('Test Accuracy: %f %%' % (acc))
  69. if best_acc is None or acc > best_acc:
  70. best_acc = acc
  71. if os.path.exists("models") is False:
  72. os.makedirs("models")
  73. torch.save(cnn.state_dict(), 'models/cnn.pkl')
  74. else:
  75. learning_rate = learning_rate * 0.8
  76. print("Best Accuracy: %f %%" % best_acc)
  77. print("Best Model: models/cnn.pkl")

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等