You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yelp.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import csv
  2. from typing import Iterable
  3. from ...core.const import Const
  4. from ...core import DataSet, Instance, Vocabulary
  5. from ...core.vocabulary import VocabularyOption
  6. from ..base_loader import DataInfo,DataSetLoader
  7. from typing import Union, Dict
  8. from ..utils import check_dataloader_paths, get_tokenizer
  9. class YelpLoader(DataSetLoader):
  10. """
  11. 读取Yelp_full/Yelp_polarity数据集, DataSet包含fields:
  12. words: list(str), 需要分类的文本
  13. target: str, 文本的标签
  14. chars:list(str),未index的字符列表
  15. 数据集:yelp_full/yelp_polarity
  16. :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
  17. :param lower: 是否需要自动转小写,默认为False。
  18. """
  19. def __init__(self, fine_grained=False, lower=False):
  20. super(YelpLoader, self).__init__()
  21. tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral',
  22. '4.0': 'positive', '5.0': 'very positive'}
  23. if not fine_grained:
  24. tag_v['1.0'] = tag_v['2.0']
  25. tag_v['5.0'] = tag_v['4.0']
  26. self.fine_grained = fine_grained
  27. self.tag_v = tag_v
  28. self.lower = lower
  29. self.tokenizer = get_tokenizer()
  30. def _load(self, path):
  31. ds = DataSet()
  32. csv_reader = csv.reader(open(path, encoding='utf-8'))
  33. all_count = 0
  34. real_count = 0
  35. for row in csv_reader:
  36. all_count += 1
  37. if len(row) == 2:
  38. target = self.tag_v[row[0] + ".0"]
  39. words = clean_str(row[1], self.tokenizer, self.lower)
  40. if len(words) != 0:
  41. ds.append(Instance(words=words, target=target))
  42. real_count += 1
  43. print("all count:", all_count)
  44. print("real count:", real_count)
  45. return ds
  46. def process(self, paths: Union[str, Dict[str, str]],
  47. train_ds: Iterable[str] = None,
  48. src_vocab_op: VocabularyOption = None,
  49. tgt_vocab_op: VocabularyOption = None,
  50. char_level_op=False):
  51. paths = check_dataloader_paths(paths)
  52. info = DataInfo(datasets=self.load(paths))
  53. src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
  54. tgt_vocab = Vocabulary(unknown=None, padding=None) \
  55. if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
  56. _train_ds = [info.datasets[name]
  57. for name in train_ds] if train_ds else info.datasets.values()
  58. def wordtochar(words):
  59. chars = []
  60. for word in words:
  61. word = word.lower()
  62. for char in word:
  63. chars.append(char)
  64. chars.append('')
  65. chars.pop()
  66. return chars
  67. input_name, target_name = Const.INPUT, Const.TARGET
  68. info.vocabs = {}
  69. # 就分隔为char形式
  70. if char_level_op:
  71. for dataset in info.datasets.values():
  72. dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT)
  73. else:
  74. src_vocab.from_dataset(*_train_ds, field_name=input_name)
  75. src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name)
  76. info.vocabs[input_name] = src_vocab
  77. tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
  78. tgt_vocab.index_dataset(
  79. *info.datasets.values(),
  80. field_name=target_name, new_field_name=target_name)
  81. info.vocabs[target_name] = tgt_vocab
  82. info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False)
  83. for name, dataset in info.datasets.items():
  84. dataset.set_input(Const.INPUT)
  85. dataset.set_target(Const.TARGET)
  86. return info
  87. def clean_str(sentence, tokenizer, char_lower=False):
  88. """
  89. heavily borrowed from github
  90. https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb
  91. :param sentence: is a str
  92. :return:
  93. """
  94. if char_lower:
  95. sentence = sentence.lower()
  96. import re
  97. nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
  98. words = tokenizer(sentence)
  99. words_collection = []
  100. for word in words:
  101. if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']:
  102. continue
  103. tt = nonalpnum.split(word)
  104. t = ''.join(tt)
  105. if t != '':
  106. words_collection.append(t)
  107. return words_collection