from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.core.vocabulary import VocabularyOption from fastNLP.io.base_loader import DataSetLoader, DataBundle from typing import Union, Dict, List, Iterator from fastNLP import DataSet from fastNLP import Instance from fastNLP import Vocabulary from fastNLP import Const # from reproduction.utils import check_dataloader_paths from functools import partial from reproduction.utils import check_dataloader_paths, get_tokenizer class IMDBLoader(DataSetLoader): """ 读取IMDB数据集,DataSet包含以下fields: words: list(str), 需要分类的文本 target: str, 文本的标签 """ def __init__(self): super(IMDBLoader, self).__init__() self.tokenizer = get_tokenizer() def _load(self, path): dataset = DataSet() with open(path, 'r', encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue parts = line.split('\t') target = parts[0] words = self.tokenizer(parts[1].lower()) dataset.append(Instance(words=words, target=target)) if len(dataset)==0: raise RuntimeError(f"{path} has no valid data.") return dataset def process(self, paths: Union[str, Dict[str, str]], src_vocab_opt: VocabularyOption = None, tgt_vocab_opt: VocabularyOption = None, src_embed_opt: EmbeddingOption = None, char_level_op=False): datasets = {} info = DataBundle() for name, path in paths.items(): dataset = self.load(path) datasets[name] = dataset def wordtochar(words): chars = [] for word in words: word = word.lower() for char in word: chars.append(char) chars.append('') chars.pop() return chars if char_level_op: for dataset in datasets.values(): dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) src_vocab.from_dataset(datasets['train'], field_name='words') src_vocab.index_dataset(*datasets.values(), field_name='words') tgt_vocab = Vocabulary(unknown=None, padding=None) \ if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) tgt_vocab.from_dataset(datasets['train'], field_name='target') tgt_vocab.index_dataset(*datasets.values(), field_name='target') info.vocabs = { "words": src_vocab, "target": tgt_vocab } info.datasets = datasets if src_embed_opt is not None: embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) info.embeddings['words'] = embed for name, dataset in info.datasets.items(): dataset.set_input("words") dataset.set_target("target") return info if __name__=="__main__": datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv", "test": "/remote-home/ygwang/IMDB_data/test.csv"} datainfo=IMDBLoader().process(datapath,char_level_op=True) #print(datainfo.datasets["train"]) len_count = 0 for instance in datainfo.datasets["train"]: len_count += len(instance["chars"]) ave_len = len_count / len(datainfo.datasets["train"]) print(ave_len)