You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

IMDBLoader.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
  2. from fastNLP.core.vocabulary import VocabularyOption
  3. from fastNLP.io.base_loader import DataSetLoader, DataBundle
  4. from typing import Union, Dict, List, Iterator
  5. from fastNLP import DataSet
  6. from fastNLP import Instance
  7. from fastNLP import Vocabulary
  8. from fastNLP import Const
  9. # from reproduction.utils import check_dataloader_paths
  10. from functools import partial
  11. from reproduction.utils import check_dataloader_paths, get_tokenizer
  12. class IMDBLoader(DataSetLoader):
  13. """
  14. 读取IMDB数据集,DataSet包含以下fields:
  15. words: list(str), 需要分类的文本
  16. target: str, 文本的标签
  17. """
  18. def __init__(self):
  19. super(IMDBLoader, self).__init__()
  20. self.tokenizer = get_tokenizer()
  21. def _load(self, path):
  22. dataset = DataSet()
  23. with open(path, 'r', encoding="utf-8") as f:
  24. for line in f:
  25. line = line.strip()
  26. if not line:
  27. continue
  28. parts = line.split('\t')
  29. target = parts[0]
  30. words = self.tokenizer(parts[1].lower())
  31. dataset.append(Instance(words=words, target=target))
  32. if len(dataset)==0:
  33. raise RuntimeError(f"{path} has no valid data.")
  34. return dataset
  35. def process(self,
  36. paths: Union[str, Dict[str, str]],
  37. src_vocab_opt: VocabularyOption = None,
  38. tgt_vocab_opt: VocabularyOption = None,
  39. src_embed_opt: EmbeddingOption = None,
  40. char_level_op=False):
  41. datasets = {}
  42. info = DataBundle()
  43. for name, path in paths.items():
  44. dataset = self.load(path)
  45. datasets[name] = dataset
  46. def wordtochar(words):
  47. chars = []
  48. for word in words:
  49. word = word.lower()
  50. for char in word:
  51. chars.append(char)
  52. chars.append('')
  53. chars.pop()
  54. return chars
  55. if char_level_op:
  56. for dataset in datasets.values():
  57. dataset.apply_field(wordtochar, field_name="words", new_field_name='chars')
  58. datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False)
  59. src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
  60. src_vocab.from_dataset(datasets['train'], field_name='words')
  61. src_vocab.index_dataset(*datasets.values(), field_name='words')
  62. tgt_vocab = Vocabulary(unknown=None, padding=None) \
  63. if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
  64. tgt_vocab.from_dataset(datasets['train'], field_name='target')
  65. tgt_vocab.index_dataset(*datasets.values(), field_name='target')
  66. info.vocabs = {
  67. "words": src_vocab,
  68. "target": tgt_vocab
  69. }
  70. info.datasets = datasets
  71. if src_embed_opt is not None:
  72. embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab)
  73. info.embeddings['words'] = embed
  74. for name, dataset in info.datasets.items():
  75. dataset.set_input("words")
  76. dataset.set_target("target")
  77. return info
  78. if __name__=="__main__":
  79. datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv",
  80. "test": "/remote-home/ygwang/IMDB_data/test.csv"}
  81. datainfo=IMDBLoader().process(datapath,char_level_op=True)
  82. #print(datainfo.datasets["train"])
  83. len_count = 0
  84. for instance in datainfo.datasets["train"]:
  85. len_count += len(instance["chars"])
  86. ave_len = len_count / len(datainfo.datasets["train"])
  87. print(ave_len)