| @@ -97,7 +97,7 @@ | |||||
| # 将句子分成单词形式, 详见DataSet.apply()方法 | # 将句子分成单词形式, 详见DataSet.apply()方法 | ||||
| dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words') | dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words') | ||||
| # 或使用DataSet.apply_field() | # 或使用DataSet.apply_field() | ||||
| dataset.apply(lambda sent:sent.split(), field_name='sentence', new_field_name='words') | |||||
| dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words') | |||||
| # 除了匿名函数,也可以定义函数传递进去 | # 除了匿名函数,也可以定义函数传递进去 | ||||
| def get_words(instance): | def get_words(instance): | ||||
| sentence = instance['sentence'] | sentence = instance['sentence'] | ||||
| @@ -14,7 +14,7 @@ class DotAttention(nn.Module): | |||||
| """ | """ | ||||
| TODO | TODO | ||||
| """ | """ | ||||
| def __init__(self, key_size, value_size, dropout=0.1): | |||||
| def __init__(self, key_size, value_size, dropout=0): | |||||
| super(DotAttention, self).__init__() | super(DotAttention, self).__init__() | ||||
| self.key_size = key_size | self.key_size = key_size | ||||
| self.value_size = value_size | self.value_size = value_size | ||||
| @@ -25,14 +25,14 @@ class DotAttention(nn.Module): | |||||
| def forward(self, Q, K, V, mask_out=None): | def forward(self, Q, K, V, mask_out=None): | ||||
| """ | """ | ||||
| :param Q: [batch, seq_len, key_size] | |||||
| :param K: [batch, seq_len, key_size] | |||||
| :param V: [batch, seq_len, value_size] | |||||
| :param mask_out: [batch, seq_len] | |||||
| :param Q: [batch, seq_len_q, key_size] | |||||
| :param K: [batch, seq_len_k, key_size] | |||||
| :param V: [batch, seq_len_k, value_size] | |||||
| :param mask_out: [batch, 1, seq_len] or [batch, seq_len_q, seq_len_k] | |||||
| """ | """ | ||||
| output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | ||||
| if mask_out is not None: | if mask_out is not None: | ||||
| output.masked_fill_(mask_out, -float('inf')) | |||||
| output.masked_fill_(mask_out, -1e8) | |||||
| output = self.softmax(output) | output = self.softmax(output) | ||||
| output = self.drop(output) | output = self.drop(output) | ||||
| return torch.matmul(output, V) | return torch.matmul(output, V) | ||||
| @@ -58,7 +58,8 @@ class MultiHeadAttention(nn.Module): | |||||
| self.q_in = nn.Linear(input_size, in_size) | self.q_in = nn.Linear(input_size, in_size) | ||||
| self.k_in = nn.Linear(input_size, in_size) | self.k_in = nn.Linear(input_size, in_size) | ||||
| self.v_in = nn.Linear(input_size, in_size) | self.v_in = nn.Linear(input_size, in_size) | ||||
| self.attention = DotAttention(key_size=key_size, value_size=value_size) | |||||
| # follow the paper, do not apply dropout within dot-product | |||||
| self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0) | |||||
| self.out = nn.Linear(value_size * num_head, input_size) | self.out = nn.Linear(value_size * num_head, input_size) | ||||
| self.drop = TimestepDropout(dropout) | self.drop = TimestepDropout(dropout) | ||||
| self.reset_parameters() | self.reset_parameters() | ||||
| @@ -73,28 +74,29 @@ class MultiHeadAttention(nn.Module): | |||||
| def forward(self, Q, K, V, atte_mask_out=None): | def forward(self, Q, K, V, atte_mask_out=None): | ||||
| """ | """ | ||||
| :param Q: [batch, seq_len, model_size] | |||||
| :param K: [batch, seq_len, model_size] | |||||
| :param V: [batch, seq_len, model_size] | |||||
| :param Q: [batch, seq_len_q, model_size] | |||||
| :param K: [batch, seq_len_k, model_size] | |||||
| :param V: [batch, seq_len_k, model_size] | |||||
| :param seq_mask: [batch, seq_len] | :param seq_mask: [batch, seq_len] | ||||
| """ | """ | ||||
| batch, seq_len, _ = Q.size() | |||||
| batch, sq, _ = Q.size() | |||||
| sk = K.size(1) | |||||
| d_k, d_v, n_head = self.key_size, self.value_size, self.num_head | d_k, d_v, n_head = self.key_size, self.value_size, self.num_head | ||||
| # input linear | # input linear | ||||
| q = self.q_in(Q).view(batch, seq_len, n_head, d_k) | |||||
| k = self.k_in(K).view(batch, seq_len, n_head, d_k) | |||||
| v = self.v_in(V).view(batch, seq_len, n_head, d_k) | |||||
| q = self.q_in(Q).view(batch, sq, n_head, d_k) | |||||
| k = self.k_in(K).view(batch, sk, n_head, d_k) | |||||
| v = self.v_in(V).view(batch, sk, n_head, d_v) | |||||
| # transpose q, k and v to do batch attention | # transpose q, k and v to do batch attention | ||||
| q = q.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) | |||||
| k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) | |||||
| v = v.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_v) | |||||
| q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k) | |||||
| k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k) | |||||
| v = v.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_v) | |||||
| if atte_mask_out is not None: | if atte_mask_out is not None: | ||||
| atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) | atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) | ||||
| atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, seq_len, d_v) | |||||
| atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v) | |||||
| # concat all heads, do output linear | # concat all heads, do output linear | ||||
| atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, seq_len, -1) | |||||
| atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) | |||||
| output = self.drop(self.out(atte)) | output = self.drop(self.out(atte)) | ||||
| return output | return output | ||||
| @@ -7,7 +7,6 @@ import torch | |||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.core.predictor import Predictor | from fastNLP.core.predictor import Predictor | ||||
| from fastNLP.modules.encoder.linear import Linear | |||||
| def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
| @@ -27,7 +26,7 @@ def prepare_fake_dataset(): | |||||
| class LinearModel(torch.nn.Module): | class LinearModel(torch.nn.Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(LinearModel, self).__init__() | super(LinearModel, self).__init__() | ||||
| self.linear = Linear(2, 1) | |||||
| self.linear = torch.nn.Linear(2, 1) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| return {"predict": self.linear(x)} | return {"predict": self.linear(x)} | ||||
| @@ -1,7 +1,7 @@ | |||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| from fastNLP.io import ConfigSection, ConfigLoader, ConfigSaver | |||||
| # from fastNLP.io import ConfigSection, ConfigLoader, ConfigSaver | |||||
| class TestConfigSaver(unittest.TestCase): | class TestConfigSaver(unittest.TestCase): | ||||
| @@ -24,7 +24,7 @@ Example:: | |||||
| RUNNER.run_model(model, data=get_mydata(), | RUNNER.run_model(model, data=get_mydata(), | ||||
| loss=Myloss(), metrics=Mymetric()) | loss=Myloss(), metrics=Mymetric()) | ||||
| """ | """ | ||||
| from fastNLP import Trainer, Tester, DataSet | |||||
| from fastNLP import Trainer, Tester, DataSet, Callback | |||||
| from fastNLP import AccuracyMetric | from fastNLP import AccuracyMetric | ||||
| from fastNLP import CrossEntropyLoss | from fastNLP import CrossEntropyLoss | ||||
| from fastNLP.core.const import Const as C | from fastNLP.core.const import Const as C | ||||
| @@ -42,6 +42,10 @@ POS_TAGGING = 'pos_tagging' | |||||
| NLI = 'nli' | NLI = 'nli' | ||||
| class ModelRunner(): | class ModelRunner(): | ||||
| class Checker(Callback): | |||||
| def on_backward_begin(self, loss): | |||||
| assert loss.to('cpu').numpy().isfinate() | |||||
| def gen_seq(self, length, vocab_size): | def gen_seq(self, length, vocab_size): | ||||
| """generate fake sequence indexes with given length""" | """generate fake sequence indexes with given length""" | ||||
| # reserve 0 for padding | # reserve 0 for padding | ||||
| @@ -25,10 +25,24 @@ def prepare_parser_data(): | |||||
| is_input=True, is_target=True) | is_input=True, is_target=True) | ||||
| return ds | return ds | ||||
| class TestBiaffineParser(unittest.TestCase): | class TestBiaffineParser(unittest.TestCase): | ||||
| def test_train(self): | def test_train(self): | ||||
| model = BiaffineParser(init_embed=(VOCAB_SIZE, 30), | |||||
| pos_vocab_size=VOCAB_SIZE, pos_emb_dim=30, | |||||
| model = BiaffineParser(init_embed=(VOCAB_SIZE, 10), | |||||
| pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, | |||||
| rnn_hidden_size=10, | |||||
| arc_mlp_size=10, | |||||
| label_mlp_size=10, | |||||
| num_label=NUM_CLS, encoder='var-lstm') | num_label=NUM_CLS, encoder='var-lstm') | ||||
| ds = prepare_parser_data() | ds = prepare_parser_data() | ||||
| RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric()) | RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric()) | ||||
| def test_train2(self): | |||||
| model = BiaffineParser(init_embed=(VOCAB_SIZE, 10), | |||||
| pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, | |||||
| rnn_hidden_size=16, | |||||
| arc_mlp_size=10, | |||||
| label_mlp_size=10, | |||||
| num_label=NUM_CLS, encoder='transformer') | |||||
| ds = prepare_parser_data() | |||||
| RUNNER.run_model(model, ds, loss=ParserLoss(), metrics=ParserMetric()) | |||||
| @@ -4,13 +4,13 @@ from fastNLP.models.star_transformer import STNLICls, STSeqCls, STSeqLabel | |||||
| # add star-transformer tests, for 3 kinds of tasks. | # add star-transformer tests, for 3 kinds of tasks. | ||||
| def test_cls(): | def test_cls(): | ||||
| model = STSeqCls((VOCAB_SIZE, 100), NUM_CLS, dropout=0) | |||||
| model = STSeqCls((VOCAB_SIZE, 10), NUM_CLS, dropout=0) | |||||
| RUNNER.run_model_with_task(TEXT_CLS, model) | RUNNER.run_model_with_task(TEXT_CLS, model) | ||||
| def test_nli(): | def test_nli(): | ||||
| model = STNLICls((VOCAB_SIZE, 100), NUM_CLS, dropout=0) | |||||
| model = STNLICls((VOCAB_SIZE, 10), NUM_CLS, dropout=0) | |||||
| RUNNER.run_model_with_task(NLI, model) | RUNNER.run_model_with_task(NLI, model) | ||||
| def test_seq_label(): | def test_seq_label(): | ||||
| model = STSeqLabel((VOCAB_SIZE, 100), NUM_CLS, dropout=0) | |||||
| model = STSeqLabel((VOCAB_SIZE, 10), NUM_CLS, dropout=0) | |||||
| RUNNER.run_model_with_task(POS_TAGGING, model) | RUNNER.run_model_with_task(POS_TAGGING, model) | ||||
| @@ -2,7 +2,7 @@ import unittest | |||||
| import torch | import torch | ||||
| from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine | |||||
| # from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine | |||||
| from fastNLP.modules.encoder.star_transformer import StarTransformer | from fastNLP.modules.encoder.star_transformer import StarTransformer | ||||