From 8b1a8a6bc1a07b53210ba72bc0928f5d2c624a41 Mon Sep 17 00:00:00 2001 From: shibeiji Date: Wed, 27 Jan 2021 20:12:07 +0800 Subject: [PATCH] bert ner for msra dataset --- model_zoo/official/nlp/bert/README.md | 28 +++++- model_zoo/official/nlp/bert/README_CN.md | 27 +++++- model_zoo/official/nlp/bert/run_ner.py | 16 ++-- .../official/nlp/bert/scripts/run_ner.sh | 3 +- .../nlp/bert/src/assessment_method.py | 68 ++++++++++++++ model_zoo/official/nlp/bert/src/dataset.py | 31 ++++--- ...process.py => finetune_data_preprocess.py} | 92 +++++++++++++++++++ 7 files changed, 243 insertions(+), 22 deletions(-) rename model_zoo/official/nlp/bert/src/{clue_classification_dataset_process.py => finetune_data_preprocess.py} (62%) diff --git a/model_zoo/official/nlp/bert/README.md b/model_zoo/official/nlp/bert/README.md index e91dbbda2d..9e202cc185 100644 --- a/model_zoo/official/nlp/bert/README.md +++ b/model_zoo/official/nlp/bert/README.md @@ -27,6 +27,7 @@ - [Evaluation](#evaluation) - [evaluation on cola dataset when running on Ascend](#evaluation-on-cola-dataset-when-running-on-ascend) - [evaluation on cluener dataset when running on Ascend](#evaluation-on-cluener-dataset-when-running-on-ascend) + - [evaluation on msra dataset when running on Ascend](#evaluation-on-msra-dataset-when-running-on-ascend) - [evaluation on squad v1.1 dataset when running on Ascend](#evaluation-on-squad-v11-dataset-when-running-on-ascend) - [Model Description](#model-description) - [Performance](#performance) @@ -215,7 +216,7 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol ├─bert_for_finetune.py # backbone code of network ├─bert_for_pre_training.py # backbone code of network ├─bert_model.py # backbone code of network - ├─clue_classification_dataset_precess.py # data preprocessing + ├─finetune_data_preprocess.py # data preprocessing ├─cluner_evaluation.py # evaluation for cluner ├─config.py # parameter configuration for pretraining ├─CRF.py # assessment method for clue dataset @@ -301,6 +302,7 @@ options: --load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval --train_data_file_path ner tfrecord for training. E.g., train.tfrecord --eval_data_file_path ner tfrecord for predictions if f1 is used to evaluate result, ner json for predictions if clue_benchmark is used to evaluate result + --dataset_format dataset format, support mindrecord or tfrecord --schema_file_path path to datafile schema file usage: run_squad.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL] @@ -544,6 +546,30 @@ Recall 0.948683 F1 0.920507 ``` +#### evaluation on msra dataset when running on Ascend + +For preprocess, you can first convert the original txt format of MSRA dataset into mindrecord by run the command as below: + +```python +python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len +``` + +For finetune and evaluation, just do + +```bash +bash scripts/ner.sh +``` + +The command above will run in the background, you can view training logs in ner_log.txt. + +If you choose SpanF1 as assessment method and mode use_crf is set to be "true", the result will be as follows if evaluation is done after finetuning 10 epoches: + +```text +Precision 0.953826 +Recall 0.957749 +F1 0.955784 +``` + #### evaluation on squad v1.1 dataset when running on Ascend ```bash diff --git a/model_zoo/official/nlp/bert/README_CN.md b/model_zoo/official/nlp/bert/README_CN.md index 8ae3d576bb..9972bf6184 100644 --- a/model_zoo/official/nlp/bert/README_CN.md +++ b/model_zoo/official/nlp/bert/README_CN.md @@ -28,6 +28,7 @@ - [用法](#用法-1) - [Ascend处理器上运行后评估cola数据集](#ascend处理器上运行后评估cola数据集) - [Ascend处理器上运行后评估cluener数据集](#ascend处理器上运行后评估cluener数据集) + - [Ascend处理器上运行后评估msra数据集](#ascend处理器上运行后评估msra数据集) - [Ascend处理器上运行后评估squad v1.1数据集](#ascend处理器上运行后评估squad-v11数据集) - [模型描述](#模型描述) - [性能](#性能) @@ -215,7 +216,7 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol ├─bert_for_finetune.py # 网络骨干编码 ├─bert_for_pre_training.py # 网络骨干编码 ├─bert_model.py # 网络骨干编码 - ├─clue_classification_dataset_precess.py # 数据预处理 + ├─finetune_data_preprocess.py # 数据预处理 ├─cluner_evaluation.py # 评估线索生成工具 ├─config.py # 预训练参数配置 ├─CRF.py # 线索数据集评估方法 @@ -299,6 +300,7 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol --load_finetune_checkpoint_path 如仅执行评估,提供微调检查点保存路径 --train_data_file_path 用于保存训练数据的TFRecord文件,如train.tfrecord文件 --eval_data_file_path 如采用f1来评估结果,则为TFRecord文件保存预测;如采用clue_benchmark来评估结果,则为JSON文件保存预测 + --dataset_format 数据集格式,支持tfrecord和mindrecord格式 --schema_file_path 模式文件保存路径 用法:run_squad.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL] @@ -508,6 +510,29 @@ Recall 0.948683 F1 0.920507 ``` +#### Ascend处理器上运行后评估msra数据集 + +您可以采用如下方式,先将MSRA数据集的原始格式在预处理流程中转换为mindrecord格式以提升性能: + +```python +python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len +``` + +此后,您可以进行微调再训练和推理流程, + +```bash +bash scripts/ner.sh +``` + +以上命令后台运行,您可以在ner_log.txt中查看训练日志。 +如您选择SpanF1作为评估方法并且模型结构中配置CRF模式,在微调训练10个epoch之后进行推理,可得到如下结果: + +```text +Precision 0.953826 +Recall 0.957749 +F1 0.955784 +``` + #### Ascend处理器上运行后评估squad v1.1数据集 ```bash diff --git a/model_zoo/official/nlp/bert/run_ner.py b/model_zoo/official/nlp/bert/run_ner.py index d7742d1c19..ce1c3188ba 100644 --- a/model_zoo/official/nlp/bert/run_ner.py +++ b/model_zoo/official/nlp/bert/run_ner.py @@ -23,7 +23,7 @@ from src.bert_for_finetune import BertFinetuneCell, BertNER from src.finetune_eval_config import optimizer_cfg, bert_net_cfg from src.dataset import create_ner_dataset from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index -from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation +from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation, SpanF1 import mindspore.common.dtype as mstype from mindspore import context from mindspore import log as logger @@ -86,7 +86,7 @@ def eval_result_print(assessment_method="accuracy", callback=None): if assessment_method == "accuracy": print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, callback.acc_num / callback.total_num)) - elif assessment_method == "f1": + elif assessment_method in ("f1", "spanf1"): print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) @@ -118,6 +118,8 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=41, assessment_met callback = Accuracy() elif assessment_method == "f1": callback = F1((use_crf.lower() == "true"), num_class) + elif assessment_method == "spanf1": + callback = SpanF1((use_crf.lower() == "true"), tag_to_index) elif assessment_method == "mcc": callback = MCC() elif assessment_method == "spearman_correlation": @@ -143,8 +145,8 @@ def parse_args(): parser = argparse.ArgumentParser(description="run ner") parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"], help="Device type, default is Ascend") - parser.add_argument("--assessment_method", type=str, default="F1", choices=["F1", "clue_benchmark"], - help="assessment_method include: [F1, clue_benchmark], default is F1") + parser.add_argument("--assessment_method", type=str, default="F1", choices=["F1", "clue_benchmark", "SpanF1"], + help="assessment_method include: [F1, clue_benchmark, SpanF1], default is F1") parser.add_argument("--do_train", type=str, default="false", choices=["true", "false"], help="Eable train, default is false") parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"], @@ -169,6 +171,8 @@ def parse_args(): help="Data path, it is better to use absolute path") parser.add_argument("--eval_data_file_path", type=str, default="", help="Data path, it is better to use absolute path") + parser.add_argument("--dataset_format", type=str, default="mindrecord", choices=["mindrecord", "tfrecord"], + help="Dataset format, support mindrecord or tfrecord") parser.add_argument("--schema_file_path", type=str, default="", help="Schema path, it is better to use absolute path") args_opt = parser.parse_args() @@ -225,7 +229,7 @@ def run_ner(): tag_to_index=tag_to_index, dropout_prob=0.1) ds = create_ner_dataset(batch_size=args_opt.train_batch_size, repeat_count=1, assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, - schema_file_path=args_opt.schema_file_path, + schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format, do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) @@ -240,7 +244,7 @@ def run_ner(): if args_opt.do_eval.lower() == "true": ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1, assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, - schema_file_path=args_opt.schema_file_path, + schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format, do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path, diff --git a/model_zoo/official/nlp/bert/scripts/run_ner.sh b/model_zoo/official/nlp/bert/scripts/run_ner.sh index 9bb8b78370..542b5b64b7 100644 --- a/model_zoo/official/nlp/bert/scripts/run_ner.sh +++ b/model_zoo/official/nlp/bert/scripts/run_ner.sh @@ -18,7 +18,7 @@ echo "========================================================================== echo "Please run the script as: " echo "bash scripts/run_ner.sh" echo "for example: bash scripts/run_ner.sh" -echo "assessment_method include: [F1, clue_benchmark]" +echo "assessment_method include: [F1, SpanF1, clue_benchmark]" echo "==============================================================================================================" mkdir -p ms_log @@ -46,4 +46,5 @@ python ${PROJECT_DIR}/../run_ner.py \ --load_finetune_checkpoint_path="" \ --train_data_file_path="" \ --eval_data_file_path="" \ + --dataset_format="tfrecord" \ --schema_file_path="" > ner_log.txt 2>&1 & diff --git a/model_zoo/official/nlp/bert/src/assessment_method.py b/model_zoo/official/nlp/bert/src/assessment_method.py index dae4894129..5d5d787374 100644 --- a/model_zoo/official/nlp/bert/src/assessment_method.py +++ b/model_zoo/official/nlp/bert/src/assessment_method.py @@ -68,6 +68,74 @@ class F1(): self.FP += np.sum(pos_eva&(~pos_label)) self.FN += np.sum((~pos_eva)&pos_label) + +class SpanF1(): + ''' + calculate F1、precision and recall score in span manner for NER + ''' + def __init__(self, use_crf=False, label2id=None): + self.TP = 0 + self.FP = 0 + self.FN = 0 + self.use_crf = use_crf + self.label2id = label2id + if label2id is None: + raise ValueError("label2id info should not be empty") + self.id2label = {} + for key, value in label2id.items(): + self.id2label[value] = key + + def tag2span(self, ids): + ''' + conbert ids list to span mode + ''' + labels = np.array([self.id2label[id] for id in ids]) + spans = [] + prev_label = None + for idx, tag in enumerate(labels): + tag = tag.lower() + cur_label, label = tag[:1], tag[2:] + if cur_label in ('b', 's'): + spans.append((label, [idx, idx])) + elif cur_label in ('m', 'e') and prev_label in ('b', 'm') and label == spans[-1][0]: + spans[-1][1][1] = idx + elif cur_label == 'o': + pass + else: + spans.append((label, [idx, idx])) + prev_label = cur_label + return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans] + + + def update(self, logits, labels): + ''' + update span F1 score + ''' + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + if self.use_crf: + backpointers, best_tag_id = logits + best_path = postprocess(backpointers, best_tag_id) + logit_id = [] + for ele in best_path: + logit_id.extend(ele) + else: + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + logit_id = np.reshape(logit_id, -1) + + label_spans = self.tag2span(labels) + pred_spans = self.tag2span(logit_id) + for span in pred_spans: + if span in label_spans: + self.TP += 1 + label_spans.remove(span) + else: + self.FP += 1 + for span in label_spans: + self.FN += 1 + + class MCC(): ''' Calculate Matthews Correlation Coefficient diff --git a/model_zoo/official/nlp/bert/src/dataset.py b/model_zoo/official/nlp/bert/src/dataset.py index 06006e39bc..5710713f49 100644 --- a/model_zoo/official/nlp/bert/src/dataset.py +++ b/model_zoo/official/nlp/bert/src/dataset.py @@ -52,25 +52,30 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, return data_set -def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", - data_file_path=None, schema_file_path=None, do_shuffle=True): +def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", data_file_path=None, + dataset_format="mindrecord", schema_file_path=None, do_shuffle=True): """create finetune or evaluation dataset""" type_cast_op = C.TypeCast(mstype.int32) - data_set = ds.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, - columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"], - shuffle=do_shuffle) + if dataset_format == "mindrecord": + dataset = ds.MindDataset([data_file_path], + columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"], + shuffle=do_shuffle) + else: + dataset = ds.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"], + shuffle=do_shuffle) if assessment_method == "Spearman_correlation": type_cast_op_float = C.TypeCast(mstype.float32) - data_set = data_set.map(operations=type_cast_op_float, input_columns="label_ids") + dataset = dataset.map(operations=type_cast_op_float, input_columns="label_ids") else: - data_set = data_set.map(operations=type_cast_op, input_columns="label_ids") - data_set = data_set.map(operations=type_cast_op, input_columns="segment_ids") - data_set = data_set.map(operations=type_cast_op, input_columns="input_mask") - data_set = data_set.map(operations=type_cast_op, input_columns="input_ids") - data_set = data_set.repeat(repeat_count) + dataset = dataset.map(operations=type_cast_op, input_columns="label_ids") + dataset = dataset.map(operations=type_cast_op, input_columns="segment_ids") + dataset = dataset.map(operations=type_cast_op, input_columns="input_mask") + dataset = dataset.map(operations=type_cast_op, input_columns="input_ids") + dataset = dataset.repeat(repeat_count) # apply batch operations - data_set = data_set.batch(batch_size, drop_remainder=True) - return data_set + dataset = dataset.batch(batch_size, drop_remainder=True) + return dataset def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", diff --git a/model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py b/model_zoo/official/nlp/bert/src/finetune_data_preprocess.py similarity index 62% rename from model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py rename to model_zoo/official/nlp/bert/src/finetune_data_preprocess.py index 8d92447eab..2a5c9d626e 100755 --- a/model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py +++ b/model_zoo/official/nlp/bert/src/finetune_data_preprocess.py @@ -18,12 +18,14 @@ sample script of processing CLUE classification dataset using mindspore.dataset. """ import os +import argparse import numpy as np import mindspore.common.dtype as mstype import mindspore.dataset as ds import mindspore.dataset.text as text import mindspore.dataset.transforms.c_transforms as ops +from utils import convert_labels_to_index def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False, @@ -135,3 +137,93 @@ def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage dataset = dataset.map(operations=ops.Mask(ops.Relational.NE, 0, mstype.int32), input_columns=["mask_ids"]) dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) return dataset + + +def process_cluener_msra(data_file): + """process MSRA dataset for CLUE""" + content = [] + labels = [] + for line in open(data_file): + line = line.strip() + if line: + word = line.split("\t")[0] + if len(line.split("\t")) == 1: + label = "O" + else: + label = line.split("\t")[1].split("\n")[0] + if label[0] != "O": + label = label[0] + "_" + label[2:] + if label[0] == "I": + label = "M" + label[1:] + content.append(word) + labels.append(label) + else: + for i in range(1, len(labels) - 1): + if labels[i][0] == "B" and labels[i+1][0] != "M": + labels[i] = "S" + labels[i][1:] + elif labels[i][0] == "M" and labels[i+1][0] != labels[i][0]: + labels[i] = "E" + labels[i][1:] + last = len(labels) - 1 + if labels[last][0] == "B": + labels[last] = "S" + labels[last][1:] + elif labels[last][0] == "M": + labels[last] = "E" + labels[last][1:] + + yield (np.array("".join(content)), np.array(list(labels))) + content.clear() + labels.clear() + continue + + +def process_msra_clue_dataset(data_dir, label_list, bert_vocab_path, max_seq_len=128): + """Process MSRA dataset""" + ### Loading MSRA from CLUEDataset + dataset = ds.GeneratorDataset(process_cluener_msra(data_dir), column_names=['text', 'label']) + + ### Processing label + label_vocab = text.Vocab.from_list(label_list) + label_lookup = text.Lookup(label_vocab) + dataset = dataset.map(operations=label_lookup, input_columns="label", output_columns="label_ids") + dataset = dataset.map(operations=ops.Concatenate(prepend=np.array([0], dtype='i')), + input_columns=["label_ids"]) + dataset = dataset.map(operations=ops.Slice(slice(0, max_seq_len)), input_columns=["label_ids"]) + dataset = dataset.map(operations=ops.PadEnd([max_seq_len], 0), input_columns=["label_ids"]) + ### Processing sentence + vocab = text.Vocab.from_file(bert_vocab_path) + lookup = text.Lookup(vocab, unknown_token='[UNK]') + unicode_char_tokenizer = text.UnicodeCharTokenizer() + dataset = dataset.map(operations=unicode_char_tokenizer, input_columns=["text"], output_columns=["sentence"]) + dataset = dataset.map(operations=ops.Slice(slice(0, max_seq_len-2)), input_columns=["sentence"]) + dataset = dataset.map(operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'), + append=np.array(["[SEP]"], dtype='S')), input_columns=["sentence"]) + dataset = dataset.map(operations=lookup, input_columns=["sentence"], output_columns=["input_ids"]) + dataset = dataset.map(operations=ops.PadEnd([max_seq_len], 0), input_columns=["input_ids"]) + dataset = dataset.map(operations=ops.Duplicate(), input_columns=["input_ids"], + output_columns=["input_ids", "input_mask"], + column_order=["input_ids", "input_mask", "label_ids"]) + dataset = dataset.map(operations=ops.Mask(ops.Relational.NE, 0, mstype.int32), input_columns=["input_mask"]) + dataset = dataset.map(operations=ops.Duplicate(), input_columns=["input_ids"], + output_columns=["input_ids", "segment_ids"], + column_order=["input_ids", "input_mask", "segment_ids", "label_ids"]) + dataset = dataset.map(operations=ops.Fill(0), input_columns=["segment_ids"]) + return dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="create mindrecord") + parser.add_argument("--data_dir", type=str, default="", help="dataset path") + parser.add_argument("--vocab_file", type=str, default="", help="Vocab file path") + parser.add_argument("--max_seq_len", type=int, default=128, help="Sequence length") + parser.add_argument("--save_path", type=str, default="./my.mindrecord", help="Path to save mindrecord") + parser.add_argument("--label2id", type=str, default="", + help="Label2id file path, must be set for cluener2020 task") + args_opt = parser.parse_args() + if args_opt.label2id == "": + raise ValueError("label2id should not be empty") + labels_list = [] + with open(args_opt.label2id) as f: + for tag in f: + labels_list.append(tag.strip()) + tag_to_index = list(convert_labels_to_index(labels_list).keys()) + ds = process_msra_clue_dataset(args_opt.data_dir, tag_to_index, args_opt.vocab_file, args_opt.max_seq_len) + ds.save(args_opt.save_path)