You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_CRF.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import unittest
  2. class TestCRF(unittest.TestCase):
  3. def test_case1(self):
  4. # 检查allowed_transitions()能否正确使用
  5. from fastNLP.modules.decoder.CRF import allowed_transitions
  6. id2label = {0: 'B', 1: 'I', 2:'O'}
  7. expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
  8. (2, 4), (3, 0), (3, 2)}
  9. self.assertSetEqual(expected_res, set(allowed_transitions(id2label)))
  10. id2label = {0: 'B', 1:'M', 2:'E', 3:'S'}
  11. expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)}
  12. self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES')))
  13. id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"}
  14. allowed_transitions(id2label)
  15. labels = ['O']
  16. for label in ['X', 'Y']:
  17. for tag in 'BI':
  18. labels.append('{}-{}'.format(tag, label))
  19. id2label = {idx:label for idx, label in enumerate(labels)}
  20. expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1),
  21. (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3),
  22. (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)}
  23. self.assertSetEqual(expected_res, set(allowed_transitions(id2label)))
  24. labels = []
  25. for label in ['X', 'Y']:
  26. for tag in 'BMES':
  27. labels.append('{}-{}'.format(tag, label))
  28. id2label = {idx: label for idx, label in enumerate(labels)}
  29. expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4),
  30. (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0),
  31. (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)}
  32. self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES')))
  33. def test_case2(self):
  34. # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。
  35. pass
  36. # import torch
  37. # from fastNLP.modules.decoder.CRF import seq_len_to_byte_mask
  38. #
  39. # labels = ['O']
  40. # for label in ['X', 'Y']:
  41. # for tag in 'BI':
  42. # labels.append('{}-{}'.format(tag, label))
  43. # id2label = {idx: label for idx, label in enumerate(labels)}
  44. # num_tags = len(id2label)
  45. #
  46. # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
  47. # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label),
  48. # include_start_end_transitions=False)
  49. # batch_size = 3
  50. # logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log()
  51. # trans_m = allen_CRF.transitions
  52. # seq_lens = torch.randint(1, 20, size=(batch_size,))
  53. # seq_lens[-1] = 20
  54. # mask = seq_len_to_byte_mask(seq_lens)
  55. # allen_res = allen_CRF.viterbi_tags(logits, mask)
  56. #
  57. # from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions
  58. # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label))
  59. # fast_CRF.trans_m = trans_m
  60. # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True)
  61. # # score equal
  62. # self.assertListEqual([score for _, score in allen_res], fast_res[1])
  63. # # seq equal
  64. # self.assertListEqual([_ for _, score in allen_res], fast_res[0])
  65. #
  66. #
  67. # labels = []
  68. # for label in ['X', 'Y']:
  69. # for tag in 'BMES':
  70. # labels.append('{}-{}'.format(tag, label))
  71. # id2label = {idx: label for idx, label in enumerate(labels)}
  72. # num_tags = len(id2label)
  73. #
  74. # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
  75. # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label),
  76. # include_start_end_transitions=False)
  77. # batch_size = 3
  78. # logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log()
  79. # trans_m = allen_CRF.transitions
  80. # seq_lens = torch.randint(1, 20, size=(batch_size,))
  81. # seq_lens[-1] = 20
  82. # mask = seq_len_to_byte_mask(seq_lens)
  83. # allen_res = allen_CRF.viterbi_tags(logits, mask)
  84. #
  85. # from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions
  86. # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
  87. # encoding_type='BMES'))
  88. # fast_CRF.trans_m = trans_m
  89. # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True)
  90. # # score equal
  91. # self.assertListEqual([score for _, score in allen_res], fast_res[1])
  92. # # seq equal
  93. # self.assertListEqual([_ for _, score in allen_res], fast_res[0])