You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tutorials.py 6.6 kB

7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import unittest
  2. from fastNLP import DataSet
  3. from fastNLP import Instance
  4. from fastNLP import Vocabulary
  5. from fastNLP.core.losses import CrossEntropyLoss
  6. from fastNLP.core.metrics import AccuracyMetric
  7. class TestTutorial(unittest.TestCase):
  8. def test_fastnlp_10min_tutorial(self):
  9. # 从csv读取数据到DataSet
  10. sample_path = "test/data_for_tests/tutorial_sample_dataset.csv"
  11. dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'),
  12. sep='\t')
  13. print(len(dataset))
  14. print(dataset[0])
  15. print(dataset[-3])
  16. dataset.append(Instance(raw_sentence='fake data', label='0'))
  17. # 将所有数字转为小写
  18. dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')
  19. # label转int
  20. dataset.apply(lambda x: int(x['label']), new_field_name='label')
  21. # 使用空格分割句子
  22. def split_sent(ins):
  23. return ins['raw_sentence'].split()
  24. dataset.apply(split_sent, new_field_name='words')
  25. # 增加长度信息
  26. dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')
  27. print(len(dataset))
  28. print(dataset[0])
  29. # DataSet.drop(func)筛除数据
  30. dataset.drop(lambda x: x['seq_len'] <= 3, inplace=True)
  31. print(len(dataset))
  32. # 设置DataSet中,哪些field要转为tensor
  33. # set target,loss或evaluate中的golden,计算loss,模型评估时使用
  34. dataset.set_target("label")
  35. # set input,模型forward时使用
  36. dataset.set_input("words", "seq_len")
  37. # 分出测试集、训练集
  38. test_data, train_data = dataset.split(0.5)
  39. print(len(test_data))
  40. print(len(train_data))
  41. # 构建词表, Vocabulary.add(word)
  42. vocab = Vocabulary(min_freq=2)
  43. train_data.apply(lambda x: [vocab.add(word) for word in x['words']])
  44. vocab.build_vocab()
  45. # index句子, Vocabulary.to_index(word)
  46. train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
  47. test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
  48. print(test_data[0])
  49. # 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具
  50. from fastNLP.core.batch import Batch
  51. from fastNLP.core.sampler import RandomSampler
  52. batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())
  53. for batch_x, batch_y in batch_iterator:
  54. print("batch_x has: ", batch_x)
  55. print("batch_y has: ", batch_y)
  56. break
  57. from fastNLP.models import CNNText
  58. model = CNNText((len(vocab), 50), num_classes=5, padding=2, dropout=0.1)
  59. from fastNLP import Trainer
  60. from copy import deepcopy
  61. # 更改DataSet中对应field的名称,要以模型的forward等参数名一致
  62. train_data.rename_field('label', 'label_seq')
  63. test_data.rename_field('label', 'label_seq')
  64. loss = CrossEntropyLoss(pred="output", target="label_seq")
  65. metric = AccuracyMetric(pred="predict", target="label_seq")
  66. # 实例化Trainer,传入模型和数据,进行训练
  67. # 先在test_data拟合(确保模型的实现是正确的)
  68. copy_model = deepcopy(model)
  69. overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,
  70. loss=loss,
  71. metrics=metric,
  72. save_path=None,
  73. batch_size=32,
  74. n_epochs=5)
  75. overfit_trainer.train()
  76. # 用train_data训练,在test_data验证
  77. trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,
  78. loss=CrossEntropyLoss(pred="output", target="label_seq"),
  79. metrics=AccuracyMetric(pred="predict", target="label_seq"),
  80. save_path=None,
  81. batch_size=32,
  82. n_epochs=5)
  83. trainer.train()
  84. print('Train finished!')
  85. # 调用Tester在test_data上评价效果
  86. from fastNLP import Tester
  87. tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"),
  88. batch_size=4)
  89. acc = tester.test()
  90. print(acc)
  91. def test_fastnlp_1min_tutorial(self):
  92. # tutorials/fastnlp_1min_tutorial.ipynb
  93. data_path = "test/data_for_tests/tutorial_sample_dataset.csv"
  94. ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\t')
  95. print(ds[1])
  96. # 将所有数字转为小写
  97. ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')
  98. # label转int
  99. ds.apply(lambda x: int(x['label']), new_field_name='target', is_target=True)
  100. def split_sent(ins):
  101. return ins['raw_sentence'].split()
  102. ds.apply(split_sent, new_field_name='words', is_input=True)
  103. # 分割训练集/验证集
  104. train_data, dev_data = ds.split(0.3)
  105. print("Train size: ", len(train_data))
  106. print("Test size: ", len(dev_data))
  107. from fastNLP import Vocabulary
  108. vocab = Vocabulary(min_freq=2)
  109. train_data.apply(lambda x: [vocab.add(word) for word in x['words']])
  110. # index句子, Vocabulary.to_index(word)
  111. train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words',
  112. is_input=True)
  113. dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words',
  114. is_input=True)
  115. from fastNLP.models import CNNText
  116. model = CNNText((len(vocab), 50), num_classes=5, padding=2, dropout=0.1)
  117. from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
  118. trainer = Trainer(model=model,
  119. train_data=train_data,
  120. dev_data=dev_data,
  121. loss=CrossEntropyLoss(),
  122. optimizer= Adam(),
  123. metrics=AccuracyMetric(target='target')
  124. )
  125. trainer.train()
  126. print('Train finished!')
  127. def setUp(self):
  128. import os
  129. self._init_wd = os.path.abspath(os.curdir)
  130. def tearDown(self):
  131. import os
  132. os.chdir(self._init_wd)