You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from torch.autograd import Variable
  5. from . import dataset as dst
  6. from .model import CNN_text
  7. # Hyper Parameters
  8. batch_size = 50
  9. learning_rate = 0.0001
  10. num_epochs = 20
  11. cuda = True
  12. # split Dataset
  13. dataset = dst.MRDataset()
  14. length = len(dataset)
  15. train_dataset = dataset[:int(0.9 * length)]
  16. test_dataset = dataset[int(0.9 * length):]
  17. train_dataset = dst.train_set(train_dataset)
  18. test_dataset = dst.test_set(test_dataset)
  19. # Data Loader
  20. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  21. batch_size=batch_size,
  22. shuffle=True)
  23. test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
  24. batch_size=batch_size,
  25. shuffle=False)
  26. # cnn
  27. cnn = CNN_text(embed_num=len(dataset.word2id()), pretrained_embeddings=dataset.word_embeddings())
  28. if cuda:
  29. cnn.cuda()
  30. # Loss and Optimizer
  31. criterion = nn.CrossEntropyLoss()
  32. optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
  33. # train and test
  34. best_acc = None
  35. for epoch in range(num_epochs):
  36. # Train the Model
  37. cnn.train()
  38. for i, (sents, labels) in enumerate(train_loader):
  39. sents = Variable(sents)
  40. labels = Variable(labels)
  41. if cuda:
  42. sents = sents.cuda()
  43. labels = labels.cuda()
  44. optimizer.zero_grad()
  45. outputs = cnn(sents)
  46. loss = criterion(outputs, labels)
  47. loss.backward()
  48. optimizer.step()
  49. if (i + 1) % 100 == 0:
  50. print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
  51. % (epoch + 1, num_epochs, i + 1, len(train_dataset) // batch_size, loss.data[0]))
  52. # Test the Model
  53. cnn.eval()
  54. correct = 0
  55. total = 0
  56. for sents, labels in test_loader:
  57. sents = Variable(sents)
  58. if cuda:
  59. sents = sents.cuda()
  60. labels = labels.cuda()
  61. outputs = cnn(sents)
  62. _, predicted = torch.max(outputs.data, 1)
  63. total += labels.size(0)
  64. correct += (predicted == labels).sum()
  65. acc = 100. * correct / total
  66. print('Test Accuracy: %f %%' % (acc))
  67. if best_acc is None or acc > best_acc:
  68. best_acc = acc
  69. if os.path.exists("models") is False:
  70. os.makedirs("models")
  71. torch.save(cnn.state_dict(), 'models/cnn.pkl')
  72. else:
  73. learning_rate = learning_rate * 0.8
  74. print("Best Accuracy: %f %%" % best_acc)
  75. print("Best Model: models/cnn.pkl")

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等