From: @zhao_ting_v Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejianpull/13903/MERGE
| @@ -1,4 +1,4 @@ | |||
| # Contents | |||
| # Contents | |||
| - [Contents](#contents) | |||
| - [TinyBERT Description](#tinybert-description) | |||
| @@ -197,8 +197,9 @@ usage: run_general_task.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN | |||
| [--load_gd_ckpt_path LOAD_GD_CKPT_PATH] | |||
| [--load_td1_ckpt_path LOAD_TD1_CKPT_PATH] | |||
| [--train_data_dir TRAIN_DATA_DIR] | |||
| [--eval_data_dir EVAL_DATA_DIR] | |||
| [--eval_data_dir EVAL_DATA_DIR] [--task_type TASK_TYPE] | |||
| [--task_name TASK_NAME] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE] | |||
| [--assessment_method ASSESSMENT_METHOD] | |||
| options: | |||
| --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" | |||
| @@ -217,7 +218,9 @@ options: | |||
| --load_td1_ckpt_path path to load checkpoint files which produced by task distill phase 1: PATH, default is "" | |||
| --train_data_dir path to train dataset directory: PATH, default is "" | |||
| --eval_data_dir path to eval dataset directory: PATH, default is "" | |||
| --task_name classification task: "SST-2" | "QNLI" | "MNLI", default is "" | |||
| --task_type task type: "classification" | "ner", default is "classification" | |||
| --task_name classification or ner task: "SST-2" | "QNLI" | "MNLI" | "TNEWS", "CLUENER", default is "" | |||
| --assessment_method assessment method to do evaluation: acc | f1 | |||
| --schema_dir path to schema.json file, PATH, default is "" | |||
| --dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord | |||
| ``` | |||
| @@ -249,6 +252,7 @@ Parameters for optimizer: | |||
| Parameters for bert network: | |||
| seq_length length of input sequence: N, default is 128 | |||
| vocab_size size of each embedding vector: N, must be consistent with the dataset you use. Default is 30522 | |||
| Usually, we use 21128 for CN vocabs and 30522 for EN vocabs according to the origin paper. Default is 30522 | |||
| hidden_size size of bert encoder layers: N | |||
| num_hidden_layers number of hidden layers: N | |||
| num_attention_heads number of attention heads: N, default is 12 | |||
| @@ -22,7 +22,7 @@ from mindspore import Tensor, context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | |||
| from src.td_config import td_student_net_cfg | |||
| from src.tinybert_model import BertModelCLS | |||
| from src.tinybert_model import BertModelCLS, BertModelNER | |||
| parser = argparse.ArgumentParser(description='tinybert task distill') | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id") | |||
| @@ -31,7 +31,10 @@ parser.add_argument("--file_name", type=str, default="tinybert", help="output fi | |||
| parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") | |||
| parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name') | |||
| parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"], | |||
| help="The type of the task to train.") | |||
| parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"], | |||
| help="The name of the task to train.") | |||
| args = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| @@ -43,7 +46,9 @@ DEFAULT_SEQ_LENGTH = 128 | |||
| DEFAULT_BS = 32 | |||
| task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, | |||
| "QNLI": {"num_labels": 2, "seq_length": 128}, | |||
| "MNLI": {"num_labels": 3, "seq_length": 128}} | |||
| "MNLI": {"num_labels": 3, "seq_length": 128}, | |||
| "TNEWS": {"num_labels": 15, "seq_length": 128}, | |||
| "CLUENER": {"num_labels": 10, "seq_length": 128}} | |||
| class Task: | |||
| """ | |||
| @@ -68,8 +73,13 @@ if __name__ == '__main__': | |||
| task = Task(args.task_name) | |||
| td_student_net_cfg.seq_length = task.seq_length | |||
| td_student_net_cfg.batch_size = DEFAULT_BS | |||
| if args.task_type == "classification": | |||
| eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||
| elif args.task_type == "ner": | |||
| eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||
| else: | |||
| raise ValueError(f"Not support task type: {args.task_type}") | |||
| eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||
| param_dict = load_checkpoint(args.ckpt_file) | |||
| new_param_dict = {} | |||
| for key, value in param_dict.items(): | |||
| @@ -33,14 +33,10 @@ from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate | |||
| from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg | |||
| from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell | |||
| def run_general_distill(): | |||
| """ | |||
| run general distill | |||
| """ | |||
| def get_argument(): | |||
| """Tinybert general distill argument parser.""" | |||
| parser = argparse.ArgumentParser(description='tinybert general distill') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'], | |||
| help='device where the code will be implemented. (Default: Ascend)') | |||
| parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], | |||
| help="Run distribute, default is false.") | |||
| @@ -61,20 +57,21 @@ def run_general_distill(): | |||
| parser.add_argument("--dataset_type", type=str, default="tfrecord", | |||
| help="dataset type tfrecord/mindrecord, default is tfrecord") | |||
| args_opt = parser.parse_args() | |||
| return args_opt | |||
| def run_general_distill(): | |||
| """ | |||
| run general distill | |||
| """ | |||
| args_opt = get_argument() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, | |||
| reserve_class_name_in_scope=False) | |||
| if args_opt.device_target == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||
| elif args_opt.device_target == "GPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||
| else: | |||
| raise Exception("Target error, GPU or Ascend is supported.") | |||
| context.set_context(reserve_class_name_in_scope=False) | |||
| context.set_context(device_id=args_opt.device_id) | |||
| save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, | |||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
| if args_opt.distribute == "true": | |||
| if args_opt.device_target == 'Ascend': | |||
| D.init() | |||
| @@ -104,6 +101,14 @@ def run_general_distill(): | |||
| # and the loss scale is not necessary | |||
| enable_loss_scale = False | |||
| if args_opt.device_target == "CPU": | |||
| logger.warning('CPU only support float32 temporarily, run with float32.') | |||
| bert_teacher_net_cfg.dtype = mstype.float32 | |||
| bert_teacher_net_cfg.compute_type = mstype.float32 | |||
| bert_student_net_cfg.dtype = mstype.float32 | |||
| bert_student_net_cfg.compute_type = mstype.float32 | |||
| enable_loss_scale = False | |||
| netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, | |||
| teacher_ckpt=args_opt.load_teacher_ckpt_path, | |||
| student_config=bert_student_net_cfg, | |||
| @@ -28,10 +28,10 @@ from mindspore.nn.optim import AdamWeightDecay | |||
| from mindspore import log as logger | |||
| from src.dataset import create_tinybert_dataset, DataType | |||
| from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate | |||
| from src.assessment_method import Accuracy | |||
| from src.assessment_method import Accuracy, F1 | |||
| from src.td_config import phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg | |||
| from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell | |||
| from src.tinybert_model import BertModelCLS | |||
| from src.tinybert_model import BertModelCLS, BertModelNER | |||
| _cur_dir = os.getcwd() | |||
| td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt') | |||
| @@ -46,7 +46,7 @@ def parse_args(): | |||
| parse args | |||
| """ | |||
| parser = argparse.ArgumentParser(description='tinybert task distill') | |||
| parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'], | |||
| parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||
| help='device where the code will be implemented. (Default: Ascend)') | |||
| parser.add_argument("--do_train", type=str, default="true", choices=["true", "false"], | |||
| help="Do train task, default is true.") | |||
| @@ -69,21 +69,46 @@ def parse_args(): | |||
| parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | |||
| parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], | |||
| parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"], | |||
| help="The type of the task to train.") | |||
| parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"], | |||
| help="The name of the task to train.") | |||
| parser.add_argument("--assessment_method", type=str, default="accuracy", choices=["accuracy", "bf1", "mf1"], | |||
| help="assessment_method include: [accuracy, bf1, mf1], default is accuracy") | |||
| parser.add_argument("--dataset_type", type=str, default="tfrecord", | |||
| help="dataset type tfrecord/mindrecord, default is tfrecord") | |||
| args = parser.parse_args() | |||
| if args.do_train.lower() != "true" and args.do_eval.lower() != "true": | |||
| raise ValueError("do train or do eval must have one be true, please confirm your config") | |||
| if args.task_name in ["SST-2", "QNLI", "MNLI", "TNEWS"] and args.task_type != "classification": | |||
| raise ValueError(f"{args.task_name} is a classification dataset, please set --task_type=classification") | |||
| if args.task_name in ["CLUENER"] and args.task_type != "ner": | |||
| raise ValueError(f"{args.task_name} is a ner dataset, please set --task_type=ner") | |||
| if args.task_name in ["SST-2", "QNLI", "MNLI"] and \ | |||
| (td_teacher_net_cfg.vocab_size != 30522 or td_student_net_cfg.vocab_size != 30522): | |||
| logger.warning(f"{args.task_name} is an English dataset. Usually, we use 21128 for CN vocabs and 30522 for "\ | |||
| "EN vocabs according to the origin paper.") | |||
| if args.task_name in ["TNEWS", "CLUENER"] and \ | |||
| (td_teacher_net_cfg.vocab_size != 21128 or td_student_net_cfg.vocab_size != 21128): | |||
| logger.warning(f"{args.task_name} is a Chinese dataset. Usually, we use 21128 for CN vocabs and 30522 for " \ | |||
| "EN vocabs according to the origin paper.") | |||
| return args | |||
| args_opt = parse_args() | |||
| if args_opt.dataset_type == "tfrecord": | |||
| dataset_type = DataType.TFRECORD | |||
| elif args_opt.dataset_type == "mindrecord": | |||
| dataset_type = DataType.MINDRECORD | |||
| else: | |||
| raise Exception("dataset format is not supported yet") | |||
| DEFAULT_NUM_LABELS = 2 | |||
| DEFAULT_SEQ_LENGTH = 128 | |||
| task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, | |||
| "QNLI": {"num_labels": 2, "seq_length": 128}, | |||
| "MNLI": {"num_labels": 3, "seq_length": 128}} | |||
| "MNLI": {"num_labels": 3, "seq_length": 128}, | |||
| "TNEWS": {"num_labels": 15, "seq_length": 128}, | |||
| "CLUENER": {"num_labels": 43, "seq_length": 128}} | |||
| class Task: | |||
| @@ -112,29 +137,15 @@ def run_predistill(): | |||
| run predistill | |||
| """ | |||
| cfg = phase1_cfg | |||
| if args_opt.device_target == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||
| elif args_opt.device_target == "GPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||
| else: | |||
| raise Exception("Target error, GPU or Ascend is supported.") | |||
| context.set_context(reserve_class_name_in_scope=False) | |||
| load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path | |||
| load_student_checkpoint_path = args_opt.load_gd_ckpt_path | |||
| netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | |||
| student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | |||
| is_training=True, task_type='classification', | |||
| is_training=True, task_type=args_opt.task_type, | |||
| num_labels=task.num_labels, is_predistill=True) | |||
| rank = 0 | |||
| device_num = 1 | |||
| if args_opt.dataset_type == "tfrecord": | |||
| dataset_type = DataType.TFRECORD | |||
| elif args_opt.dataset_type == "mindrecord": | |||
| dataset_type = DataType.MINDRECORD | |||
| else: | |||
| raise Exception("dataset format is not supported yet") | |||
| dataset = create_tinybert_dataset('td', cfg.batch_size, | |||
| device_num, rank, args_opt.do_shuffle, | |||
| args_opt.train_data_dir, args_opt.schema_dir, | |||
| @@ -190,25 +201,19 @@ def run_task_distill(ckpt_file): | |||
| raise ValueError("Student ckpt file should not be None") | |||
| cfg = phase2_cfg | |||
| if args_opt.device_target == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||
| elif args_opt.device_target == "GPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||
| else: | |||
| raise Exception("Target error, GPU or Ascend is supported.") | |||
| load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path | |||
| load_student_checkpoint_path = ckpt_file | |||
| netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | |||
| student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | |||
| is_training=True, task_type='classification', | |||
| is_training=True, task_type=args_opt.task_type, | |||
| num_labels=task.num_labels, is_predistill=False) | |||
| rank = 0 | |||
| device_num = 1 | |||
| train_dataset = create_tinybert_dataset('td', cfg.batch_size, | |||
| device_num, rank, args_opt.do_shuffle, | |||
| args_opt.train_data_dir, args_opt.schema_dir) | |||
| args_opt.train_data_dir, args_opt.schema_dir, | |||
| data_type=dataset_type) | |||
| dataset_size = train_dataset.get_dataset_size() | |||
| print('td2 train dataset size: ', dataset_size) | |||
| @@ -238,7 +243,8 @@ def run_task_distill(ckpt_file): | |||
| eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size, | |||
| device_num, rank, args_opt.do_shuffle, | |||
| args_opt.eval_data_dir, args_opt.schema_dir) | |||
| args_opt.eval_data_dir, args_opt.schema_dir, | |||
| data_type=dataset_type) | |||
| print('td2 eval dataset size: ', eval_dataset.get_dataset_size()) | |||
| if args_opt.do_eval.lower() == "true": | |||
| @@ -263,6 +269,19 @@ def run_task_distill(ckpt_file): | |||
| dataset_sink_mode=(args_opt.enable_data_sink == 'true'), | |||
| sink_size=args_opt.data_sink_steps) | |||
| def eval_result_print(assessment_method="accuracy", callback=None): | |||
| """print eval result""" | |||
| if assessment_method == "accuracy": | |||
| print("============== acc is {}".format(callback.acc_num / callback.total_num)) | |||
| elif assessment_method == "bf1": | |||
| print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) | |||
| print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) | |||
| print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) | |||
| elif assessment_method == "mf1": | |||
| print("F1 {:.6f} ".format(callback.eval())) | |||
| else: | |||
| raise ValueError("Assessment method not supported, support: [accuracy, f1]") | |||
| def do_eval_standalone(): | |||
| """ | |||
| do eval standalone | |||
| @@ -270,13 +289,12 @@ def do_eval_standalone(): | |||
| ckpt_file = args_opt.load_td1_ckpt_path | |||
| if ckpt_file == '': | |||
| raise ValueError("Student ckpt file should not be None") | |||
| if args_opt.device_target == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||
| elif args_opt.device_target == "GPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||
| if args_opt.task_type == "classification": | |||
| eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||
| elif args_opt.task_type == "ner": | |||
| eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||
| else: | |||
| raise Exception("Target error, GPU or Ascend is supported.") | |||
| eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||
| raise ValueError(f"Not support the task type {args_opt.task_type}") | |||
| param_dict = load_checkpoint(ckpt_file) | |||
| new_param_dict = {} | |||
| for key, value in param_dict.items(): | |||
| @@ -289,11 +307,18 @@ def do_eval_standalone(): | |||
| eval_dataset = create_tinybert_dataset('td', batch_size=eval_cfg.batch_size, | |||
| device_num=1, rank=0, do_shuffle="false", | |||
| data_dir=args_opt.eval_data_dir, | |||
| schema_dir=args_opt.schema_dir) | |||
| schema_dir=args_opt.schema_dir, | |||
| data_type=dataset_type) | |||
| print('eval dataset size: ', eval_dataset.get_dataset_size()) | |||
| print('eval dataset batch size: ', eval_dataset.get_batch_size()) | |||
| callback = Accuracy() | |||
| if args_opt.assessment_method == "accuracy": | |||
| callback = Accuracy() | |||
| elif args_opt.assessment_method == "bf1": | |||
| callback = F1(num_labels=task.num_labels) | |||
| elif args_opt.assessment_method == "mf1": | |||
| callback = F1(num_labels=task.num_labels, mode="MultiLabel") | |||
| else: | |||
| raise ValueError("Assessment method not supported, support: [accuracy, f1]") | |||
| columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | |||
| for data in eval_dataset.create_dict_iterator(num_epochs=1): | |||
| input_data = [] | |||
| @@ -302,16 +327,16 @@ def do_eval_standalone(): | |||
| input_ids, input_mask, token_type_id, label_ids = input_data | |||
| logits = eval_model(input_ids, token_type_id, input_mask) | |||
| callback.update(logits, label_ids) | |||
| acc = callback.acc_num / callback.total_num | |||
| print("======================================") | |||
| print("============== acc is {}".format(acc)) | |||
| print("======================================") | |||
| print("==============================================================") | |||
| eval_result_print(args_opt.assessment_method, callback) | |||
| print("==============================================================") | |||
| if __name__ == '__main__': | |||
| if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true": | |||
| raise ValueError("do_train or do eval must have one be true, please confirm your config") | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, | |||
| reserve_class_name_in_scope=False) | |||
| if args_opt.device_target == "Ascend": | |||
| context.set_context(device_id=args_opt.device_id) | |||
| enable_loss_scale = True | |||
| if args_opt.device_target == "GPU": | |||
| if td_student_net_cfg.compute_type != mstype.float32: | |||
| @@ -321,6 +346,14 @@ if __name__ == '__main__': | |||
| # and the loss scale is not necessary | |||
| enable_loss_scale = False | |||
| if args_opt.device_target == "CPU": | |||
| logger.warning('CPU only support float32 temporarily, run with float32.') | |||
| td_teacher_net_cfg.dtype = mstype.float32 | |||
| td_teacher_net_cfg.compute_type = mstype.float32 | |||
| td_student_net_cfg.dtype = mstype.float32 | |||
| td_student_net_cfg.compute_type = mstype.float32 | |||
| enable_loss_scale = False | |||
| td_teacher_net_cfg.seq_length = task.seq_length | |||
| td_student_net_cfg.seq_length = task.seq_length | |||
| @@ -32,7 +32,6 @@ python ${PROJECT_DIR}/../run_task_distill.py \ | |||
| --do_eval="true" \ | |||
| --td_phase1_epoch_size=10 \ | |||
| --td_phase2_epoch_size=3 \ | |||
| --task_name="" \ | |||
| --do_shuffle="true" \ | |||
| --enable_data_sink="true" \ | |||
| --data_sink_steps=100 \ | |||
| @@ -44,5 +43,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \ | |||
| --train_data_dir="" \ | |||
| --eval_data_dir="" \ | |||
| --schema_dir="" \ | |||
| --dataset_type="tfrecord" > log.txt 2>&1 & | |||
| --dataset_type="tfrecord" \ | |||
| --task_type="classification" \ | |||
| --task_name="" \ | |||
| --assessment_method="accuracy" > log.txt 2>&1 & | |||
| @@ -32,23 +32,56 @@ class Accuracy(): | |||
| self.total_num += len(labels) | |||
| class F1(): | |||
| """F1""" | |||
| def __init__(self): | |||
| ''' | |||
| calculate F1 score | |||
| ''' | |||
| def __init__(self, num_labels=2, mode="Binary"): | |||
| self.TP = 0 | |||
| self.FP = 0 | |||
| self.FN = 0 | |||
| self.num_labels = num_labels | |||
| self.P = 0 | |||
| self.AP = 0 | |||
| self.mode = mode | |||
| if self.mode.lower() not in ("binary", "multilabel"): | |||
| raise ValueError("Assessment mode not supported, support: [Binary, MultiLabel]") | |||
| def update(self, logits, labels): | |||
| """Update F1 score""" | |||
| ''' | |||
| update F1 score | |||
| ''' | |||
| labels = labels.asnumpy() | |||
| labels = np.reshape(labels, -1) | |||
| logits = logits.asnumpy() | |||
| logit_id = np.argmax(logits, axis=-1) | |||
| logit_id = np.reshape(logit_id, -1) | |||
| pos_eva = np.isin(logit_id, [2, 3, 4, 5, 6, 7]) | |||
| pos_label = np.isin(labels, [2, 3, 4, 5, 6, 7]) | |||
| self.TP += np.sum(pos_eva & pos_label) | |||
| self.FP += np.sum(pos_eva & (~pos_label)) | |||
| self.FN += np.sum((~pos_eva) & pos_label) | |||
| print("-----------------precision is ", self.TP / (self.TP + self.FP)) | |||
| print("-----------------recall is ", self.TP / (self.TP + self.FN)) | |||
| if self.mode.lower() == "binary": | |||
| pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)]) | |||
| pos_label = np.isin(labels, [i for i in range(1, self.num_labels)]) | |||
| self.TP += np.sum(pos_eva&pos_label) | |||
| self.FP += np.sum(pos_eva&(~pos_label)) | |||
| self.FN += np.sum((~pos_eva)&pos_label) | |||
| else: | |||
| target = np.zeros((len(labels), self.num_labels), dtype=np.int) | |||
| pred = np.zeros((len(logit_id), self.num_labels), dtype=np.int) | |||
| for i, label in enumerate(labels): | |||
| target[i][label] = 1 | |||
| for i, label in enumerate(logit_id): | |||
| pred[i][label] = 1 | |||
| positives = pred.sum(axis=0) | |||
| actual_positives = target.sum(axis=0) | |||
| true_positives = (target * pred).sum(axis=0) | |||
| self.TP += true_positives | |||
| self.P += positives | |||
| self.AP += actual_positives | |||
| def eval(self): | |||
| if self.mode.lower() == "binary": | |||
| f1 = self.TP / (2 * self.TP + self.FP + self.FN) | |||
| else: | |||
| tp = np.sum(self.TP) | |||
| p = np.sum(self.P) | |||
| ap = np.sum(self.AP) | |||
| f1 = 2 * tp / (ap + p) | |||
| return f1 | |||
| @@ -28,7 +28,7 @@ from mindspore.communication.management import get_group_size | |||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from .tinybert_model import BertModel, TinyBertModel, BertModelCLS | |||
| from .tinybert_model import BertModel, TinyBertModel, BertModelCLS, BertModelNER | |||
| GRADIENT_CLIP_TYPE = 1 | |||
| @@ -362,8 +362,18 @@ class BertNetworkWithLoss_td(nn.Cell): | |||
| temperature=1.0, dropout_prob=0.1): | |||
| super(BertNetworkWithLoss_td, self).__init__() | |||
| # load teacher model | |||
| self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob, | |||
| use_one_hot_embeddings, "teacher") | |||
| if task_type == "classification": | |||
| self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob, | |||
| use_one_hot_embeddings, "teacher") | |||
| self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob, | |||
| use_one_hot_embeddings, "student") | |||
| elif task_type == "ner": | |||
| self.teacher = BertModelNER(teacher_config, False, num_labels, dropout_prob, | |||
| use_one_hot_embeddings, "teacher") | |||
| self.bert = BertModelNER(student_config, is_training, num_labels, dropout_prob, | |||
| use_one_hot_embeddings, "student") | |||
| else: | |||
| raise ValueError(f"Not support task type: {task_type}") | |||
| param_dict = load_checkpoint(teacher_ckpt) | |||
| new_param_dict = {} | |||
| for key, value in param_dict.items(): | |||
| @@ -377,8 +387,6 @@ class BertNetworkWithLoss_td(nn.Cell): | |||
| for param in params: | |||
| param.requires_grad = False | |||
| # load student model | |||
| self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob, | |||
| use_one_hot_embeddings, "student") | |||
| param_dict = load_checkpoint(student_ckpt) | |||
| if is_predistill: | |||
| new_param_dict = {} | |||
| @@ -401,7 +409,7 @@ class BertNetworkWithLoss_td(nn.Cell): | |||
| self.is_predistill = is_predistill | |||
| self.is_att_fit = is_att_fit | |||
| self.is_rep_fit = is_rep_fit | |||
| self.task_type = task_type | |||
| self.use_soft_cross_entropy = task_type in ["classification", "ner"] | |||
| self.temperature = temperature | |||
| self.loss_mse = nn.MSELoss() | |||
| self.select = P.Select() | |||
| @@ -456,7 +464,7 @@ class BertNetworkWithLoss_td(nn.Cell): | |||
| rep_loss += self.loss_mse(student_rep, teacher_rep) | |||
| total_loss += rep_loss | |||
| else: | |||
| if self.task_type == "classification": | |||
| if self.use_soft_cross_entropy: | |||
| cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature) | |||
| else: | |||
| cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1]) | |||
| @@ -926,3 +926,40 @@ class BertModelCLS(nn.Cell): | |||
| if self._phase == 'train' or self.phase_type == "teacher": | |||
| return seq_output, att_output, logits, log_probs | |||
| return log_probs | |||
| class BertModelNER(nn.Cell): | |||
| """ | |||
| This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11). | |||
| The returned output represents the final logits as the results of log_softmax is proportional to that of softmax. | |||
| """ | |||
| def __init__(self, config, is_training, num_labels=11, dropout_prob=0.0, | |||
| use_one_hot_embeddings=False, phase_type="student"): | |||
| super(BertModelNER, self).__init__() | |||
| if not is_training: | |||
| config.hidden_dropout_prob = 0.0 | |||
| config.hidden_probs_dropout_prob = 0.0 | |||
| self.bert = BertModel(config, is_training, use_one_hot_embeddings) | |||
| self.cast = P.Cast() | |||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||
| self.dtype = config.dtype | |||
| self.num_labels = num_labels | |||
| self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, | |||
| has_bias=True).to_float(config.compute_type) | |||
| self.dropout = nn.ReLU() | |||
| self.reshape = P.Reshape() | |||
| self.shape = (-1, config.hidden_size) | |||
| self.origin_shape = (-1, config.seq_length, self.num_labels) | |||
| def construct(self, input_ids, input_mask, token_type_id): | |||
| """Return the final logits as the results of log_softmax.""" | |||
| sequence_output, _, _, encoder_outputs, attention_outputs = \ | |||
| self.bert(input_ids, token_type_id, input_mask) | |||
| seq = self.dropout(sequence_output) | |||
| seq = self.reshape(seq, self.shape) | |||
| logits = self.dense_1(seq) | |||
| logits = self.cast(logits, self.dtype) | |||
| return_value = self.log_softmax(logits) | |||
| if self._phase == 'train' or self.phase_type == "teacher": | |||
| return encoder_outputs, attention_outputs, logits, return_value | |||
| return return_value | |||