| @@ -551,7 +551,7 @@ F1 0.920507 | |||
| For preprocess, you can first convert the original txt format of MSRA dataset into mindrecord by run the command as below: | |||
| ```python | |||
| python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len | |||
| python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len --class_filter="NAMEX" --split_begin=0.0 --split_end=1.0 | |||
| ``` | |||
| For finetune and evaluation, just do | |||
| @@ -562,12 +562,10 @@ bash scripts/ner.sh | |||
| The command above will run in the background, you can view training logs in ner_log.txt. | |||
| If you choose SpanF1 as assessment method and mode use_crf is set to be "true", the result will be as follows if evaluation is done after finetuning 10 epoches: | |||
| If you choose MF1(F1 score with multi-labels) as assessment method, the result will be as follows if evaluation is done after finetuning 10 epoches: | |||
| ```text | |||
| Precision 0.953826 | |||
| Recall 0.957749 | |||
| F1 0.955784 | |||
| F1 0.931243 | |||
| ``` | |||
| #### evaluation on squad v1.1 dataset when running on Ascend | |||
| @@ -515,7 +515,7 @@ F1 0.920507 | |||
| 您可以采用如下方式,先将MSRA数据集的原始格式在预处理流程中转换为mindrecord格式以提升性能: | |||
| ```python | |||
| python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len | |||
| python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len --class_filter="NAMEX" --split_begin=0.0 --split_end=1.0 | |||
| ``` | |||
| 此后,您可以进行微调再训练和推理流程, | |||
| @@ -525,12 +525,10 @@ bash scripts/ner.sh | |||
| ``` | |||
| 以上命令后台运行,您可以在ner_log.txt中查看训练日志。 | |||
| 如您选择SpanF1作为评估方法并且模型结构中配置CRF模式,在微调训练10个epoch之后进行推理,可得到如下结果: | |||
| 如您选择MF1(多标签的F1得分)作为评估方法,在微调训练10个epoch之后进行推理,可得到如下结果: | |||
| ```text | |||
| Precision 0.953826 | |||
| Recall 0.957749 | |||
| F1 0.955784 | |||
| F1 0.931243 | |||
| ``` | |||
| #### Ascend处理器上运行后评估squad v1.1数据集 | |||
| @@ -19,11 +19,12 @@ Bert finetune and evaluation script. | |||
| import os | |||
| import argparse | |||
| import time | |||
| from src.bert_for_finetune import BertFinetuneCell, BertNER | |||
| from src.finetune_eval_config import optimizer_cfg, bert_net_cfg | |||
| from src.dataset import create_ner_dataset | |||
| from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index | |||
| from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation, SpanF1 | |||
| from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import context | |||
| from mindspore import log as logger | |||
| @@ -79,17 +80,22 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin | |||
| netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||
| model = Model(netwithgrads) | |||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(dataset.get_dataset_size()), ckpoint_cb] | |||
| train_begin = time.time() | |||
| model.train(epoch_num, dataset, callbacks=callbacks) | |||
| train_end = time.time() | |||
| print("latency: {:.6f} s".format(train_end - train_begin)) | |||
| def eval_result_print(assessment_method="accuracy", callback=None): | |||
| """print eval result""" | |||
| if assessment_method == "accuracy": | |||
| print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, | |||
| callback.acc_num / callback.total_num)) | |||
| elif assessment_method in ("f1", "spanf1"): | |||
| elif assessment_method == "bf1": | |||
| print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) | |||
| print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) | |||
| print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) | |||
| elif assessment_method == "mf1": | |||
| print("F1 {:.6f} ".format(callback.eval()[0])) | |||
| elif assessment_method == "mcc": | |||
| print("MCC {:.6f} ".format(callback.cal())) | |||
| elif assessment_method == "spearman_correlation": | |||
| @@ -116,10 +122,10 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=41, assessment_met | |||
| else: | |||
| if assessment_method == "accuracy": | |||
| callback = Accuracy() | |||
| elif assessment_method == "f1": | |||
| elif assessment_method == "bf1": | |||
| callback = F1((use_crf.lower() == "true"), num_class) | |||
| elif assessment_method == "spanf1": | |||
| callback = SpanF1((use_crf.lower() == "true"), tag_to_index) | |||
| elif assessment_method == "mf1": | |||
| callback = F1((use_crf.lower() == "true"), num_labels=num_class, mode="MultiLabel") | |||
| elif assessment_method == "mcc": | |||
| callback = MCC() | |||
| elif assessment_method == "spearman_correlation": | |||
| @@ -145,8 +151,8 @@ def parse_args(): | |||
| parser = argparse.ArgumentParser(description="run ner") | |||
| parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"], | |||
| help="Device type, default is Ascend") | |||
| parser.add_argument("--assessment_method", type=str, default="F1", choices=["F1", "clue_benchmark", "SpanF1"], | |||
| help="assessment_method include: [F1, clue_benchmark, SpanF1], default is F1") | |||
| parser.add_argument("--assessment_method", type=str, default="BF1", choices=["BF1", "clue_benchmark", "MF1"], | |||
| help="assessment_method include: [BF1, clue_benchmark, MF1], default is BF1") | |||
| parser.add_argument("--do_train", type=str, default="false", choices=["true", "false"], | |||
| help="Eable train, default is false") | |||
| parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"], | |||
| @@ -231,6 +237,12 @@ def run_ner(): | |||
| assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, | |||
| schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format, | |||
| do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) | |||
| print("==============================================================") | |||
| print("processor_name: {}".format(args_opt.device_target)) | |||
| print("test_name: BERT Finetune Training") | |||
| print("model_name: {}".format("BERT+MLP+CRF" if args_opt.use_crf.lower() == "true" else "BERT + MLP")) | |||
| print("batch_size: {}".format(args_opt.train_batch_size)) | |||
| do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) | |||
| if args_opt.do_eval.lower() == "true": | |||
| @@ -245,7 +257,7 @@ def run_ner(): | |||
| ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1, | |||
| assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, | |||
| schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format, | |||
| do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) | |||
| do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"), drop_remainder=False) | |||
| do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, | |||
| args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path, | |||
| args_opt.label_file_path, tag_to_index, args_opt.eval_batch_size) | |||
| @@ -18,7 +18,7 @@ echo "========================================================================== | |||
| echo "Please run the script as: " | |||
| echo "bash scripts/run_ner.sh" | |||
| echo "for example: bash scripts/run_ner.sh" | |||
| echo "assessment_method include: [F1, SpanF1, clue_benchmark]" | |||
| echo "assessment_method include: [BF1, MF1, clue_benchmark]" | |||
| echo "==============================================================================================================" | |||
| mkdir -p ms_log | |||
| @@ -30,7 +30,7 @@ python ${PROJECT_DIR}/../run_ner.py \ | |||
| --device_target="Ascend" \ | |||
| --do_train="true" \ | |||
| --do_eval="false" \ | |||
| --assessment_method="F1" \ | |||
| --assessment_method="BF1" \ | |||
| --use_crf="false" \ | |||
| --device_id=0 \ | |||
| --epoch_num=5 \ | |||
| @@ -18,6 +18,7 @@ Bert evaluation assessment method script. | |||
| ''' | |||
| import math | |||
| import numpy as np | |||
| from mindspore.nn.metrics import ConfusionMatrixMetric | |||
| from .CRF import postprocess | |||
| class Accuracy(): | |||
| @@ -39,12 +40,18 @@ class F1(): | |||
| ''' | |||
| calculate F1 score | |||
| ''' | |||
| def __init__(self, use_crf=False, num_labels=2): | |||
| def __init__(self, use_crf=False, num_labels=2, mode="Binary"): | |||
| self.TP = 0 | |||
| self.FP = 0 | |||
| self.FN = 0 | |||
| self.use_crf = use_crf | |||
| self.num_labels = num_labels | |||
| self.mode = mode | |||
| if self.mode.lower() not in ("binary", "multilabel"): | |||
| raise ValueError("Assessment mode not supported, support: [Binary, MultiLabel]") | |||
| if self.mode.lower() != "binary": | |||
| self.metric = ConfusionMatrixMetric(skip_channel=False, metric_name=("f1 score"), | |||
| calculation_method=False, decrease="mean") | |||
| def update(self, logits, labels): | |||
| ''' | |||
| @@ -62,78 +69,24 @@ class F1(): | |||
| logits = logits.asnumpy() | |||
| logit_id = np.argmax(logits, axis=-1) | |||
| logit_id = np.reshape(logit_id, -1) | |||
| pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)]) | |||
| pos_label = np.isin(labels, [i for i in range(1, self.num_labels)]) | |||
| self.TP += np.sum(pos_eva&pos_label) | |||
| self.FP += np.sum(pos_eva&(~pos_label)) | |||
| self.FN += np.sum((~pos_eva)&pos_label) | |||
| class SpanF1(): | |||
| ''' | |||
| calculate F1、precision and recall score in span manner for NER | |||
| ''' | |||
| def __init__(self, use_crf=False, label2id=None): | |||
| self.TP = 0 | |||
| self.FP = 0 | |||
| self.FN = 0 | |||
| self.use_crf = use_crf | |||
| self.label2id = label2id | |||
| if label2id is None: | |||
| raise ValueError("label2id info should not be empty") | |||
| self.id2label = {} | |||
| for key, value in label2id.items(): | |||
| self.id2label[value] = key | |||
| def tag2span(self, ids): | |||
| ''' | |||
| conbert ids list to span mode | |||
| ''' | |||
| labels = np.array([self.id2label[id] for id in ids]) | |||
| spans = [] | |||
| prev_label = None | |||
| for idx, tag in enumerate(labels): | |||
| tag = tag.lower() | |||
| cur_label, label = tag[:1], tag[2:] | |||
| if cur_label in ('b', 's'): | |||
| spans.append((label, [idx, idx])) | |||
| elif cur_label in ('m', 'e') and prev_label in ('b', 'm') and label == spans[-1][0]: | |||
| spans[-1][1][1] = idx | |||
| elif cur_label == 'o': | |||
| pass | |||
| else: | |||
| spans.append((label, [idx, idx])) | |||
| prev_label = cur_label | |||
| return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans] | |||
| def update(self, logits, labels): | |||
| ''' | |||
| update span F1 score | |||
| ''' | |||
| labels = labels.asnumpy() | |||
| labels = np.reshape(labels, -1) | |||
| if self.use_crf: | |||
| backpointers, best_tag_id = logits | |||
| best_path = postprocess(backpointers, best_tag_id) | |||
| logit_id = [] | |||
| for ele in best_path: | |||
| logit_id.extend(ele) | |||
| if self.mode.lower() == "binary": | |||
| pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)]) | |||
| pos_label = np.isin(labels, [i for i in range(1, self.num_labels)]) | |||
| self.TP += np.sum(pos_eva&pos_label) | |||
| self.FP += np.sum(pos_eva&(~pos_label)) | |||
| self.FN += np.sum((~pos_eva)&pos_label) | |||
| else: | |||
| logits = logits.asnumpy() | |||
| logit_id = np.argmax(logits, axis=-1) | |||
| logit_id = np.reshape(logit_id, -1) | |||
| label_spans = self.tag2span(labels) | |||
| pred_spans = self.tag2span(logit_id) | |||
| for span in pred_spans: | |||
| if span in label_spans: | |||
| self.TP += 1 | |||
| label_spans.remove(span) | |||
| else: | |||
| self.FP += 1 | |||
| for span in label_spans: | |||
| self.FN += 1 | |||
| target = np.zeros((len(labels), self.num_labels), dtype=np.int) | |||
| pred = np.zeros((len(logit_id), self.num_labels), dtype=np.int) | |||
| for i, label in enumerate(labels): | |||
| target[i][label] = 1 | |||
| for i, label in enumerate(logit_id): | |||
| pred[i][label] = 1 | |||
| self.metric.update(pred, target) | |||
| def eval(self): | |||
| return self.metric.eval() | |||
| class MCC(): | |||
| @@ -53,7 +53,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, | |||
| def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", data_file_path=None, | |||
| dataset_format="mindrecord", schema_file_path=None, do_shuffle=True): | |||
| dataset_format="mindrecord", schema_file_path=None, do_shuffle=True, drop_remainder=True): | |||
| """create finetune or evaluation dataset""" | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| if dataset_format == "mindrecord": | |||
| @@ -74,7 +74,7 @@ def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy | |||
| dataset = dataset.map(operations=type_cast_op, input_columns="input_ids") | |||
| dataset = dataset.repeat(repeat_count) | |||
| # apply batch operations | |||
| dataset = dataset.batch(batch_size, drop_remainder=True) | |||
| dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) | |||
| return dataset | |||
| @@ -20,6 +20,7 @@ sample script of processing CLUE classification dataset using mindspore.dataset. | |||
| import os | |||
| import argparse | |||
| import numpy as np | |||
| from lxml import etree | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| @@ -139,46 +140,60 @@ def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage | |||
| return dataset | |||
| def process_cluener_msra(data_file): | |||
| """process MSRA dataset for CLUE""" | |||
| content = [] | |||
| labels = [] | |||
| for line in open(data_file): | |||
| line = line.strip() | |||
| if line: | |||
| word = line.split("\t")[0] | |||
| if len(line.split("\t")) == 1: | |||
| label = "O" | |||
| def process_cluener_msra(data_file, class_filter=None, split_begin=None, split_end=None): | |||
| """ | |||
| Data pre-process for MSRA dataset | |||
| Args: | |||
| data_file (path): The original dataset file path. | |||
| class_filter (list of str): Only tags within the class_filter will be counted unless the list is None. | |||
| split_begin (float): Only data after split_begin part will be counted. Used for split dataset | |||
| into training and evaluation subsets if needed. | |||
| split_end (float): Only data before split_end part will be counted. Used for split dataset | |||
| into training and evaluation subsets if needed. | |||
| """ | |||
| tree = etree.parse(data_file) | |||
| root = tree.getroot() | |||
| print("original dataset length: ", len(root)) | |||
| dataset_size = len(root) | |||
| beg = 0 if split_begin is None or not 0 <= split_begin <= 1.0 else int(dataset_size * split_begin) | |||
| end = dataset_size if split_end is None or not 0 <= split_end <= 1.0 else int(dataset_size * split_end) | |||
| print("preporcessed dataset_size: ", end - beg) | |||
| for i in range(beg, end): | |||
| sentence = root[i] | |||
| tags = [] | |||
| content = "" | |||
| for phrases in sentence: | |||
| labeled_words = [word for word in phrases] | |||
| if labeled_words: | |||
| for words in phrases: | |||
| name = words.tag | |||
| label = words.get("TYPE") | |||
| words = words.text | |||
| if not words: | |||
| continue | |||
| content += words | |||
| if class_filter and name not in class_filter: | |||
| tags += ["O" for _ in words] | |||
| else: | |||
| length = len(words) | |||
| labels = ["S_"] if length == 1 else ["B_"] + ["M_" for i in range(length - 2)] + ["E_"] | |||
| tags += [ele + label for ele in labels] | |||
| else: | |||
| label = line.split("\t")[1].split("\n")[0] | |||
| if label[0] != "O": | |||
| label = label[0] + "_" + label[2:] | |||
| if label[0] == "I": | |||
| label = "M" + label[1:] | |||
| content.append(word) | |||
| labels.append(label) | |||
| else: | |||
| for i in range(1, len(labels) - 1): | |||
| if labels[i][0] == "B" and labels[i+1][0] != "M": | |||
| labels[i] = "S" + labels[i][1:] | |||
| elif labels[i][0] == "M" and labels[i+1][0] != labels[i][0]: | |||
| labels[i] = "E" + labels[i][1:] | |||
| last = len(labels) - 1 | |||
| if labels[last][0] == "B": | |||
| labels[last] = "S" + labels[last][1:] | |||
| elif labels[last][0] == "M": | |||
| labels[last] = "E" + labels[last][1:] | |||
| phrases = phrases.text | |||
| if phrases: | |||
| content += phrases | |||
| tags += ["O" for ele in phrases] | |||
| if len(content) != len(tags): | |||
| raise ValueError("Mismathc length of content: ", len(content), " and label: ", len(tags)) | |||
| yield (np.array("".join(content)), np.array(list(tags))) | |||
| yield (np.array("".join(content)), np.array(list(labels))) | |||
| content.clear() | |||
| labels.clear() | |||
| continue | |||
| def process_msra_clue_dataset(data_dir, label_list, bert_vocab_path, max_seq_len=128): | |||
| def process_msra_clue_dataset(data_dir, label_list, bert_vocab_path, max_seq_len=128, class_filter=None, | |||
| split_begin=None, split_end=None): | |||
| """Process MSRA dataset""" | |||
| ### Loading MSRA from CLUEDataset | |||
| dataset = ds.GeneratorDataset(process_cluener_msra(data_dir), column_names=['text', 'label']) | |||
| dataset = ds.GeneratorDataset(process_cluener_msra(data_dir, class_filter, split_begin, split_end), | |||
| column_names=['text', 'label']) | |||
| ### Processing label | |||
| label_vocab = text.Vocab.from_list(label_list) | |||
| @@ -213,10 +228,16 @@ if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description="create mindrecord") | |||
| parser.add_argument("--data_dir", type=str, default="", help="dataset path") | |||
| parser.add_argument("--vocab_file", type=str, default="", help="Vocab file path") | |||
| parser.add_argument("--max_seq_len", type=int, default=128, help="Sequence length") | |||
| parser.add_argument("--save_path", type=str, default="./my.mindrecord", help="Path to save mindrecord") | |||
| parser.add_argument("--label2id", type=str, default="", | |||
| help="Label2id file path, must be set for cluener2020 task") | |||
| parser.add_argument("--max_seq_len", type=int, default=128, help="Sequence length") | |||
| parser.add_argument("--class_filter", nargs='*', help="Specified classes will be counted, if empty all in counted") | |||
| parser.add_argument("--split_begin", type=float, default=None, help="Specified subsets of date will be counted," | |||
| "if not None, the data will counted begin from split_begin") | |||
| parser.add_argument("--split_end", type=float, default=None, help="Specified subsets of date will be counted," | |||
| "if not None, the data will counted before split_before") | |||
| args_opt = parser.parse_args() | |||
| if args_opt.label2id == "": | |||
| raise ValueError("label2id should not be empty") | |||
| @@ -225,5 +246,6 @@ if __name__ == "__main__": | |||
| for tag in f: | |||
| labels_list.append(tag.strip()) | |||
| tag_to_index = list(convert_labels_to_index(labels_list).keys()) | |||
| ds = process_msra_clue_dataset(args_opt.data_dir, tag_to_index, args_opt.vocab_file, args_opt.max_seq_len) | |||
| ds = process_msra_clue_dataset(args_opt.data_dir, tag_to_index, args_opt.vocab_file, args_opt.max_seq_len, | |||
| args_opt.class_filter, args_opt.split_begin, args_opt.split_end) | |||
| ds.save(args_opt.save_path) | |||
| @@ -106,10 +106,11 @@ class LossCallBack(Callback): | |||
| percent = 1 | |||
| epoch_num -= 1 | |||
| print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}" | |||
| .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs))) | |||
| .format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs), | |||
| flush=True)) | |||
| else: | |||
| print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, | |||
| str(cb_params.net_outputs))) | |||
| str(cb_params.net_outputs), flush=True)) | |||
| def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix): | |||
| """ | |||