You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SSTLoader.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from typing import Iterable
  2. from nltk import Tree
  3. from fastNLP.io.base_loader import DataInfo, DataSetLoader
  4. from fastNLP.core.vocabulary import VocabularyOption, Vocabulary
  5. from fastNLP import DataSet
  6. from fastNLP import Instance
  7. from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
  8. class SSTLoader(DataSetLoader):
  9. URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
  10. DATA_DIR = 'sst/'
  11. """
  12. 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader`
  13. 读取SST数据集, DataSet包含fields::
  14. words: list(str) 需要分类的文本
  15. target: str 文本的标签
  16. 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
  17. :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False``
  18. :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
  19. """
  20. def __init__(self, subtree=False, fine_grained=False):
  21. self.subtree = subtree
  22. tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral',
  23. '3': 'positive', '4': 'very positive'}
  24. if not fine_grained:
  25. tag_v['0'] = tag_v['1']
  26. tag_v['4'] = tag_v['3']
  27. self.tag_v = tag_v
  28. def _load(self, path):
  29. """
  30. :param str path: 存储数据的路径
  31. :return: 一个 :class:`~fastNLP.DataSet` 类型的对象
  32. """
  33. datalist = []
  34. with open(path, 'r', encoding='utf-8') as f:
  35. datas = []
  36. for l in f:
  37. datas.extend([(s, self.tag_v[t])
  38. for s, t in self._get_one(l, self.subtree)])
  39. ds = DataSet()
  40. for words, tag in datas:
  41. ds.append(Instance(words=words, target=tag))
  42. return ds
  43. @staticmethod
  44. def _get_one(data, subtree):
  45. tree = Tree.fromstring(data)
  46. if subtree:
  47. return [(t.leaves(), t.label()) for t in tree.subtrees()]
  48. return [(tree.leaves(), tree.label())]
  49. def process(self,
  50. paths,
  51. train_ds: Iterable[str] = None,
  52. src_vocab_op: VocabularyOption = None,
  53. tgt_vocab_op: VocabularyOption = None,
  54. src_embed_op: EmbeddingOption = None):
  55. input_name, target_name = 'words', 'target'
  56. src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
  57. tgt_vocab = Vocabulary(unknown=None, padding=None) \
  58. if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op)
  59. info = DataInfo(datasets=self.load(paths))
  60. _train_ds = [info.datasets[name]
  61. for name in train_ds] if train_ds else info.datasets.values()
  62. src_vocab.from_dataset(*_train_ds, field_name=input_name)
  63. tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
  64. src_vocab.index_dataset(
  65. *info.datasets.values(),
  66. field_name=input_name, new_field_name=input_name)
  67. tgt_vocab.index_dataset(
  68. *info.datasets.values(),
  69. field_name=target_name, new_field_name=target_name)
  70. info.vocabs = {
  71. input_name: src_vocab,
  72. target_name: tgt_vocab
  73. }
  74. if src_embed_op is not None:
  75. src_embed_op.vocab = src_vocab
  76. init_emb = EmbedLoader.load_with_vocab(**src_embed_op)
  77. info.embeddings[input_name] = init_emb
  78. for name, dataset in info.datasets.items():
  79. dataset.set_input(input_name)
  80. dataset.set_target(target_name)
  81. return info