2. 修改embedding.py中的bug 3. ConllReader默认跳过所有的DOCSTART标签 4. 交换bert的heavy lifting到_bert, 将BertEncoder在bert.py中暴露 5. crf中allow_transition的include_end_start修改为false,以与CRF的默认值适配 6. allow_transition与SpanMetric支持BIOES类型的tag 7. datainfo中增加打印格式化输出tags/v0.4.10
| @@ -12,7 +12,11 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的 | |||
| __all__ = [ | |||
| "Instance", | |||
| "FieldArray", | |||
| "Batch", | |||
| "DataSetIter", | |||
| "BatchIter", | |||
| "TorchLoaderIter", | |||
| "Vocabulary", | |||
| "DataSet", | |||
| "Const", | |||
| @@ -14,7 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||
| 介绍core 的子模块的分工,好像必要性不大 | |||
| """ | |||
| from .batch import Batch | |||
| from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||
| from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | |||
| from .const import Const | |||
| from .dataset import DataSet | |||
| @@ -3,7 +3,9 @@ batch 模块实现了 fastNLP 所需的 Batch 类。 | |||
| """ | |||
| __all__ = [ | |||
| "Batch" | |||
| "BatchIter", | |||
| "DataSetIter", | |||
| "TorchLoaderIter", | |||
| ] | |||
| import atexit | |||
| @@ -12,9 +14,11 @@ from queue import Empty, Full | |||
| import numpy as np | |||
| import torch | |||
| import torch.multiprocessing as mp | |||
| import torch.utils.data | |||
| from numbers import Number | |||
| from .sampler import RandomSampler | |||
| from .sampler import SequentialSampler | |||
| from .dataset import DataSet | |||
| _python_is_exit = False | |||
| @@ -27,162 +31,157 @@ def _set_python_is_exit(): | |||
| atexit.register(_set_python_is_exit) | |||
| class Batch(object): | |||
| """ | |||
| 别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch` | |||
| Batch 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | |||
| 组成 `x` 和 `y`:: | |||
| batch = Batch(data_set, batch_size=16, sampler=SequentialSampler()) | |||
| num_batch = len(batch) | |||
| for batch_x, batch_y in batch: | |||
| # do stuff ... | |||
| :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
| :param int batch_size: 取出的batch大小 | |||
| :param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.RandomSampler`. | |||
| Default: ``None`` | |||
| :param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | |||
| Default: ``False`` | |||
| :param bool prefetch: 若为 ``True`` 使用多进程预先取出下一batch. | |||
| Default: ``False`` | |||
| """ | |||
| def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): | |||
| class DataSetGetter: | |||
| def __init__(self, dataset: DataSet, as_numpy=False): | |||
| self.dataset = dataset | |||
| self.batch_size = batch_size | |||
| if sampler is None: | |||
| sampler = RandomSampler() | |||
| self.sampler = sampler | |||
| self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} | |||
| self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} | |||
| self.as_numpy = as_numpy | |||
| self.idx_list = None | |||
| self.curidx = 0 | |||
| self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | |||
| self.cur_batch_indices = None | |||
| self.prefetch = prefetch | |||
| self.lengths = 0 | |||
| def fetch_one(self): | |||
| if self.curidx >= len(self.idx_list): | |||
| return None | |||
| self.idx_list = list(range(len(dataset))) | |||
| def __getitem__(self, idx: int): | |||
| # mapping idx to sampled idx | |||
| idx = self.idx_list[idx] | |||
| inputs = {n:f.get(idx) for n, f in self.inputs.items()} | |||
| targets = {n:f.get(idx) for n, f in self.targets.items()} | |||
| return idx, inputs, targets | |||
| def __len__(self): | |||
| return len(self.dataset) | |||
| def collate_fn(self, batch: list): | |||
| batch_x = {n:[] for n in self.inputs.keys()} | |||
| batch_y = {n:[] for n in self.targets.keys()} | |||
| indices = [] | |||
| for idx, x, y in batch: | |||
| indices.append(idx) | |||
| for n, v in x.items(): | |||
| batch_x[n].append(v) | |||
| for n, v in y.items(): | |||
| batch_y[n].append(v) | |||
| def pad_batch(batch_dict, field_array): | |||
| for n, vlist in batch_dict.items(): | |||
| f = field_array[n] | |||
| if f.padder is None: | |||
| batch_dict[n] = np.array(vlist) | |||
| else: | |||
| data = f.pad(vlist) | |||
| if not self.as_numpy: | |||
| data, flag = _to_tensor(data, f.dtype) | |||
| batch_dict[n] = data | |||
| return batch_dict | |||
| return (indices, | |||
| pad_batch(batch_x, self.inputs), | |||
| pad_batch(batch_y, self.targets)) | |||
| def set_idx_list(self, idx_list): | |||
| if len(idx_list) != len(self.idx_list): | |||
| raise ValueError | |||
| self.idx_list = idx_list | |||
| def __getattr__(self, item): | |||
| if hasattr(self.dataset, item): | |||
| return getattr(self.dataset, item) | |||
| else: | |||
| endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||
| batch_x, batch_y = {}, {} | |||
| indices = self.idx_list[self.curidx:endidx] | |||
| self.cur_batch_indices = indices | |||
| for field_name, field in self.dataset.get_all_fields().items(): | |||
| if field.is_target or field.is_input: | |||
| batch = field.get(indices) | |||
| if not self.as_numpy and \ | |||
| field.dtype is not None and \ | |||
| issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): | |||
| batch = _to_tensor(batch) | |||
| if field.is_target: | |||
| batch_y[field_name] = batch | |||
| if field.is_input: | |||
| batch_x[field_name] = batch | |||
| self.curidx = endidx | |||
| return batch_x, batch_y | |||
| raise AttributeError("'DataSetGetter' object has no attribute '{}'".format(item)) | |||
| class SamplerAdapter(torch.utils.data.Sampler): | |||
| def __init__(self, sampler, dataset): | |||
| self.sampler = sampler | |||
| self.dataset = dataset | |||
| def __iter__(self): | |||
| """ | |||
| Iterate on dataset, fetch batch data. Fetch process don't block the iterate process | |||
| :return: | |||
| """ | |||
| if self.prefetch: | |||
| return self._run_batch_iter(self) | |||
| def batch_iter(): | |||
| self.init_iter() | |||
| while 1: | |||
| res = self.fetch_one() | |||
| if res is None: | |||
| break | |||
| yield res | |||
| return batch_iter() | |||
| return iter(self.sampler(self.dataset)) | |||
| class BatchIter: | |||
| def __init__(self): | |||
| self.dataiter = None | |||
| self.num_batches = None | |||
| self.cur_batch_indices = None | |||
| self.batch_size = None | |||
| def init_iter(self): | |||
| self.idx_list = self.sampler(self.dataset) | |||
| self.curidx = 0 | |||
| self.lengths = self.dataset.get_length() | |||
| pass | |||
| @staticmethod | |||
| def get_num_batches(num_samples, batch_size, drop_last): | |||
| num_batches = num_samples // batch_size | |||
| if not drop_last and (num_samples % batch_size > 0): | |||
| num_batches += 1 | |||
| return num_batches | |||
| def __iter__(self): | |||
| self.init_iter() | |||
| for indices, batch_x, batch_y in self.dataiter: | |||
| self.cur_batch_indices = indices | |||
| yield batch_x, batch_y | |||
| def get_batch_indices(self): | |||
| return self.cur_batch_indices | |||
| def __len__(self): | |||
| return self.num_batches | |||
| def get_batch_indices(self): | |||
| """ | |||
| 取得当前batch在DataSet中所在的index下标序列 | |||
| :return list(int) indexes: 下标序列 | |||
| """ | |||
| return self.cur_batch_indices | |||
| @staticmethod | |||
| def _run_fetch(batch, q): | |||
| try: | |||
| global _python_is_exit | |||
| batch.init_iter() | |||
| # print('start fetch') | |||
| while 1: | |||
| res = batch.fetch_one() | |||
| # print('fetch one') | |||
| while 1: | |||
| try: | |||
| q.put(res, timeout=3) | |||
| break | |||
| except Full: | |||
| if _python_is_exit: | |||
| return | |||
| if res is None: | |||
| # print('fetch done, waiting processing') | |||
| break | |||
| # print('fetch exit') | |||
| except Exception as e: | |||
| q.put(e) | |||
| finally: | |||
| q.join() | |||
| @staticmethod | |||
| def _run_batch_iter(batch): | |||
| q = mp.JoinableQueue(maxsize=10) | |||
| fetch_p = mp.Process(target=Batch._run_fetch, args=(batch, q)) | |||
| fetch_p.daemon = True | |||
| fetch_p.start() | |||
| # print('fork fetch process') | |||
| while 1: | |||
| try: | |||
| res = q.get(timeout=1) | |||
| q.task_done() | |||
| # print('get fetched') | |||
| if res is None: | |||
| break | |||
| elif isinstance(res, Exception): | |||
| raise res | |||
| yield res | |||
| except Empty as e: | |||
| if fetch_p.is_alive(): | |||
| continue | |||
| else: | |||
| break | |||
| fetch_p.terminate() | |||
| fetch_p.join() | |||
| # print('iter done') | |||
| @property | |||
| def dataset(self): | |||
| return self.dataiter.dataset | |||
| class DataSetIter(BatchIter): | |||
| def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | |||
| num_workers=0, pin_memory=False, drop_last=False, | |||
| timeout=0, worker_init_fn=None): | |||
| super().__init__() | |||
| assert isinstance(dataset, DataSet) | |||
| sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||
| dataset = DataSetGetter(dataset, as_numpy) | |||
| collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None | |||
| self.dataiter = torch.utils.data.DataLoader( | |||
| dataset=dataset, batch_size=batch_size, sampler=sampler, | |||
| collate_fn=collate_fn, num_workers=num_workers, | |||
| pin_memory=pin_memory, drop_last=drop_last, | |||
| timeout=timeout, worker_init_fn=worker_init_fn) | |||
| self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last) | |||
| self.batch_size = batch_size | |||
| class TorchLoaderIter(BatchIter): | |||
| def __init__(self, dataset): | |||
| super().__init__() | |||
| assert isinstance(dataset, torch.utils.data.DataLoader) | |||
| self.dataiter = dataset | |||
| self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last) | |||
| self.batch_size = dataset.batch_size | |||
| def _to_tensor(batch): | |||
| class OnlineDataGettter: | |||
| # TODO | |||
| pass | |||
| class OnlineDataIter(BatchIter): | |||
| # TODO | |||
| def __init__(self, dataset, batch_size=1, buffer_size=10000, sampler=None, as_numpy=False, | |||
| num_workers=0, pin_memory=False, drop_last=False, | |||
| timeout=0, worker_init_fn=None, **kwargs): | |||
| super().__init__() | |||
| def _to_tensor(batch, field_dtype): | |||
| try: | |||
| if issubclass(batch.dtype.type, np.floating): | |||
| batch = torch.as_tensor(batch).float() # 默认使用float32 | |||
| if field_dtype is not None \ | |||
| and issubclass(field_dtype, Number) \ | |||
| and not isinstance(batch, torch.Tensor): | |||
| if issubclass(batch.dtype.type, np.floating): | |||
| new_batch = torch.as_tensor(batch).float() # 默认使用float32 | |||
| else: | |||
| new_batch = torch.as_tensor(batch) # 复用内存地址,避免复制 | |||
| return new_batch, True | |||
| else: | |||
| batch = torch.as_tensor(batch) # 复用内存地址,避免复制 | |||
| return batch, False | |||
| except: | |||
| pass | |||
| return batch | |||
| return batch, False | |||
| @@ -176,7 +176,10 @@ class FieldArray: | |||
| if self.padder is None or pad is False: | |||
| return np.array(contents) | |||
| else: | |||
| return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||
| return self.pad(contents) | |||
| def pad(self, contents): | |||
| return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||
| def set_padder(self, padder): | |||
| """ | |||
| @@ -34,14 +34,23 @@ class LossBase(object): | |||
| """ | |||
| def __init__(self): | |||
| self.param_map = {} | |||
| self._param_map = {} # key是fun的参数,value是以该值从传入的dict取出value | |||
| self._checked = False | |||
| @property | |||
| def param_map(self): | |||
| if len(self._param_map) == 0: # 如果为空说明还没有初始化 | |||
| func_spect = inspect.getfullargspec(self.get_loss) | |||
| func_args = [arg for arg in func_spect.args if arg != 'self'] | |||
| for arg in func_args: | |||
| self._param_map[arg] = arg | |||
| return self._param_map | |||
| def get_loss(self, *args, **kwargs): | |||
| raise NotImplementedError | |||
| def _init_param_map(self, key_map=None, **kwargs): | |||
| """检查key_map和其他参数map,并将这些映射关系添加到self.param_map | |||
| """检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||
| :param dict key_map: 表示key的映射关系 | |||
| :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | |||
| @@ -53,30 +62,30 @@ class LossBase(object): | |||
| raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | |||
| for key, value in key_map.items(): | |||
| if value is None: | |||
| self.param_map[key] = key | |||
| self._param_map[key] = key | |||
| continue | |||
| if not isinstance(key, str): | |||
| raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||
| if not isinstance(value, str): | |||
| raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||
| self.param_map[key] = value | |||
| self._param_map[key] = value | |||
| value_counter[value].add(key) | |||
| for key, value in kwargs.items(): | |||
| if value is None: | |||
| self.param_map[key] = key | |||
| self._param_map[key] = key | |||
| continue | |||
| if not isinstance(value, str): | |||
| raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||
| self.param_map[key] = value | |||
| self._param_map[key] = value | |||
| value_counter[value].add(key) | |||
| for value, key_set in value_counter.items(): | |||
| if len(key_set) > 1: | |||
| raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | |||
| # check consistence between signature and param_map | |||
| # check consistence between signature and _param_map | |||
| func_spect = inspect.getfullargspec(self.get_loss) | |||
| func_args = [arg for arg in func_spect.args if arg != 'self'] | |||
| for func_param, input_param in self.param_map.items(): | |||
| for func_param, input_param in self._param_map.items(): | |||
| if func_param not in func_args: | |||
| raise NameError( | |||
| f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " | |||
| @@ -96,7 +105,7 @@ class LossBase(object): | |||
| :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||
| """ | |||
| fast_param = {} | |||
| if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||
| if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||
| fast_param['pred'] = list(pred_dict.values())[0] | |||
| fast_param['target'] = list(target_dict.values())[0] | |||
| return fast_param | |||
| @@ -115,19 +124,19 @@ class LossBase(object): | |||
| return loss | |||
| if not self._checked: | |||
| # 1. check consistence between signature and param_map | |||
| # 1. check consistence between signature and _param_map | |||
| func_spect = inspect.getfullargspec(self.get_loss) | |||
| func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||
| for func_arg, input_arg in self.param_map.items(): | |||
| for func_arg, input_arg in self._param_map.items(): | |||
| if func_arg not in func_args: | |||
| raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") | |||
| # 2. only part of the param_map are passed, left are not | |||
| # 2. only part of the _param_map are passed, left are not | |||
| for arg in func_args: | |||
| if arg not in self.param_map: | |||
| self.param_map[arg] = arg # This param does not need mapping. | |||
| if arg not in self._param_map: | |||
| self._param_map[arg] = arg # This param does not need mapping. | |||
| self._evaluate_args = func_args | |||
| self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||
| self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} | |||
| mapped_pred_dict = {} | |||
| mapped_target_dict = {} | |||
| @@ -149,7 +158,7 @@ class LossBase(object): | |||
| replaced_missing = list(missing) | |||
| for idx, func_arg in enumerate(missing): | |||
| # Don't delete `` in this information, nor add `` | |||
| replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||
| replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||
| f"in `{self.__class__.__name__}`)" | |||
| check_res = _CheckRes(missing=replaced_missing, | |||
| @@ -162,6 +171,8 @@ class LossBase(object): | |||
| if check_res.missing or check_res.duplicated: | |||
| raise _CheckError(check_res=check_res, | |||
| func_signature=_get_func_signature(self.get_loss)) | |||
| self._checked = True | |||
| refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | |||
| loss = self.get_loss(**refined_args) | |||
| @@ -115,9 +115,18 @@ class MetricBase(object): | |||
| """ | |||
| def __init__(self): | |||
| self.param_map = {} # key is param in function, value is input param. | |||
| self._param_map = {} # key is param in function, value is input param. | |||
| self._checked = False | |||
| @property | |||
| def param_map(self): | |||
| if len(self._param_map) == 0: # 如果为空说明还没有初始化 | |||
| func_spect = inspect.getfullargspec(self.evaluate) | |||
| func_args = [arg for arg in func_spect.args if arg != 'self'] | |||
| for arg in func_args: | |||
| self._param_map[arg] = arg | |||
| return self._param_map | |||
| @abstractmethod | |||
| def evaluate(self, *args, **kwargs): | |||
| raise NotImplementedError | |||
| @@ -127,7 +136,7 @@ class MetricBase(object): | |||
| raise NotImplemented | |||
| def _init_param_map(self, key_map=None, **kwargs): | |||
| """检查key_map和其他参数map,并将这些映射关系添加到self.param_map | |||
| """检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||
| :param dict key_map: 表示key的映射关系 | |||
| :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | |||
| @@ -139,30 +148,30 @@ class MetricBase(object): | |||
| raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | |||
| for key, value in key_map.items(): | |||
| if value is None: | |||
| self.param_map[key] = key | |||
| self._param_map[key] = key | |||
| continue | |||
| if not isinstance(key, str): | |||
| raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||
| if not isinstance(value, str): | |||
| raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||
| self.param_map[key] = value | |||
| self._param_map[key] = value | |||
| value_counter[value].add(key) | |||
| for key, value in kwargs.items(): | |||
| if value is None: | |||
| self.param_map[key] = key | |||
| self._param_map[key] = key | |||
| continue | |||
| if not isinstance(value, str): | |||
| raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||
| self.param_map[key] = value | |||
| self._param_map[key] = value | |||
| value_counter[value].add(key) | |||
| for value, key_set in value_counter.items(): | |||
| if len(key_set) > 1: | |||
| raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | |||
| # check consistence between signature and param_map | |||
| # check consistence between signature and _param_map | |||
| func_spect = inspect.getfullargspec(self.evaluate) | |||
| func_args = [arg for arg in func_spect.args if arg != 'self'] | |||
| for func_param, input_param in self.param_map.items(): | |||
| for func_param, input_param in self._param_map.items(): | |||
| if func_param not in func_args: | |||
| raise NameError( | |||
| f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | |||
| @@ -177,7 +186,7 @@ class MetricBase(object): | |||
| :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||
| """ | |||
| fast_param = {} | |||
| if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||
| if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||
| fast_param['pred'] = list(pred_dict.values())[0] | |||
| fast_param['target'] = list(target_dict.values())[0] | |||
| return fast_param | |||
| @@ -206,19 +215,19 @@ class MetricBase(object): | |||
| if not self._checked: | |||
| if not callable(self.evaluate): | |||
| raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | |||
| # 1. check consistence between signature and param_map | |||
| # 1. check consistence between signature and _param_map | |||
| func_spect = inspect.getfullargspec(self.evaluate) | |||
| func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||
| for func_arg, input_arg in self.param_map.items(): | |||
| for func_arg, input_arg in self._param_map.items(): | |||
| if func_arg not in func_args: | |||
| raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | |||
| # 2. only part of the param_map are passed, left are not | |||
| # 2. only part of the _param_map are passed, left are not | |||
| for arg in func_args: | |||
| if arg not in self.param_map: | |||
| self.param_map[arg] = arg # This param does not need mapping. | |||
| if arg not in self._param_map: | |||
| self._param_map[arg] = arg # This param does not need mapping. | |||
| self._evaluate_args = func_args | |||
| self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||
| self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} | |||
| # need to wrap inputs in dict. | |||
| mapped_pred_dict = {} | |||
| @@ -242,7 +251,7 @@ class MetricBase(object): | |||
| replaced_missing = list(missing) | |||
| for idx, func_arg in enumerate(missing): | |||
| # Don't delete `` in this information, nor add `` | |||
| replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||
| replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||
| f"in `{self.__class__.__name__}`)" | |||
| check_res = _CheckRes(missing=replaced_missing, | |||
| @@ -255,10 +264,10 @@ class MetricBase(object): | |||
| if check_res.missing or check_res.duplicated: | |||
| raise _CheckError(check_res=check_res, | |||
| func_signature=_get_func_signature(self.evaluate)) | |||
| self._checked = True | |||
| refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | |||
| self.evaluate(**refined_args) | |||
| self._checked = True | |||
| return | |||
| @@ -416,19 +425,19 @@ def _bioes_tag_to_spans(tags, ignore_labels=None): | |||
| ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
| spans = [] | |||
| prev_bmes_tag = None | |||
| prev_bioes_tag = None | |||
| for idx, tag in enumerate(tags): | |||
| tag = tag.lower() | |||
| bmes_tag, label = tag[:1], tag[2:] | |||
| if bmes_tag in ('b', 's'): | |||
| bieso_tag, label = tag[:1], tag[2:] | |||
| if bieso_tag in ('b', 's'): | |||
| spans.append((label, [idx, idx])) | |||
| elif bmes_tag in ('i', 'e') and prev_bmes_tag in ('b', 'i') and label == spans[-1][0]: | |||
| elif bieso_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: | |||
| spans[-1][1][1] = idx | |||
| elif bmes_tag == 'o': | |||
| elif bieso_tag == 'o': | |||
| pass | |||
| else: | |||
| spans.append((label, [idx, idx])) | |||
| prev_bmes_tag = bmes_tag | |||
| prev_bioes_tag = bieso_tag | |||
| return [(span[0], (span[1][0], span[1][1] + 1)) | |||
| for span in spans | |||
| if span[0] not in ignore_labels | |||
| @@ -6,7 +6,7 @@ from collections import defaultdict | |||
| import torch | |||
| from . import Batch | |||
| from . import DataSetIter | |||
| from . import DataSet | |||
| from . import SequentialSampler | |||
| from .utils import _build_args | |||
| @@ -44,8 +44,7 @@ class Predictor(object): | |||
| self.network.eval() | |||
| batch_output = defaultdict(list) | |||
| data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False, | |||
| prefetch=False) | |||
| data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
| if hasattr(self.network, "predict"): | |||
| predict_func = self.network.predict | |||
| @@ -37,7 +37,7 @@ import warnings | |||
| import torch | |||
| import torch.nn as nn | |||
| from .batch import Batch | |||
| from .batch import BatchIter, DataSetIter | |||
| from .dataset import DataSet | |||
| from .metrics import _prepare_metrics | |||
| from .sampler import SequentialSampler | |||
| @@ -82,7 +82,7 @@ class Tester(object): | |||
| :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||
| """ | |||
| def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | |||
| def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): | |||
| super(Tester, self).__init__() | |||
| if not isinstance(data, DataSet): | |||
| @@ -96,6 +96,14 @@ class Tester(object): | |||
| self._model = _move_model_to_device(model, device=device) | |||
| self.batch_size = batch_size | |||
| self.verbose = verbose | |||
| if isinstance(data, DataSet): | |||
| self.data_iterator = DataSetIter( | |||
| dataset=data, batch_size=batch_size, num_workers=num_workers) | |||
| elif isinstance(data, BatchIter): | |||
| self.data_iterator = data | |||
| else: | |||
| raise TypeError("data type {} not support".format(type(data))) | |||
| # 如果是DataParallel将没有办法使用predict方法 | |||
| if isinstance(self._model, nn.DataParallel): | |||
| @@ -124,7 +132,7 @@ class Tester(object): | |||
| self._model_device = _get_model_device(self._model) | |||
| network = self._model | |||
| self._mode(network, is_test=True) | |||
| data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
| data_iterator = self.data_iterator | |||
| eval_results = {} | |||
| try: | |||
| with torch.no_grad(): | |||
| @@ -311,8 +311,9 @@ try: | |||
| from tqdm.auto import tqdm | |||
| except: | |||
| from .utils import _pseudo_tqdm as tqdm | |||
| import warnings | |||
| from .batch import Batch | |||
| from .batch import DataSetIter, BatchIter | |||
| from .callback import CallbackManager, CallbackException | |||
| from .dataset import DataSet | |||
| from .losses import _prepare_losser | |||
| @@ -320,7 +321,6 @@ from .metrics import _prepare_metrics | |||
| from .optimizer import Optimizer | |||
| from .sampler import Sampler | |||
| from .sampler import RandomSampler | |||
| from .sampler import SequentialSampler | |||
| from .tester import Tester | |||
| from .utils import _CheckError | |||
| from .utils import _build_args | |||
| @@ -351,6 +351,8 @@ class Trainer(object): | |||
| :param int batch_size: 训练和验证的时候的batch大小。 | |||
| :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | |||
| :param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | |||
| :param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch | |||
| :param num_workers: int, 有多少个线程来进行数据pad处理。 | |||
| :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | |||
| 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | |||
| :param int n_epochs: 需要优化迭代多少次。 | |||
| @@ -367,7 +369,6 @@ class Trainer(object): | |||
| :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | |||
| :param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 | |||
| 保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||
| :param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。 | |||
| :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | |||
| :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||
| 的计算位置进行管理。支持以下的输入: | |||
| @@ -394,16 +395,17 @@ class Trainer(object): | |||
| """ | |||
| def __init__(self, train_data, model, optimizer=None, loss=None, | |||
| batch_size=32, sampler=None, update_every=1, | |||
| n_epochs=10, print_every=5, | |||
| batch_size=32, sampler=None, drop_last=False, update_every=1, | |||
| num_workers=0, n_epochs=10, print_every=5, | |||
| dev_data=None, metrics=None, metric_key=None, | |||
| validate_every=-1, save_path=None, | |||
| prefetch=False, use_tqdm=True, device=None, | |||
| callbacks=None, | |||
| check_code_level=0): | |||
| validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, | |||
| callbacks=None, check_code_level=0): | |||
| if prefetch and num_workers==0: | |||
| num_workers = 1 | |||
| if prefetch: | |||
| warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.") | |||
| super(Trainer, self).__init__() | |||
| if not isinstance(train_data, DataSet): | |||
| raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||
| if not isinstance(model, nn.Module): | |||
| raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||
| @@ -430,17 +432,27 @@ class Trainer(object): | |||
| if metric_key is not None: | |||
| self.increase_better = False if metric_key[0] == "-" else True | |||
| self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
| elif len(metrics) > 0: | |||
| self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||
| else: | |||
| self.metric_key = None | |||
| # prepare loss | |||
| losser = _prepare_losser(loss) | |||
| # sampler check | |||
| if sampler is not None and not isinstance(sampler, Sampler): | |||
| raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | |||
| if sampler is None: | |||
| sampler = RandomSampler() | |||
| if isinstance(train_data, DataSet): | |||
| self.data_iterator = DataSetIter( | |||
| dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last) | |||
| elif isinstance(train_data, BatchIter): | |||
| self.data_iterator = train_data | |||
| else: | |||
| raise TypeError("train_data type {} not support".format(type(train_data))) | |||
| if check_code_level > -1: | |||
| if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | |||
| _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
| metric_key=metric_key, check_level=check_code_level, | |||
| batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||
| @@ -460,8 +472,6 @@ class Trainer(object): | |||
| self.best_dev_epoch = None | |||
| self.best_dev_step = None | |||
| self.best_dev_perf = None | |||
| self.sampler = sampler if sampler is not None else RandomSampler() | |||
| self.prefetch = prefetch | |||
| self.n_steps = (len(self.train_data) // self.batch_size + int( | |||
| len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
| @@ -493,7 +503,7 @@ class Trainer(object): | |||
| self.callback_manager = CallbackManager(env={"trainer": self}, | |||
| callbacks=callbacks) | |||
| def train(self, load_best_model=True, on_exception='auto'): | |||
| """ | |||
| 使用该函数使Trainer开始训练。 | |||
| @@ -572,8 +582,7 @@ class Trainer(object): | |||
| with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||
| self.pbar = pbar | |||
| avg_loss = 0 | |||
| data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
| prefetch=self.prefetch) | |||
| data_iterator = self.data_iterator | |||
| self.batch_per_epoch = data_iterator.num_batches | |||
| for epoch in range(1, self.n_epochs + 1): | |||
| self.epoch = epoch | |||
| @@ -746,7 +755,9 @@ class Trainer(object): | |||
| :return bool value: True means current results on dev set is the best. | |||
| """ | |||
| indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) | |||
| indicator, indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) | |||
| if self.metric_key is None: | |||
| self.metric_key = indicator | |||
| is_better = True | |||
| if self.best_metric_indicator is None: | |||
| # first-time validation | |||
| @@ -785,15 +796,34 @@ def _get_value_info(_dict): | |||
| strs.append(_str) | |||
| return strs | |||
| from numbers import Number | |||
| from .batch import _to_tensor | |||
| def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
| dev_data=None, metric_key=None, | |||
| check_level=0): | |||
| # check get_loss 方法 | |||
| model_devcie = model.parameters().__next__().device | |||
| model_devcie = _get_model_device(model=model) | |||
| batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
| for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
| def _iter(): | |||
| start_idx = 0 | |||
| while start_idx<len(dataset): | |||
| batch_x = {} | |||
| batch_y = {} | |||
| for field_name, field in dataset.get_all_fields().items(): | |||
| indices = list(range(start_idx, min(start_idx+batch_size, len(dataset)))) | |||
| if field.is_target or field.is_input: | |||
| batch = field.get(indices) | |||
| if field.dtype is not None and \ | |||
| issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): | |||
| batch, _ = _to_tensor(batch, field.dtype) | |||
| if field.is_target: | |||
| batch_y[field_name] = batch | |||
| if field.is_input: | |||
| batch_x[field_name] = batch | |||
| yield (batch_x, batch_y) | |||
| start_idx += batch_size | |||
| for batch_count, (batch_x, batch_y) in enumerate(_iter()): | |||
| _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||
| # forward check | |||
| if batch_count == 0: | |||
| @@ -861,26 +891,16 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||
| loss, metrics = metrics | |||
| if isinstance(metrics, dict): | |||
| if len(metrics) == 1: | |||
| # only single metric, just use it | |||
| metric_dict = list(metrics.values())[0] | |||
| metrics_name = list(metrics.keys())[0] | |||
| else: | |||
| metrics_name = metric_list[0].__class__.__name__ | |||
| if metrics_name not in metrics: | |||
| raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | |||
| metric_dict = metrics[metrics_name] | |||
| metric_dict = list(metrics.values())[0] # 取第一个metric | |||
| if len(metric_dict) == 1: | |||
| if metric_key is None: | |||
| indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | |||
| elif len(metric_dict) > 1 and metric_key is None: | |||
| raise RuntimeError( | |||
| f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") | |||
| else: | |||
| # metric_key is set | |||
| if metric_key not in metric_dict: | |||
| raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") | |||
| indicator_val = metric_dict[metric_key] | |||
| indicator = metric_key | |||
| else: | |||
| raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) | |||
| return indicator_val | |||
| return indicator, indicator_val | |||
| @@ -124,6 +124,14 @@ class DataInfo: | |||
| self.embeddings = embeddings or {} | |||
| self.datasets = datasets or {} | |||
| def __repr__(self): | |||
| _str = 'In total {} datasets:\n'.format(len(self.datasets)) | |||
| for name, dataset in self.datasets.items(): | |||
| _str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||
| _str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||
| for name, vocab in self.vocabs.items(): | |||
| _str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
| return _str | |||
| class DataSetLoader: | |||
| """ | |||
| @@ -115,7 +115,8 @@ class ConllLoader(DataSetLoader): | |||
| """ | |||
| 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` | |||
| 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html | |||
| 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 | |||
| 该符号在conll 2003中被用为文档分割符。 | |||
| 列号从0开始, 每列对应内容为:: | |||
| @@ -90,11 +90,12 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
| return sample | |||
| with open(path, 'r', encoding=encoding) as f: | |||
| sample = [] | |||
| start = next(f) | |||
| if '-DOCSTART-' not in start: | |||
| start = next(f).strip() | |||
| if '-DOCSTART-' not in start and start!='': | |||
| sample.append(start.split()) | |||
| for line_idx, line in enumerate(f, 1): | |||
| if line.startswith('\n'): | |||
| line = line.strip() | |||
| if line=='': | |||
| if len(sample): | |||
| try: | |||
| res = parse_conll(sample) | |||
| @@ -107,7 +108,8 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
| elif line.startswith('#'): | |||
| continue | |||
| else: | |||
| sample.append(line.split()) | |||
| if not line.startswith('-DOCSTART-'): | |||
| sample.append(line.split()) | |||
| if len(sample) > 0: | |||
| try: | |||
| res = parse_conll(sample) | |||
| @@ -115,4 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
| except Exception as e: | |||
| if dropna: | |||
| return | |||
| raise ValueError('invalid instance at line: {}'.format(line_idx)) | |||
| print('invalid instance at line: {}'.format(line_idx)) | |||
| raise e | |||
| @@ -9,7 +9,7 @@ from torch import nn | |||
| from ..utils import initial_parameter | |||
| def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||
| def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): | |||
| """ | |||
| 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` | |||
| @@ -17,7 +17,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||
| :param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
| "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | |||
| :param str encoding_type: 支持"bio", "bmes", "bmeso"。 | |||
| :param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 | |||
| :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||
| 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||
| start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | |||
| @@ -58,7 +58,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||
| def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
| """ | |||
| :param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 | |||
| :param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 | |||
| :param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
| :param str from_label: 比如"PER", "LOC"等label | |||
| :param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
| @@ -134,9 +134,19 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||
| return to_tag in ['b', 's', 'end', 'o'] | |||
| else: | |||
| raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||
| elif encoding_type == 'bioes': | |||
| if from_tag == 'start': | |||
| return to_tag in ['b', 's', 'o'] | |||
| elif from_tag == 'b': | |||
| return to_tag in ['i', 'e'] and from_label == to_label | |||
| elif from_tag == 'i': | |||
| return to_tag in ['i', 'e'] and from_label == to_label | |||
| elif from_tag in ['e', 's', 'o']: | |||
| return to_tag in ['b', 's', 'end', 'o'] | |||
| else: | |||
| raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag)) | |||
| else: | |||
| raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | |||
| raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type)) | |||
| class ConditionalRandomField(nn.Module): | |||
| @@ -18,7 +18,8 @@ __all__ = [ | |||
| "VarLSTM", | |||
| "VarGRU" | |||
| ] | |||
| from .bert import BertModel | |||
| from ._bert import BertModel | |||
| from .bert import BertWordPieceEncoder | |||
| from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||
| from .conv_maxpool import ConvMaxpool | |||
| from .embedding import Embedding | |||
| @@ -6,18 +6,399 @@ | |||
| """ | |||
| import torch | |||
| from torch import nn | |||
| from ... import Vocabulary | |||
| import collections | |||
| import os | |||
| import unicodedata | |||
| from ...io.file_utils import _get_base_url, cached_path | |||
| from .bert import BertModel | |||
| import numpy as np | |||
| from itertools import chain | |||
| import copy | |||
| import json | |||
| import math | |||
| import os | |||
| import torch | |||
| from torch import nn | |||
| CONFIG_FILE = 'bert_config.json' | |||
| MODEL_WEIGHTS = 'pytorch_model.bin' | |||
| def gelu(x): | |||
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||
| def swish(x): | |||
| return x * torch.sigmoid(x) | |||
| ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||
| class BertLayerNorm(nn.Module): | |||
| def __init__(self, hidden_size, eps=1e-12): | |||
| super(BertLayerNorm, self).__init__() | |||
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |||
| self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||
| self.variance_epsilon = eps | |||
| def forward(self, x): | |||
| u = x.mean(-1, keepdim=True) | |||
| s = (x - u).pow(2).mean(-1, keepdim=True) | |||
| x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||
| return self.weight * x + self.bias | |||
| class BertEmbeddings(nn.Module): | |||
| def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): | |||
| super(BertEmbeddings, self).__init__() | |||
| self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |||
| self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | |||
| self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) | |||
| # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||
| # any TensorFlow checkpoint file | |||
| self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
| self.dropout = nn.Dropout(hidden_dropout_prob) | |||
| def forward(self, input_ids, token_type_ids=None): | |||
| seq_length = input_ids.size(1) | |||
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||
| if token_type_ids is None: | |||
| token_type_ids = torch.zeros_like(input_ids) | |||
| words_embeddings = self.word_embeddings(input_ids) | |||
| position_embeddings = self.position_embeddings(position_ids) | |||
| token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||
| embeddings = words_embeddings + position_embeddings + token_type_embeddings | |||
| embeddings = self.LayerNorm(embeddings) | |||
| embeddings = self.dropout(embeddings) | |||
| return embeddings | |||
| class BertSelfAttention(nn.Module): | |||
| def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |||
| super(BertSelfAttention, self).__init__() | |||
| if hidden_size % num_attention_heads != 0: | |||
| raise ValueError( | |||
| "The hidden size (%d) is not a multiple of the number of attention " | |||
| "heads (%d)" % (hidden_size, num_attention_heads)) | |||
| self.num_attention_heads = num_attention_heads | |||
| self.attention_head_size = int(hidden_size / num_attention_heads) | |||
| self.all_head_size = self.num_attention_heads * self.attention_head_size | |||
| self.query = nn.Linear(hidden_size, self.all_head_size) | |||
| self.key = nn.Linear(hidden_size, self.all_head_size) | |||
| self.value = nn.Linear(hidden_size, self.all_head_size) | |||
| self.dropout = nn.Dropout(attention_probs_dropout_prob) | |||
| def transpose_for_scores(self, x): | |||
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||
| x = x.view(*new_x_shape) | |||
| return x.permute(0, 2, 1, 3) | |||
| def forward(self, hidden_states, attention_mask): | |||
| mixed_query_layer = self.query(hidden_states) | |||
| mixed_key_layer = self.key(hidden_states) | |||
| mixed_value_layer = self.value(hidden_states) | |||
| query_layer = self.transpose_for_scores(mixed_query_layer) | |||
| key_layer = self.transpose_for_scores(mixed_key_layer) | |||
| value_layer = self.transpose_for_scores(mixed_value_layer) | |||
| # Take the dot product between "query" and "key" to get the raw attention scores. | |||
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |||
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |||
| # Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |||
| attention_scores = attention_scores + attention_mask | |||
| # Normalize the attention scores to probabilities. | |||
| attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||
| # This is actually dropping out entire tokens to attend to, which might | |||
| # seem a bit unusual, but is taken from the original Transformer paper. | |||
| attention_probs = self.dropout(attention_probs) | |||
| context_layer = torch.matmul(attention_probs, value_layer) | |||
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |||
| context_layer = context_layer.view(*new_context_layer_shape) | |||
| return context_layer | |||
| class BertSelfOutput(nn.Module): | |||
| def __init__(self, hidden_size, hidden_dropout_prob): | |||
| super(BertSelfOutput, self).__init__() | |||
| self.dense = nn.Linear(hidden_size, hidden_size) | |||
| self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
| self.dropout = nn.Dropout(hidden_dropout_prob) | |||
| def forward(self, hidden_states, input_tensor): | |||
| hidden_states = self.dense(hidden_states) | |||
| hidden_states = self.dropout(hidden_states) | |||
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
| return hidden_states | |||
| class BertAttention(nn.Module): | |||
| def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |||
| super(BertAttention, self).__init__() | |||
| self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |||
| self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) | |||
| def forward(self, input_tensor, attention_mask): | |||
| self_output = self.self(input_tensor, attention_mask) | |||
| attention_output = self.output(self_output, input_tensor) | |||
| return attention_output | |||
| class BertIntermediate(nn.Module): | |||
| def __init__(self, hidden_size, intermediate_size, hidden_act): | |||
| super(BertIntermediate, self).__init__() | |||
| self.dense = nn.Linear(hidden_size, intermediate_size) | |||
| self.intermediate_act_fn = ACT2FN[hidden_act] \ | |||
| if isinstance(hidden_act, str) else hidden_act | |||
| def forward(self, hidden_states): | |||
| hidden_states = self.dense(hidden_states) | |||
| hidden_states = self.intermediate_act_fn(hidden_states) | |||
| return hidden_states | |||
| class BertOutput(nn.Module): | |||
| def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): | |||
| super(BertOutput, self).__init__() | |||
| self.dense = nn.Linear(intermediate_size, hidden_size) | |||
| self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
| self.dropout = nn.Dropout(hidden_dropout_prob) | |||
| def forward(self, hidden_states, input_tensor): | |||
| hidden_states = self.dense(hidden_states) | |||
| hidden_states = self.dropout(hidden_states) | |||
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
| return hidden_states | |||
| class BertLayer(nn.Module): | |||
| def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||
| intermediate_size, hidden_act): | |||
| super(BertLayer, self).__init__() | |||
| self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||
| hidden_dropout_prob) | |||
| self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) | |||
| self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) | |||
| def forward(self, hidden_states, attention_mask): | |||
| attention_output = self.attention(hidden_states, attention_mask) | |||
| intermediate_output = self.intermediate(attention_output) | |||
| layer_output = self.output(intermediate_output, attention_output) | |||
| return layer_output | |||
| class BertEncoder(nn.Module): | |||
| def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||
| hidden_dropout_prob, | |||
| intermediate_size, hidden_act): | |||
| super(BertEncoder, self).__init__() | |||
| layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||
| intermediate_size, hidden_act) | |||
| self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) | |||
| def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |||
| all_encoder_layers = [] | |||
| for layer_module in self.layer: | |||
| hidden_states = layer_module(hidden_states, attention_mask) | |||
| if output_all_encoded_layers: | |||
| all_encoder_layers.append(hidden_states) | |||
| if not output_all_encoded_layers: | |||
| all_encoder_layers.append(hidden_states) | |||
| return all_encoder_layers | |||
| class BertPooler(nn.Module): | |||
| def __init__(self, hidden_size): | |||
| super(BertPooler, self).__init__() | |||
| self.dense = nn.Linear(hidden_size, hidden_size) | |||
| self.activation = nn.Tanh() | |||
| def forward(self, hidden_states): | |||
| # We "pool" the model by simply taking the hidden state corresponding | |||
| # to the first token. | |||
| first_token_tensor = hidden_states[:, 0] | |||
| pooled_output = self.dense(first_token_tensor) | |||
| pooled_output = self.activation(pooled_output) | |||
| return pooled_output | |||
| class BertModel(nn.Module): | |||
| """BERT(Bidirectional Embedding Representations from Transformers). | |||
| 如果你想使用预训练好的权重矩阵,请在以下网址下载. | |||
| sources:: | |||
| 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||
| 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||
| 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||
| 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||
| 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||
| 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||
| 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||
| 用预训练权重矩阵来建立BERT模型:: | |||
| model = BertModel.from_pretrained("path/to/weights/directory") | |||
| 用随机初始化权重矩阵来建立BERT模型:: | |||
| model = BertModel() | |||
| :param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小 | |||
| :param int hidden_size: 隐层大小,默认值为768,为BERT base的版本 | |||
| :param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本 | |||
| :param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本 | |||
| :param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本 | |||
| :param str hidden_act: FFN隐藏层激活函数,默认值为``gelu`` | |||
| :param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1 | |||
| :param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1 | |||
| :param int max_position_embeddings: 最大的序列长度,默认值为512, | |||
| :param int type_vocab_size: 最大segment数量,默认值为2 | |||
| :param int initializer_range: 初始化权重范围,默认值为0.02 | |||
| """ | |||
| def __init__(self, vocab_size=30522, | |||
| hidden_size=768, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act="gelu", | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02): | |||
| super(BertModel, self).__init__() | |||
| self.hidden_size = hidden_size | |||
| self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, | |||
| type_vocab_size, hidden_dropout_prob) | |||
| self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, | |||
| attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, | |||
| hidden_act) | |||
| self.pooler = BertPooler(hidden_size) | |||
| self.initializer_range = initializer_range | |||
| self.apply(self.init_bert_weights) | |||
| def init_bert_weights(self, module): | |||
| if isinstance(module, (nn.Linear, nn.Embedding)): | |||
| # Slightly different from the TF version which uses truncated_normal for initialization | |||
| # cf https://github.com/pytorch/pytorch/pull/5617 | |||
| module.weight.data.normal_(mean=0.0, std=self.initializer_range) | |||
| elif isinstance(module, BertLayerNorm): | |||
| module.bias.data.zero_() | |||
| module.weight.data.fill_(1.0) | |||
| if isinstance(module, nn.Linear) and module.bias is not None: | |||
| module.bias.data.zero_() | |||
| def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | |||
| if attention_mask is None: | |||
| attention_mask = torch.ones_like(input_ids) | |||
| if token_type_ids is None: | |||
| token_type_ids = torch.zeros_like(input_ids) | |||
| # We create a 3D attention mask from a 2D tensor mask. | |||
| # Sizes are [batch_size, 1, 1, to_seq_length] | |||
| # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |||
| # this attention mask is more simple than the triangular masking of causal attention | |||
| # used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |||
| extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |||
| # Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |||
| # masked positions, this operation will create a tensor which is 0.0 for | |||
| # positions we want to attend and -10000.0 for masked positions. | |||
| # Since we are adding it to the raw scores before the softmax, this is | |||
| # effectively the same as removing these entirely. | |||
| extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |||
| embedding_output = self.embeddings(input_ids, token_type_ids) | |||
| encoded_layers = self.encoder(embedding_output, | |||
| extended_attention_mask, | |||
| output_all_encoded_layers=output_all_encoded_layers) | |||
| sequence_output = encoded_layers[-1] | |||
| pooled_output = self.pooler(sequence_output) | |||
| if not output_all_encoded_layers: | |||
| encoded_layers = encoded_layers[-1] | |||
| return encoded_layers, pooled_output | |||
| @classmethod | |||
| def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): | |||
| # Load config | |||
| config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | |||
| config = json.load(open(config_file, "r")) | |||
| # config = BertConfig.from_json_file(config_file) | |||
| # logger.info("Model config {}".format(config)) | |||
| # Instantiate model. | |||
| model = cls(*inputs, **config, **kwargs) | |||
| if state_dict is None: | |||
| weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) | |||
| state_dict = torch.load(weights_path) | |||
| old_keys = [] | |||
| new_keys = [] | |||
| for key in state_dict.keys(): | |||
| new_key = None | |||
| if 'gamma' in key: | |||
| new_key = key.replace('gamma', 'weight') | |||
| if 'beta' in key: | |||
| new_key = key.replace('beta', 'bias') | |||
| if new_key: | |||
| old_keys.append(key) | |||
| new_keys.append(new_key) | |||
| for old_key, new_key in zip(old_keys, new_keys): | |||
| state_dict[new_key] = state_dict.pop(old_key) | |||
| missing_keys = [] | |||
| unexpected_keys = [] | |||
| error_msgs = [] | |||
| # copy state_dict so _load_from_state_dict can modify it | |||
| metadata = getattr(state_dict, '_metadata', None) | |||
| state_dict = state_dict.copy() | |||
| if metadata is not None: | |||
| state_dict._metadata = metadata | |||
| def load(module, prefix=''): | |||
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |||
| module._load_from_state_dict( | |||
| state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |||
| for name, child in module._modules.items(): | |||
| if child is not None: | |||
| load(child, prefix + name + '.') | |||
| load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | |||
| if len(missing_keys) > 0: | |||
| print("Weights of {} not initialized from pretrained model: {}".format( | |||
| model.__class__.__name__, missing_keys)) | |||
| if len(unexpected_keys) > 0: | |||
| print("Weights from pretrained model not used in {}: {}".format( | |||
| model.__class__.__name__, unexpected_keys)) | |||
| return model | |||
| def whitespace_tokenize(text): | |||
| """Runs basic whitespace cleaning and splitting on a piece of text.""" | |||
| @@ -547,79 +928,3 @@ class _WordPieceBertModel(nn.Module): | |||
| outputs[l_index] = bert_outputs[l] | |||
| return outputs | |||
| class BertWordPieceEncoder(nn.Module): | |||
| """ | |||
| 可以通过读取vocabulary使用的Bert的Encoder。传入vocab,然后调用index_datasets方法在vocabulary中生成word piece的表示。 | |||
| :param vocab: Vocabulary. | |||
| :param model_dir_or_name: | |||
| :param layers: | |||
| :param requires_grad: | |||
| """ | |||
| def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1', | |||
| requires_grad:bool=False): | |||
| super().__init__() | |||
| PRETRAIN_URL = _get_base_url('bert') | |||
| # TODO 修改 | |||
| PRETRAINED_BERT_MODEL_DIR = {'en-base': 'bert_en-80f95ea7.tar.gz', | |||
| 'cn': 'elmo_cn.zip'} | |||
| if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||
| model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
| model_url = PRETRAIN_URL + model_name | |||
| model_dir = cached_path(model_url) | |||
| # 检查是否存在 | |||
| elif os.path.isdir(model_dir_or_name): | |||
| model_dir = model_dir_or_name | |||
| else: | |||
| raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
| self.model = _WordPieceBertModel(model_dir=model_dir, vocab=vocab, layers=layers) | |||
| self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
| self.requires_grad = requires_grad | |||
| @property | |||
| def requires_grad(self): | |||
| """ | |||
| Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
| :return: | |||
| """ | |||
| requires_grads = set([param.requires_grad for name, param in self.named_parameters()]) | |||
| if len(requires_grads)==1: | |||
| return requires_grads.pop() | |||
| else: | |||
| return None | |||
| @requires_grad.setter | |||
| def requires_grad(self, value): | |||
| for name, param in self.named_parameters(): | |||
| param.requires_grad = value | |||
| @property | |||
| def embed_size(self): | |||
| return self._embed_size | |||
| def index_datasets(self, *datasets): | |||
| """ | |||
| 对datasets进行word piece的index。 | |||
| Example:: | |||
| :param datasets: | |||
| :return: | |||
| """ | |||
| self.model.index_dataset(*datasets) | |||
| def forward(self, words, token_type_ids=None): | |||
| """ | |||
| 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | |||
| 删除这两个表示。 | |||
| :param words: batch_size x max_len | |||
| :param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话 | |||
| :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||
| """ | |||
| outputs = self.model(words, token_type_ids) | |||
| outputs = torch.cat([*outputs], dim=-1) | |||
| return outputs | |||
| @@ -1,378 +1,95 @@ | |||
| """ | |||
| bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | |||
| """ | |||
| import copy | |||
| import json | |||
| import math | |||
| import os | |||
| import torch | |||
| from torch import nn | |||
| import torch | |||
| from ...core import Vocabulary | |||
| from ...io.file_utils import _get_base_url, cached_path | |||
| from ._bert import _WordPieceBertModel | |||
| CONFIG_FILE = 'bert_config.json' | |||
| MODEL_WEIGHTS = 'pytorch_model.bin' | |||
| def gelu(x): | |||
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||
| def swish(x): | |||
| return x * torch.sigmoid(x) | |||
| ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||
| class BertLayerNorm(nn.Module): | |||
| def __init__(self, hidden_size, eps=1e-12): | |||
| super(BertLayerNorm, self).__init__() | |||
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |||
| self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||
| self.variance_epsilon = eps | |||
| def forward(self, x): | |||
| u = x.mean(-1, keepdim=True) | |||
| s = (x - u).pow(2).mean(-1, keepdim=True) | |||
| x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||
| return self.weight * x + self.bias | |||
| class BertEmbeddings(nn.Module): | |||
| def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): | |||
| super(BertEmbeddings, self).__init__() | |||
| self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |||
| self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | |||
| self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) | |||
| # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||
| # any TensorFlow checkpoint file | |||
| self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
| self.dropout = nn.Dropout(hidden_dropout_prob) | |||
| def forward(self, input_ids, token_type_ids=None): | |||
| seq_length = input_ids.size(1) | |||
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||
| if token_type_ids is None: | |||
| token_type_ids = torch.zeros_like(input_ids) | |||
| words_embeddings = self.word_embeddings(input_ids) | |||
| position_embeddings = self.position_embeddings(position_ids) | |||
| token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||
| embeddings = words_embeddings + position_embeddings + token_type_embeddings | |||
| embeddings = self.LayerNorm(embeddings) | |||
| embeddings = self.dropout(embeddings) | |||
| return embeddings | |||
| class BertSelfAttention(nn.Module): | |||
| def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |||
| super(BertSelfAttention, self).__init__() | |||
| if hidden_size % num_attention_heads != 0: | |||
| raise ValueError( | |||
| "The hidden size (%d) is not a multiple of the number of attention " | |||
| "heads (%d)" % (hidden_size, num_attention_heads)) | |||
| self.num_attention_heads = num_attention_heads | |||
| self.attention_head_size = int(hidden_size / num_attention_heads) | |||
| self.all_head_size = self.num_attention_heads * self.attention_head_size | |||
| self.query = nn.Linear(hidden_size, self.all_head_size) | |||
| self.key = nn.Linear(hidden_size, self.all_head_size) | |||
| self.value = nn.Linear(hidden_size, self.all_head_size) | |||
| self.dropout = nn.Dropout(attention_probs_dropout_prob) | |||
| def transpose_for_scores(self, x): | |||
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||
| x = x.view(*new_x_shape) | |||
| return x.permute(0, 2, 1, 3) | |||
| def forward(self, hidden_states, attention_mask): | |||
| mixed_query_layer = self.query(hidden_states) | |||
| mixed_key_layer = self.key(hidden_states) | |||
| mixed_value_layer = self.value(hidden_states) | |||
| query_layer = self.transpose_for_scores(mixed_query_layer) | |||
| key_layer = self.transpose_for_scores(mixed_key_layer) | |||
| value_layer = self.transpose_for_scores(mixed_value_layer) | |||
| # Take the dot product between "query" and "key" to get the raw attention scores. | |||
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |||
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |||
| # Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |||
| attention_scores = attention_scores + attention_mask | |||
| # Normalize the attention scores to probabilities. | |||
| attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||
| # This is actually dropping out entire tokens to attend to, which might | |||
| # seem a bit unusual, but is taken from the original Transformer paper. | |||
| attention_probs = self.dropout(attention_probs) | |||
| context_layer = torch.matmul(attention_probs, value_layer) | |||
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |||
| context_layer = context_layer.view(*new_context_layer_shape) | |||
| return context_layer | |||
| class BertSelfOutput(nn.Module): | |||
| def __init__(self, hidden_size, hidden_dropout_prob): | |||
| super(BertSelfOutput, self).__init__() | |||
| self.dense = nn.Linear(hidden_size, hidden_size) | |||
| self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
| self.dropout = nn.Dropout(hidden_dropout_prob) | |||
| def forward(self, hidden_states, input_tensor): | |||
| hidden_states = self.dense(hidden_states) | |||
| hidden_states = self.dropout(hidden_states) | |||
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
| return hidden_states | |||
| class BertAttention(nn.Module): | |||
| def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |||
| super(BertAttention, self).__init__() | |||
| self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |||
| self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) | |||
| def forward(self, input_tensor, attention_mask): | |||
| self_output = self.self(input_tensor, attention_mask) | |||
| attention_output = self.output(self_output, input_tensor) | |||
| return attention_output | |||
| class BertIntermediate(nn.Module): | |||
| def __init__(self, hidden_size, intermediate_size, hidden_act): | |||
| super(BertIntermediate, self).__init__() | |||
| self.dense = nn.Linear(hidden_size, intermediate_size) | |||
| self.intermediate_act_fn = ACT2FN[hidden_act] \ | |||
| if isinstance(hidden_act, str) else hidden_act | |||
| def forward(self, hidden_states): | |||
| hidden_states = self.dense(hidden_states) | |||
| hidden_states = self.intermediate_act_fn(hidden_states) | |||
| return hidden_states | |||
| class BertOutput(nn.Module): | |||
| def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): | |||
| super(BertOutput, self).__init__() | |||
| self.dense = nn.Linear(intermediate_size, hidden_size) | |||
| self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
| self.dropout = nn.Dropout(hidden_dropout_prob) | |||
| def forward(self, hidden_states, input_tensor): | |||
| hidden_states = self.dense(hidden_states) | |||
| hidden_states = self.dropout(hidden_states) | |||
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
| return hidden_states | |||
| class BertLayer(nn.Module): | |||
| def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||
| intermediate_size, hidden_act): | |||
| super(BertLayer, self).__init__() | |||
| self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||
| hidden_dropout_prob) | |||
| self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) | |||
| self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) | |||
| def forward(self, hidden_states, attention_mask): | |||
| attention_output = self.attention(hidden_states, attention_mask) | |||
| intermediate_output = self.intermediate(attention_output) | |||
| layer_output = self.output(intermediate_output, attention_output) | |||
| return layer_output | |||
| class BertEncoder(nn.Module): | |||
| def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||
| hidden_dropout_prob, | |||
| intermediate_size, hidden_act): | |||
| super(BertEncoder, self).__init__() | |||
| layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||
| intermediate_size, hidden_act) | |||
| self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) | |||
| def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |||
| all_encoder_layers = [] | |||
| for layer_module in self.layer: | |||
| hidden_states = layer_module(hidden_states, attention_mask) | |||
| if output_all_encoded_layers: | |||
| all_encoder_layers.append(hidden_states) | |||
| if not output_all_encoded_layers: | |||
| all_encoder_layers.append(hidden_states) | |||
| return all_encoder_layers | |||
| class BertPooler(nn.Module): | |||
| def __init__(self, hidden_size): | |||
| super(BertPooler, self).__init__() | |||
| self.dense = nn.Linear(hidden_size, hidden_size) | |||
| self.activation = nn.Tanh() | |||
| def forward(self, hidden_states): | |||
| # We "pool" the model by simply taking the hidden state corresponding | |||
| # to the first token. | |||
| first_token_tensor = hidden_states[:, 0] | |||
| pooled_output = self.dense(first_token_tensor) | |||
| pooled_output = self.activation(pooled_output) | |||
| return pooled_output | |||
| class BertModel(nn.Module): | |||
| """BERT(Bidirectional Embedding Representations from Transformers). | |||
| 如果你想使用预训练好的权重矩阵,请在以下网址下载. | |||
| sources:: | |||
| 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||
| 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||
| 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||
| 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||
| 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||
| 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||
| 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||
| 用预训练权重矩阵来建立BERT模型:: | |||
| model = BertModel.from_pretrained("path/to/weights/directory") | |||
| 用随机初始化权重矩阵来建立BERT模型:: | |||
| model = BertModel() | |||
| :param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小 | |||
| :param int hidden_size: 隐层大小,默认值为768,为BERT base的版本 | |||
| :param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本 | |||
| :param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本 | |||
| :param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本 | |||
| :param str hidden_act: FFN隐藏层激活函数,默认值为``gelu`` | |||
| :param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1 | |||
| :param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1 | |||
| :param int max_position_embeddings: 最大的序列长度,默认值为512, | |||
| :param int type_vocab_size: 最大segment数量,默认值为2 | |||
| :param int initializer_range: 初始化权重范围,默认值为0.02 | |||
| class BertWordPieceEncoder(nn.Module): | |||
| """ | |||
| 可以通过读取vocabulary使用的Bert的Encoder。传入vocab,然后调用index_datasets方法在vocabulary中生成word piece的表示。 | |||
| def __init__(self, vocab_size=30522, | |||
| hidden_size=768, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act="gelu", | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02): | |||
| super(BertModel, self).__init__() | |||
| self.hidden_size = hidden_size | |||
| self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, | |||
| type_vocab_size, hidden_dropout_prob) | |||
| self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, | |||
| attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, | |||
| hidden_act) | |||
| self.pooler = BertPooler(hidden_size) | |||
| self.initializer_range = initializer_range | |||
| self.apply(self.init_bert_weights) | |||
| def init_bert_weights(self, module): | |||
| if isinstance(module, (nn.Linear, nn.Embedding)): | |||
| # Slightly different from the TF version which uses truncated_normal for initialization | |||
| # cf https://github.com/pytorch/pytorch/pull/5617 | |||
| module.weight.data.normal_(mean=0.0, std=self.initializer_range) | |||
| elif isinstance(module, BertLayerNorm): | |||
| module.bias.data.zero_() | |||
| module.weight.data.fill_(1.0) | |||
| if isinstance(module, nn.Linear) and module.bias is not None: | |||
| module.bias.data.zero_() | |||
| def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | |||
| if attention_mask is None: | |||
| attention_mask = torch.ones_like(input_ids) | |||
| if token_type_ids is None: | |||
| token_type_ids = torch.zeros_like(input_ids) | |||
| # We create a 3D attention mask from a 2D tensor mask. | |||
| # Sizes are [batch_size, 1, 1, to_seq_length] | |||
| # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |||
| # this attention mask is more simple than the triangular masking of causal attention | |||
| # used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |||
| extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |||
| # Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |||
| # masked positions, this operation will create a tensor which is 0.0 for | |||
| # positions we want to attend and -10000.0 for masked positions. | |||
| # Since we are adding it to the raw scores before the softmax, this is | |||
| # effectively the same as removing these entirely. | |||
| extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |||
| embedding_output = self.embeddings(input_ids, token_type_ids) | |||
| encoded_layers = self.encoder(embedding_output, | |||
| extended_attention_mask, | |||
| output_all_encoded_layers=output_all_encoded_layers) | |||
| sequence_output = encoded_layers[-1] | |||
| pooled_output = self.pooler(sequence_output) | |||
| if not output_all_encoded_layers: | |||
| encoded_layers = encoded_layers[-1] | |||
| return encoded_layers, pooled_output | |||
| @classmethod | |||
| def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): | |||
| # Load config | |||
| config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | |||
| config = json.load(open(config_file, "r")) | |||
| # config = BertConfig.from_json_file(config_file) | |||
| # logger.info("Model config {}".format(config)) | |||
| # Instantiate model. | |||
| model = cls(*inputs, **config, **kwargs) | |||
| if state_dict is None: | |||
| weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) | |||
| state_dict = torch.load(weights_path) | |||
| old_keys = [] | |||
| new_keys = [] | |||
| for key in state_dict.keys(): | |||
| new_key = None | |||
| if 'gamma' in key: | |||
| new_key = key.replace('gamma', 'weight') | |||
| if 'beta' in key: | |||
| new_key = key.replace('beta', 'bias') | |||
| if new_key: | |||
| old_keys.append(key) | |||
| new_keys.append(new_key) | |||
| for old_key, new_key in zip(old_keys, new_keys): | |||
| state_dict[new_key] = state_dict.pop(old_key) | |||
| missing_keys = [] | |||
| unexpected_keys = [] | |||
| error_msgs = [] | |||
| # copy state_dict so _load_from_state_dict can modify it | |||
| metadata = getattr(state_dict, '_metadata', None) | |||
| state_dict = state_dict.copy() | |||
| if metadata is not None: | |||
| state_dict._metadata = metadata | |||
| def load(module, prefix=''): | |||
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |||
| module._load_from_state_dict( | |||
| state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |||
| for name, child in module._modules.items(): | |||
| if child is not None: | |||
| load(child, prefix + name + '.') | |||
| load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | |||
| if len(missing_keys) > 0: | |||
| print("Weights of {} not initialized from pretrained model: {}".format( | |||
| model.__class__.__name__, missing_keys)) | |||
| if len(unexpected_keys) > 0: | |||
| print("Weights from pretrained model not used in {}: {}".format( | |||
| model.__class__.__name__, unexpected_keys)) | |||
| return model | |||
| :param fastNLP.Vocabulary vocab: 词表 | |||
| :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` | |||
| :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||
| :param bool requires_grad: 是否需要gradient。 | |||
| """ | |||
| def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1', | |||
| requires_grad:bool=False): | |||
| super().__init__() | |||
| PRETRAIN_URL = _get_base_url('bert') | |||
| PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||
| 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
| 'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
| 'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||
| 'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||
| 'cn': 'bert-base-chinese-29d0a84a.zip', | |||
| 'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||
| 'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
| 'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||
| 'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
| } | |||
| if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||
| model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
| model_url = PRETRAIN_URL + model_name | |||
| model_dir = cached_path(model_url) | |||
| # 检查是否存在 | |||
| elif os.path.isdir(model_dir_or_name): | |||
| model_dir = model_dir_or_name | |||
| else: | |||
| raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
| self.model = _WordPieceBertModel(model_dir=model_dir, vocab=vocab, layers=layers) | |||
| self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
| self.requires_grad = requires_grad | |||
| @property | |||
| def requires_grad(self): | |||
| """ | |||
| Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
| :return: | |||
| """ | |||
| requires_grads = set([param.requires_grad for name, param in self.named_parameters()]) | |||
| if len(requires_grads)==1: | |||
| return requires_grads.pop() | |||
| else: | |||
| return None | |||
| @requires_grad.setter | |||
| def requires_grad(self, value): | |||
| for name, param in self.named_parameters(): | |||
| param.requires_grad = value | |||
| @property | |||
| def embed_size(self): | |||
| return self._embed_size | |||
| def index_datasets(self, *datasets): | |||
| """ | |||
| 根据datasets中的'words'列对datasets进行word piece的index。 | |||
| Example:: | |||
| :param datasets: | |||
| :return: | |||
| """ | |||
| self.model.index_dataset(*datasets) | |||
| def forward(self, words, token_type_ids=None): | |||
| """ | |||
| 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | |||
| 删除这两个表示。 | |||
| :param words: batch_size x max_len | |||
| :param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话 | |||
| :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||
| """ | |||
| outputs = self.model(words, token_type_ids) | |||
| outputs = torch.cat([*outputs], dim=-1) | |||
| return outputs | |||
| @@ -15,7 +15,7 @@ from ...io.file_utils import cached_path, _get_base_url | |||
| from ._bert import _WordBertModel | |||
| from typing import List | |||
| from ... import DataSet, Batch, SequentialSampler | |||
| from ... import DataSet, DataSetIter, SequentialSampler | |||
| from ...core.utils import _move_model_to_device, _get_model_device | |||
| @@ -157,7 +157,6 @@ class StaticEmbedding(TokenEmbedding): | |||
| super(StaticEmbedding, self).__init__(vocab) | |||
| # 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, | |||
| PRETRAIN_URL = _get_base_url('static') | |||
| PRETRAIN_STATIC_FILES = { | |||
| 'en': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
| 'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
| @@ -170,6 +169,7 @@ class StaticEmbedding(TokenEmbedding): | |||
| # 得到cache_path | |||
| if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | |||
| PRETRAIN_URL = _get_base_url('static') | |||
| model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] | |||
| model_url = PRETRAIN_URL + model_name | |||
| model_path = cached_path(model_url) | |||
| @@ -234,7 +234,7 @@ class ContextualEmbedding(TokenEmbedding): | |||
| with torch.no_grad(): | |||
| for index, dataset in enumerate(datasets): | |||
| try: | |||
| batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), prefetch=False) | |||
| batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
| for batch_x, batch_y in batch: | |||
| words = batch_x['words'].to(device) | |||
| words_list = words.tolist() | |||
| @@ -325,11 +325,11 @@ class ElmoEmbedding(ContextualEmbedding): | |||
| self.layers = layers | |||
| # 根据model_dir_or_name检查是否存在并下载 | |||
| PRETRAIN_URL = _get_base_url('elmo') | |||
| PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', | |||
| 'cn': 'elmo_cn-5e9b34e2.tar.gz'} | |||
| if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | |||
| PRETRAIN_URL = _get_base_url('elmo') | |||
| model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] | |||
| model_url = PRETRAIN_URL + model_name | |||
| model_dir = cached_path(model_url) | |||
| @@ -383,7 +383,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
| def requires_grad(self, value): | |||
| for name, param in self.named_parameters(): | |||
| if 'words_to_chars_embedding' in name: # 这个不能加入到requires_grad中 | |||
| pass | |||
| continue | |||
| param.requires_grad = value | |||
| @@ -411,7 +411,6 @@ class BertEmbedding(ContextualEmbedding): | |||
| pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): | |||
| super(BertEmbedding, self).__init__(vocab) | |||
| # 根据model_dir_or_name检查是否存在并下载 | |||
| PRETRAIN_URL = _get_base_url('bert') | |||
| PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||
| 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
| 'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
| @@ -427,6 +426,7 @@ class BertEmbedding(ContextualEmbedding): | |||
| } | |||
| if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
| PRETRAIN_URL = _get_base_url('bert') | |||
| model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
| model_url = PRETRAIN_URL + model_name | |||
| model_dir = cached_path(model_url) | |||
| @@ -478,7 +478,7 @@ class BertEmbedding(ContextualEmbedding): | |||
| def requires_grad(self, value): | |||
| for name, param in self.named_parameters(): | |||
| if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中 | |||
| pass | |||
| continue | |||
| param.requires_grad = value | |||
| @@ -566,6 +566,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||
| for i in range(len(kernel_sizes))]) | |||
| self._embed_size = embed_size | |||
| self.fc = nn.Linear(sum(filter_nums), embed_size) | |||
| self.init_param() | |||
| def forward(self, words): | |||
| """ | |||
| @@ -618,9 +619,17 @@ class CNNCharEmbedding(TokenEmbedding): | |||
| def requires_grad(self, value): | |||
| for name, param in self.named_parameters(): | |||
| if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | |||
| pass | |||
| continue | |||
| param.requires_grad = value | |||
| def init_param(self): | |||
| for name, param in self.named_parameters(): | |||
| if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset | |||
| continue | |||
| if param.data.dim()>1: | |||
| nn.init.xavier_normal_(param, 1) | |||
| else: | |||
| nn.init.uniform_(param, -1, 1) | |||
| class LSTMCharEmbedding(TokenEmbedding): | |||
| """ | |||
| @@ -744,7 +753,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
| def requires_grad(self, value): | |||
| for name, param in self.named_parameters(): | |||
| if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | |||
| pass | |||
| continue | |||
| param.requires_grad = value | |||
| @@ -35,8 +35,18 @@ class LSTM(nn.Module): | |||
| self.batch_first = batch_first | |||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | |||
| dropout=dropout, bidirectional=bidirectional) | |||
| self.init_param() | |||
| initial_parameter(self, initial_method) | |||
| def init_param(self): | |||
| for name, param in self.named_parameters(): | |||
| if 'bias_i' in name: | |||
| param.data.fill_(1) | |||
| elif 'bias_h' in name: | |||
| param.data.fill_(0) | |||
| else: | |||
| nn.init.xavier_normal_(param) | |||
| def forward(self, x, seq_len=None, h0=None, c0=None): | |||
| """ | |||
| @@ -184,11 +184,8 @@ def train(path): | |||
| m.weight.requires_grad = True | |||
| # Trainer | |||
| trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||
| loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | |||
| **train_args.data, | |||
| optimizer=fastNLP.Adam(**optim_args.data), | |||
| save_path=path, | |||
| trainer = Trainer(train_data=train_data, model=model, optimizer=fastNLP.Adam(**optim_args.data), loss=ParserLoss(), | |||
| dev_data=dev_data, metrics=ParserMetric(), metric_key='UAS', save_path=path, | |||
| callbacks=[MyCallback()]) | |||
| # Start training | |||
| @@ -89,11 +89,11 @@ def train(train_data_path, dev_data_path, checkpoint=None, save=None): | |||
| model = torch.load(checkpoint) | |||
| # call trainer to train | |||
| trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||
| target="truth", | |||
| seq_lens="word_seq_origin_len"), | |||
| dev_data=dev_data, metric_key="f", | |||
| use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save) | |||
| trainer = Trainer(dataset, model, loss=None, n_epochs=20, print_every=10, dev_data=dev_data, | |||
| metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||
| target="truth", | |||
| seq_lens="word_seq_origin_len"), metric_key="f", save_path=save, | |||
| use_tqdm=True) | |||
| trainer.train(load_best_model=True) | |||
| # save model & pipeline | |||
| @@ -149,14 +149,10 @@ def train(): | |||
| ) if x.requires_grad and x.size(0) != len(word_v)] | |||
| optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | |||
| {'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | |||
| trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||
| loss=loss, metrics=metric, metric_key=metric_key, | |||
| optimizer=torch.optim.Adam(optim_cfg), | |||
| n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=10, validate_every=3000, | |||
| device=device, | |||
| use_tqdm=False, prefetch=False, | |||
| save_path=g_args.log, | |||
| callbacks=[MyCallback()]) | |||
| trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss, | |||
| batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric, | |||
| metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False, | |||
| device=device, callbacks=[MyCallback()]) | |||
| trainer.train() | |||
| tester = FN.Tester(data=test_data, model=model, metrics=metric, | |||
| @@ -70,19 +70,10 @@ test_data = preprocess_data(test_data, bert_dirs) | |||
| model = BertForNLI(bert_dir=bert_dirs) | |||
| trainer = Trainer( | |||
| train_data=train_data, | |||
| model=model, | |||
| optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||
| batch_size=torch.cuda.device_count() * 12, | |||
| n_epochs=4, | |||
| print_every=-1, | |||
| dev_data=dev_data, | |||
| metrics=AccuracyMetric(), | |||
| metric_key='acc', | |||
| device=[i for i in range(torch.cuda.device_count())], | |||
| check_code_level=-1 | |||
| ) | |||
| trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||
| batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data, | |||
| metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | |||
| check_code_level=-1) | |||
| trainer.train(load_best_model=True) | |||
| tester = Tester( | |||
| @@ -13,7 +13,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
| } | |||
| 如果paths为不合法的,将直接进行raise相应的错误 | |||
| :param paths: 路径 | |||
| :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt, | |||
| test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||
| :return: | |||
| """ | |||
| if isinstance(paths, str): | |||
| @@ -3,7 +3,7 @@ import unittest | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP import Batch | |||
| from fastNLP import DataSetIter | |||
| from fastNLP import DataSet | |||
| from fastNLP import Instance | |||
| from fastNLP import SequentialSampler | |||
| @@ -57,7 +57,7 @@ class TestCase1(unittest.TestCase): | |||
| dataset = construct_dataset( | |||
| [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | |||
| dataset.set_target() | |||
| batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
| batch = DataSetIter(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
| cnt = 0 | |||
| for _, _ in batch: | |||
| @@ -68,7 +68,7 @@ class TestCase1(unittest.TestCase): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| ds.set_input("x") | |||
| ds.set_target("y") | |||
| iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
| iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
| for x, y in iter: | |||
| self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | |||
| self.assertEqual(len(x["x"]), 4) | |||
| @@ -81,7 +81,7 @@ class TestCase1(unittest.TestCase): | |||
| "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||
| ds.set_input("x") | |||
| ds.set_target("y") | |||
| iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
| iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
| for x, y in iter: | |||
| self.assertEqual(x["x"].shape, (4, 4)) | |||
| self.assertEqual(y["y"].shape, (4, 4)) | |||
| @@ -91,7 +91,7 @@ class TestCase1(unittest.TestCase): | |||
| "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||
| ds.set_input("x") | |||
| ds.set_target("y") | |||
| iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
| iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
| for x, y in iter: | |||
| self.assertEqual(x["x"].shape, (4, 4)) | |||
| self.assertEqual(y["y"].shape, (4, 4)) | |||
| @@ -101,7 +101,7 @@ class TestCase1(unittest.TestCase): | |||
| "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||
| ds.set_input("x") | |||
| ds.set_target("y") | |||
| iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
| iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
| for x, y in iter: | |||
| self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||
| self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
| @@ -113,7 +113,7 @@ class TestCase1(unittest.TestCase): | |||
| "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||
| ds.set_input("x") | |||
| ds.set_target("y") | |||
| iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
| iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
| for x, y in iter: | |||
| self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||
| self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
| @@ -125,7 +125,7 @@ class TestCase1(unittest.TestCase): | |||
| [Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | |||
| ds.set_input("x") | |||
| ds.set_target("y") | |||
| iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
| iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
| for x, y in iter: | |||
| self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||
| self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
| @@ -137,7 +137,7 @@ class TestCase1(unittest.TestCase): | |||
| [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | |||
| ds.set_input("x") | |||
| ds.set_target("y") | |||
| iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
| iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
| for x, y in iter: | |||
| print(x, y) | |||
| @@ -146,7 +146,7 @@ class TestCase1(unittest.TestCase): | |||
| num_samples = 1000 | |||
| dataset = generate_fake_dataset(num_samples) | |||
| batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
| batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
| for batch_x, batch_y in batch: | |||
| pass | |||
| @@ -40,89 +40,50 @@ class TestCallback(unittest.TestCase): | |||
| def test_gradient_clip(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=20, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=SGD(lr=0.1), | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) | |||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
| batch_size=32, n_epochs=20, print_every=50, dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||
| callbacks=[GradientClipCallback(model.parameters(), clip_value=2)], check_code_level=2) | |||
| trainer.train() | |||
| def test_early_stop(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=20, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=SGD(lr=0.01), | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[EarlyStopCallback(5)]) | |||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.01), loss=BCELoss(pred="predict", target="y"), | |||
| batch_size=32, n_epochs=20, print_every=50, dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||
| callbacks=[EarlyStopCallback(5)], check_code_level=2) | |||
| trainer.train() | |||
| def test_lr_scheduler(self): | |||
| data_set, model = prepare_env() | |||
| optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=5, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=optimizer, | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | |||
| trainer = Trainer(data_set, model, optimizer=optimizer, loss=BCELoss(pred="predict", target="y"), batch_size=32, | |||
| n_epochs=5, print_every=50, dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||
| callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))], | |||
| check_code_level=2) | |||
| trainer.train() | |||
| def test_KeyBoardInterrupt(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=5, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=SGD(lr=0.1), | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| callbacks=[ControlC(False)]) | |||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
| batch_size=32, n_epochs=5, print_every=50, use_tqdm=False, callbacks=[ControlC(False)], | |||
| check_code_level=2) | |||
| trainer.train() | |||
| def test_LRFinder(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=5, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=SGD(lr=0.1), | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| callbacks=[LRFinder(len(data_set) // 32)]) | |||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
| batch_size=32, n_epochs=5, print_every=50, use_tqdm=False, | |||
| callbacks=[LRFinder(len(data_set) // 32)], check_code_level=2) | |||
| trainer.train() | |||
| def test_TensorboardCallback(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=5, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=SGD(lr=0.1), | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[TensorboardCallback("loss", "metric")]) | |||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
| batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||
| callbacks=[TensorboardCallback("loss", "metric")], check_code_level=2) | |||
| trainer.train() | |||
| def test_readonly_property(self): | |||
| @@ -141,16 +102,9 @@ class TestCallback(unittest.TestCase): | |||
| print(self.optimizer) | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=total_epochs, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=SGD(lr=0.1), | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[MyCallback()]) | |||
| trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
| batch_size=32, n_epochs=total_epochs, print_every=50, dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, callbacks=[MyCallback()], | |||
| check_code_level=2) | |||
| trainer.train() | |||
| assert passed_epochs == list(range(1, total_epochs + 1)) | |||
| @@ -161,7 +161,15 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| print(e) | |||
| return | |||
| self.assertTrue(True, False), "No exception catches." | |||
| def test_duplicate(self): | |||
| # 0.4.1的潜在bug,不能出现形参重复的情况 | |||
| metric = AccuracyMetric(pred='predictions', target='targets') | |||
| pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0} | |||
| target_dict = {'targets':torch.zeros(4, 3), 'target': 0} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| def test_seq_len(self): | |||
| N = 256 | |||
| seq_len = torch.zeros(N).long() | |||
| @@ -46,18 +46,10 @@ class TrainerTestGround(unittest.TestCase): | |||
| model = NaiveClassifier(2, 1) | |||
| trainer = Trainer(train_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| n_epochs=10, | |||
| batch_size=32, | |||
| print_every=50, | |||
| validate_every=-1, | |||
| dev_data=dev_set, | |||
| optimizer=SGD(lr=0.1), | |||
| check_code_level=2, | |||
| use_tqdm=True, | |||
| save_path=None) | |||
| trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
| batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||
| use_tqdm=True, check_code_level=2) | |||
| trainer.train() | |||
| """ | |||
| # 应该正确运行 | |||
| @@ -83,10 +75,7 @@ class TrainerTestGround(unittest.TestCase): | |||
| model = Model() | |||
| with self.assertRaises(RuntimeError): | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model | |||
| ) | |||
| trainer = Trainer(train_data=dataset, model=model) | |||
| """ | |||
| # 应该获取到的报错提示 | |||
| NameError: | |||
| @@ -116,12 +105,7 @@ class TrainerTestGround(unittest.TestCase): | |||
| return {'loss': loss} | |||
| model = Model() | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model, | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||
| trainer.train() | |||
| """ | |||
| # 应该正确运行 | |||
| @@ -147,12 +131,7 @@ class TrainerTestGround(unittest.TestCase): | |||
| model = Model() | |||
| with self.assertRaises(NameError): | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model, | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||
| trainer.train() | |||
| def test_trainer_suggestion4(self): | |||
| @@ -175,12 +154,7 @@ class TrainerTestGround(unittest.TestCase): | |||
| model = Model() | |||
| with self.assertRaises(NameError): | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model, | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||
| def test_trainer_suggestion5(self): | |||
| # 检查报错提示能否正确提醒用户 | |||
| @@ -203,12 +177,7 @@ class TrainerTestGround(unittest.TestCase): | |||
| return {'loss': loss} | |||
| model = Model() | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model, | |||
| use_tqdm=False, | |||
| print_every=2 | |||
| ) | |||
| trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||
| def test_trainer_suggestion6(self): | |||
| # 检查报错提示能否正确提醒用户 | |||
| @@ -233,14 +202,8 @@ class TrainerTestGround(unittest.TestCase): | |||
| model = Model() | |||
| with self.assertRaises(NameError): | |||
| trainer = Trainer( | |||
| train_data=dataset, | |||
| model=model, | |||
| dev_data=dataset, | |||
| loss=CrossEntropyLoss(), | |||
| metrics=AccuracyMetric(), | |||
| use_tqdm=False, | |||
| print_every=2) | |||
| trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, dev_data=dataset, | |||
| metrics=AccuracyMetric(), use_tqdm=False) | |||
| """ | |||
| def test_trainer_multiprocess(self): | |||
| @@ -130,11 +130,8 @@ class ModelRunner(): | |||
| tester = Tester(data=data, model=model, metrics=metrics, | |||
| batch_size=BATCH_SIZE, verbose=0) | |||
| before_train = tester.test() | |||
| trainer = Trainer(model=model, train_data=data, dev_data=None, | |||
| n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, | |||
| loss=loss, | |||
| save_path=None, | |||
| use_tqdm=False) | |||
| trainer = Trainer(train_data=data, model=model, loss=loss, batch_size=BATCH_SIZE, n_epochs=N_EPOCHS, | |||
| dev_data=None, save_path=None, use_tqdm=False) | |||
| trainer.train(load_best_model=False) | |||
| after_train = tester.test() | |||
| for metric_name, v1 in before_train.items(): | |||
| @@ -1,6 +1,5 @@ | |||
| import unittest | |||
| import fastNLP | |||
| from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | |||
| from .model_runner import * | |||
| @@ -10,14 +10,14 @@ class TestCRF(unittest.TestCase): | |||
| id2label = {0: 'B', 1: 'I', 2:'O'} | |||
| expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
| (2, 4), (3, 0), (3, 2)} | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||
| id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | |||
| expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||
| id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | |||
| allowed_transitions(id2label) | |||
| allowed_transitions(id2label, include_start_end=True) | |||
| labels = ['O'] | |||
| for label in ['X', 'Y']: | |||
| @@ -27,7 +27,7 @@ class TestCRF(unittest.TestCase): | |||
| expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
| (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
| (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||
| labels = [] | |||
| for label in ['X', 'Y']: | |||
| @@ -37,7 +37,7 @@ class TestCRF(unittest.TestCase): | |||
| expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
| (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
| (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||
| self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||
| def test_case2(self): | |||
| # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | |||
| @@ -60,10 +60,10 @@ class TestTutorial(unittest.TestCase): | |||
| print(test_data[0]) | |||
| # 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.batch import DataSetIter | |||
| from fastNLP.core.sampler import RandomSampler | |||
| batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||
| batch_iterator = DataSetIter(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||
| for batch_x, batch_y in batch_iterator: | |||
| print("batch_x has: ", batch_x) | |||
| print("batch_y has: ", batch_y) | |||
| @@ -85,12 +85,8 @@ class TestTutorial(unittest.TestCase): | |||
| # 实例化Trainer,传入模型和数据,进行训练 | |||
| # 先在test_data拟合(确保模型的实现是正确的) | |||
| copy_model = deepcopy(model) | |||
| overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, | |||
| loss=loss, | |||
| metrics=metric, | |||
| save_path=None, | |||
| batch_size=32, | |||
| n_epochs=5) | |||
| overfit_trainer = Trainer(train_data=test_data, model=copy_model, loss=loss, batch_size=32, n_epochs=5, | |||
| dev_data=test_data, metrics=metric, save_path=None) | |||
| overfit_trainer.train() | |||
| # 用train_data训练,在test_data验证 | |||
| @@ -147,13 +143,8 @@ class TestTutorial(unittest.TestCase): | |||
| from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | |||
| trainer = Trainer(model=model, | |||
| train_data=train_data, | |||
| dev_data=dev_data, | |||
| loss=CrossEntropyLoss(), | |||
| optimizer= Adam(), | |||
| metrics=AccuracyMetric(target='target') | |||
| ) | |||
| trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(), loss=CrossEntropyLoss(), | |||
| dev_data=dev_data, metrics=AccuracyMetric(target='target')) | |||
| trainer.train() | |||
| print('Train finished!') | |||