You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SSTLoader.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from typing import Iterable
  2. from nltk import Tree
  3. from fastNLP.io.base_loader import DataInfo, DataSetLoader
  4. from fastNLP.core.vocabulary import VocabularyOption, Vocabulary
  5. from fastNLP import DataSet
  6. from fastNLP import Instance
  7. from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
  8. import csv
  9. from typing import Union, Dict
  10. class SSTLoader(DataSetLoader):
  11. URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
  12. DATA_DIR = 'sst/'
  13. """
  14. 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
  15. 读取SST数据集, DataSet包含fields::
  16. words: list(str) 需要分类的文本
  17. target: str 文本的标签
  18. 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
  19. :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
  20. :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
  21. """
  22. def __init__(self, subtree=False, fine_grained=False):
  23. self.subtree = subtree
  24. tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
  25. '3': 'positive', '4': 'very positive'}
  26. if not fine_grained:
  27. tag_v['0'] = tag_v['1']
  28. tag_v['4'] = tag_v['3']
  29. self.tag_v = tag_v
  30. def _load(self, path):
  31. """
  32. :param str path: 存储数据的路径
  33. :return: 一个 :class:`~fastNLP.DataSet` 类型的对象
  34. """
  35. datalist = []
  36. with open(path, 'r', encoding='utf-8') as f:
  37. datas = []
  38. for l in f:
  39. datas.extend([(s, self.tag_v[t])
  40. for s, t in self._get_one(l, self.subtree)])
  41. ds = DataSet()
  42. for words, tag in datas:
  43. ds.append(Instance(words=words, target=tag))
  44. return ds
  45. @staticmethod
  46. def _get_one(data, subtree):
  47. tree = Tree.fromstring(data)
  48. if subtree:
  49. return [(t.leaves(), t.label()) for t in tree.subtrees()]
  50. return [(tree.leaves(), tree.label())]
  51. def process(self,
  52. paths,
  53. train_ds: Iterable[str] = None,
  54. src_vocab_op: VocabularyOption = None,
  55. tgt_vocab_op: VocabularyOption = None,
  56. src_embed_op: EmbeddingOption = None):
  57. input_name, target_name = 'words', 'target'
  58. src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
  59. tgt_vocab = Vocabulary(unknown=None, padding=None) \
  60. if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
  61. info = DataInfo(datasets=self.load(paths))
  62. _train_ds = [info.datasets[name]
  63. for name in train_ds] if train_ds else info.datasets.values()
  64. src_vocab.from_dataset(*_train_ds, field_name=input_name)
  65. tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
  66. src_vocab.index_dataset(
  67. *info.datasets.values(),
  68. field_name=input_name, new_field_name=input_name)
  69. tgt_vocab.index_dataset(
  70. *info.datasets.values(),
  71. field_name=target_name, new_field_name=target_name)
  72. info.vocabs = {
  73. input_name: src_vocab,
  74. target_name: tgt_vocab
  75. }
  76. if src_embed_op is not None:
  77. src_embed_op.vocab = src_vocab
  78. init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
  79. info.embeddings[input_name] = init_emb
  80. for name, dataset in info.datasets.items():
  81. dataset.set_input(input_name)
  82. dataset.set_target(target_name)
  83. return info
  84. class sst2Loader(DataSetLoader):
  85. '''
  86. 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
  87. '''
  88. def __init__(self):
  89. super(sst2Loader, self).__init__()
  90. def _load(self, path: str) -> DataSet:
  91. ds = DataSet()
  92. all_count=0
  93. csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t')
  94. skip_row = 0
  95. for idx,row in enumerate(csv_reader):
  96. if idx<=skip_row:
  97. continue
  98. target = row[1]
  99. words = row[0].split()
  100. ds.append(Instance(words=words,target=target))
  101. all_count+=1
  102. print("all count:", all_count)
  103. return ds
  104. def process(self,
  105. paths: Union[str, Dict[str, str]],
  106. src_vocab_opt: VocabularyOption = None,
  107. tgt_vocab_opt: VocabularyOption = None,
  108. src_embed_opt: EmbeddingOption = None,
  109. char_level_op=False):
  110. paths = check_dataloader_paths(paths)
  111. datasets = {}
  112. info = DataInfo()
  113. for name, path in paths.items():
  114. dataset = self.load(path)
  115. datasets[name] = dataset
  116. def wordtochar(words):
  117. chars=[]
  118. for word in words:
  119. word=word.lower()
  120. for char in word:
  121. chars.append(char)
  122. return chars
  123. input_name, target_name = 'words', 'target'
  124. info.vocabs={}
  125. # 就分隔为char形式
  126. if char_level_op:
  127. for dataset in datasets.values():
  128. dataset.apply_field(wordtochar, field_name="words", new_field_name='chars')
  129. src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
  130. src_vocab.from_dataset(datasets['train'], field_name='words')
  131. src_vocab.index_dataset(*datasets.values(), field_name='words')
  132. tgt_vocab = Vocabulary(unknown=None, padding=None) \
  133. if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
  134. tgt_vocab.from_dataset(datasets['train'], field_name='target')
  135. tgt_vocab.index_dataset(*datasets.values(), field_name='target')
  136. info.vocabs = {
  137. "words": src_vocab,
  138. "target": tgt_vocab
  139. }
  140. info.datasets = datasets
  141. if src_embed_opt is not None:
  142. embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab)
  143. info.embeddings['words'] = embed
  144. return info
  145. if __name__=="__main__":
  146. datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv",
  147. "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"}
  148. datainfo=sst2Loader().process(datapath,char_level_op=True)
  149. #print(datainfo.datasets["train"])
  150. len_count = 0
  151. for instance in datainfo.datasets["train"]:
  152. len_count += len(instance["chars"])
  153. ave_len = len_count / len(datainfo.datasets["train"])
  154. print(ave_len)