You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 2.5 kB

7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. class ConllxDataLoader(object):
  2. def load(self, path):
  3. datalist = []
  4. with open(path, 'r', encoding='utf-8') as f:
  5. sample = []
  6. for line in f:
  7. if line.startswith('\n'):
  8. datalist.append(sample)
  9. sample = []
  10. elif line.startswith('#'):
  11. continue
  12. else:
  13. sample.append(line.split('\t'))
  14. if len(sample) > 0:
  15. datalist.append(sample)
  16. data = [self.get_one(sample) for sample in datalist]
  17. return list(filter(lambda x: x is not None, data))
  18. def get_one(self, sample):
  19. sample = list(map(list, zip(*sample)))
  20. if len(sample) == 0:
  21. return None
  22. for w in sample[7]:
  23. if w == '_':
  24. print('Error Sample {}'.format(sample))
  25. return None
  26. # return word_seq, pos_seq, head_seq, head_tag_seq
  27. return sample[1], sample[3], list(map(int, sample[6])), sample[7]
  28. class MyDataloader:
  29. def load(self, data_path):
  30. with open(data_path, "r", encoding="utf-8") as f:
  31. lines = f.readlines()
  32. data = self.parse(lines)
  33. return data
  34. def parse(self, lines):
  35. """
  36. [
  37. [word], [pos], [head_index], [head_tag]
  38. ]
  39. """
  40. sample = []
  41. data = []
  42. for i, line in enumerate(lines):
  43. line = line.strip()
  44. if len(line) == 0 or i + 1 == len(lines):
  45. data.append(list(map(list, zip(*sample))))
  46. sample = []
  47. else:
  48. sample.append(line.split())
  49. if len(sample) > 0:
  50. data.append(list(map(list, zip(*sample))))
  51. return data
  52. def add_seg_tag(data):
  53. """
  54. :param data: list of ([word], [pos], [heads], [head_tags])
  55. :return: list of ([word], [pos])
  56. """
  57. _processed = []
  58. for word_list, pos_list, _, _ in data:
  59. new_sample = []
  60. for word, pos in zip(word_list, pos_list):
  61. if len(word) == 1:
  62. new_sample.append((word, 'S-' + pos))
  63. else:
  64. new_sample.append((word[0], 'B-' + pos))
  65. for c in word[1:-1]:
  66. new_sample.append((c, 'M-' + pos))
  67. new_sample.append((word[-1], 'E-' + pos))
  68. _processed.append(list(map(list, zip(*new_sample))))
  69. return _processed