| @@ -420,11 +420,11 @@ class _WordBertModel(nn.Module): | |||||
| if self.pool_method == 'first': | if self.pool_method == 'first': | ||||
| batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] | batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] | ||||
| batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | ||||
| batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||||
| _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||||
| elif self.pool_method == 'last': | elif self.pool_method == 'last': | ||||
| batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1 | batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1 | ||||
| batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | ||||
| batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||||
| _batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||||
| for l_index, l in enumerate(self.layers): | for l_index, l in enumerate(self.layers): | ||||
| output_layer = bert_outputs[l] | output_layer = bert_outputs[l] | ||||
| @@ -437,12 +437,12 @@ class _WordBertModel(nn.Module): | |||||
| # 从word_piece collapse到word的表示 | # 从word_piece collapse到word的表示 | ||||
| truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size | truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size | ||||
| if self.pool_method == 'first': | if self.pool_method == 'first': | ||||
| tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length] | |||||
| tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||||
| tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | ||||
| outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | ||||
| elif self.pool_method == 'last': | elif self.pool_method == 'last': | ||||
| tmp = truncate_output_layer[batch_indexes, batch_word_pieces_cum_length] | |||||
| tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||||
| tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | ||||
| outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | ||||
| elif self.pool_method == 'max': | elif self.pool_method == 'max': | ||||