2. vocabulary增加from_dataset(), index_dataset()函数。避免需要多行写index dataset的问题。 3. 在utils中新增一个cache_result()修饰器,用于cache函数的返回值。 4. callback中新增update_every属性tags/v0.4.10
| @@ -1,5 +1,5 @@ | |||
| from .batch import Batch | |||
| # from .dataset import DataSet | |||
| from .dataset import DataSet | |||
| from .fieldarray import FieldArray | |||
| from .instance import Instance | |||
| from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||
| @@ -9,5 +9,5 @@ from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSample | |||
| from .tester import Tester | |||
| from .trainer import Trainer | |||
| from .vocabulary import Vocabulary | |||
| from ..io.dataset_loader import DataSet | |||
| from .callback import Callback | |||
| from .utils import cache_results | |||
| @@ -61,6 +61,10 @@ class Callback(object): | |||
| """If use_tqdm, return trainer's tqdm print bar, else return None.""" | |||
| return self._trainer.pbar | |||
| @property | |||
| def update_every(self): | |||
| """The model in trainer will update parameters every `update_every` batches.""" | |||
| return self._trainer.update_every | |||
| def on_train_begin(self): | |||
| # before the main training loop | |||
| pass | |||
| @@ -6,7 +6,6 @@ from fastNLP.core.fieldarray import AutoPadder | |||
| from fastNLP.core.fieldarray import FieldArray | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.utils import get_func_signature | |||
| from fastNLP.io.base_loader import DataLoaderRegister | |||
| class DataSet(object): | |||
| @@ -105,11 +104,6 @@ class DataSet(object): | |||
| raise AttributeError | |||
| if isinstance(item, str) and item in self.field_arrays: | |||
| return self.field_arrays[item] | |||
| try: | |||
| reader = DataLoaderRegister.get_reader(item) | |||
| return reader | |||
| except AttributeError: | |||
| raise | |||
| def __setstate__(self, state): | |||
| self.__dict__ = state | |||
| @@ -369,7 +363,7 @@ class DataSet(object): | |||
| :return dataset: the read data set | |||
| """ | |||
| with open(csv_path, "r") as f: | |||
| with open(csv_path, "r", encoding='utf-8') as f: | |||
| start_idx = 0 | |||
| if headers is None: | |||
| headers = f.readline().rstrip('\r\n') | |||
| @@ -11,6 +11,64 @@ import torch | |||
| CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
| 'varargs']) | |||
| def _prepare_cache_filepath(filepath): | |||
| """ | |||
| 检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | |||
| :param filepath: str. | |||
| :return: None, if not, this function will raise error | |||
| """ | |||
| _cache_filepath = os.path.abspath(filepath) | |||
| if os.path.isdir(_cache_filepath): | |||
| raise RuntimeError("The cache_file_path must be a file, not a directory.") | |||
| cache_dir = os.path.dirname(_cache_filepath) | |||
| if not os.path.exists(cache_dir): | |||
| os.makedirs(cache_dir) | |||
| def cache_results(cache_filepath, refresh=False, verbose=1): | |||
| def wrapper_(func): | |||
| signature = inspect.signature(func) | |||
| for key, _ in signature.parameters.items(): | |||
| if key in ('cache_filepath', 'refresh', 'verbose'): | |||
| raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||
| def wrapper(*args, **kwargs): | |||
| if 'cache_filepath' in kwargs: | |||
| _cache_filepath = kwargs.pop('cache_filepath') | |||
| assert isinstance(_cache_filepath, str), "cache_filepath can only be str." | |||
| else: | |||
| _cache_filepath = cache_filepath | |||
| if 'refresh' in kwargs: | |||
| _refresh = kwargs.pop('refresh') | |||
| assert isinstance(_refresh, bool), "refresh can only be bool." | |||
| else: | |||
| _refresh = refresh | |||
| if 'verbose' in kwargs: | |||
| _verbose = kwargs.pop('verbose') | |||
| assert isinstance(_verbose, int), "verbose can only be integer." | |||
| refresh_flag = True | |||
| if _cache_filepath is not None and _refresh is False: | |||
| # load data | |||
| if os.path.exists(_cache_filepath): | |||
| with open(_cache_filepath, 'rb') as f: | |||
| results = _pickle.load(f) | |||
| if verbose==1: | |||
| print("Read cache from {}.".format(_cache_filepath)) | |||
| refresh_flag = False | |||
| if refresh_flag: | |||
| results = func(*args, **kwargs) | |||
| if _cache_filepath is not None: | |||
| if results is None: | |||
| raise RuntimeError("The return value is None. Delete the decorator.") | |||
| _prepare_cache_filepath(_cache_filepath) | |||
| with open(_cache_filepath, 'wb') as f: | |||
| _pickle.dump(results, f) | |||
| print("Save cache to {}.".format(_cache_filepath)) | |||
| return results | |||
| return wrapper | |||
| return wrapper_ | |||
| def save_pickle(obj, pickle_path, file_name): | |||
| """Save an object into a pickle file. | |||
| @@ -1,5 +1,5 @@ | |||
| from collections import Counter | |||
| from fastNLP.core.dataset import DataSet | |||
| def check_build_vocab(func): | |||
| """A decorator to make sure the indexing is built before used. | |||
| @@ -151,6 +151,68 @@ class Vocabulary(object): | |||
| else: | |||
| raise ValueError("word {} not in vocabulary".format(w)) | |||
| @check_build_vocab | |||
| def index_dataset(self, *datasets, field_name, new_field_name=None): | |||
| """ | |||
| example: | |||
| # remember to use `field_name` | |||
| vocab.index_dataset(tr_data, dev_data, te_data, field_name='words') | |||
| :param datasets: fastNLP Dataset type. you can pass multiple datasets | |||
| :param field_name: str, what field to index. Only support 0,1,2 dimension. | |||
| :param new_field_name: str. What the indexed field should be named, default is to overwrite field_name | |||
| :return: | |||
| """ | |||
| def index_instance(ins): | |||
| """ | |||
| 有几种情况, str, 1d-list, 2d-list | |||
| :param ins: | |||
| :return: | |||
| """ | |||
| field = ins[field_name] | |||
| if isinstance(field, str): | |||
| return self.to_index(field) | |||
| elif isinstance(field, list): | |||
| if not isinstance(field[0], list): | |||
| return [self.to_index(w) for w in field] | |||
| else: | |||
| if isinstance(field[0][0], list): | |||
| raise RuntimeError("Only support field with 2 dimensions.") | |||
| return[[self.to_index(c) for c in w] for w in field] | |||
| if new_field_name is None: | |||
| new_field_name = field_name | |||
| for dataset in datasets: | |||
| if isinstance(dataset, DataSet): | |||
| dataset.apply(index_instance, new_field_name=new_field_name) | |||
| else: | |||
| raise RuntimeError("Only DataSet type is allowed.") | |||
| def from_dataset(self, *datasets, field_name): | |||
| """ | |||
| Construct vocab from dataset. | |||
| :param datasets: DataSet. | |||
| :param field_name: str, what field is used to construct dataset. | |||
| :return: | |||
| """ | |||
| def construct_vocab(ins): | |||
| field = ins[field_name] | |||
| if isinstance(field, str): | |||
| self.add_word(field) | |||
| elif isinstance(field, list): | |||
| if not isinstance(field[0], list): | |||
| self.add_word_lst(field) | |||
| else: | |||
| if isinstance(field[0][0], list): | |||
| raise RuntimeError("Only support field with 2 dimensions.") | |||
| [self.add_word_lst(w) for w in field] | |||
| for dataset in datasets: | |||
| if isinstance(dataset, DataSet): | |||
| dataset.apply(construct_vocab) | |||
| else: | |||
| raise RuntimeError("Only DataSet type is allowed.") | |||
| def to_index(self, w): | |||
| """ Turn a word to an index. If w is not in Vocabulary, return the unknown label. | |||
| @@ -0,0 +1 @@ | |||
| from .embed_loader import EmbedLoader | |||
| @@ -1,3 +1,5 @@ | |||
| import os | |||
| import numpy as np | |||
| import torch | |||
| @@ -124,3 +126,97 @@ class EmbedLoader(BaseLoader): | |||
| size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||
| embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | |||
| return embedding_matrix | |||
| @staticmethod | |||
| def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True): | |||
| """ | |||
| load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining | |||
| embedding are initialized from a normal distribution which has the mean and std of the found words vectors. | |||
| The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||
| :param embed_filepath: str, where to read pretrain embedding | |||
| :param vocab: Vocabulary. | |||
| :param dtype: the dtype of the embedding matrix | |||
| :param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||
| :return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain | |||
| embedding | |||
| """ | |||
| assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||
| if not os.path.exists(embed_filepath): | |||
| raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) | |||
| with open(embed_filepath, 'r', encoding='utf-8') as f: | |||
| hit_flags = np.zeros(len(vocab), dtype=bool) | |||
| line = f.readline().strip() | |||
| parts = line.split() | |||
| if len(parts)==2: | |||
| dim = int(parts[1]) | |||
| else: | |||
| dim = len(parts)-1 | |||
| f.seek(0) | |||
| matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||
| for line in f: | |||
| parts = line.strip().split() | |||
| if parts[0] in vocab: | |||
| index = vocab.to_index(parts[0]) | |||
| matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | |||
| hit_flags[index] = True | |||
| total_hits = sum(hit_flags) | |||
| print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | |||
| found_vectors = matrix[hit_flags] | |||
| if len(found_vectors)!=0: | |||
| mean = np.mean(found_vectors, axis=1, keepdims=True) | |||
| std = np.std(found_vectors, axis=1, keepdims=True) | |||
| unfound_vec_num = len(vocab) - total_hits | |||
| r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype)*std + mean | |||
| matrix[hit_flags==False] = r_vecs | |||
| if normalize: | |||
| matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||
| return matrix | |||
| @staticmethod | |||
| def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True): | |||
| """ | |||
| load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding. | |||
| The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||
| :param embed_filepath: str, where to read pretrain embedding | |||
| :param dtype: the dtype of the embedding matrix | |||
| :param padding: the padding tag for vocabulary. | |||
| :param unknown: the unknown tag for vocabulary. | |||
| :param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||
| :return: np.ndarray() is determined by the pretraining embeddings | |||
| Vocabulary: contain all pretraining words and two special tag[<pad>, <unk>] | |||
| """ | |||
| vocab = Vocabulary(padding=padding, unknown=unknown) | |||
| vec_dict = {} | |||
| with open(embed_filepath, 'r', encoding='utf-8') as f: | |||
| line = f.readline() | |||
| start = 1 | |||
| dim = -1 | |||
| if len(line.strip().split())!=2: | |||
| f.seek(0) | |||
| start = 0 | |||
| for idx, line in enumerate(f, start=start): | |||
| parts = line.strip().split() | |||
| word = parts[0] | |||
| if dim==-1: | |||
| dim = len(parts)-1 | |||
| vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | |||
| vec_dict[word] = vec | |||
| vocab.add_word(word) | |||
| if dim==-1: | |||
| raise RuntimeError("{} is an empty file.".format(embed_filepath)) | |||
| matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||
| for key, vec in vec_dict.items(): | |||
| index = vocab.to_index(key) | |||
| matrix[index] = vec | |||
| if normalize: | |||
| matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||
| return matrix, vocab | |||
| @@ -1,35 +0,0 @@ | |||
| import logging | |||
| import os | |||
| def create_logger(logger_name, log_path, log_format=None, log_level=logging.INFO): | |||
| """Create a logger. | |||
| :param str logger_name: | |||
| :param str log_path: | |||
| :param log_format: | |||
| :param log_level: | |||
| :return: logger | |||
| To use a logger:: | |||
| logger.debug("this is a debug message") | |||
| logger.info("this is a info message") | |||
| logger.warning("this is a warning message") | |||
| logger.error("this is an error message") | |||
| """ | |||
| logger = logging.getLogger(logger_name) | |||
| logger.setLevel(log_level) | |||
| if log_path is None: | |||
| handler = logging.StreamHandler() | |||
| else: | |||
| os.stat(os.path.dirname(os.path.abspath(log_path))) | |||
| handler = logging.FileHandler(log_path) | |||
| handler.setLevel(log_level) | |||
| if log_format is None: | |||
| log_format = "[%(asctime)s %(name)-13s %(levelname)s %(process)d %(thread)d " \ | |||
| "%(filename)s:%(lineno)-5d] %(message)s" | |||
| formatter = logging.Formatter(log_format) | |||
| handler.setFormatter(formatter) | |||
| logger.addHandler(handler) | |||
| return logger | |||
| @@ -0,0 +1,115 @@ | |||
| import unittest | |||
| import _pickle | |||
| from fastNLP import cache_results | |||
| from fastNLP.io.embed_loader import EmbedLoader | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| import time | |||
| import os | |||
| @cache_results('test/demo1.pkl') | |||
| def process_data_1(embed_file, cws_train): | |||
| embed, vocab = EmbedLoader.load_without_vocab(embed_file) | |||
| time.sleep(1) # 测试是否通过读取cache获得结果 | |||
| with open(cws_train, 'r', encoding='utf-8') as f: | |||
| d = DataSet() | |||
| for line in f: | |||
| line = line.strip() | |||
| if len(line)>0: | |||
| d.append(Instance(raw=line)) | |||
| return embed, vocab, d | |||
| class TestCache(unittest.TestCase): | |||
| def test_cache_save(self): | |||
| try: | |||
| start_time = time.time() | |||
| embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train') | |||
| end_time = time.time() | |||
| pre_time = end_time - start_time | |||
| with open('test/demo1.pkl', 'rb') as f: | |||
| _embed, _vocab, _d = _pickle.load(f) | |||
| self.assertEqual(embed.shape, _embed.shape) | |||
| for i in range(embed.shape[0]): | |||
| self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||
| start_time = time.time() | |||
| embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train') | |||
| end_time = time.time() | |||
| read_time = end_time - start_time | |||
| print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
| self.assertGreater(pre_time-0.5, read_time) | |||
| finally: | |||
| os.remove('test/demo1.pkl') | |||
| def test_cache_save_overwrite_path(self): | |||
| try: | |||
| start_time = time.time() | |||
| embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||
| cache_filepath='test/demo_overwrite.pkl') | |||
| end_time = time.time() | |||
| pre_time = end_time - start_time | |||
| with open('test/demo_overwrite.pkl', 'rb') as f: | |||
| _embed, _vocab, _d = _pickle.load(f) | |||
| self.assertEqual(embed.shape, _embed.shape) | |||
| for i in range(embed.shape[0]): | |||
| self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||
| start_time = time.time() | |||
| embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||
| cache_filepath='test/demo_overwrite.pkl') | |||
| end_time = time.time() | |||
| read_time = end_time - start_time | |||
| print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
| self.assertGreater(pre_time-0.5, read_time) | |||
| finally: | |||
| os.remove('test/demo_overwrite.pkl') | |||
| def test_cache_refresh(self): | |||
| try: | |||
| start_time = time.time() | |||
| embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||
| refresh=True) | |||
| end_time = time.time() | |||
| pre_time = end_time - start_time | |||
| with open('test/demo1.pkl', 'rb') as f: | |||
| _embed, _vocab, _d = _pickle.load(f) | |||
| self.assertEqual(embed.shape, _embed.shape) | |||
| for i in range(embed.shape[0]): | |||
| self.assertListEqual(embed[i].tolist(), _embed[i].tolist()) | |||
| start_time = time.time() | |||
| embed, vocab, d = process_data_1('test/data_for_tests/word2vec_test.txt', 'test/data_for_tests/cws_train', | |||
| refresh=True) | |||
| end_time = time.time() | |||
| read_time = end_time - start_time | |||
| print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
| self.assertGreater(0.1, pre_time-read_time) | |||
| finally: | |||
| os.remove('test/demo1.pkl') | |||
| def test_duplicate_keyword(self): | |||
| with self.assertRaises(RuntimeError): | |||
| @cache_results(None) | |||
| def func_verbose(a, verbose): | |||
| pass | |||
| func_verbose(0, 1) | |||
| with self.assertRaises(RuntimeError): | |||
| @cache_results(None) | |||
| def func_cache(a, cache_filepath): | |||
| pass | |||
| func_cache(1, 2) | |||
| with self.assertRaises(RuntimeError): | |||
| @cache_results(None) | |||
| def func_refresh(a, refresh): | |||
| pass | |||
| func_refresh(1, 2) | |||
| def test_create_cache_dir(self): | |||
| @cache_results('test/demo1/demo.pkl') | |||
| def cache(): | |||
| return 1, 2 | |||
| try: | |||
| results = cache() | |||
| print(results) | |||
| finally: | |||
| os.remove('test/demo1/demo.pkl') | |||
| os.rmdir('test/demo1') | |||
| @@ -2,6 +2,8 @@ import unittest | |||
| from collections import Counter | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||
| "works", "well", "in", "most", "cases", "scales", "well"] | |||
| @@ -31,6 +33,42 @@ class TestAdd(unittest.TestCase): | |||
| vocab.update(text) | |||
| self.assertEqual(vocab.word_count, counter) | |||
| def test_from_dataset(self): | |||
| start_char = 65 | |||
| num_samples = 10 | |||
| # 0 dim | |||
| dataset = DataSet() | |||
| for i in range(num_samples): | |||
| ins = Instance(char=chr(start_char+i)) | |||
| dataset.append(ins) | |||
| vocab = Vocabulary() | |||
| vocab.from_dataset(dataset, field_name='char') | |||
| for i in range(num_samples): | |||
| self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
| vocab.index_dataset(dataset, field_name='char') | |||
| # 1 dim | |||
| dataset = DataSet() | |||
| for i in range(num_samples): | |||
| ins = Instance(char=[chr(start_char+i)]*6) | |||
| dataset.append(ins) | |||
| vocab = Vocabulary() | |||
| vocab.from_dataset(dataset, field_name='char') | |||
| for i in range(num_samples): | |||
| self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
| vocab.index_dataset(dataset, field_name='char') | |||
| # 2 dim | |||
| dataset = DataSet() | |||
| for i in range(num_samples): | |||
| ins = Instance(char=[[chr(start_char+i) for _ in range(6)] for _ in range(6)]) | |||
| dataset.append(ins) | |||
| vocab = Vocabulary() | |||
| vocab.from_dataset(dataset, field_name='char') | |||
| for i in range(num_samples): | |||
| self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
| vocab.index_dataset(dataset, field_name='char') | |||
| class TestIndexing(unittest.TestCase): | |||
| def test_len(self): | |||
| @@ -1,4 +1,5 @@ | |||
| import unittest | |||
| import numpy as np | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.io.embed_loader import EmbedLoader | |||
| @@ -10,3 +11,34 @@ class TestEmbedLoader(unittest.TestCase): | |||
| vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | |||
| embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab) | |||
| self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) | |||
| def test_load_with_vocab(self): | |||
| vocab = Vocabulary() | |||
| glove = "test/data_for_tests/glove.6B.50d_test.txt" | |||
| word2vec = "test/data_for_tests/word2vec_test.txt" | |||
| vocab.add_word('the') | |||
| g_m = EmbedLoader.load_with_vocab(glove, vocab) | |||
| self.assertEqual(g_m.shape, (3, 50)) | |||
| w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True) | |||
| self.assertEqual(w_m.shape, (3, 50)) | |||
| self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 3) | |||
| def test_load_without_vocab(self): | |||
| words = ['the', 'of', 'in', 'a', 'to', 'and'] | |||
| glove = "test/data_for_tests/glove.6B.50d_test.txt" | |||
| word2vec = "test/data_for_tests/word2vec_test.txt" | |||
| g_m, vocab = EmbedLoader.load_without_vocab(glove) | |||
| self.assertEqual(g_m.shape, (8, 50)) | |||
| for word in words: | |||
| self.assertIn(word, vocab) | |||
| w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True) | |||
| self.assertEqual(w_m.shape, (8, 50)) | |||
| self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 8) | |||
| for word in words: | |||
| self.assertIn(word, vocab) | |||
| # no unk | |||
| w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True, unknown=None) | |||
| self.assertEqual(w_m.shape, (7, 50)) | |||
| self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7) | |||
| for word in words: | |||
| self.assertIn(word, vocab) | |||