You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yelpLoader.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import ast
  2. import csv
  3. from typing import Iterable
  4. from fastNLP import DataSet, Instance, Vocabulary
  5. from fastNLP.core.vocabulary import VocabularyOption
  6. from fastNLP.io import JsonLoader
  7. from fastNLP.io.base_loader import DataBundle,DataSetLoader
  8. from fastNLP.io.embed_loader import EmbeddingOption
  9. from fastNLP.io.file_reader import _read_json
  10. from typing import Union, Dict
  11. from reproduction.utils import check_dataloader_paths, get_tokenizer
  12. def clean_str(sentence, tokenizer, char_lower=False):
  13. """
  14. heavily borrowed from github
  15. https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb
  16. :param sentence: is a str
  17. :return:
  18. """
  19. if char_lower:
  20. sentence = sentence.lower()
  21. import re
  22. nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
  23. words = tokenizer(sentence)
  24. words_collection = []
  25. for word in words:
  26. if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']:
  27. continue
  28. tt = nonalpnum.split(word)
  29. t = ''.join(tt)
  30. if t != '':
  31. words_collection.append(t)
  32. return words_collection
  33. class yelpLoader(DataSetLoader):
  34. """
  35. 读取Yelp_full/Yelp_polarity数据集, DataSet包含fields:
  36. words: list(str), 需要分类的文本
  37. target: str, 文本的标签
  38. chars:list(str),未index的字符列表
  39. 数据集:yelp_full/yelp_polarity
  40. :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
  41. """
  42. def __init__(self, fine_grained=False,lower=False):
  43. super(yelpLoader, self).__init__()
  44. tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral',
  45. '4.0': 'positive', '5.0': 'very positive'}
  46. if not fine_grained:
  47. tag_v['1.0'] = tag_v['2.0']
  48. tag_v['5.0'] = tag_v['4.0']
  49. self.fine_grained = fine_grained
  50. self.tag_v = tag_v
  51. self.lower = lower
  52. self.tokenizer = get_tokenizer()
  53. '''
  54. 读取Yelp数据集, DataSet包含fields:
  55. review_id: str, 22 character unique review id
  56. user_id: str, 22 character unique user id
  57. business_id: str, 22 character business id
  58. useful: int, number of useful votes received
  59. funny: int, number of funny votes received
  60. cool: int, number of cool votes received
  61. date: str, date formatted YYYY-MM-DD
  62. words: list(str), 需要分类的文本
  63. target: str, 文本的标签
  64. 数据来源: https://www.yelp.com/dataset/download
  65. def _load_json(self, path):
  66. ds = DataSet()
  67. for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
  68. d = ast.literal_eval(d)
  69. d["words"] = d.pop("text").split()
  70. d["target"] = self.tag_v[str(d.pop("stars"))]
  71. ds.append(Instance(**d))
  72. return ds
  73. def _load_yelp2015_broken(self,path):
  74. ds = DataSet()
  75. with open (path,encoding='ISO 8859-1') as f:
  76. row=f.readline()
  77. all_count=0
  78. exp_count=0
  79. while row:
  80. row=row.split("\t\t")
  81. all_count+=1
  82. if len(row)>=3:
  83. words=row[-1].split()
  84. try:
  85. target=self.tag_v[str(row[-2])+".0"]
  86. ds.append(Instance(words=words, target=target))
  87. except KeyError:
  88. exp_count+=1
  89. else:
  90. exp_count+=1
  91. row = f.readline()
  92. print("error sample count:",exp_count)
  93. print("all count:",all_count)
  94. return ds
  95. '''
  96. def _load(self, path):
  97. ds = DataSet()
  98. csv_reader=csv.reader(open(path,encoding='utf-8'))
  99. all_count=0
  100. real_count=0
  101. for row in csv_reader:
  102. all_count+=1
  103. if len(row)==2:
  104. target=self.tag_v[row[0]+".0"]
  105. words = clean_str(row[1], self.tokenizer, self.lower)
  106. if len(words)!=0:
  107. ds.append(Instance(words=words,target=target))
  108. real_count += 1
  109. print("all count:", all_count)
  110. print("real count:", real_count)
  111. return ds
  112. def process(self, paths: Union[str, Dict[str, str]],
  113. train_ds: Iterable[str] = None,
  114. src_vocab_op: VocabularyOption = None,
  115. tgt_vocab_op: VocabularyOption = None,
  116. embed_opt: EmbeddingOption = None,
  117. char_level_op=False):
  118. paths = check_dataloader_paths(paths)
  119. datasets = {}
  120. info = DataBundle(datasets=self.load(paths))
  121. src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
  122. tgt_vocab = Vocabulary(unknown=None, padding=None) \
  123. if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
  124. _train_ds = [info.datasets[name]
  125. for name in train_ds] if train_ds else info.datasets.values()
  126. def wordtochar(words):
  127. chars = []
  128. for word in words:
  129. word = word.lower()
  130. for char in word:
  131. chars.append(char)
  132. chars.append('')
  133. chars.pop()
  134. return chars
  135. input_name, target_name = 'words', 'target'
  136. info.vocabs={}
  137. #就分隔为char形式
  138. if char_level_op:
  139. for dataset in info.datasets.values():
  140. dataset.apply_field(wordtochar, field_name="words",new_field_name='chars')
  141. # if embed_opt is not None:
  142. # embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab)
  143. # info.embeddings['words'] = embed
  144. else:
  145. src_vocab.from_dataset(*_train_ds, field_name=input_name)
  146. src_vocab.index_dataset(*info.datasets.values(),field_name=input_name, new_field_name=input_name)
  147. info.vocabs[input_name]=src_vocab
  148. tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
  149. tgt_vocab.index_dataset(
  150. *info.datasets.values(),
  151. field_name=target_name, new_field_name=target_name)
  152. info.vocabs[target_name]=tgt_vocab
  153. info.datasets['train'],info.datasets['dev']=info.datasets['train'].split(0.1, shuffle=False)
  154. for name, dataset in info.datasets.items():
  155. dataset.set_input("words")
  156. dataset.set_target("target")
  157. return info
  158. if __name__=="__main__":
  159. testloader=yelpLoader()
  160. # datapath = {"train": "/remote-home/ygwang/yelp_full/train.csv",
  161. # "test": "/remote-home/ygwang/yelp_full/test.csv"}
  162. #datapath={"train": "/remote-home/ygwang/yelp_full/test.csv"}
  163. datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv",
  164. "test": "/remote-home/ygwang/yelp_polarity/test.csv"}
  165. datainfo=testloader.process(datapath,char_level_op=True)
  166. len_count=0
  167. for instance in datainfo.datasets["train"]:
  168. len_count+=len(instance["chars"])
  169. ave_len=len_count/len(datainfo.datasets["train"])
  170. print(ave_len)