You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MTL16Loader.py 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
  2. from fastNLP.core.vocabulary import VocabularyOption
  3. from fastNLP.io.base_loader import DataSetLoader, DataBundle
  4. from typing import Union, Dict, List, Iterator
  5. from fastNLP import DataSet
  6. from fastNLP import Instance
  7. from fastNLP import Vocabulary
  8. from fastNLP import Const
  9. from reproduction.utils import check_dataloader_paths
  10. from functools import partial
  11. class MTL16Loader(DataSetLoader):
  12. """
  13. 读取MTL16数据集,DataSet包含以下fields:
  14. words: list(str), 需要分类的文本
  15. target: str, 文本的标签
  16. 数据来源:https://pan.baidu.com/s/1c2L6vdA
  17. """
  18. def __init__(self):
  19. super(MTL16Loader, self).__init__()
  20. def _load(self, path):
  21. dataset = DataSet()
  22. with open(path, 'r', encoding="utf-8") as f:
  23. for line in f:
  24. line = line.strip()
  25. if not line:
  26. continue
  27. parts = line.split('\t')
  28. target = parts[0]
  29. words = parts[1].lower().split()
  30. dataset.append(Instance(words=words, target=target))
  31. if len(dataset)==0:
  32. raise RuntimeError(f"{path} has no valid data.")
  33. return dataset
  34. def process(self,
  35. paths: Union[str, Dict[str, str]],
  36. src_vocab_opt: VocabularyOption = None,
  37. tgt_vocab_opt: VocabularyOption = None,
  38. src_embed_opt: EmbeddingOption = None):
  39. paths = check_dataloader_paths(paths)
  40. datasets = {}
  41. info = DataBundle()
  42. for name, path in paths.items():
  43. dataset = self.load(path)
  44. datasets[name] = dataset
  45. src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
  46. src_vocab.from_dataset(datasets['train'], field_name='words')
  47. src_vocab.index_dataset(*datasets.values(), field_name='words')
  48. tgt_vocab = Vocabulary(unknown=None, padding=None) \
  49. if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt)
  50. tgt_vocab.from_dataset(datasets['train'], field_name='target')
  51. tgt_vocab.index_dataset(*datasets.values(), field_name='target')
  52. info.vocabs = {
  53. "words": src_vocab,
  54. "target": tgt_vocab
  55. }
  56. info.datasets = datasets
  57. if src_embed_opt is not None:
  58. embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab)
  59. info.embeddings['words'] = embed
  60. for name, dataset in info.datasets.items():
  61. dataset.set_input("words")
  62. dataset.set_target("target")
  63. return info