| @@ -69,6 +69,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
| else: | |||
| raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
| self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) | |||
| num_layers = self.model.encoder.num_layers | |||
| if layers == 'mix': | |||
| self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers'] + 1), | |||
| @@ -78,9 +79,9 @@ class ElmoEmbedding(ContextualEmbedding): | |||
| self._embed_size = self.model.config['lstm']['projection_dim'] * 2 | |||
| else: | |||
| layers = list(map(int, layers.split(','))) | |||
| assert len(layers) > 0, "Must choose one output" | |||
| assert len(layers) > 0, "Must choose at least one output, but got None." | |||
| for layer in layers: | |||
| assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." | |||
| assert 0 <= layer <= num_layers, f"Layer index should be in range [0, {num_layers}], but got {layer}." | |||
| self.layers = layers | |||
| self._get_outputs = self._get_layer_outputs | |||
| self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2 | |||
| @@ -241,7 +241,7 @@ class BertForQuestionAnswering(BaseModel): | |||
| def forward(self, words): | |||
| """ | |||
| :param torch.LongTensor words: [batch_size, seq_len] | |||
| :return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len] | |||
| :return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len + 2] | |||
| """ | |||
| sequence_output = self.bert(words) | |||
| logits = self.qa_outputs(sequence_output) # [batch_size, seq_len, num_labels] | |||
| @@ -0,0 +1,229 @@ | |||
| ! 33 | |||
| " 34 | |||
| # 35 | |||
| $ 36 | |||
| % 37 | |||
| & 38 | |||
| ' 39 | |||
| ( 40 | |||
| ) 41 | |||
| * 42 | |||
| + 43 | |||
| , 44 | |||
| - 45 | |||
| . 46 | |||
| / 47 | |||
| 0 48 | |||
| 1 49 | |||
| 2 50 | |||
| 3 51 | |||
| 4 52 | |||
| 5 53 | |||
| 6 54 | |||
| 7 55 | |||
| 8 56 | |||
| 9 57 | |||
| : 58 | |||
| ; 59 | |||
| < 60 | |||
| = 61 | |||
| > 62 | |||
| ? 63 | |||
| @ 64 | |||
| A 65 | |||
| B 66 | |||
| C 67 | |||
| D 68 | |||
| E 69 | |||
| F 70 | |||
| G 71 | |||
| H 72 | |||
| I 73 | |||
| J 74 | |||
| K 75 | |||
| L 76 | |||
| M 77 | |||
| N 78 | |||
| O 79 | |||
| P 80 | |||
| Q 81 | |||
| R 82 | |||
| S 83 | |||
| T 84 | |||
| U 85 | |||
| V 86 | |||
| W 87 | |||
| X 88 | |||
| Y 89 | |||
| Z 90 | |||
| [ 91 | |||
| \ 92 | |||
| ] 93 | |||
| ^ 94 | |||
| _ 95 | |||
| ` 96 | |||
| a 97 | |||
| b 98 | |||
| c 99 | |||
| d 100 | |||
| e 101 | |||
| f 102 | |||
| g 103 | |||
| h 104 | |||
| i 105 | |||
| j 106 | |||
| k 107 | |||
| l 108 | |||
| m 109 | |||
| n 110 | |||
| o 111 | |||
| p 112 | |||
| q 113 | |||
| r 114 | |||
| s 115 | |||
| t 116 | |||
| u 117 | |||
| v 118 | |||
| w 119 | |||
| x 120 | |||
| y 121 | |||
| z 122 | |||
| { 123 | |||
| | 124 | |||
| } 125 | |||
| ~ 126 | |||
| 127 | |||
| 128 | |||
| 129 | |||
| 130 | |||
| 131 | |||
| 132 | |||
| 134 | |||
| 135 | |||
| 136 | |||
| 137 | |||
| 138 | |||
| 139 | |||
| 140 | |||
| 141 | |||
| 142 | |||
| 143 | |||
| 144 | |||
| 145 | |||
| 146 | |||
| 147 | |||
| 148 | |||
| 149 | |||
| 150 | |||
| 151 | |||
| 152 | |||
| 153 | |||
| 154 | |||
| 155 | |||
| 156 | |||
| 157 | |||
| 158 | |||
| 159 | |||
| 160 | |||
| ¡ 161 | |||
| ¢ 162 | |||
| £ 163 | |||
| ¤ 164 | |||
| ¥ 165 | |||
| ¦ 166 | |||
| § 167 | |||
| ¨ 168 | |||
| © 169 | |||
| ª 170 | |||
| « 171 | |||
| ¬ 172 | |||
| 173 | |||
| ® 174 | |||
| ¯ 175 | |||
| ° 176 | |||
| ± 177 | |||
| ² 178 | |||
| ³ 179 | |||
| ´ 180 | |||
| µ 181 | |||
| ¶ 182 | |||
| · 183 | |||
| ¸ 184 | |||
| ¹ 185 | |||
| º 186 | |||
| » 187 | |||
| ¼ 188 | |||
| ½ 189 | |||
| ¾ 190 | |||
| ¿ 191 | |||
| À 192 | |||
| Á 193 | |||
| Â 194 | |||
| Ã 195 | |||
| Ä 196 | |||
| Å 197 | |||
| Æ 198 | |||
| Ç 199 | |||
| È 200 | |||
| É 201 | |||
| Ê 202 | |||
| Ë 203 | |||
| Ì 204 | |||
| Í 205 | |||
| Î 206 | |||
| Ï 207 | |||
| Ð 208 | |||
| Ñ 209 | |||
| Ò 210 | |||
| Ó 211 | |||
| Ô 212 | |||
| Õ 213 | |||
| Ö 214 | |||
| × 215 | |||
| Ø 216 | |||
| Ù 217 | |||
| Ú 218 | |||
| Û 219 | |||
| Ü 220 | |||
| Ý 221 | |||
| Þ 222 | |||
| ß 223 | |||
| à 224 | |||
| á 225 | |||
| â 226 | |||
| ã 227 | |||
| ä 228 | |||
| å 229 | |||
| æ 230 | |||
| ç 231 | |||
| è 232 | |||
| é 233 | |||
| ê 234 | |||
| ë 235 | |||
| ì 236 | |||
| í 237 | |||
| î 238 | |||
| ï 239 | |||
| ð 240 | |||
| ñ 241 | |||
| ò 242 | |||
| ó 243 | |||
| ô 244 | |||
| õ 245 | |||
| ö 246 | |||
| ÷ 247 | |||
| ø 248 | |||
| ù 249 | |||
| ú 250 | |||
| û 251 | |||
| ü 252 | |||
| ý 253 | |||
| þ 254 | |||
| ÿ 255 | |||
| <bos> 256 | |||
| <eos> 257 | |||
| <bow> 258 | |||
| <eow> 259 | |||
| <char_pad> 260 | |||
| <oov> 1 | |||
| <pad> -1 | |||
| @@ -0,0 +1,29 @@ | |||
| { | |||
| "lstm": { | |||
| "use_skip_connections": true, | |||
| "projection_dim": 16, | |||
| "cell_clip": 3, | |||
| "proj_clip": 3, | |||
| "dim": 16, | |||
| "n_layers": 1 | |||
| }, | |||
| "char_cnn": { | |||
| "activation": "relu", | |||
| "filters": [ | |||
| [ | |||
| 1, | |||
| 16 | |||
| ], | |||
| [ | |||
| 2, | |||
| 16 | |||
| ] | |||
| ], | |||
| "n_highway": 1, | |||
| "embedding": { | |||
| "dim": 4 | |||
| }, | |||
| "n_characters": 262, | |||
| "max_characters_per_token": 50 | |||
| } | |||
| } | |||
| @@ -29,8 +29,11 @@ class TestDownload(unittest.TestCase): | |||
| class TestBertEmbedding(unittest.TestCase): | |||
| def test_bert_embedding_1(self): | |||
| vocab = Vocabulary().add_word_lst("this is a test .".split()) | |||
| embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert') | |||
| vocab = Vocabulary().add_word_lst("this is a test . [SEP]".split()) | |||
| embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||
| requires_grad = embed.requires_grad | |||
| embed.requires_grad = not requires_grad | |||
| embed.train() | |||
| words = torch.LongTensor([[2, 3, 4, 0]]) | |||
| result = embed(words) | |||
| self.assertEqual(result.size(), (1, 4, 16)) | |||
| @@ -18,4 +18,19 @@ class TestDownload(unittest.TestCase): | |||
| # 首先保证所有权重可以加载;上传权重;验证可以下载 | |||
| class TestRunElmo(unittest.TestCase): | |||
| def test_elmo_embedding(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', layers='0,1') | |||
| words = torch.LongTensor([[0, 1, 2]]) | |||
| hidden = elmo_embed(words) | |||
| print(hidden.size()) | |||
| def test_elmo_embedding_layer_assertion(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| try: | |||
| elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', | |||
| layers='0,1,2') | |||
| except AssertionError as e: | |||
| print(e) | |||