| @@ -0,0 +1,46 @@ | |||
| class Const(): | |||
| """fastNLP中field命名常量。 | |||
| 具体列表:: | |||
| INPUT 模型的序列输入 words(复数words1, words2) | |||
| CHAR_INPUT 模型character输入 chars(复数chars1, chars2) | |||
| INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2) | |||
| OUTPUT 模型输出 pred(复数pred1, pred2) | |||
| TARGET 真实目标 target(复数target1,target2) | |||
| """ | |||
| INPUT = 'words' | |||
| CHAR_INPUT = 'chars' | |||
| INPUT_LEN = 'seq_len' | |||
| OUTPUT = 'pred' | |||
| TARGET = 'target' | |||
| @staticmethod | |||
| def INPUTS(i): | |||
| """得到第 i 个 ``INPUT`` 的命名""" | |||
| i = int(i) + 1 | |||
| return Const.INPUT + str(i) | |||
| @staticmethod | |||
| def CHAR_INPUTS(i): | |||
| """得到第 i 个 ``CHAR_INPUT`` 的命名""" | |||
| i = int(i) + 1 | |||
| return Const.CHAR_INPUT + str(i) | |||
| @staticmethod | |||
| def INPUT_LENS(i): | |||
| """得到第 i 个 ``INPUT_LEN`` 的命名""" | |||
| i = int(i) + 1 | |||
| return Const.INPUT_LEN + str(i) | |||
| @staticmethod | |||
| def OUTPUTS(i): | |||
| """得到第 i 个 ``OUTPUT`` 的命名""" | |||
| i = int(i) + 1 | |||
| return Const.OUTPUT + str(i) | |||
| @staticmethod | |||
| def TARGETS(i): | |||
| """得到第 i 个 ``TARGET`` 的命名""" | |||
| i = int(i) + 1 | |||
| return Const.TARGET + str(i) | |||
| @@ -193,9 +193,9 @@ class ConllLoader(DataSetLoader): | |||
| :param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexs`` 一一对应 | |||
| :param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
| :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||
| :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | |||
| """ | |||
| def __init__(self, headers, indexs=None, dropna=True): | |||
| def __init__(self, headers, indexs=None, dropna=False): | |||
| super(ConllLoader, self).__init__() | |||
| if not isinstance(headers, (list, tuple)): | |||
| raise TypeError('invalid headers: {}, should be list of strings'.format(headers)) | |||
| @@ -314,7 +314,7 @@ class JsonLoader(DataSetLoader): | |||
| `value`也可为 ``None`` , 这时读入后的`field_name`与json对象对应属性同名 | |||
| ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` | |||
| :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
| Default: ``True`` | |||
| Default: ``False`` | |||
| """ | |||
| def __init__(self, fields=None, dropna=False): | |||
| super(JsonLoader, self).__init__() | |||
| @@ -375,9 +375,9 @@ class CSVLoader(DataSetLoader): | |||
| 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | |||
| :param str sep: CSV文件中列与列之间的分隔符. Default: "," | |||
| :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
| Default: ``True`` | |||
| Default: ``False`` | |||
| """ | |||
| def __init__(self, headers=None, sep=",", dropna=True): | |||
| def __init__(self, headers=None, sep=",", dropna=False): | |||
| self.headers = headers | |||
| self.sep = sep | |||
| self.dropna = dropna | |||
| @@ -1,8 +1,9 @@ | |||
| """Star-Transformer 的 一个 Pytorch 实现. | |||
| """ | |||
| from fastNLP.modules.encoder.star_transformer import StarTransformer | |||
| from fastNLP.core.utils import seq_lens_to_masks | |||
| from ..modules.encoder.star_transformer import StarTransformer | |||
| from ..core.utils import seq_lens_to_masks | |||
| from ..modules.utils import get_embeddings | |||
| from ..core.const import Const | |||
| import torch | |||
| from torch import nn | |||
| @@ -139,7 +140,7 @@ class STSeqLabel(nn.Module): | |||
| nodes, _ = self.enc(words, mask) | |||
| output = self.cls(nodes) | |||
| output = output.transpose(1,2) # make hidden to be dim 1 | |||
| return {'output': output} # [bsz, n_cls, seq_len] | |||
| return {Const.OUTPUT: output} # [bsz, n_cls, seq_len] | |||
| def predict(self, words, seq_len): | |||
| """ | |||
| @@ -149,8 +150,8 @@ class STSeqLabel(nn.Module): | |||
| :return output: [batch, seq_len] 输出序列中每个元素的分类 | |||
| """ | |||
| y = self.forward(words, seq_len) | |||
| _, pred = y['output'].max(1) | |||
| return {'output': pred} | |||
| _, pred = y[Const.OUTPUT].max(1) | |||
| return {Const.OUTPUT: pred} | |||
| class STSeqCls(nn.Module): | |||
| @@ -201,7 +202,7 @@ class STSeqCls(nn.Module): | |||
| nodes, relay = self.enc(words, mask) | |||
| y = 0.5 * (relay + nodes.max(1)[0]) | |||
| output = self.cls(y) # [bsz, n_cls] | |||
| return {'output': output} | |||
| return {Const.OUTPUT: output} | |||
| def predict(self, words, seq_len): | |||
| """ | |||
| @@ -211,8 +212,8 @@ class STSeqCls(nn.Module): | |||
| :return output: [batch, num_cls] 输出序列的分类 | |||
| """ | |||
| y = self.forward(words, seq_len) | |||
| _, pred = y['output'].max(1) | |||
| return {'output': pred} | |||
| _, pred = y[Const.OUTPUT].max(1) | |||
| return {Const.OUTPUT: pred} | |||
| class STNLICls(nn.Module): | |||
| @@ -269,7 +270,7 @@ class STNLICls(nn.Module): | |||
| y1 = enc(words1, mask1) | |||
| y2 = enc(words2, mask2) | |||
| output = self.cls(y1, y2) # [bsz, n_cls] | |||
| return {'output': output} | |||
| return {Const.OUTPUT: output} | |||
| def predict(self, words1, words2, seq_len1, seq_len2): | |||
| """ | |||
| @@ -281,5 +282,5 @@ class STNLICls(nn.Module): | |||
| :return output: [batch, num_cls] 输出分类的概率 | |||
| """ | |||
| y = self.forward(words1, words2, seq_len1, seq_len2) | |||
| _, pred = y['output'].max(1) | |||
| return {'output': pred} | |||
| _, pred = y[Const.OUTPUT].max(1) | |||
| return {Const.OUTPUT: pred} | |||
| @@ -3,7 +3,7 @@ import unittest | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||
| from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||
| LRFinder, \ | |||
| TensorboardCallback | |||
| from fastNLP.core.dataset import DataSet | |||
| @@ -1,7 +1,7 @@ | |||
| import unittest | |||
| from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, \ | |||
| CSVLoader, SNLILoader | |||
| CSVLoader, SNLILoader, JsonLoader | |||
| class TestDatasetLoader(unittest.TestCase): | |||
| @@ -24,3 +24,8 @@ class TestDatasetLoader(unittest.TestCase): | |||
| def test_SNLILoader(self): | |||
| ds = SNLILoader().load('test/data_for_tests/sample_snli.jsonl') | |||
| assert len(ds) == 3 | |||
| def test_JsonLoader(self): | |||
| ds = JsonLoader().load('test/data_for_tests/sample_snli.jsonl') | |||
| assert len(ds) == 3 | |||