From: @shibeiji Reviewed-by: @guoqi1024,@c_34,@linqingke Signed-off-by:tags/v1.2.0-rc1
| @@ -27,6 +27,7 @@ | |||||
| - [Evaluation](#evaluation) | - [Evaluation](#evaluation) | ||||
| - [evaluation on cola dataset when running on Ascend](#evaluation-on-cola-dataset-when-running-on-ascend) | - [evaluation on cola dataset when running on Ascend](#evaluation-on-cola-dataset-when-running-on-ascend) | ||||
| - [evaluation on cluener dataset when running on Ascend](#evaluation-on-cluener-dataset-when-running-on-ascend) | - [evaluation on cluener dataset when running on Ascend](#evaluation-on-cluener-dataset-when-running-on-ascend) | ||||
| - [evaluation on msra dataset when running on Ascend](#evaluation-on-msra-dataset-when-running-on-ascend) | |||||
| - [evaluation on squad v1.1 dataset when running on Ascend](#evaluation-on-squad-v11-dataset-when-running-on-ascend) | - [evaluation on squad v1.1 dataset when running on Ascend](#evaluation-on-squad-v11-dataset-when-running-on-ascend) | ||||
| - [Model Description](#model-description) | - [Model Description](#model-description) | ||||
| - [Performance](#performance) | - [Performance](#performance) | ||||
| @@ -215,7 +216,7 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol | |||||
| ├─bert_for_finetune.py # backbone code of network | ├─bert_for_finetune.py # backbone code of network | ||||
| ├─bert_for_pre_training.py # backbone code of network | ├─bert_for_pre_training.py # backbone code of network | ||||
| ├─bert_model.py # backbone code of network | ├─bert_model.py # backbone code of network | ||||
| ├─clue_classification_dataset_precess.py # data preprocessing | |||||
| ├─finetune_data_preprocess.py # data preprocessing | |||||
| ├─cluner_evaluation.py # evaluation for cluner | ├─cluner_evaluation.py # evaluation for cluner | ||||
| ├─config.py # parameter configuration for pretraining | ├─config.py # parameter configuration for pretraining | ||||
| ├─CRF.py # assessment method for clue dataset | ├─CRF.py # assessment method for clue dataset | ||||
| @@ -301,6 +302,7 @@ options: | |||||
| --load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval | --load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval | ||||
| --train_data_file_path ner tfrecord for training. E.g., train.tfrecord | --train_data_file_path ner tfrecord for training. E.g., train.tfrecord | ||||
| --eval_data_file_path ner tfrecord for predictions if f1 is used to evaluate result, ner json for predictions if clue_benchmark is used to evaluate result | --eval_data_file_path ner tfrecord for predictions if f1 is used to evaluate result, ner json for predictions if clue_benchmark is used to evaluate result | ||||
| --dataset_format dataset format, support mindrecord or tfrecord | |||||
| --schema_file_path path to datafile schema file | --schema_file_path path to datafile schema file | ||||
| usage: run_squad.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL] | usage: run_squad.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL] | ||||
| @@ -544,6 +546,30 @@ Recall 0.948683 | |||||
| F1 0.920507 | F1 0.920507 | ||||
| ``` | ``` | ||||
| #### evaluation on msra dataset when running on Ascend | |||||
| For preprocess, you can first convert the original txt format of MSRA dataset into mindrecord by run the command as below: | |||||
| ```python | |||||
| python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len | |||||
| ``` | |||||
| For finetune and evaluation, just do | |||||
| ```bash | |||||
| bash scripts/ner.sh | |||||
| ``` | |||||
| The command above will run in the background, you can view training logs in ner_log.txt. | |||||
| If you choose SpanF1 as assessment method and mode use_crf is set to be "true", the result will be as follows if evaluation is done after finetuning 10 epoches: | |||||
| ```text | |||||
| Precision 0.953826 | |||||
| Recall 0.957749 | |||||
| F1 0.955784 | |||||
| ``` | |||||
| #### evaluation on squad v1.1 dataset when running on Ascend | #### evaluation on squad v1.1 dataset when running on Ascend | ||||
| ```bash | ```bash | ||||
| @@ -28,6 +28,7 @@ | |||||
| - [用法](#用法-1) | - [用法](#用法-1) | ||||
| - [Ascend处理器上运行后评估cola数据集](#ascend处理器上运行后评估cola数据集) | - [Ascend处理器上运行后评估cola数据集](#ascend处理器上运行后评估cola数据集) | ||||
| - [Ascend处理器上运行后评估cluener数据集](#ascend处理器上运行后评估cluener数据集) | - [Ascend处理器上运行后评估cluener数据集](#ascend处理器上运行后评估cluener数据集) | ||||
| - [Ascend处理器上运行后评估msra数据集](#ascend处理器上运行后评估msra数据集) | |||||
| - [Ascend处理器上运行后评估squad v1.1数据集](#ascend处理器上运行后评估squad-v11数据集) | - [Ascend处理器上运行后评估squad v1.1数据集](#ascend处理器上运行后评估squad-v11数据集) | ||||
| - [模型描述](#模型描述) | - [模型描述](#模型描述) | ||||
| - [性能](#性能) | - [性能](#性能) | ||||
| @@ -215,7 +216,7 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol | |||||
| ├─bert_for_finetune.py # 网络骨干编码 | ├─bert_for_finetune.py # 网络骨干编码 | ||||
| ├─bert_for_pre_training.py # 网络骨干编码 | ├─bert_for_pre_training.py # 网络骨干编码 | ||||
| ├─bert_model.py # 网络骨干编码 | ├─bert_model.py # 网络骨干编码 | ||||
| ├─clue_classification_dataset_precess.py # 数据预处理 | |||||
| ├─finetune_data_preprocess.py # 数据预处理 | |||||
| ├─cluner_evaluation.py # 评估线索生成工具 | ├─cluner_evaluation.py # 评估线索生成工具 | ||||
| ├─config.py # 预训练参数配置 | ├─config.py # 预训练参数配置 | ||||
| ├─CRF.py # 线索数据集评估方法 | ├─CRF.py # 线索数据集评估方法 | ||||
| @@ -299,6 +300,7 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol | |||||
| --load_finetune_checkpoint_path 如仅执行评估,提供微调检查点保存路径 | --load_finetune_checkpoint_path 如仅执行评估,提供微调检查点保存路径 | ||||
| --train_data_file_path 用于保存训练数据的TFRecord文件,如train.tfrecord文件 | --train_data_file_path 用于保存训练数据的TFRecord文件,如train.tfrecord文件 | ||||
| --eval_data_file_path 如采用f1来评估结果,则为TFRecord文件保存预测;如采用clue_benchmark来评估结果,则为JSON文件保存预测 | --eval_data_file_path 如采用f1来评估结果,则为TFRecord文件保存预测;如采用clue_benchmark来评估结果,则为JSON文件保存预测 | ||||
| --dataset_format 数据集格式,支持tfrecord和mindrecord格式 | |||||
| --schema_file_path 模式文件保存路径 | --schema_file_path 模式文件保存路径 | ||||
| 用法:run_squad.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL] | 用法:run_squad.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN] [----do_eval DO_EVAL] | ||||
| @@ -508,6 +510,29 @@ Recall 0.948683 | |||||
| F1 0.920507 | F1 0.920507 | ||||
| ``` | ``` | ||||
| #### Ascend处理器上运行后评估msra数据集 | |||||
| 您可以采用如下方式,先将MSRA数据集的原始格式在预处理流程中转换为mindrecord格式以提升性能: | |||||
| ```python | |||||
| python src/finetune_data_preprocess.py ----data_dir=/path/msra_dataset.txt --vocab_file=/path/vacab_file --save_path=/path/msra_dataset.mindrecord --label2id=/path/label2id_file --max_seq_len=seq_len | |||||
| ``` | |||||
| 此后,您可以进行微调再训练和推理流程, | |||||
| ```bash | |||||
| bash scripts/ner.sh | |||||
| ``` | |||||
| 以上命令后台运行,您可以在ner_log.txt中查看训练日志。 | |||||
| 如您选择SpanF1作为评估方法并且模型结构中配置CRF模式,在微调训练10个epoch之后进行推理,可得到如下结果: | |||||
| ```text | |||||
| Precision 0.953826 | |||||
| Recall 0.957749 | |||||
| F1 0.955784 | |||||
| ``` | |||||
| #### Ascend处理器上运行后评估squad v1.1数据集 | #### Ascend处理器上运行后评估squad v1.1数据集 | ||||
| ```bash | ```bash | ||||
| @@ -23,7 +23,7 @@ from src.bert_for_finetune import BertFinetuneCell, BertNER | |||||
| from src.finetune_eval_config import optimizer_cfg, bert_net_cfg | from src.finetune_eval_config import optimizer_cfg, bert_net_cfg | ||||
| from src.dataset import create_ner_dataset | from src.dataset import create_ner_dataset | ||||
| from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index | from src.utils import make_directory, LossCallBack, LoadNewestCkpt, BertLearningRate, convert_labels_to_index | ||||
| from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation | |||||
| from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation, SpanF1 | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -86,7 +86,7 @@ def eval_result_print(assessment_method="accuracy", callback=None): | |||||
| if assessment_method == "accuracy": | if assessment_method == "accuracy": | ||||
| print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, | print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, | ||||
| callback.acc_num / callback.total_num)) | callback.acc_num / callback.total_num)) | ||||
| elif assessment_method == "f1": | |||||
| elif assessment_method in ("f1", "spanf1"): | |||||
| print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) | print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) | ||||
| print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) | print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) | ||||
| print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) | print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) | ||||
| @@ -118,6 +118,8 @@ def do_eval(dataset=None, network=None, use_crf="", num_class=41, assessment_met | |||||
| callback = Accuracy() | callback = Accuracy() | ||||
| elif assessment_method == "f1": | elif assessment_method == "f1": | ||||
| callback = F1((use_crf.lower() == "true"), num_class) | callback = F1((use_crf.lower() == "true"), num_class) | ||||
| elif assessment_method == "spanf1": | |||||
| callback = SpanF1((use_crf.lower() == "true"), tag_to_index) | |||||
| elif assessment_method == "mcc": | elif assessment_method == "mcc": | ||||
| callback = MCC() | callback = MCC() | ||||
| elif assessment_method == "spearman_correlation": | elif assessment_method == "spearman_correlation": | ||||
| @@ -143,8 +145,8 @@ def parse_args(): | |||||
| parser = argparse.ArgumentParser(description="run ner") | parser = argparse.ArgumentParser(description="run ner") | ||||
| parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"], | parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"], | ||||
| help="Device type, default is Ascend") | help="Device type, default is Ascend") | ||||
| parser.add_argument("--assessment_method", type=str, default="F1", choices=["F1", "clue_benchmark"], | |||||
| help="assessment_method include: [F1, clue_benchmark], default is F1") | |||||
| parser.add_argument("--assessment_method", type=str, default="F1", choices=["F1", "clue_benchmark", "SpanF1"], | |||||
| help="assessment_method include: [F1, clue_benchmark, SpanF1], default is F1") | |||||
| parser.add_argument("--do_train", type=str, default="false", choices=["true", "false"], | parser.add_argument("--do_train", type=str, default="false", choices=["true", "false"], | ||||
| help="Eable train, default is false") | help="Eable train, default is false") | ||||
| parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"], | parser.add_argument("--do_eval", type=str, default="false", choices=["true", "false"], | ||||
| @@ -169,6 +171,8 @@ def parse_args(): | |||||
| help="Data path, it is better to use absolute path") | help="Data path, it is better to use absolute path") | ||||
| parser.add_argument("--eval_data_file_path", type=str, default="", | parser.add_argument("--eval_data_file_path", type=str, default="", | ||||
| help="Data path, it is better to use absolute path") | help="Data path, it is better to use absolute path") | ||||
| parser.add_argument("--dataset_format", type=str, default="mindrecord", choices=["mindrecord", "tfrecord"], | |||||
| help="Dataset format, support mindrecord or tfrecord") | |||||
| parser.add_argument("--schema_file_path", type=str, default="", | parser.add_argument("--schema_file_path", type=str, default="", | ||||
| help="Schema path, it is better to use absolute path") | help="Schema path, it is better to use absolute path") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| @@ -225,7 +229,7 @@ def run_ner(): | |||||
| tag_to_index=tag_to_index, dropout_prob=0.1) | tag_to_index=tag_to_index, dropout_prob=0.1) | ||||
| ds = create_ner_dataset(batch_size=args_opt.train_batch_size, repeat_count=1, | ds = create_ner_dataset(batch_size=args_opt.train_batch_size, repeat_count=1, | ||||
| assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, | assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, | ||||
| schema_file_path=args_opt.schema_file_path, | |||||
| schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format, | |||||
| do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) | do_shuffle=(args_opt.train_data_shuffle.lower() == "true")) | ||||
| do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) | do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num) | ||||
| @@ -240,7 +244,7 @@ def run_ner(): | |||||
| if args_opt.do_eval.lower() == "true": | if args_opt.do_eval.lower() == "true": | ||||
| ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1, | ds = create_ner_dataset(batch_size=args_opt.eval_batch_size, repeat_count=1, | ||||
| assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, | assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, | ||||
| schema_file_path=args_opt.schema_file_path, | |||||
| schema_file_path=args_opt.schema_file_path, dataset_format=args_opt.dataset_format, | |||||
| do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) | do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) | ||||
| do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, | do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, | ||||
| args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path, | args_opt.eval_data_file_path, load_finetune_checkpoint_path, args_opt.vocab_file_path, | ||||
| @@ -18,7 +18,7 @@ echo "========================================================================== | |||||
| echo "Please run the script as: " | echo "Please run the script as: " | ||||
| echo "bash scripts/run_ner.sh" | echo "bash scripts/run_ner.sh" | ||||
| echo "for example: bash scripts/run_ner.sh" | echo "for example: bash scripts/run_ner.sh" | ||||
| echo "assessment_method include: [F1, clue_benchmark]" | |||||
| echo "assessment_method include: [F1, SpanF1, clue_benchmark]" | |||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| mkdir -p ms_log | mkdir -p ms_log | ||||
| @@ -46,4 +46,5 @@ python ${PROJECT_DIR}/../run_ner.py \ | |||||
| --load_finetune_checkpoint_path="" \ | --load_finetune_checkpoint_path="" \ | ||||
| --train_data_file_path="" \ | --train_data_file_path="" \ | ||||
| --eval_data_file_path="" \ | --eval_data_file_path="" \ | ||||
| --dataset_format="tfrecord" \ | |||||
| --schema_file_path="" > ner_log.txt 2>&1 & | --schema_file_path="" > ner_log.txt 2>&1 & | ||||
| @@ -68,6 +68,74 @@ class F1(): | |||||
| self.FP += np.sum(pos_eva&(~pos_label)) | self.FP += np.sum(pos_eva&(~pos_label)) | ||||
| self.FN += np.sum((~pos_eva)&pos_label) | self.FN += np.sum((~pos_eva)&pos_label) | ||||
| class SpanF1(): | |||||
| ''' | |||||
| calculate F1、precision and recall score in span manner for NER | |||||
| ''' | |||||
| def __init__(self, use_crf=False, label2id=None): | |||||
| self.TP = 0 | |||||
| self.FP = 0 | |||||
| self.FN = 0 | |||||
| self.use_crf = use_crf | |||||
| self.label2id = label2id | |||||
| if label2id is None: | |||||
| raise ValueError("label2id info should not be empty") | |||||
| self.id2label = {} | |||||
| for key, value in label2id.items(): | |||||
| self.id2label[value] = key | |||||
| def tag2span(self, ids): | |||||
| ''' | |||||
| conbert ids list to span mode | |||||
| ''' | |||||
| labels = np.array([self.id2label[id] for id in ids]) | |||||
| spans = [] | |||||
| prev_label = None | |||||
| for idx, tag in enumerate(labels): | |||||
| tag = tag.lower() | |||||
| cur_label, label = tag[:1], tag[2:] | |||||
| if cur_label in ('b', 's'): | |||||
| spans.append((label, [idx, idx])) | |||||
| elif cur_label in ('m', 'e') and prev_label in ('b', 'm') and label == spans[-1][0]: | |||||
| spans[-1][1][1] = idx | |||||
| elif cur_label == 'o': | |||||
| pass | |||||
| else: | |||||
| spans.append((label, [idx, idx])) | |||||
| prev_label = cur_label | |||||
| return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans] | |||||
| def update(self, logits, labels): | |||||
| ''' | |||||
| update span F1 score | |||||
| ''' | |||||
| labels = labels.asnumpy() | |||||
| labels = np.reshape(labels, -1) | |||||
| if self.use_crf: | |||||
| backpointers, best_tag_id = logits | |||||
| best_path = postprocess(backpointers, best_tag_id) | |||||
| logit_id = [] | |||||
| for ele in best_path: | |||||
| logit_id.extend(ele) | |||||
| else: | |||||
| logits = logits.asnumpy() | |||||
| logit_id = np.argmax(logits, axis=-1) | |||||
| logit_id = np.reshape(logit_id, -1) | |||||
| label_spans = self.tag2span(labels) | |||||
| pred_spans = self.tag2span(logit_id) | |||||
| for span in pred_spans: | |||||
| if span in label_spans: | |||||
| self.TP += 1 | |||||
| label_spans.remove(span) | |||||
| else: | |||||
| self.FP += 1 | |||||
| for span in label_spans: | |||||
| self.FN += 1 | |||||
| class MCC(): | class MCC(): | ||||
| ''' | ''' | ||||
| Calculate Matthews Correlation Coefficient | Calculate Matthews Correlation Coefficient | ||||
| @@ -52,25 +52,30 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, | |||||
| return data_set | return data_set | ||||
| def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", | |||||
| data_file_path=None, schema_file_path=None, do_shuffle=True): | |||||
| def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", data_file_path=None, | |||||
| dataset_format="mindrecord", schema_file_path=None, do_shuffle=True): | |||||
| """create finetune or evaluation dataset""" | """create finetune or evaluation dataset""" | ||||
| type_cast_op = C.TypeCast(mstype.int32) | type_cast_op = C.TypeCast(mstype.int32) | ||||
| data_set = ds.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, | |||||
| columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"], | |||||
| shuffle=do_shuffle) | |||||
| if dataset_format == "mindrecord": | |||||
| dataset = ds.MindDataset([data_file_path], | |||||
| columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"], | |||||
| shuffle=do_shuffle) | |||||
| else: | |||||
| dataset = ds.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, | |||||
| columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"], | |||||
| shuffle=do_shuffle) | |||||
| if assessment_method == "Spearman_correlation": | if assessment_method == "Spearman_correlation": | ||||
| type_cast_op_float = C.TypeCast(mstype.float32) | type_cast_op_float = C.TypeCast(mstype.float32) | ||||
| data_set = data_set.map(operations=type_cast_op_float, input_columns="label_ids") | |||||
| dataset = dataset.map(operations=type_cast_op_float, input_columns="label_ids") | |||||
| else: | else: | ||||
| data_set = data_set.map(operations=type_cast_op, input_columns="label_ids") | |||||
| data_set = data_set.map(operations=type_cast_op, input_columns="segment_ids") | |||||
| data_set = data_set.map(operations=type_cast_op, input_columns="input_mask") | |||||
| data_set = data_set.map(operations=type_cast_op, input_columns="input_ids") | |||||
| data_set = data_set.repeat(repeat_count) | |||||
| dataset = dataset.map(operations=type_cast_op, input_columns="label_ids") | |||||
| dataset = dataset.map(operations=type_cast_op, input_columns="segment_ids") | |||||
| dataset = dataset.map(operations=type_cast_op, input_columns="input_mask") | |||||
| dataset = dataset.map(operations=type_cast_op, input_columns="input_ids") | |||||
| dataset = dataset.repeat(repeat_count) | |||||
| # apply batch operations | # apply batch operations | ||||
| data_set = data_set.batch(batch_size, drop_remainder=True) | |||||
| return data_set | |||||
| dataset = dataset.batch(batch_size, drop_remainder=True) | |||||
| return dataset | |||||
| def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", | def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", | ||||
| @@ -18,12 +18,14 @@ sample script of processing CLUE classification dataset using mindspore.dataset. | |||||
| """ | """ | ||||
| import os | import os | ||||
| import argparse | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.text as text | import mindspore.dataset.text as text | ||||
| import mindspore.dataset.transforms.c_transforms as ops | import mindspore.dataset.transforms.c_transforms as ops | ||||
| from utils import convert_labels_to_index | |||||
| def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False, | def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False, | ||||
| @@ -135,3 +137,93 @@ def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage | |||||
| dataset = dataset.map(operations=ops.Mask(ops.Relational.NE, 0, mstype.int32), input_columns=["mask_ids"]) | dataset = dataset.map(operations=ops.Mask(ops.Relational.NE, 0, mstype.int32), input_columns=["mask_ids"]) | ||||
| dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) | dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) | ||||
| return dataset | return dataset | ||||
| def process_cluener_msra(data_file): | |||||
| """process MSRA dataset for CLUE""" | |||||
| content = [] | |||||
| labels = [] | |||||
| for line in open(data_file): | |||||
| line = line.strip() | |||||
| if line: | |||||
| word = line.split("\t")[0] | |||||
| if len(line.split("\t")) == 1: | |||||
| label = "O" | |||||
| else: | |||||
| label = line.split("\t")[1].split("\n")[0] | |||||
| if label[0] != "O": | |||||
| label = label[0] + "_" + label[2:] | |||||
| if label[0] == "I": | |||||
| label = "M" + label[1:] | |||||
| content.append(word) | |||||
| labels.append(label) | |||||
| else: | |||||
| for i in range(1, len(labels) - 1): | |||||
| if labels[i][0] == "B" and labels[i+1][0] != "M": | |||||
| labels[i] = "S" + labels[i][1:] | |||||
| elif labels[i][0] == "M" and labels[i+1][0] != labels[i][0]: | |||||
| labels[i] = "E" + labels[i][1:] | |||||
| last = len(labels) - 1 | |||||
| if labels[last][0] == "B": | |||||
| labels[last] = "S" + labels[last][1:] | |||||
| elif labels[last][0] == "M": | |||||
| labels[last] = "E" + labels[last][1:] | |||||
| yield (np.array("".join(content)), np.array(list(labels))) | |||||
| content.clear() | |||||
| labels.clear() | |||||
| continue | |||||
| def process_msra_clue_dataset(data_dir, label_list, bert_vocab_path, max_seq_len=128): | |||||
| """Process MSRA dataset""" | |||||
| ### Loading MSRA from CLUEDataset | |||||
| dataset = ds.GeneratorDataset(process_cluener_msra(data_dir), column_names=['text', 'label']) | |||||
| ### Processing label | |||||
| label_vocab = text.Vocab.from_list(label_list) | |||||
| label_lookup = text.Lookup(label_vocab) | |||||
| dataset = dataset.map(operations=label_lookup, input_columns="label", output_columns="label_ids") | |||||
| dataset = dataset.map(operations=ops.Concatenate(prepend=np.array([0], dtype='i')), | |||||
| input_columns=["label_ids"]) | |||||
| dataset = dataset.map(operations=ops.Slice(slice(0, max_seq_len)), input_columns=["label_ids"]) | |||||
| dataset = dataset.map(operations=ops.PadEnd([max_seq_len], 0), input_columns=["label_ids"]) | |||||
| ### Processing sentence | |||||
| vocab = text.Vocab.from_file(bert_vocab_path) | |||||
| lookup = text.Lookup(vocab, unknown_token='[UNK]') | |||||
| unicode_char_tokenizer = text.UnicodeCharTokenizer() | |||||
| dataset = dataset.map(operations=unicode_char_tokenizer, input_columns=["text"], output_columns=["sentence"]) | |||||
| dataset = dataset.map(operations=ops.Slice(slice(0, max_seq_len-2)), input_columns=["sentence"]) | |||||
| dataset = dataset.map(operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'), | |||||
| append=np.array(["[SEP]"], dtype='S')), input_columns=["sentence"]) | |||||
| dataset = dataset.map(operations=lookup, input_columns=["sentence"], output_columns=["input_ids"]) | |||||
| dataset = dataset.map(operations=ops.PadEnd([max_seq_len], 0), input_columns=["input_ids"]) | |||||
| dataset = dataset.map(operations=ops.Duplicate(), input_columns=["input_ids"], | |||||
| output_columns=["input_ids", "input_mask"], | |||||
| column_order=["input_ids", "input_mask", "label_ids"]) | |||||
| dataset = dataset.map(operations=ops.Mask(ops.Relational.NE, 0, mstype.int32), input_columns=["input_mask"]) | |||||
| dataset = dataset.map(operations=ops.Duplicate(), input_columns=["input_ids"], | |||||
| output_columns=["input_ids", "segment_ids"], | |||||
| column_order=["input_ids", "input_mask", "segment_ids", "label_ids"]) | |||||
| dataset = dataset.map(operations=ops.Fill(0), input_columns=["segment_ids"]) | |||||
| return dataset | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description="create mindrecord") | |||||
| parser.add_argument("--data_dir", type=str, default="", help="dataset path") | |||||
| parser.add_argument("--vocab_file", type=str, default="", help="Vocab file path") | |||||
| parser.add_argument("--max_seq_len", type=int, default=128, help="Sequence length") | |||||
| parser.add_argument("--save_path", type=str, default="./my.mindrecord", help="Path to save mindrecord") | |||||
| parser.add_argument("--label2id", type=str, default="", | |||||
| help="Label2id file path, must be set for cluener2020 task") | |||||
| args_opt = parser.parse_args() | |||||
| if args_opt.label2id == "": | |||||
| raise ValueError("label2id should not be empty") | |||||
| labels_list = [] | |||||
| with open(args_opt.label2id) as f: | |||||
| for tag in f: | |||||
| labels_list.append(tag.strip()) | |||||
| tag_to_index = list(convert_labels_to_index(labels_list).keys()) | |||||
| ds = process_msra_clue_dataset(args_opt.data_dir, tag_to_index, args_opt.vocab_file, args_opt.max_seq_len) | |||||
| ds.save(args_opt.save_path) | |||||