from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.core.vocabulary import VocabularyOption from fastNLP.io.base_loader import DataSetLoader, DataBundle from typing import Union, Dict, List, Iterator from fastNLP import DataSet from fastNLP import Instance from fastNLP import Vocabulary from fastNLP import Const from reproduction.utils import check_dataloader_paths from functools import partial class MTL16Loader(DataSetLoader): """ 读取MTL16数据集,DataSet包含以下fields: words: list(str), 需要分类的文本 target: str, 文本的标签 数据来源:https://pan.baidu.com/s/1c2L6vdA """ def __init__(self): super(MTL16Loader, self).__init__() def _load(self, path): dataset = DataSet() with open(path, 'r', encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue parts = line.split('\t') target = parts[0] words = parts[1].lower().split() dataset.append(Instance(words=words, target=target)) if len(dataset)==0: raise RuntimeError(f"{path} has no valid data.") return dataset def process(self, paths: Union[str, Dict[str, str]], src_vocab_opt: VocabularyOption = None, tgt_vocab_opt: VocabularyOption = None, src_embed_opt: EmbeddingOption = None): paths = check_dataloader_paths(paths) datasets = {} info = DataBundle() for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) src_vocab.from_dataset(datasets['train'], field_name='words') src_vocab.index_dataset(*datasets.values(), field_name='words') tgt_vocab = Vocabulary(unknown=None, padding=None) \ if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) tgt_vocab.from_dataset(datasets['train'], field_name='target') tgt_vocab.index_dataset(*datasets.values(), field_name='target') info.vocabs = { "words": src_vocab, "target": tgt_vocab } info.datasets = datasets if src_embed_opt is not None: embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) info.embeddings['words'] = embed for name, dataset in info.datasets.items(): dataset.set_input("words") dataset.set_target("target") return info