You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sstloader.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. from typing import Iterable
  2. from nltk import Tree
  3. from fastNLP.io.base_loader import DataBundle, DataSetLoader
  4. from fastNLP.core.vocabulary import VocabularyOption, Vocabulary
  5. from fastNLP import DataSet
  6. from fastNLP import Instance
  7. from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
  8. import csv
  9. from typing import Union, Dict
  10. from reproduction.utils import check_dataloader_paths, get_tokenizer
  11. class SSTLoader(DataSetLoader):
  12. URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
  13. DATA_DIR = 'sst/'
  14. """
  15. 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
  16. 读取SST数据集, DataSet包含fields::
  17. words: list(str) 需要分类的文本
  18. target: str 文本的标签
  19. 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
  20. :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
  21. :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
  22. """
  23. def __init__(self, subtree=False, fine_grained=False):
  24. self.subtree = subtree
  25. tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
  26. '3': 'positive', '4': 'very positive'}
  27. if not fine_grained:
  28. tag_v['0'] = tag_v['1']
  29. tag_v['4'] = tag_v['3']
  30. self.tag_v = tag_v
  31. def _load(self, path):
  32. """
  33. :param str path: 存储数据的路径
  34. :return: 一个 :class:`~fastNLP.DataSet` 类型的对象
  35. """
  36. datalist = []
  37. with open(path, 'r', encoding='utf-8') as f:
  38. datas = []
  39. for l in f:
  40. datas.extend([(s, self.tag_v[t])
  41. for s, t in self._get_one(l, self.subtree)])
  42. ds = DataSet()
  43. for words, tag in datas:
  44. ds.append(Instance(words=words, target=tag))
  45. return ds
  46. @staticmethod
  47. def _get_one(data, subtree):
  48. tree = Tree.fromstring(data)
  49. if subtree:
  50. return [(t.leaves(), t.label()) for t in tree.subtrees()]
  51. return [(tree.leaves(), tree.label())]
  52. def process(self,
  53. paths,
  54. train_ds: Iterable[str] = None,
  55. src_vocab_op: VocabularyOption = None,
  56. tgt_vocab_op: VocabularyOption = None,
  57. src_embed_op: EmbeddingOption = None):
  58. input_name, target_name = 'words', 'target'
  59. src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
  60. tgt_vocab = Vocabulary(unknown=None, padding=None) \
  61. if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
  62. info = DataBundle(datasets=self.load(paths))
  63. _train_ds = [info.datasets[name]
  64. for name in train_ds] if train_ds else info.datasets.values()
  65. src_vocab.from_dataset(*_train_ds, field_name=input_name)
  66. tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
  67. src_vocab.index_dataset(
  68. *info.datasets.values(),
  69. field_name=input_name, new_field_name=input_name)
  70. tgt_vocab.index_dataset(
  71. *info.datasets.values(),
  72. field_name=target_name, new_field_name=target_name)
  73. info.vocabs = {
  74. input_name: src_vocab,
  75. target_name: tgt_vocab
  76. }
  77. if src_embed_op is not None:
  78. src_embed_op.vocab = src_vocab
  79. init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
  80. info.embeddings[input_name] = init_emb
  81. for name, dataset in info.datasets.items():
  82. dataset.set_input(input_name)
  83. dataset.set_target(target_name)
  84. return info
  85. class sst2Loader(DataSetLoader):
  86. '''
  87. 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
  88. '''
  89. def __init__(self):
  90. super(sst2Loader, self).__init__()
  91. self.tokenizer = get_tokenizer()
  92. def _load(self, path: str) -> DataSet:
  93. ds = DataSet()
  94. all_count=0
  95. csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t')
  96. skip_row = 0
  97. for idx,row in enumerate(csv_reader):
  98. if idx<=skip_row:
  99. continue
  100. target = row[1]
  101. words=self.tokenizer(row[0])
  102. ds.append(Instance(words=words,target=target))
  103. all_count+=1
  104. print("all count:", all_count)
  105. return ds
  106. def process(self,
  107. paths: Union[str, Dict[str, str]],
  108. src_vocab_opt: VocabularyOption = None,
  109. tgt_vocab_opt: VocabularyOption = None,
  110. src_embed_opt: EmbeddingOption = None,
  111. char_level_op=False):
  112. paths = check_dataloader_paths(paths)
  113. datasets = {}
  114. info = DataBundle()
  115. for name, path in paths.items():
  116. dataset = self.load(path)
  117. datasets[name] = dataset
  118. def wordtochar(words):
  119. chars = []
  120. for word in words:
  121. word = word.lower()
  122. for char in word:
  123. chars.append(char)
  124. chars.append('')
  125. chars.pop()
  126. return chars
  127. input_name, target_name = 'words', 'target'
  128. info.vocabs={}
  129. # 就分隔为char形式
  130. if char_level_op:
  131. for dataset in datasets.values():
  132. dataset.apply_field(wordtochar, field_name="words", new_field_name='chars')
  133. src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
  134. src_vocab.from_dataset(datasets['train'], field_name='words')
  135. src_vocab.index_dataset(*datasets.values(), field_name='words')
  136. tgt_vocab = Vocabulary(unknown=None, padding=None) \
  137. if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
  138. tgt_vocab.from_dataset(datasets['train'], field_name='target')
  139. tgt_vocab.index_dataset(*datasets.values(), field_name='target')
  140. info.vocabs = {
  141. "words": src_vocab,
  142. "target": tgt_vocab
  143. }
  144. info.datasets = datasets
  145. if src_embed_opt is not None:
  146. embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab)
  147. info.embeddings['words'] = embed
  148. for name, dataset in info.datasets.items():
  149. dataset.set_input("words")
  150. dataset.set_target("target")
  151. return info
  152. if __name__=="__main__":
  153. datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv",
  154. "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"}
  155. datainfo=sst2Loader().process(datapath,char_level_op=True)
  156. #print(datainfo.datasets["train"])
  157. len_count = 0
  158. for instance in datainfo.datasets["train"]:
  159. len_count += len(instance["chars"])
  160. ave_len = len_count / len(datainfo.datasets["train"])
  161. print(ave_len)