| @@ -217,11 +217,11 @@ class ModelProcessor(Processor): | |||
| tmp_batch = [] | |||
| value = value.cpu().numpy() | |||
| if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1): | |||
| batch_output[key].extend(value.tolist()) | |||
| else: | |||
| for idx, seq_len in enumerate(seq_lens): | |||
| tmp_batch.append(value[idx, :seq_len]) | |||
| batch_output[key].extend(tmp_batch) | |||
| else: | |||
| batch_output[key].extend(value.tolist()) | |||
| batch_output[self.seq_len_field_name].extend(seq_lens) | |||
| @@ -53,6 +53,54 @@ class SeqLabelEvaluator(Evaluator): | |||
| accuracy = total_correct / total_count | |||
| return {"accuracy": float(accuracy)} | |||
| class SeqLabelEvaluator2(Evaluator): | |||
| # 上面的evaluator应该是错误的 | |||
| def __init__(self, seq_lens_field_name='word_seq_origin_len'): | |||
| super(SeqLabelEvaluator2, self).__init__() | |||
| self.end_tagidx_set = set() | |||
| self.seq_lens_field_name = seq_lens_field_name | |||
| def __call__(self, predict, truth, **_): | |||
| """ | |||
| :param predict: list of batch, the network outputs from all batches. | |||
| :param truth: list of dict, the ground truths from all batch_y. | |||
| :return accuracy: | |||
| """ | |||
| seq_lens = _[self.seq_lens_field_name] | |||
| corr_count = 0 | |||
| pred_count = 0 | |||
| truth_count = 0 | |||
| for x, y, seq_len in zip(predict, truth, seq_lens): | |||
| x = x.cpu().numpy() | |||
| y = y.cpu().numpy() | |||
| for idx, s_l in enumerate(seq_len): | |||
| x_ = x[idx] | |||
| y_ = y[idx] | |||
| x_ = x_[:s_l] | |||
| y_ = y_[:s_l] | |||
| flag = True | |||
| start = 0 | |||
| for idx_i, (x_i, y_i) in enumerate(zip(x_, y_)): | |||
| if x_i in self.end_tagidx_set: | |||
| truth_count += 1 | |||
| for j in range(start, idx_i + 1): | |||
| if y_[j]!=x_[j]: | |||
| flag = False | |||
| break | |||
| if flag: | |||
| corr_count += 1 | |||
| flag = True | |||
| start = idx_i + 1 | |||
| if y_i in self.end_tagidx_set: | |||
| pred_count += 1 | |||
| P = corr_count / (float(pred_count) + 1e-6) | |||
| R = corr_count / (float(truth_count) + 1e-6) | |||
| F = 2 * P * R / (P + R + 1e-6) | |||
| return {"P": P, 'R':R, 'F': F} | |||
| class SNLIEvaluator(Evaluator): | |||
| def __init__(self): | |||
| @@ -167,8 +167,10 @@ class AdvSeqLabel(SeqLabeling): | |||
| x = self.Linear2(x) | |||
| # x = x.view(batch_size, max_len, -1) | |||
| # [batch_size, max_len, num_classes] | |||
| # TODO seq_lens的key这样做不合理 | |||
| return {"loss": self._internal_loss(x, truth) if truth is not None else None, | |||
| "predict": self.decode(x)} | |||
| "predict": self.decode(x), | |||
| 'word_seq_origin_len': word_seq_origin_len} | |||
| def predict(self, **x): | |||
| out = self.forward(**x) | |||
| @@ -111,7 +111,7 @@ class POSCWSReader(DataSetLoader): | |||
| continue | |||
| line = ' '.join(words) | |||
| if cut_long_sent: | |||
| sents = cut_long_sent(line) | |||
| sents = cut_long_sentence(line) | |||
| else: | |||
| sents = [line] | |||
| for sent in sents: | |||
| @@ -127,3 +127,50 @@ class POSCWSReader(DataSetLoader): | |||
| return dataset | |||
| class ConlluCWSReader(object): | |||
| # 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||
| def __init__(self): | |||
| pass | |||
| def load(self, path, cut_long_sent=False): | |||
| datalist = [] | |||
| with open(path, 'r', encoding='utf-8') as f: | |||
| sample = [] | |||
| for line in f: | |||
| if line.startswith('\n'): | |||
| datalist.append(sample) | |||
| sample = [] | |||
| elif line.startswith('#'): | |||
| continue | |||
| else: | |||
| sample.append(line.split('\t')) | |||
| if len(sample) > 0: | |||
| datalist.append(sample) | |||
| ds = DataSet() | |||
| for sample in datalist: | |||
| # print(sample) | |||
| res = self.get_one(sample) | |||
| if res is None: | |||
| continue | |||
| line = ' '.join(res) | |||
| if cut_long_sent: | |||
| sents = cut_long_sentence(line) | |||
| else: | |||
| sents = [line] | |||
| for raw_sentence in sents: | |||
| ds.append(Instance(raw_sentence=raw_sentence)) | |||
| return ds | |||
| def get_one(self, sample): | |||
| if len(sample)==0: | |||
| return None | |||
| text = [] | |||
| for w in sample: | |||
| t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
| if t3 == '_': | |||
| return None | |||
| text.append(t1) | |||
| return text | |||
| @@ -117,3 +117,56 @@ class CWSBiLSTMSegApp(BaseModel): | |||
| pred_probs = pred_dict['pred_probs'] | |||
| _, pred_tags = pred_probs.max(dim=-1) | |||
| return {'pred_tags': pred_tags} | |||
| from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||
| class CWSBiLSTMCRF(BaseModel): | |||
| def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||
| hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=4): | |||
| super(CWSBiLSTMCRF, self).__init__() | |||
| self.tag_size = tag_size | |||
| self.encoder_model = CWSBiLSTMEncoder(vocab_num, embed_dim, bigram_vocab_num, bigram_embed_dim, num_bigram_per_char, | |||
| hidden_size, bidirectional, embed_drop_p, num_layers) | |||
| size_layer = [hidden_size, 200, tag_size] | |||
| self.decoder_model = MLP(size_layer) | |||
| self.crf = ConditionalRandomField(tag_size=tag_size, include_start_end_trans=False) | |||
| def forward(self, chars, tags, seq_lens, bigrams=None): | |||
| device = self.parameters().__next__().device | |||
| chars = chars.to(device).long() | |||
| if not bigrams is None: | |||
| bigrams = bigrams.to(device).long() | |||
| else: | |||
| bigrams = None | |||
| seq_lens = seq_lens.to(device).long() | |||
| masks = seq_lens_to_mask(seq_lens) | |||
| feats = self.encoder_model(chars, bigrams, seq_lens) | |||
| feats = self.decoder_model(feats) | |||
| losses = self.crf(feats, tags, masks) | |||
| pred_dict = {} | |||
| pred_dict['seq_lens'] = seq_lens | |||
| pred_dict['loss'] = torch.mean(losses) | |||
| return pred_dict | |||
| def predict(self, chars, seq_lens, bigrams=None): | |||
| device = self.parameters().__next__().device | |||
| chars = chars.to(device).long() | |||
| if not bigrams is None: | |||
| bigrams = bigrams.to(device).long() | |||
| else: | |||
| bigrams = None | |||
| seq_lens = seq_lens.to(device).long() | |||
| masks = seq_lens_to_mask(seq_lens) | |||
| feats = self.encoder_model(chars, bigrams, seq_lens) | |||
| feats = self.decoder_model(feats) | |||
| probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||
| return {'pred_tags': probs} | |||
| @@ -118,6 +118,23 @@ class CWSTagProcessor(Processor): | |||
| def _tags_from_word_len(self, word_len): | |||
| raise NotImplementedError | |||
| class CWSBMESTagProcessor(CWSTagProcessor): | |||
| def __init__(self, field_name, new_added_field_name=None): | |||
| super(CWSBMESTagProcessor, self).__init__(field_name, new_added_field_name) | |||
| self.tag_size = 4 | |||
| def _tags_from_word_len(self, word_len): | |||
| tag_list = [] | |||
| if word_len == 1: | |||
| tag_list.append(3) | |||
| else: | |||
| tag_list.append(0) | |||
| for _ in range(word_len-2): | |||
| tag_list.append(1) | |||
| tag_list.append(2) | |||
| return tag_list | |||
| class CWSSegAppTagProcessor(CWSTagProcessor): | |||
| def __init__(self, field_name, new_added_field_name=None): | |||
| @@ -239,3 +256,29 @@ class SegApp2OutputProcessor(Processor): | |||
| start_idx = idx + 1 | |||
| ins[self.new_added_field_name] = ' '.join(words) | |||
| class BMES2OutputProcessor(Processor): | |||
| def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||
| super(BMES2OutputProcessor, self).__init__(None, None) | |||
| self.chars_field_name = chars_field_name | |||
| self.tag_field_name = tag_field_name | |||
| self.new_added_field_name = new_added_field_name | |||
| def process(self, dataset): | |||
| assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||
| for ins in dataset: | |||
| pred_tags = ins[self.tag_field_name] | |||
| chars = ins[self.chars_field_name] | |||
| words = [] | |||
| start_idx = 0 | |||
| for idx, tag in enumerate(pred_tags): | |||
| if tag==3: | |||
| # 当前没有考虑将原文替换回去 | |||
| words.extend(chars[start_idx:idx+1]) | |||
| start_idx = idx + 1 | |||
| elif tag==2: | |||
| words.append(''.join(chars[start_idx:idx+1])) | |||
| start_idx = idx + 1 | |||
| ins[self.new_added_field_name] = ' '.join(words) | |||
| @@ -24,37 +24,52 @@ def refine_ys_on_seq_len(ys, seq_lens): | |||
| def flat_nested_list(nested_list): | |||
| return list(chain(*nested_list)) | |||
| def calculate_pre_rec_f1(model, batcher): | |||
| def calculate_pre_rec_f1(model, batcher, type='segapp'): | |||
| true_ys, pred_ys = decode_iterator(model, batcher) | |||
| true_ys = flat_nested_list(true_ys) | |||
| pred_ys = flat_nested_list(pred_ys) | |||
| cor_num = 0 | |||
| yp_wordnum = pred_ys.count(1) | |||
| yt_wordnum = true_ys.count(1) | |||
| start = 0 | |||
| if true_ys[0]==1 and pred_ys[0]==1: | |||
| cor_num += 1 | |||
| start = 1 | |||
| for i in range(1, len(true_ys)): | |||
| if true_ys[i] == 1: | |||
| flag = True | |||
| if true_ys[start-1] != pred_ys[start-1]: | |||
| flag = False | |||
| else: | |||
| if type=='segapp': | |||
| yp_wordnum = pred_ys.count(1) | |||
| yt_wordnum = true_ys.count(1) | |||
| if true_ys[0]==1 and pred_ys[0]==1: | |||
| cor_num += 1 | |||
| start = 1 | |||
| for i in range(1, len(true_ys)): | |||
| if true_ys[i] == 1: | |||
| flag = True | |||
| if true_ys[start-1] != pred_ys[start-1]: | |||
| flag = False | |||
| else: | |||
| for j in range(start, i + 1): | |||
| if true_ys[j] != pred_ys[j]: | |||
| flag = False | |||
| break | |||
| if flag: | |||
| cor_num += 1 | |||
| start = i + 1 | |||
| elif type=='bmes': | |||
| yp_wordnum = pred_ys.count(2) + pred_ys.count(3) | |||
| yt_wordnum = true_ys.count(2) + true_ys.count(3) | |||
| for i in range(len(true_ys)): | |||
| if true_ys[i] == 2 or true_ys[i] == 3: | |||
| flag = True | |||
| for j in range(start, i + 1): | |||
| if true_ys[j] != pred_ys[j]: | |||
| flag = False | |||
| break | |||
| if flag: | |||
| cor_num += 1 | |||
| start = i + 1 | |||
| if flag: | |||
| cor_num += 1 | |||
| start = i + 1 | |||
| P = cor_num / (float(yp_wordnum) + 1e-6) | |||
| R = cor_num / (float(yt_wordnum) + 1e-6) | |||
| F = 2 * P * R / (P + R + 1e-6) | |||
| print(cor_num, yt_wordnum, yp_wordnum) | |||
| # print(cor_num, yt_wordnum, yp_wordnum) | |||
| return P, R, F | |||
| @@ -0,0 +1,89 @@ | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| def cut_long_sentence(sent, max_sample_length=200): | |||
| sent_no_space = sent.replace(' ', '') | |||
| cutted_sentence = [] | |||
| if len(sent_no_space) > max_sample_length: | |||
| parts = sent.strip().split() | |||
| new_line = '' | |||
| length = 0 | |||
| for part in parts: | |||
| length += len(part) | |||
| new_line += part + ' ' | |||
| if length > max_sample_length: | |||
| new_line = new_line[:-1] | |||
| cutted_sentence.append(new_line) | |||
| length = 0 | |||
| new_line = '' | |||
| if new_line != '': | |||
| cutted_sentence.append(new_line[:-1]) | |||
| else: | |||
| cutted_sentence.append(sent) | |||
| return cutted_sentence | |||
| class ConlluPOSReader(object): | |||
| # 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BMES的tag)。 | |||
| def __init__(self): | |||
| pass | |||
| def load(self, path): | |||
| datalist = [] | |||
| with open(path, 'r', encoding='utf-8') as f: | |||
| sample = [] | |||
| for line in f: | |||
| if line.startswith('\n'): | |||
| datalist.append(sample) | |||
| sample = [] | |||
| elif line.startswith('#'): | |||
| continue | |||
| else: | |||
| sample.append(line.split('\t')) | |||
| if len(sample) > 0: | |||
| datalist.append(sample) | |||
| ds = DataSet() | |||
| for sample in datalist: | |||
| # print(sample) | |||
| res = self.get_one(sample) | |||
| if res is None: | |||
| continue | |||
| char_seq = [] | |||
| pos_seq = [] | |||
| for word, tag in zip(res[0], res[1]): | |||
| if len(word)==1: | |||
| char_seq.append(word) | |||
| pos_seq.append('S-{}'.format(tag)) | |||
| elif len(word)>1: | |||
| pos_seq.append('B-{}'.format(tag)) | |||
| for _ in range(len(word)-2): | |||
| pos_seq.append('M-{}'.format(tag)) | |||
| pos_seq.append('E-{}'.format(tag)) | |||
| char_seq.extend(list(word)) | |||
| else: | |||
| raise ValueError("Zero length of word detected.") | |||
| ds.append(Instance(words=char_seq, | |||
| tag=pos_seq)) | |||
| return ds | |||
| def get_one(self, sample): | |||
| if len(sample)==0: | |||
| return None | |||
| text = [] | |||
| pos_tags = [] | |||
| for w in sample: | |||
| t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
| if t3 == '_': | |||
| return None | |||
| text.append(t1) | |||
| pos_tags.append(t2) | |||
| return text, pos_tags | |||
| if __name__ == '__main__': | |||
| reader = ConlluPOSReader() | |||
| d = reader.load('/home/hyan/train.conllx') | |||
| print('reader') | |||
| @@ -1,16 +1,18 @@ | |||
| [train] | |||
| epochs = 300 | |||
| epochs = 6 | |||
| batch_size = 32 | |||
| pickle_path = "./save/" | |||
| validate = false | |||
| validate = true | |||
| save_best_dev = true | |||
| model_saved_path = "./save/" | |||
| valid_step = 250 | |||
| eval_sort_key = 'accuracy' | |||
| [model] | |||
| rnn_hidden_units = 100 | |||
| word_emb_dim = 100 | |||
| rnn_hidden_units = 300 | |||
| word_emb_dim = 300 | |||
| dropout = 0.5 | |||
| use_crf = true | |||
| use_cuda = true | |||
| print_every_step = 10 | |||
| [test] | |||
| @@ -34,4 +36,4 @@ pickle_path = "./save/" | |||
| use_crf = true | |||
| use_cuda = true | |||
| rnn_hidden_units = 100 | |||
| word_emb_dim = 100 | |||
| word_emb_dim = 100 | |||
| @@ -78,7 +78,7 @@ class PosOutputStrProcessor(Processor): | |||
| word_pos_list = [] | |||
| for word, pos in zip(word_list, pos_list): | |||
| word_pos_list.append(word + self.sep + pos) | |||
| #TODO 应该可以定制 | |||
| ins['word_pos_output'] = ' '.join(word_pos_list) | |||
| return dataset | |||