| @@ -6,91 +6,33 @@ from copy import deepcopy | |||||
| from fastNLP.core.field import TextField, LabelField | from fastNLP.core.field import TextField, LabelField | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
| from fastNLP.loader.dataset_loader import POSDataSetLoader, ClassDataSetLoader | |||||
| def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None): | |||||
| if has_target is True: | |||||
| if label_vocab is None: | |||||
| raise RuntimeError("Must provide label vocabulary to transform labels.") | |||||
| return create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab) | |||||
| else: | |||||
| return create_unlabeled_dataset_from_lists(str_lists, word_vocab) | |||||
| def create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab): | |||||
| """Create an DataSet instance that contains labels. | |||||
| :param str_lists: list of list of strings, [num_examples, 2, *]. | |||||
| :: | |||||
| [ | |||||
| [[word_11, word_12, ...], [label_11, label_12, ...]], | |||||
| ... | |||||
| ] | |||||
| :param word_vocab: dict of (str: int), which means (word: index). | |||||
| :param label_vocab: dict of (str: int), which means (word: index). | |||||
| :return data_set: a DataSet instance. | |||||
| """ | |||||
| data_set = DataSet() | |||||
| for example in str_lists: | |||||
| word_seq, label_seq = example[0], example[1] | |||||
| x = TextField(word_seq, is_target=False) | |||||
| y = TextField(label_seq, is_target=True) | |||||
| data_set.append(Instance(word_seq=x, label_seq=y)) | |||||
| data_set.index_field("word_seq", word_vocab) | |||||
| data_set.index_field("label_seq", label_vocab) | |||||
| return data_set | |||||
| def create_unlabeled_dataset_from_lists(str_lists, word_vocab): | |||||
| """Create an DataSet instance that contains no labels. | |||||
| :param str_lists: list of list of strings, [num_examples, *]. | |||||
| :: | |||||
| [ | |||||
| [word_11, word_12, ...], | |||||
| ... | |||||
| ] | |||||
| :param word_vocab: dict of (str: int), which means (word: index). | |||||
| :return data_set: a DataSet instance. | |||||
| """ | |||||
| data_set = DataSet() | |||||
| for word_seq in str_lists: | |||||
| x = TextField(word_seq, is_target=False) | |||||
| data_set.append(Instance(word_seq=x)) | |||||
| data_set.index_field("word_seq", word_vocab) | |||||
| return data_set | |||||
| class DataSet(list): | class DataSet(list): | ||||
| """A DataSet object is a list of Instance objects. | """A DataSet object is a list of Instance objects. | ||||
| """ | """ | ||||
| def __init__(self, name="", instances=None, load_func=None): | |||||
| def __init__(self, name="", instances=None): | |||||
| """ | """ | ||||
| :param name: str, the name of the dataset. (default: "") | :param name: str, the name of the dataset. (default: "") | ||||
| :param instances: list of Instance objects. (default: None) | :param instances: list of Instance objects. (default: None) | ||||
| :param load_func: a function that takes the dataset path (string) as input and returns multi-level lists. | |||||
| """ | """ | ||||
| list.__init__([]) | list.__init__([]) | ||||
| self.name = name | self.name = name | ||||
| self.origin_len = None | |||||
| if instances is not None: | if instances is not None: | ||||
| self.extend(instances) | self.extend(instances) | ||||
| self.data_set_load_func = load_func | |||||
| def index_all(self, vocab): | def index_all(self, vocab): | ||||
| for ins in self: | for ins in self: | ||||
| ins.index_all(vocab) | ins.index_all(vocab) | ||||
| return self | |||||
| def index_field(self, field_name, vocab): | def index_field(self, field_name, vocab): | ||||
| for ins in self: | for ins in self: | ||||
| ins.index_field(field_name, vocab) | ins.index_field(field_name, vocab) | ||||
| return self | |||||
| def to_tensor(self, idx: int, padding_length: dict): | def to_tensor(self, idx: int, padding_length: dict): | ||||
| """Convert an instance in a dataset to tensor. | """Convert an instance in a dataset to tensor. | ||||
| @@ -102,7 +44,7 @@ class DataSet(list): | |||||
| """ | """ | ||||
| ins = self[idx] | ins = self[idx] | ||||
| return ins.to_tensor(padding_length) | |||||
| return ins.to_tensor(padding_length, self.origin_len) | |||||
| def get_length(self): | def get_length(self): | ||||
| """Fetch lengths of all fields in all instances in a dataset. | """Fetch lengths of all fields in all instances in a dataset. | ||||
| @@ -117,42 +59,9 @@ class DataSet(list): | |||||
| lengths[field_name].append(field_length) | lengths[field_name].append(field_length) | ||||
| return lengths | return lengths | ||||
| def convert(self, data): | |||||
| """Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training.""" | |||||
| raise NotImplementedError | |||||
| def convert_with_vocabs(self, data, vocabs): | |||||
| """Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing.""" | |||||
| raise NotImplementedError | |||||
| def convert_for_infer(self, data, vocabs): | |||||
| """Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting.""" | |||||
| def load(self, data_path, vocabs=None, infer=False): | |||||
| """Load data from the given files. | |||||
| :param data_path: str, the path to the data | |||||
| :param infer: bool. If True, there is no label information in the data. Default: False. | |||||
| :param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed. | |||||
| """ | |||||
| raw_data = self.data_set_load_func(data_path) | |||||
| if infer is True: | |||||
| self.convert_for_infer(raw_data, vocabs) | |||||
| else: | |||||
| if vocabs is not None: | |||||
| self.convert_with_vocabs(raw_data, vocabs) | |||||
| else: | |||||
| self.convert(raw_data) | |||||
| def load_raw(self, raw_data, vocabs): | |||||
| """Load raw data without loader. Used in FastNLP class. | |||||
| :param raw_data: | |||||
| :param vocabs: | |||||
| :return: | |||||
| """ | |||||
| self.convert_for_infer(raw_data, vocabs) | |||||
| def shuffle(self): | |||||
| random.shuffle(self) | |||||
| return self | |||||
| def split(self, ratio, shuffle=True): | def split(self, ratio, shuffle=True): | ||||
| """Train/dev splitting | """Train/dev splitting | ||||
| @@ -165,7 +74,7 @@ class DataSet(list): | |||||
| """ | """ | ||||
| assert 0 < ratio < 1 | assert 0 < ratio < 1 | ||||
| if shuffle: | if shuffle: | ||||
| random.shuffle(self) | |||||
| self.shuffle() | |||||
| split_idx = int(len(self) * ratio) | split_idx = int(len(self) * ratio) | ||||
| dev_set = deepcopy(self) | dev_set = deepcopy(self) | ||||
| train_set = deepcopy(self) | train_set = deepcopy(self) | ||||
| @@ -173,134 +82,32 @@ class DataSet(list): | |||||
| del dev_set[split_idx:] | del dev_set[split_idx:] | ||||
| return train_set, dev_set | return train_set, dev_set | ||||
| class SeqLabelDataSet(DataSet): | |||||
| def __init__(self, instances=None, load_func=POSDataSetLoader().load): | |||||
| super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func) | |||||
| self.word_vocab = Vocabulary() | |||||
| self.label_vocab = Vocabulary() | |||||
| def convert(self, data): | |||||
| """Convert lists of strings into Instances with Fields. | |||||
| :param data: 3-level lists. Entries are strings. | |||||
| def rename_field(self, old_name, new_name): | |||||
| """rename a field | |||||
| """ | """ | ||||
| bar = ProgressBar(total=len(data)) | |||||
| for example in data: | |||||
| word_seq, label_seq = example[0], example[1] | |||||
| # list, list | |||||
| self.word_vocab.update(word_seq) | |||||
| self.label_vocab.update(label_seq) | |||||
| x = TextField(word_seq, is_target=False) | |||||
| x_len = LabelField(len(word_seq), is_target=False) | |||||
| y = TextField(label_seq, is_target=False) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| instance.add_field("truth", y) | |||||
| instance.add_field("word_seq_origin_len", x_len) | |||||
| self.append(instance) | |||||
| bar.move() | |||||
| self.index_field("word_seq", self.word_vocab) | |||||
| self.index_field("truth", self.label_vocab) | |||||
| # no need to index "word_seq_origin_len" | |||||
| def convert_with_vocabs(self, data, vocabs): | |||||
| for example in data: | |||||
| word_seq, label_seq = example[0], example[1] | |||||
| # list, list | |||||
| x = TextField(word_seq, is_target=False) | |||||
| x_len = LabelField(len(word_seq), is_target=False) | |||||
| y = TextField(label_seq, is_target=False) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| instance.add_field("truth", y) | |||||
| instance.add_field("word_seq_origin_len", x_len) | |||||
| self.append(instance) | |||||
| self.index_field("word_seq", vocabs["word_vocab"]) | |||||
| self.index_field("truth", vocabs["label_vocab"]) | |||||
| # no need to index "word_seq_origin_len" | |||||
| def convert_for_infer(self, data, vocabs): | |||||
| for word_seq in data: | |||||
| # list | |||||
| x = TextField(word_seq, is_target=False) | |||||
| x_len = LabelField(len(word_seq), is_target=False) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| instance.add_field("word_seq_origin_len", x_len) | |||||
| self.append(instance) | |||||
| self.index_field("word_seq", vocabs["word_vocab"]) | |||||
| # no need to index "word_seq_origin_len" | |||||
| class TextClassifyDataSet(DataSet): | |||||
| def __init__(self, instances=None, load_func=ClassDataSetLoader().load): | |||||
| super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func) | |||||
| self.word_vocab = Vocabulary() | |||||
| self.label_vocab = Vocabulary(need_default=False) | |||||
| def convert(self, data): | |||||
| for example in data: | |||||
| word_seq, label = example[0], example[1] | |||||
| # list, str | |||||
| self.word_vocab.update(word_seq) | |||||
| self.label_vocab.update(label) | |||||
| x = TextField(word_seq, is_target=False) | |||||
| y = LabelField(label, is_target=True) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| instance.add_field("label", y) | |||||
| self.append(instance) | |||||
| self.index_field("word_seq", self.word_vocab) | |||||
| self.index_field("label", self.label_vocab) | |||||
| def convert_with_vocabs(self, data, vocabs): | |||||
| for example in data: | |||||
| word_seq, label = example[0], example[1] | |||||
| # list, str | |||||
| x = TextField(word_seq, is_target=False) | |||||
| y = LabelField(label, is_target=True) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| instance.add_field("label", y) | |||||
| self.append(instance) | |||||
| self.index_field("word_seq", vocabs["word_vocab"]) | |||||
| self.index_field("label", vocabs["label_vocab"]) | |||||
| def convert_for_infer(self, data, vocabs): | |||||
| for word_seq in data: | |||||
| # list | |||||
| x = TextField(word_seq, is_target=False) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| self.append(instance) | |||||
| self.index_field("word_seq", vocabs["word_vocab"]) | |||||
| def change_field_is_target(data_set, field_name, new_target): | |||||
| """Change the flag of is_target in a field. | |||||
| :param data_set: a DataSet object | |||||
| :param field_name: str, the name of the field | |||||
| :param new_target: one of (True, False, None), representing this field is batch_x / is batch_y / neither. | |||||
| """ | |||||
| for inst in data_set: | |||||
| inst.fields[field_name].is_target = new_target | |||||
| class ProgressBar: | |||||
| for ins in self: | |||||
| ins.rename_field(old_name, new_name) | |||||
| return self | |||||
| def __init__(self, count=0, total=0, width=100): | |||||
| self.count = count | |||||
| self.total = total | |||||
| self.width = width | |||||
| def set_target(self, **fields): | |||||
| """Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. | |||||
| def move(self): | |||||
| self.count += 1 | |||||
| progress = self.width * self.count // self.total | |||||
| sys.stdout.write('{0:3}/{1:3}: '.format(self.count, self.total)) | |||||
| sys.stdout.write('#' * progress + '-' * (self.width - progress) + '\r') | |||||
| if progress == self.width: | |||||
| sys.stdout.write('\n') | |||||
| sys.stdout.flush() | |||||
| :param key-value pairs for field-name and `is_target` value(True, False or None). | |||||
| """ | |||||
| for ins in self: | |||||
| ins.set_target(**fields) | |||||
| return self | |||||
| def update_vocab(self, **name_vocab): | |||||
| for field_name, vocab in name_vocab.items(): | |||||
| for ins in self: | |||||
| vocab.update(ins[field_name].contents()) | |||||
| return self | |||||
| def set_origin_len(self, origin_field, origin_len_name=None): | |||||
| if origin_field is None: | |||||
| self.origin_len = None | |||||
| else: | |||||
| self.origin_len = (origin_field + "_origin_len", origin_field) \ | |||||
| if origin_len_name is None else (origin_len_name, origin_field) | |||||
| return self | |||||
| @@ -18,6 +18,8 @@ class Field(object): | |||||
| def to_tensor(self, padding_length): | def to_tensor(self, padding_length): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def contents(self): | |||||
| raise NotImplementedError | |||||
| class TextField(Field): | class TextField(Field): | ||||
| def __init__(self, text, is_target): | def __init__(self, text, is_target): | ||||
| @@ -57,6 +59,8 @@ class TextField(Field): | |||||
| pads = [0] * (padding_length - self.get_length()) | pads = [0] * (padding_length - self.get_length()) | ||||
| return torch.LongTensor(self._index + pads) | return torch.LongTensor(self._index + pads) | ||||
| def contents(self): | |||||
| return self.text.copy() | |||||
| class LabelField(Field): | class LabelField(Field): | ||||
| """The Field representing a single label. Can be a string or integer. | """The Field representing a single label. Can be a string or integer. | ||||
| @@ -92,6 +96,8 @@ class LabelField(Field): | |||||
| else: | else: | ||||
| return torch.LongTensor([self._index]) | return torch.LongTensor([self._index]) | ||||
| def contents(self): | |||||
| return [self.label] | |||||
| class SeqLabelField(Field): | class SeqLabelField(Field): | ||||
| def __init__(self, label_seq, is_target=True): | def __init__(self, label_seq, is_target=True): | ||||
| @@ -122,6 +128,8 @@ class SeqLabelField(Field): | |||||
| else: | else: | ||||
| return torch.LongTensor(self._index + pads) | return torch.LongTensor(self._index + pads) | ||||
| def contents(self): | |||||
| return self.label_seq.copy() | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| tf = TextField("test the code".split(), is_target=False) | tf = TextField("test the code".split(), is_target=False) | ||||
| @@ -1,3 +1,5 @@ | |||||
| import torch | |||||
| class Instance(object): | class Instance(object): | ||||
| """An instance which consists of Fields is an example in the DataSet. | """An instance which consists of Fields is an example in the DataSet. | ||||
| @@ -10,6 +12,28 @@ class Instance(object): | |||||
| def add_field(self, field_name, field): | def add_field(self, field_name, field): | ||||
| self.fields[field_name] = field | self.fields[field_name] = field | ||||
| return self | |||||
| def rename_field(self, old_name, new_name): | |||||
| if old_name in self.fields: | |||||
| self.fields[new_name] = self.fields.pop(old_name) | |||||
| if old_name in self.indexes: | |||||
| self.indexes[new_name] = self.indexes.pop(old_name) | |||||
| else: | |||||
| print("error, no such field: {}".format(old_name)) | |||||
| return self | |||||
| def set_target(self, **fields): | |||||
| for name, val in fields.items(): | |||||
| if name in self.fields: | |||||
| self.fields[name].is_target = val | |||||
| return self | |||||
| def __getitem__(self, name): | |||||
| if name in self.fields: | |||||
| return self.fields[name] | |||||
| else: | |||||
| raise KeyError("{} not found".format(name)) | |||||
| def get_length(self): | def get_length(self): | ||||
| """Fetch the length of all fields in the instance. | """Fetch the length of all fields in the instance. | ||||
| @@ -24,6 +48,7 @@ class Instance(object): | |||||
| """use `vocab` to index certain field | """use `vocab` to index certain field | ||||
| """ | """ | ||||
| self.indexes[field_name] = self.fields[field_name].index(vocab) | self.indexes[field_name] = self.fields[field_name].index(vocab) | ||||
| return self | |||||
| def index_all(self, vocab): | def index_all(self, vocab): | ||||
| """use `vocab` to index all fields | """use `vocab` to index all fields | ||||
| @@ -35,7 +60,7 @@ class Instance(object): | |||||
| self.indexes = indexes | self.indexes = indexes | ||||
| return indexes | return indexes | ||||
| def to_tensor(self, padding_length: dict): | |||||
| def to_tensor(self, padding_length: dict, origin_len=None): | |||||
| """Convert instance to tensor. | """Convert instance to tensor. | ||||
| :param padding_length: dict of (str: int), which means (field name: padding_length of this field) | :param padding_length: dict of (str: int), which means (field name: padding_length of this field) | ||||
| @@ -53,4 +78,7 @@ class Instance(object): | |||||
| else: | else: | ||||
| # is_target is None | # is_target is None | ||||
| continue | continue | ||||
| if origin_len is not None: | |||||
| name, field_name = origin_len | |||||
| tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) | |||||
| return tensor_x, tensor_y | return tensor_x, tensor_y | ||||
| @@ -2,9 +2,9 @@ import numpy as np | |||||
| import torch | import torch | ||||
| from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
| from fastNLP.core.dataset import create_dataset_from_lists | |||||
| from fastNLP.core.preprocess import load_pickle | from fastNLP.core.preprocess import load_pickle | ||||
| from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
| from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq2tag_dataset, convert_seq_dataset | |||||
| class Predictor(object): | class Predictor(object): | ||||
| @@ -79,7 +79,8 @@ class Predictor(object): | |||||
| :return data_set: a DataSet instance. | :return data_set: a DataSet instance. | ||||
| """ | """ | ||||
| assert isinstance(data, list) | assert isinstance(data, list) | ||||
| return create_dataset_from_lists(data, self.word_vocab, has_target=False) | |||||
| data = convert_seq_dataset(data) | |||||
| data.index_field("word_seq", self.word_vocab) | |||||
| class SeqLabelInfer(Predictor): | class SeqLabelInfer(Predictor): | ||||
| @@ -1,6 +1,7 @@ | |||||
| import os | import os | ||||
| from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.loader.dataset_loader import convert_seq_dataset | |||||
| from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | ||||
| from fastNLP.core.preprocess import load_pickle | from fastNLP.core.preprocess import load_pickle | ||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
| @@ -178,13 +179,10 @@ class FastNLP(object): | |||||
| :param infer_input: 2-D lists of strings | :param infer_input: 2-D lists of strings | ||||
| :return data_set: a DataSet object | :return data_set: a DataSet object | ||||
| """ | """ | ||||
| if self.infer_type == "seq_label": | |||||
| data_set = SeqLabelDataSet() | |||||
| data_set.load_raw(infer_input, {"word_vocab": self.word_vocab}) | |||||
| return data_set | |||||
| elif self.infer_type == "text_class": | |||||
| data_set = TextClassifyDataSet() | |||||
| data_set.load_raw(infer_input, {"word_vocab": self.word_vocab}) | |||||
| if self.infer_type in ["seq_label", "text_class"]: | |||||
| data_set = convert_seq_dataset(infer_input) | |||||
| data_set.index_field("word_seq", self.word_vocab) | |||||
| data_set.set_origin_len("word_seq") | |||||
| return data_set | return data_set | ||||
| else: | else: | ||||
| raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | ||||
| @@ -1,6 +1,71 @@ | |||||
| import os | import os | ||||
| from fastNLP.loader.base_loader import BaseLoader | from fastNLP.loader.base_loader import BaseLoader | ||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.instance import Instance | |||||
| from fastNLP.core.field import * | |||||
| def convert_seq_dataset(data): | |||||
| """Create an DataSet instance that contains no labels. | |||||
| :param data: list of list of strings, [num_examples, *]. | |||||
| :: | |||||
| [ | |||||
| [word_11, word_12, ...], | |||||
| ... | |||||
| ] | |||||
| :return: a DataSet. | |||||
| """ | |||||
| dataset = DataSet() | |||||
| for word_seq in data: | |||||
| x = TextField(word_seq, is_target=False) | |||||
| dataset.append(Instance(word_seq=x)) | |||||
| return dataset | |||||
| def convert_seq2tag_dataset(data): | |||||
| """Convert list of data into DataSet | |||||
| :param data: list of list of strings, [num_examples, *]. | |||||
| :: | |||||
| [ | |||||
| [ [word_11, word_12, ...], label_1 ], | |||||
| [ [word_21, word_22, ...], label_2 ], | |||||
| ... | |||||
| ] | |||||
| :return: a DataSet. | |||||
| """ | |||||
| dataset = DataSet() | |||||
| for sample in data: | |||||
| word_seq, label = sample[0], sample[1] | |||||
| ins = Instance() | |||||
| ins.add_field("word_seq", TextField(word_seq, is_target=False)) \ | |||||
| .add_field("label", LabelField(label, is_target=True)) | |||||
| dataset.append(ins) | |||||
| return dataset | |||||
| def convert_seq2seq_dataset(data): | |||||
| """Convert list of data into DataSet | |||||
| :param data: list of list of strings, [num_examples, *]. | |||||
| :: | |||||
| [ | |||||
| [ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
| [ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
| ... | |||||
| ] | |||||
| :return: a DataSet. | |||||
| """ | |||||
| dataset = DataSet() | |||||
| for sample in data: | |||||
| word_seq, label_seq = sample[0], sample[1] | |||||
| ins = Instance() | |||||
| ins.add_field("word_seq", TextField(word_seq, is_target=False)) \ | |||||
| .add_field("label_seq", TextField(label_seq, is_target=True)) | |||||
| dataset.append(ins) | |||||
| return dataset | |||||
| class DataSetLoader(BaseLoader): | class DataSetLoader(BaseLoader): | ||||
| @@ -48,7 +113,8 @@ class POSDataSetLoader(DataSetLoader): | |||||
| """ | """ | ||||
| with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
| lines = f.readlines() | lines = f.readlines() | ||||
| return self.parse(lines) | |||||
| data = self.parse(lines) | |||||
| return self.convert(data) | |||||
| @staticmethod | @staticmethod | ||||
| def parse(lines): | def parse(lines): | ||||
| @@ -75,6 +141,10 @@ class POSDataSetLoader(DataSetLoader): | |||||
| data.append([words, labels]) | data.append([words, labels]) | ||||
| return data | return data | ||||
| def convert(self, data): | |||||
| """Convert lists of strings into Instances with Fields. | |||||
| """ | |||||
| return convert_seq2seq_dataset(data) | |||||
| class TokenizeDataSetLoader(DataSetLoader): | class TokenizeDataSetLoader(DataSetLoader): | ||||
| """ | """ | ||||
| @@ -84,8 +154,7 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(TokenizeDataSetLoader, self).__init__() | super(TokenizeDataSetLoader, self).__init__() | ||||
| @staticmethod | |||||
| def load(data_path, max_seq_len=32): | |||||
| def load(self, data_path, max_seq_len=32): | |||||
| """ | """ | ||||
| load pku dataset for Chinese word segmentation | load pku dataset for Chinese word segmentation | ||||
| CWS (Chinese Word Segmentation) pku training dataset format: | CWS (Chinese Word Segmentation) pku training dataset format: | ||||
| @@ -130,7 +199,10 @@ class TokenizeDataSetLoader(DataSetLoader): | |||||
| seq_words = words[start:end] | seq_words = words[start:end] | ||||
| seq_labels = labels[start:end] | seq_labels = labels[start:end] | ||||
| data.append([seq_words, seq_labels]) | data.append([seq_words, seq_labels]) | ||||
| return data | |||||
| return self.convert(data) | |||||
| def convert(self, data): | |||||
| return convert_seq2seq_dataset(data) | |||||
| class ClassDataSetLoader(DataSetLoader): | class ClassDataSetLoader(DataSetLoader): | ||||
| @@ -143,7 +215,8 @@ class ClassDataSetLoader(DataSetLoader): | |||||
| assert os.path.exists(data_path) | assert os.path.exists(data_path) | ||||
| with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
| lines = f.readlines() | lines = f.readlines() | ||||
| return self.parse(lines) | |||||
| data = self.parse(lines) | |||||
| return self.convert(data) | |||||
| @staticmethod | @staticmethod | ||||
| def parse(lines): | def parse(lines): | ||||
| @@ -166,16 +239,18 @@ class ClassDataSetLoader(DataSetLoader): | |||||
| dataset.append(sentence) | dataset.append(sentence) | ||||
| return dataset | return dataset | ||||
| def convert(self, data): | |||||
| return convert_seq2tag_dataset(data) | |||||
| class ConllLoader(DataSetLoader): | class ConllLoader(DataSetLoader): | ||||
| """loader for conll format files""" | """loader for conll format files""" | ||||
| def __int__(self, data_path): | |||||
| def __init__(self): | |||||
| """ | """ | ||||
| :param str data_path: the path to the conll data set | :param str data_path: the path to the conll data set | ||||
| """ | """ | ||||
| super(ConllLoader, self).__init__() | super(ConllLoader, self).__init__() | ||||
| self.data_set = self.parse(self.load(data_path)) | |||||
| def load(self, data_path): | def load(self, data_path): | ||||
| """ | """ | ||||
| @@ -183,7 +258,8 @@ class ConllLoader(DataSetLoader): | |||||
| """ | """ | ||||
| with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
| lines = f.readlines() | lines = f.readlines() | ||||
| return lines | |||||
| data = self.parse(lines) | |||||
| return self.convert(data) | |||||
| @staticmethod | @staticmethod | ||||
| def parse(lines): | def parse(lines): | ||||
| @@ -204,6 +280,9 @@ class ConllLoader(DataSetLoader): | |||||
| tokens.append(line.split()) | tokens.append(line.split()) | ||||
| return sentences | return sentences | ||||
| def convert(self, data): | |||||
| pass | |||||
| class LMDataSetLoader(DataSetLoader): | class LMDataSetLoader(DataSetLoader): | ||||
| """Language Model Dataset Loader | """Language Model Dataset Loader | ||||
| @@ -222,7 +301,8 @@ class LMDataSetLoader(DataSetLoader): | |||||
| with open(data_path, "r", encoding="utf=8") as f: | with open(data_path, "r", encoding="utf=8") as f: | ||||
| text = " ".join(f.readlines()) | text = " ".join(f.readlines()) | ||||
| tokens = text.strip().split() | tokens = text.strip().split() | ||||
| return self.sentence_cut(tokens) | |||||
| data = self.sentence_cut(tokens) | |||||
| return self.convert(data) | |||||
| def sentence_cut(self, tokens, sentence_length=15): | def sentence_cut(self, tokens, sentence_length=15): | ||||
| start_idx = 0 | start_idx = 0 | ||||
| @@ -236,6 +316,8 @@ class LMDataSetLoader(DataSetLoader): | |||||
| data_set.append([x, y]) | data_set.append([x, y]) | ||||
| return data_set | return data_set | ||||
| def convert(self, data): | |||||
| pass | |||||
| class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
| """ | """ | ||||
| @@ -286,3 +368,5 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
| ner_examples.append([sent_words, sent_ner]) | ner_examples.append([sent_words, sent_ner]) | ||||
| return pos_tag_examples, ner_examples | return pos_tag_examples, ner_examples | ||||
| def convert(self, data): | |||||
| pass | |||||
| @@ -12,7 +12,7 @@ from fastNLP.loader.model_loader import ModelLoader | |||||
| from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
| from fastNLP.core.predictor import SeqLabelInfer | from fastNLP.core.predictor import SeqLabelInfer | ||||
| from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.preprocess import save_pickle | from fastNLP.core.preprocess import save_pickle | ||||
| from fastNLP.core.metrics import SeqLabelEvaluator | from fastNLP.core.metrics import SeqLabelEvaluator | ||||
| @@ -3,7 +3,7 @@ import unittest | |||||
| import torch | import torch | ||||
| from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
| from fastNLP.core.dataset import DataSet, create_dataset_from_lists | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.field import TextField, LabelField | from fastNLP.core.field import TextField, LabelField | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| @@ -51,14 +51,3 @@ class TestCase1(unittest.TestCase): | |||||
| self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) | self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) | ||||
| self.assertTrue(isinstance(batch_y, dict)) | self.assertTrue(isinstance(batch_y, dict)) | ||||
| self.assertTrue(isinstance(batch_y["label"], torch.LongTensor)) | self.assertTrue(isinstance(batch_y["label"], torch.LongTensor)) | ||||
| class TestCase2(unittest.TestCase): | |||||
| def test(self): | |||||
| data = DataSet() | |||||
| for text in texts: | |||||
| x = TextField(text, is_target=False) | |||||
| ins = Instance(text=x) | |||||
| data.append(ins) | |||||
| data_set = create_dataset_from_lists(texts, vocab, has_target=False) | |||||
| self.assertTrue(type(data) == type(data_set)) | |||||
| @@ -1,7 +1,6 @@ | |||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||||
| from fastNLP.core.dataset import create_dataset_from_lists | |||||
| from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq_dataset | |||||
| class TestDataSet(unittest.TestCase): | class TestDataSet(unittest.TestCase): | ||||
| @@ -19,8 +18,9 @@ class TestDataSet(unittest.TestCase): | |||||
| label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | ||||
| def test_case_1(self): | def test_case_1(self): | ||||
| data_set = create_dataset_from_lists(self.labeled_data_list, self.word_vocab, has_target=True, | |||||
| label_vocab=self.label_vocab) | |||||
| data_set = convert_seq2seq_dataset(self.labeled_data_list) | |||||
| data_set.index_field("word_seq", self.word_vocab) | |||||
| data_set.index_field("label_seq", self.label_vocab) | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | self.assertEqual(len(data_set), len(self.labeled_data_list)) | ||||
| self.assertTrue(len(data_set) > 0) | self.assertTrue(len(data_set) > 0) | ||||
| self.assertTrue(hasattr(data_set[0], "fields")) | self.assertTrue(hasattr(data_set[0], "fields")) | ||||
| @@ -39,7 +39,8 @@ class TestDataSet(unittest.TestCase): | |||||
| [self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | [self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | ||||
| def test_case_2(self): | def test_case_2(self): | ||||
| data_set = create_dataset_from_lists(self.unlabeled_data_list, self.word_vocab, has_target=False) | |||||
| data_set = convert_seq_dataset(self.unlabeled_data_list) | |||||
| data_set.index_field("word_seq", self.word_vocab) | |||||
| self.assertEqual(len(data_set), len(self.unlabeled_data_list)) | self.assertEqual(len(data_set), len(self.unlabeled_data_list)) | ||||
| self.assertTrue(len(data_set) > 0) | self.assertTrue(len(data_set) > 0) | ||||
| @@ -51,193 +52,3 @@ class TestDataSet(unittest.TestCase): | |||||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | self.assertEqual(data_set[0].fields["word_seq"]._index, | ||||
| [self.word_vocab[c] for c in self.unlabeled_data_list[0]]) | [self.word_vocab[c] for c in self.unlabeled_data_list[0]]) | ||||
| class TestDataSetConvertion(unittest.TestCase): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| ] | |||||
| unlabeled_data_list = [ | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"] | |||||
| ] | |||||
| word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||||
| label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||||
| def test_case_1(self): | |||||
| def loader(path): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| ] | |||||
| return labeled_data_list | |||||
| data_set = SeqLabelDataSet(load_func=loader) | |||||
| data_set.load("any_path") | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertTrue("truth" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) | |||||
| self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||||
| def test_case_2(self): | |||||
| def loader(path): | |||||
| unlabeled_data_list = [ | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"] | |||||
| ] | |||||
| return unlabeled_data_list | |||||
| data_set = SeqLabelDataSet(load_func=loader) | |||||
| data_set.load("any_path", vocabs={"word_vocab": self.word_vocab}, infer=True) | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
| self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||||
| def test_case_3(self): | |||||
| def loader(path): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| ] | |||||
| return labeled_data_list | |||||
| data_set = SeqLabelDataSet(load_func=loader) | |||||
| data_set.load("any_path", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
| self.assertTrue("truth" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) | |||||
| self.assertEqual(data_set[0].fields["truth"]._index, | |||||
| [self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | |||||
| self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||||
| class TestDataSetConvertionHHH(unittest.TestCase): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], "A"], | |||||
| [["a", "b", "e", "d"], "C"], | |||||
| [["a", "b", "e", "d"], "B"], | |||||
| ] | |||||
| unlabeled_data_list = [ | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"] | |||||
| ] | |||||
| word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||||
| label_vocab = {"A": 1, "B": 2, "C": 3} | |||||
| def test_case_1(self): | |||||
| def loader(path): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], "A"], | |||||
| [["a", "b", "e", "d"], "C"], | |||||
| [["a", "b", "e", "d"], "B"], | |||||
| ] | |||||
| return labeled_data_list | |||||
| data_set = TextClassifyDataSet(load_func=loader) | |||||
| data_set.load("xxx") | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertTrue("label" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["label"], "label")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) | |||||
| def test_case_2(self): | |||||
| def loader(path): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], "A"], | |||||
| [["a", "b", "e", "d"], "C"], | |||||
| [["a", "b", "e", "d"], "B"], | |||||
| ] | |||||
| return labeled_data_list | |||||
| data_set = TextClassifyDataSet(load_func=loader) | |||||
| data_set.load("xxx", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
| self.assertTrue("label" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["label"], "label")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) | |||||
| self.assertEqual(data_set[0].fields["label"]._index, self.label_vocab[self.labeled_data_list[0][1]]) | |||||
| def test_case_3(self): | |||||
| def loader(path): | |||||
| unlabeled_data_list = [ | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"] | |||||
| ] | |||||
| return unlabeled_data_list | |||||
| data_set = TextClassifyDataSet(load_func=loader) | |||||
| data_set.load("xxx", vocabs={"word_vocab": self.word_vocab}, infer=True) | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
| @@ -1,11 +1,12 @@ | |||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.predictor import Predictor | from fastNLP.core.predictor import Predictor | ||||
| from fastNLP.core.preprocess import save_pickle | from fastNLP.core.preprocess import save_pickle | ||||
| from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
| from fastNLP.loader.base_loader import BaseLoader | from fastNLP.loader.base_loader import BaseLoader | ||||
| from fastNLP.loader.dataset_loader import convert_seq_dataset | |||||
| from fastNLP.models.cnn_text_classification import CNNText | from fastNLP.models.cnn_text_classification import CNNText | ||||
| from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
| @@ -42,8 +43,8 @@ class TestPredictor(unittest.TestCase): | |||||
| predictor = Predictor("./save/", pre.text_classify_post_processor) | predictor = Predictor("./save/", pre.text_classify_post_processor) | ||||
| # Load infer data | # Load infer data | ||||
| infer_data_set = TextClassifyDataSet(load_func=BaseLoader.load) | |||||
| infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | |||||
| infer_data_set = convert_seq_dataset(infer_data) | |||||
| infer_data_set.index_field("word_seq", vocab) | |||||
| results = predictor.predict(network=model, data=infer_data_set) | results = predictor.predict(network=model, data=infer_data_set) | ||||
| @@ -54,14 +55,11 @@ class TestPredictor(unittest.TestCase): | |||||
| self.assertTrue(isinstance(res, str)) | self.assertTrue(isinstance(res, str)) | ||||
| self.assertTrue(res in class_vocab.word2idx) | self.assertTrue(res in class_vocab.word2idx) | ||||
| del model, predictor, infer_data_set | |||||
| del model, predictor | |||||
| model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
| predictor = Predictor("./save/", pre.seq_label_post_processor) | predictor = Predictor("./save/", pre.seq_label_post_processor) | ||||
| infer_data_set = SeqLabelDataSet(load_func=BaseLoader.load) | |||||
| infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | |||||
| results = predictor.predict(network=model, data=infer_data_set) | results = predictor.predict(network=model, data=infer_data_set) | ||||
| self.assertTrue(isinstance(results, list)) | self.assertTrue(isinstance(results, list)) | ||||
| self.assertEqual(len(results), len(infer_data)) | self.assertEqual(len(results), len(infer_data)) | ||||
| @@ -1,7 +1,7 @@ | |||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import SeqLabelDataSet | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | from fastNLP.core.metrics import SeqLabelEvaluator | ||||
| from fastNLP.core.field import TextField, LabelField | from fastNLP.core.field import TextField, LabelField | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| @@ -35,7 +35,7 @@ class TestTester(unittest.TestCase): | |||||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | ||||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | ||||
| data_set = SeqLabelDataSet() | |||||
| data_set = DataSet() | |||||
| for example in train_data: | for example in train_data: | ||||
| text, label = example[0], example[1] | text, label = example[0], example[1] | ||||
| x = TextField(text, False) | x = TextField(text, False) | ||||
| @@ -1,7 +1,7 @@ | |||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import SeqLabelDataSet | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | from fastNLP.core.metrics import SeqLabelEvaluator | ||||
| from fastNLP.core.field import TextField, LabelField | from fastNLP.core.field import TextField, LabelField | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| @@ -36,7 +36,7 @@ class TestTrainer(unittest.TestCase): | |||||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | ||||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | ||||
| data_set = SeqLabelDataSet() | |||||
| data_set = DataSet() | |||||
| for example in train_data: | for example in train_data: | ||||
| text, label = example[0], example[1] | text, label = example[0], example[1] | ||||
| x = TextField(text, False) | x = TextField(text, False) | ||||
| @@ -1,6 +1,7 @@ | |||||
| import os | import os | ||||
| from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | from fastNLP.core.metrics import SeqLabelEvaluator | ||||
| from fastNLP.core.predictor import SeqLabelInfer | from fastNLP.core.predictor import SeqLabelInfer | ||||
| from fastNLP.core.preprocess import save_pickle, load_pickle | from fastNLP.core.preprocess import save_pickle, load_pickle | ||||
| @@ -37,8 +38,8 @@ def infer(): | |||||
| print("model loaded!") | print("model loaded!") | ||||
| # Load infer data | # Load infer data | ||||
| infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||||
| infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) | |||||
| infer_data = TokenizeDataSetLoader().load(data_infer_path) | |||||
| infer_data.index_field("word_seq", word2index) | |||||
| # inference | # inference | ||||
| infer = SeqLabelInfer(pickle_path) | infer = SeqLabelInfer(pickle_path) | ||||
| @@ -52,13 +53,15 @@ def train_test(): | |||||
| ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | ||||
| # define dataset | # define dataset | ||||
| data_train = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load) | |||||
| data_train.load(cws_data_path) | |||||
| train_args["vocab_size"] = len(data_train.word_vocab) | |||||
| train_args["num_classes"] = len(data_train.label_vocab) | |||||
| data_train = TokenizeDataSetLoader().load(cws_data_path) | |||||
| word_vocab = Vocabulary() | |||||
| label_vocab = Vocabulary() | |||||
| data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||||
| train_args["vocab_size"] = len(word_vocab) | |||||
| train_args["num_classes"] = len(label_vocab) | |||||
| save_pickle(data_train.word_vocab, pickle_path, "word2id.pkl") | |||||
| save_pickle(data_train.label_vocab, pickle_path, "label2id.pkl") | |||||
| save_pickle(word_vocab, pickle_path, "word2id.pkl") | |||||
| save_pickle(label_vocab, pickle_path, "label2id.pkl") | |||||
| # Trainer | # Trainer | ||||
| trainer = SeqLabelTrainer(**train_args.data) | trainer = SeqLabelTrainer(**train_args.data) | ||||
| @@ -90,7 +93,7 @@ def train_test(): | |||||
| tester = SeqLabelTester(**test_args.data) | tester = SeqLabelTester(**test_args.data) | ||||
| # Start testing | # Start testing | ||||
| change_field_is_target(data_train, "truth", True) | |||||
| data_train.set_target(truth=True) | |||||
| tester.test(model, data_train) | tester.test(model, data_train) | ||||
| @@ -1,6 +1,6 @@ | |||||
| import os | import os | ||||
| from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | from fastNLP.core.metrics import SeqLabelEvaluator | ||||
| from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
| from fastNLP.core.preprocess import save_pickle | from fastNLP.core.preprocess import save_pickle | ||||
| @@ -25,8 +25,8 @@ def test_training(): | |||||
| ConfigLoader().load_config(config_dir, { | ConfigLoader().load_config(config_dir, { | ||||
| "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | ||||
| data_set = SeqLabelDataSet() | |||||
| data_set.load(data_path) | |||||
| data_set = DataSet() | |||||
| word_vocab = V | |||||
| data_train, data_dev = data_set.split(0.3, shuffle=True) | data_train, data_dev = data_set.split(0.3, shuffle=True) | ||||
| model_args["vocab_size"] = len(data_set.word_vocab) | model_args["vocab_size"] = len(data_set.word_vocab) | ||||
| model_args["num_classes"] = len(data_set.label_vocab) | model_args["num_classes"] = len(data_set.label_vocab) | ||||
| @@ -76,5 +76,5 @@ def test_training(): | |||||
| ) | ) | ||||
| # Start testing with validation data | # Start testing with validation data | ||||
| change_field_is_target(data_dev, "truth", True) | |||||
| data_dev.set_target(truth=True) | |||||
| tester.test(model, data_dev) | tester.test(model, data_dev) | ||||