From: @zhao_ting_v Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejianpull/13903/MERGE
| @@ -1,4 +1,4 @@ | |||||
| # Contents | |||||
| # Contents | |||||
| - [Contents](#contents) | - [Contents](#contents) | ||||
| - [TinyBERT Description](#tinybert-description) | - [TinyBERT Description](#tinybert-description) | ||||
| @@ -197,8 +197,9 @@ usage: run_general_task.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN | |||||
| [--load_gd_ckpt_path LOAD_GD_CKPT_PATH] | [--load_gd_ckpt_path LOAD_GD_CKPT_PATH] | ||||
| [--load_td1_ckpt_path LOAD_TD1_CKPT_PATH] | [--load_td1_ckpt_path LOAD_TD1_CKPT_PATH] | ||||
| [--train_data_dir TRAIN_DATA_DIR] | [--train_data_dir TRAIN_DATA_DIR] | ||||
| [--eval_data_dir EVAL_DATA_DIR] | |||||
| [--eval_data_dir EVAL_DATA_DIR] [--task_type TASK_TYPE] | |||||
| [--task_name TASK_NAME] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE] | [--task_name TASK_NAME] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE] | ||||
| [--assessment_method ASSESSMENT_METHOD] | |||||
| options: | options: | ||||
| --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" | --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" | ||||
| @@ -217,7 +218,9 @@ options: | |||||
| --load_td1_ckpt_path path to load checkpoint files which produced by task distill phase 1: PATH, default is "" | --load_td1_ckpt_path path to load checkpoint files which produced by task distill phase 1: PATH, default is "" | ||||
| --train_data_dir path to train dataset directory: PATH, default is "" | --train_data_dir path to train dataset directory: PATH, default is "" | ||||
| --eval_data_dir path to eval dataset directory: PATH, default is "" | --eval_data_dir path to eval dataset directory: PATH, default is "" | ||||
| --task_name classification task: "SST-2" | "QNLI" | "MNLI", default is "" | |||||
| --task_type task type: "classification" | "ner", default is "classification" | |||||
| --task_name classification or ner task: "SST-2" | "QNLI" | "MNLI" | "TNEWS", "CLUENER", default is "" | |||||
| --assessment_method assessment method to do evaluation: acc | f1 | |||||
| --schema_dir path to schema.json file, PATH, default is "" | --schema_dir path to schema.json file, PATH, default is "" | ||||
| --dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord | --dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord | ||||
| ``` | ``` | ||||
| @@ -249,6 +252,7 @@ Parameters for optimizer: | |||||
| Parameters for bert network: | Parameters for bert network: | ||||
| seq_length length of input sequence: N, default is 128 | seq_length length of input sequence: N, default is 128 | ||||
| vocab_size size of each embedding vector: N, must be consistent with the dataset you use. Default is 30522 | vocab_size size of each embedding vector: N, must be consistent with the dataset you use. Default is 30522 | ||||
| Usually, we use 21128 for CN vocabs and 30522 for EN vocabs according to the origin paper. Default is 30522 | |||||
| hidden_size size of bert encoder layers: N | hidden_size size of bert encoder layers: N | ||||
| num_hidden_layers number of hidden layers: N | num_hidden_layers number of hidden layers: N | ||||
| num_attention_heads number of attention heads: N, default is 12 | num_attention_heads number of attention heads: N, default is 12 | ||||
| @@ -22,7 +22,7 @@ from mindspore import Tensor, context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | from mindspore.train.serialization import load_checkpoint, load_param_into_net, export | ||||
| from src.td_config import td_student_net_cfg | from src.td_config import td_student_net_cfg | ||||
| from src.tinybert_model import BertModelCLS | |||||
| from src.tinybert_model import BertModelCLS, BertModelNER | |||||
| parser = argparse.ArgumentParser(description='tinybert task distill') | parser = argparse.ArgumentParser(description='tinybert task distill') | ||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id") | parser.add_argument("--device_id", type=int, default=0, help="Device id") | ||||
| @@ -31,7 +31,10 @@ parser.add_argument("--file_name", type=str, default="tinybert", help="output fi | |||||
| parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") | parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") | ||||
| parser.add_argument("--device_target", type=str, default="Ascend", | parser.add_argument("--device_target", type=str, default="Ascend", | ||||
| choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") | choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") | ||||
| parser.add_argument('--task_name', type=str, default='SST-2', choices=['SST-2', 'QNLI', 'MNLI'], help='task name') | |||||
| parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"], | |||||
| help="The type of the task to train.") | |||||
| parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"], | |||||
| help="The name of the task to train.") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | ||||
| @@ -43,7 +46,9 @@ DEFAULT_SEQ_LENGTH = 128 | |||||
| DEFAULT_BS = 32 | DEFAULT_BS = 32 | ||||
| task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, | task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, | ||||
| "QNLI": {"num_labels": 2, "seq_length": 128}, | "QNLI": {"num_labels": 2, "seq_length": 128}, | ||||
| "MNLI": {"num_labels": 3, "seq_length": 128}} | |||||
| "MNLI": {"num_labels": 3, "seq_length": 128}, | |||||
| "TNEWS": {"num_labels": 15, "seq_length": 128}, | |||||
| "CLUENER": {"num_labels": 10, "seq_length": 128}} | |||||
| class Task: | class Task: | ||||
| """ | """ | ||||
| @@ -68,8 +73,13 @@ if __name__ == '__main__': | |||||
| task = Task(args.task_name) | task = Task(args.task_name) | ||||
| td_student_net_cfg.seq_length = task.seq_length | td_student_net_cfg.seq_length = task.seq_length | ||||
| td_student_net_cfg.batch_size = DEFAULT_BS | td_student_net_cfg.batch_size = DEFAULT_BS | ||||
| if args.task_type == "classification": | |||||
| eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||||
| elif args.task_type == "ner": | |||||
| eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||||
| else: | |||||
| raise ValueError(f"Not support task type: {args.task_type}") | |||||
| eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||||
| param_dict = load_checkpoint(args.ckpt_file) | param_dict = load_checkpoint(args.ckpt_file) | ||||
| new_param_dict = {} | new_param_dict = {} | ||||
| for key, value in param_dict.items(): | for key, value in param_dict.items(): | ||||
| @@ -33,14 +33,10 @@ from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate | |||||
| from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg | from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg | ||||
| from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell | from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell | ||||
| def run_general_distill(): | |||||
| """ | |||||
| run general distill | |||||
| """ | |||||
| def get_argument(): | |||||
| """Tinybert general distill argument parser.""" | |||||
| parser = argparse.ArgumentParser(description='tinybert general distill') | parser = argparse.ArgumentParser(description='tinybert general distill') | ||||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'], | |||||
| help='device where the code will be implemented. (Default: Ascend)') | help='device where the code will be implemented. (Default: Ascend)') | ||||
| parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], | parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], | ||||
| help="Run distribute, default is false.") | help="Run distribute, default is false.") | ||||
| @@ -61,20 +57,21 @@ def run_general_distill(): | |||||
| parser.add_argument("--dataset_type", type=str, default="tfrecord", | parser.add_argument("--dataset_type", type=str, default="tfrecord", | ||||
| help="dataset type tfrecord/mindrecord, default is tfrecord") | help="dataset type tfrecord/mindrecord, default is tfrecord") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| return args_opt | |||||
| def run_general_distill(): | |||||
| """ | |||||
| run general distill | |||||
| """ | |||||
| args_opt = get_argument() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, | |||||
| reserve_class_name_in_scope=False) | |||||
| if args_opt.device_target == "Ascend": | if args_opt.device_target == "Ascend": | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||||
| elif args_opt.device_target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||||
| else: | |||||
| raise Exception("Target error, GPU or Ascend is supported.") | |||||
| context.set_context(reserve_class_name_in_scope=False) | |||||
| context.set_context(device_id=args_opt.device_id) | |||||
| save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, | save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, | ||||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | ||||
| if args_opt.distribute == "true": | if args_opt.distribute == "true": | ||||
| if args_opt.device_target == 'Ascend': | if args_opt.device_target == 'Ascend': | ||||
| D.init() | D.init() | ||||
| @@ -104,6 +101,14 @@ def run_general_distill(): | |||||
| # and the loss scale is not necessary | # and the loss scale is not necessary | ||||
| enable_loss_scale = False | enable_loss_scale = False | ||||
| if args_opt.device_target == "CPU": | |||||
| logger.warning('CPU only support float32 temporarily, run with float32.') | |||||
| bert_teacher_net_cfg.dtype = mstype.float32 | |||||
| bert_teacher_net_cfg.compute_type = mstype.float32 | |||||
| bert_student_net_cfg.dtype = mstype.float32 | |||||
| bert_student_net_cfg.compute_type = mstype.float32 | |||||
| enable_loss_scale = False | |||||
| netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, | netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, | ||||
| teacher_ckpt=args_opt.load_teacher_ckpt_path, | teacher_ckpt=args_opt.load_teacher_ckpt_path, | ||||
| student_config=bert_student_net_cfg, | student_config=bert_student_net_cfg, | ||||
| @@ -28,10 +28,10 @@ from mindspore.nn.optim import AdamWeightDecay | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from src.dataset import create_tinybert_dataset, DataType | from src.dataset import create_tinybert_dataset, DataType | ||||
| from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate | from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate | ||||
| from src.assessment_method import Accuracy | |||||
| from src.assessment_method import Accuracy, F1 | |||||
| from src.td_config import phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg | from src.td_config import phase1_cfg, phase2_cfg, eval_cfg, td_teacher_net_cfg, td_student_net_cfg | ||||
| from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell | from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell | ||||
| from src.tinybert_model import BertModelCLS | |||||
| from src.tinybert_model import BertModelCLS, BertModelNER | |||||
| _cur_dir = os.getcwd() | _cur_dir = os.getcwd() | ||||
| td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt') | td_phase1_save_ckpt_dir = os.path.join(_cur_dir, 'tinybert_td_phase1_save_ckpt') | ||||
| @@ -46,7 +46,7 @@ def parse_args(): | |||||
| parse args | parse args | ||||
| """ | """ | ||||
| parser = argparse.ArgumentParser(description='tinybert task distill') | parser = argparse.ArgumentParser(description='tinybert task distill') | ||||
| parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'], | |||||
| parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], | |||||
| help='device where the code will be implemented. (Default: Ascend)') | help='device where the code will be implemented. (Default: Ascend)') | ||||
| parser.add_argument("--do_train", type=str, default="true", choices=["true", "false"], | parser.add_argument("--do_train", type=str, default="true", choices=["true", "false"], | ||||
| help="Do train task, default is true.") | help="Do train task, default is true.") | ||||
| @@ -69,21 +69,46 @@ def parse_args(): | |||||
| parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | ||||
| parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | ||||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | ||||
| parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], | |||||
| parser.add_argument("--task_type", type=str, default="classification", choices=["classification", "ner"], | |||||
| help="The type of the task to train.") | |||||
| parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI", "TNEWS", "CLUENER"], | |||||
| help="The name of the task to train.") | help="The name of the task to train.") | ||||
| parser.add_argument("--assessment_method", type=str, default="accuracy", choices=["accuracy", "bf1", "mf1"], | |||||
| help="assessment_method include: [accuracy, bf1, mf1], default is accuracy") | |||||
| parser.add_argument("--dataset_type", type=str, default="tfrecord", | parser.add_argument("--dataset_type", type=str, default="tfrecord", | ||||
| help="dataset type tfrecord/mindrecord, default is tfrecord") | help="dataset type tfrecord/mindrecord, default is tfrecord") | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if args.do_train.lower() != "true" and args.do_eval.lower() != "true": | |||||
| raise ValueError("do train or do eval must have one be true, please confirm your config") | |||||
| if args.task_name in ["SST-2", "QNLI", "MNLI", "TNEWS"] and args.task_type != "classification": | |||||
| raise ValueError(f"{args.task_name} is a classification dataset, please set --task_type=classification") | |||||
| if args.task_name in ["CLUENER"] and args.task_type != "ner": | |||||
| raise ValueError(f"{args.task_name} is a ner dataset, please set --task_type=ner") | |||||
| if args.task_name in ["SST-2", "QNLI", "MNLI"] and \ | |||||
| (td_teacher_net_cfg.vocab_size != 30522 or td_student_net_cfg.vocab_size != 30522): | |||||
| logger.warning(f"{args.task_name} is an English dataset. Usually, we use 21128 for CN vocabs and 30522 for "\ | |||||
| "EN vocabs according to the origin paper.") | |||||
| if args.task_name in ["TNEWS", "CLUENER"] and \ | |||||
| (td_teacher_net_cfg.vocab_size != 21128 or td_student_net_cfg.vocab_size != 21128): | |||||
| logger.warning(f"{args.task_name} is a Chinese dataset. Usually, we use 21128 for CN vocabs and 30522 for " \ | |||||
| "EN vocabs according to the origin paper.") | |||||
| return args | return args | ||||
| args_opt = parse_args() | args_opt = parse_args() | ||||
| if args_opt.dataset_type == "tfrecord": | |||||
| dataset_type = DataType.TFRECORD | |||||
| elif args_opt.dataset_type == "mindrecord": | |||||
| dataset_type = DataType.MINDRECORD | |||||
| else: | |||||
| raise Exception("dataset format is not supported yet") | |||||
| DEFAULT_NUM_LABELS = 2 | DEFAULT_NUM_LABELS = 2 | ||||
| DEFAULT_SEQ_LENGTH = 128 | DEFAULT_SEQ_LENGTH = 128 | ||||
| task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, | task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, | ||||
| "QNLI": {"num_labels": 2, "seq_length": 128}, | "QNLI": {"num_labels": 2, "seq_length": 128}, | ||||
| "MNLI": {"num_labels": 3, "seq_length": 128}} | |||||
| "MNLI": {"num_labels": 3, "seq_length": 128}, | |||||
| "TNEWS": {"num_labels": 15, "seq_length": 128}, | |||||
| "CLUENER": {"num_labels": 43, "seq_length": 128}} | |||||
| class Task: | class Task: | ||||
| @@ -112,29 +137,15 @@ def run_predistill(): | |||||
| run predistill | run predistill | ||||
| """ | """ | ||||
| cfg = phase1_cfg | cfg = phase1_cfg | ||||
| if args_opt.device_target == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||||
| elif args_opt.device_target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||||
| else: | |||||
| raise Exception("Target error, GPU or Ascend is supported.") | |||||
| context.set_context(reserve_class_name_in_scope=False) | |||||
| load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path | load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path | ||||
| load_student_checkpoint_path = args_opt.load_gd_ckpt_path | load_student_checkpoint_path = args_opt.load_gd_ckpt_path | ||||
| netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | ||||
| student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | ||||
| is_training=True, task_type='classification', | |||||
| is_training=True, task_type=args_opt.task_type, | |||||
| num_labels=task.num_labels, is_predistill=True) | num_labels=task.num_labels, is_predistill=True) | ||||
| rank = 0 | rank = 0 | ||||
| device_num = 1 | device_num = 1 | ||||
| if args_opt.dataset_type == "tfrecord": | |||||
| dataset_type = DataType.TFRECORD | |||||
| elif args_opt.dataset_type == "mindrecord": | |||||
| dataset_type = DataType.MINDRECORD | |||||
| else: | |||||
| raise Exception("dataset format is not supported yet") | |||||
| dataset = create_tinybert_dataset('td', cfg.batch_size, | dataset = create_tinybert_dataset('td', cfg.batch_size, | ||||
| device_num, rank, args_opt.do_shuffle, | device_num, rank, args_opt.do_shuffle, | ||||
| args_opt.train_data_dir, args_opt.schema_dir, | args_opt.train_data_dir, args_opt.schema_dir, | ||||
| @@ -190,25 +201,19 @@ def run_task_distill(ckpt_file): | |||||
| raise ValueError("Student ckpt file should not be None") | raise ValueError("Student ckpt file should not be None") | ||||
| cfg = phase2_cfg | cfg = phase2_cfg | ||||
| if args_opt.device_target == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||||
| elif args_opt.device_target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||||
| else: | |||||
| raise Exception("Target error, GPU or Ascend is supported.") | |||||
| load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path | load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path | ||||
| load_student_checkpoint_path = ckpt_file | load_student_checkpoint_path = ckpt_file | ||||
| netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | ||||
| student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | ||||
| is_training=True, task_type='classification', | |||||
| is_training=True, task_type=args_opt.task_type, | |||||
| num_labels=task.num_labels, is_predistill=False) | num_labels=task.num_labels, is_predistill=False) | ||||
| rank = 0 | rank = 0 | ||||
| device_num = 1 | device_num = 1 | ||||
| train_dataset = create_tinybert_dataset('td', cfg.batch_size, | train_dataset = create_tinybert_dataset('td', cfg.batch_size, | ||||
| device_num, rank, args_opt.do_shuffle, | device_num, rank, args_opt.do_shuffle, | ||||
| args_opt.train_data_dir, args_opt.schema_dir) | |||||
| args_opt.train_data_dir, args_opt.schema_dir, | |||||
| data_type=dataset_type) | |||||
| dataset_size = train_dataset.get_dataset_size() | dataset_size = train_dataset.get_dataset_size() | ||||
| print('td2 train dataset size: ', dataset_size) | print('td2 train dataset size: ', dataset_size) | ||||
| @@ -238,7 +243,8 @@ def run_task_distill(ckpt_file): | |||||
| eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size, | eval_dataset = create_tinybert_dataset('td', eval_cfg.batch_size, | ||||
| device_num, rank, args_opt.do_shuffle, | device_num, rank, args_opt.do_shuffle, | ||||
| args_opt.eval_data_dir, args_opt.schema_dir) | |||||
| args_opt.eval_data_dir, args_opt.schema_dir, | |||||
| data_type=dataset_type) | |||||
| print('td2 eval dataset size: ', eval_dataset.get_dataset_size()) | print('td2 eval dataset size: ', eval_dataset.get_dataset_size()) | ||||
| if args_opt.do_eval.lower() == "true": | if args_opt.do_eval.lower() == "true": | ||||
| @@ -263,6 +269,19 @@ def run_task_distill(ckpt_file): | |||||
| dataset_sink_mode=(args_opt.enable_data_sink == 'true'), | dataset_sink_mode=(args_opt.enable_data_sink == 'true'), | ||||
| sink_size=args_opt.data_sink_steps) | sink_size=args_opt.data_sink_steps) | ||||
| def eval_result_print(assessment_method="accuracy", callback=None): | |||||
| """print eval result""" | |||||
| if assessment_method == "accuracy": | |||||
| print("============== acc is {}".format(callback.acc_num / callback.total_num)) | |||||
| elif assessment_method == "bf1": | |||||
| print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) | |||||
| print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) | |||||
| print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) | |||||
| elif assessment_method == "mf1": | |||||
| print("F1 {:.6f} ".format(callback.eval())) | |||||
| else: | |||||
| raise ValueError("Assessment method not supported, support: [accuracy, f1]") | |||||
| def do_eval_standalone(): | def do_eval_standalone(): | ||||
| """ | """ | ||||
| do eval standalone | do eval standalone | ||||
| @@ -270,13 +289,12 @@ def do_eval_standalone(): | |||||
| ckpt_file = args_opt.load_td1_ckpt_path | ckpt_file = args_opt.load_td1_ckpt_path | ||||
| if ckpt_file == '': | if ckpt_file == '': | ||||
| raise ValueError("Student ckpt file should not be None") | raise ValueError("Student ckpt file should not be None") | ||||
| if args_opt.device_target == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||||
| elif args_opt.device_target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | |||||
| if args_opt.task_type == "classification": | |||||
| eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||||
| elif args_opt.task_type == "ner": | |||||
| eval_model = BertModelNER(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||||
| else: | else: | ||||
| raise Exception("Target error, GPU or Ascend is supported.") | |||||
| eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||||
| raise ValueError(f"Not support the task type {args_opt.task_type}") | |||||
| param_dict = load_checkpoint(ckpt_file) | param_dict = load_checkpoint(ckpt_file) | ||||
| new_param_dict = {} | new_param_dict = {} | ||||
| for key, value in param_dict.items(): | for key, value in param_dict.items(): | ||||
| @@ -289,11 +307,18 @@ def do_eval_standalone(): | |||||
| eval_dataset = create_tinybert_dataset('td', batch_size=eval_cfg.batch_size, | eval_dataset = create_tinybert_dataset('td', batch_size=eval_cfg.batch_size, | ||||
| device_num=1, rank=0, do_shuffle="false", | device_num=1, rank=0, do_shuffle="false", | ||||
| data_dir=args_opt.eval_data_dir, | data_dir=args_opt.eval_data_dir, | ||||
| schema_dir=args_opt.schema_dir) | |||||
| schema_dir=args_opt.schema_dir, | |||||
| data_type=dataset_type) | |||||
| print('eval dataset size: ', eval_dataset.get_dataset_size()) | print('eval dataset size: ', eval_dataset.get_dataset_size()) | ||||
| print('eval dataset batch size: ', eval_dataset.get_batch_size()) | print('eval dataset batch size: ', eval_dataset.get_batch_size()) | ||||
| callback = Accuracy() | |||||
| if args_opt.assessment_method == "accuracy": | |||||
| callback = Accuracy() | |||||
| elif args_opt.assessment_method == "bf1": | |||||
| callback = F1(num_labels=task.num_labels) | |||||
| elif args_opt.assessment_method == "mf1": | |||||
| callback = F1(num_labels=task.num_labels, mode="MultiLabel") | |||||
| else: | |||||
| raise ValueError("Assessment method not supported, support: [accuracy, f1]") | |||||
| columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | ||||
| for data in eval_dataset.create_dict_iterator(num_epochs=1): | for data in eval_dataset.create_dict_iterator(num_epochs=1): | ||||
| input_data = [] | input_data = [] | ||||
| @@ -302,16 +327,16 @@ def do_eval_standalone(): | |||||
| input_ids, input_mask, token_type_id, label_ids = input_data | input_ids, input_mask, token_type_id, label_ids = input_data | ||||
| logits = eval_model(input_ids, token_type_id, input_mask) | logits = eval_model(input_ids, token_type_id, input_mask) | ||||
| callback.update(logits, label_ids) | callback.update(logits, label_ids) | ||||
| acc = callback.acc_num / callback.total_num | |||||
| print("======================================") | |||||
| print("============== acc is {}".format(acc)) | |||||
| print("======================================") | |||||
| print("==============================================================") | |||||
| eval_result_print(args_opt.assessment_method, callback) | |||||
| print("==============================================================") | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true": | |||||
| raise ValueError("do_train or do eval must have one be true, please confirm your config") | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, | |||||
| reserve_class_name_in_scope=False) | |||||
| if args_opt.device_target == "Ascend": | |||||
| context.set_context(device_id=args_opt.device_id) | |||||
| enable_loss_scale = True | enable_loss_scale = True | ||||
| if args_opt.device_target == "GPU": | if args_opt.device_target == "GPU": | ||||
| if td_student_net_cfg.compute_type != mstype.float32: | if td_student_net_cfg.compute_type != mstype.float32: | ||||
| @@ -321,6 +346,14 @@ if __name__ == '__main__': | |||||
| # and the loss scale is not necessary | # and the loss scale is not necessary | ||||
| enable_loss_scale = False | enable_loss_scale = False | ||||
| if args_opt.device_target == "CPU": | |||||
| logger.warning('CPU only support float32 temporarily, run with float32.') | |||||
| td_teacher_net_cfg.dtype = mstype.float32 | |||||
| td_teacher_net_cfg.compute_type = mstype.float32 | |||||
| td_student_net_cfg.dtype = mstype.float32 | |||||
| td_student_net_cfg.compute_type = mstype.float32 | |||||
| enable_loss_scale = False | |||||
| td_teacher_net_cfg.seq_length = task.seq_length | td_teacher_net_cfg.seq_length = task.seq_length | ||||
| td_student_net_cfg.seq_length = task.seq_length | td_student_net_cfg.seq_length = task.seq_length | ||||
| @@ -32,7 +32,6 @@ python ${PROJECT_DIR}/../run_task_distill.py \ | |||||
| --do_eval="true" \ | --do_eval="true" \ | ||||
| --td_phase1_epoch_size=10 \ | --td_phase1_epoch_size=10 \ | ||||
| --td_phase2_epoch_size=3 \ | --td_phase2_epoch_size=3 \ | ||||
| --task_name="" \ | |||||
| --do_shuffle="true" \ | --do_shuffle="true" \ | ||||
| --enable_data_sink="true" \ | --enable_data_sink="true" \ | ||||
| --data_sink_steps=100 \ | --data_sink_steps=100 \ | ||||
| @@ -44,5 +43,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \ | |||||
| --train_data_dir="" \ | --train_data_dir="" \ | ||||
| --eval_data_dir="" \ | --eval_data_dir="" \ | ||||
| --schema_dir="" \ | --schema_dir="" \ | ||||
| --dataset_type="tfrecord" > log.txt 2>&1 & | |||||
| --dataset_type="tfrecord" \ | |||||
| --task_type="classification" \ | |||||
| --task_name="" \ | |||||
| --assessment_method="accuracy" > log.txt 2>&1 & | |||||
| @@ -32,23 +32,56 @@ class Accuracy(): | |||||
| self.total_num += len(labels) | self.total_num += len(labels) | ||||
| class F1(): | class F1(): | ||||
| """F1""" | |||||
| def __init__(self): | |||||
| ''' | |||||
| calculate F1 score | |||||
| ''' | |||||
| def __init__(self, num_labels=2, mode="Binary"): | |||||
| self.TP = 0 | self.TP = 0 | ||||
| self.FP = 0 | self.FP = 0 | ||||
| self.FN = 0 | self.FN = 0 | ||||
| self.num_labels = num_labels | |||||
| self.P = 0 | |||||
| self.AP = 0 | |||||
| self.mode = mode | |||||
| if self.mode.lower() not in ("binary", "multilabel"): | |||||
| raise ValueError("Assessment mode not supported, support: [Binary, MultiLabel]") | |||||
| def update(self, logits, labels): | def update(self, logits, labels): | ||||
| """Update F1 score""" | |||||
| ''' | |||||
| update F1 score | |||||
| ''' | |||||
| labels = labels.asnumpy() | labels = labels.asnumpy() | ||||
| labels = np.reshape(labels, -1) | labels = np.reshape(labels, -1) | ||||
| logits = logits.asnumpy() | logits = logits.asnumpy() | ||||
| logit_id = np.argmax(logits, axis=-1) | logit_id = np.argmax(logits, axis=-1) | ||||
| logit_id = np.reshape(logit_id, -1) | logit_id = np.reshape(logit_id, -1) | ||||
| pos_eva = np.isin(logit_id, [2, 3, 4, 5, 6, 7]) | |||||
| pos_label = np.isin(labels, [2, 3, 4, 5, 6, 7]) | |||||
| self.TP += np.sum(pos_eva & pos_label) | |||||
| self.FP += np.sum(pos_eva & (~pos_label)) | |||||
| self.FN += np.sum((~pos_eva) & pos_label) | |||||
| print("-----------------precision is ", self.TP / (self.TP + self.FP)) | |||||
| print("-----------------recall is ", self.TP / (self.TP + self.FN)) | |||||
| if self.mode.lower() == "binary": | |||||
| pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)]) | |||||
| pos_label = np.isin(labels, [i for i in range(1, self.num_labels)]) | |||||
| self.TP += np.sum(pos_eva&pos_label) | |||||
| self.FP += np.sum(pos_eva&(~pos_label)) | |||||
| self.FN += np.sum((~pos_eva)&pos_label) | |||||
| else: | |||||
| target = np.zeros((len(labels), self.num_labels), dtype=np.int) | |||||
| pred = np.zeros((len(logit_id), self.num_labels), dtype=np.int) | |||||
| for i, label in enumerate(labels): | |||||
| target[i][label] = 1 | |||||
| for i, label in enumerate(logit_id): | |||||
| pred[i][label] = 1 | |||||
| positives = pred.sum(axis=0) | |||||
| actual_positives = target.sum(axis=0) | |||||
| true_positives = (target * pred).sum(axis=0) | |||||
| self.TP += true_positives | |||||
| self.P += positives | |||||
| self.AP += actual_positives | |||||
| def eval(self): | |||||
| if self.mode.lower() == "binary": | |||||
| f1 = self.TP / (2 * self.TP + self.FP + self.FN) | |||||
| else: | |||||
| tp = np.sum(self.TP) | |||||
| p = np.sum(self.P) | |||||
| ap = np.sum(self.AP) | |||||
| f1 = 2 * tp / (ap + p) | |||||
| return f1 | |||||
| @@ -28,7 +28,7 @@ from mindspore.communication.management import get_group_size | |||||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from .tinybert_model import BertModel, TinyBertModel, BertModelCLS | |||||
| from .tinybert_model import BertModel, TinyBertModel, BertModelCLS, BertModelNER | |||||
| GRADIENT_CLIP_TYPE = 1 | GRADIENT_CLIP_TYPE = 1 | ||||
| @@ -362,8 +362,18 @@ class BertNetworkWithLoss_td(nn.Cell): | |||||
| temperature=1.0, dropout_prob=0.1): | temperature=1.0, dropout_prob=0.1): | ||||
| super(BertNetworkWithLoss_td, self).__init__() | super(BertNetworkWithLoss_td, self).__init__() | ||||
| # load teacher model | # load teacher model | ||||
| self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob, | |||||
| use_one_hot_embeddings, "teacher") | |||||
| if task_type == "classification": | |||||
| self.teacher = BertModelCLS(teacher_config, False, num_labels, dropout_prob, | |||||
| use_one_hot_embeddings, "teacher") | |||||
| self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob, | |||||
| use_one_hot_embeddings, "student") | |||||
| elif task_type == "ner": | |||||
| self.teacher = BertModelNER(teacher_config, False, num_labels, dropout_prob, | |||||
| use_one_hot_embeddings, "teacher") | |||||
| self.bert = BertModelNER(student_config, is_training, num_labels, dropout_prob, | |||||
| use_one_hot_embeddings, "student") | |||||
| else: | |||||
| raise ValueError(f"Not support task type: {task_type}") | |||||
| param_dict = load_checkpoint(teacher_ckpt) | param_dict = load_checkpoint(teacher_ckpt) | ||||
| new_param_dict = {} | new_param_dict = {} | ||||
| for key, value in param_dict.items(): | for key, value in param_dict.items(): | ||||
| @@ -377,8 +387,6 @@ class BertNetworkWithLoss_td(nn.Cell): | |||||
| for param in params: | for param in params: | ||||
| param.requires_grad = False | param.requires_grad = False | ||||
| # load student model | # load student model | ||||
| self.bert = BertModelCLS(student_config, is_training, num_labels, dropout_prob, | |||||
| use_one_hot_embeddings, "student") | |||||
| param_dict = load_checkpoint(student_ckpt) | param_dict = load_checkpoint(student_ckpt) | ||||
| if is_predistill: | if is_predistill: | ||||
| new_param_dict = {} | new_param_dict = {} | ||||
| @@ -401,7 +409,7 @@ class BertNetworkWithLoss_td(nn.Cell): | |||||
| self.is_predistill = is_predistill | self.is_predistill = is_predistill | ||||
| self.is_att_fit = is_att_fit | self.is_att_fit = is_att_fit | ||||
| self.is_rep_fit = is_rep_fit | self.is_rep_fit = is_rep_fit | ||||
| self.task_type = task_type | |||||
| self.use_soft_cross_entropy = task_type in ["classification", "ner"] | |||||
| self.temperature = temperature | self.temperature = temperature | ||||
| self.loss_mse = nn.MSELoss() | self.loss_mse = nn.MSELoss() | ||||
| self.select = P.Select() | self.select = P.Select() | ||||
| @@ -456,7 +464,7 @@ class BertNetworkWithLoss_td(nn.Cell): | |||||
| rep_loss += self.loss_mse(student_rep, teacher_rep) | rep_loss += self.loss_mse(student_rep, teacher_rep) | ||||
| total_loss += rep_loss | total_loss += rep_loss | ||||
| else: | else: | ||||
| if self.task_type == "classification": | |||||
| if self.use_soft_cross_entropy: | |||||
| cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature) | cls_loss = self.soft_cross_entropy(student_logits / self.temperature, teacher_logits / self.temperature) | ||||
| else: | else: | ||||
| cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1]) | cls_loss = self.loss_mse(student_logits[len(student_logits) - 1], label_ids[len(label_ids) - 1]) | ||||
| @@ -926,3 +926,40 @@ class BertModelCLS(nn.Cell): | |||||
| if self._phase == 'train' or self.phase_type == "teacher": | if self._phase == 'train' or self.phase_type == "teacher": | ||||
| return seq_output, att_output, logits, log_probs | return seq_output, att_output, logits, log_probs | ||||
| return log_probs | return log_probs | ||||
| class BertModelNER(nn.Cell): | |||||
| """ | |||||
| This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11). | |||||
| The returned output represents the final logits as the results of log_softmax is proportional to that of softmax. | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=11, dropout_prob=0.0, | |||||
| use_one_hot_embeddings=False, phase_type="student"): | |||||
| super(BertModelNER, self).__init__() | |||||
| if not is_training: | |||||
| config.hidden_dropout_prob = 0.0 | |||||
| config.hidden_probs_dropout_prob = 0.0 | |||||
| self.bert = BertModel(config, is_training, use_one_hot_embeddings) | |||||
| self.cast = P.Cast() | |||||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| self.dtype = config.dtype | |||||
| self.num_labels = num_labels | |||||
| self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, | |||||
| has_bias=True).to_float(config.compute_type) | |||||
| self.dropout = nn.ReLU() | |||||
| self.reshape = P.Reshape() | |||||
| self.shape = (-1, config.hidden_size) | |||||
| self.origin_shape = (-1, config.seq_length, self.num_labels) | |||||
| def construct(self, input_ids, input_mask, token_type_id): | |||||
| """Return the final logits as the results of log_softmax.""" | |||||
| sequence_output, _, _, encoder_outputs, attention_outputs = \ | |||||
| self.bert(input_ids, token_type_id, input_mask) | |||||
| seq = self.dropout(sequence_output) | |||||
| seq = self.reshape(seq, self.shape) | |||||
| logits = self.dense_1(seq) | |||||
| logits = self.cast(logits, self.dtype) | |||||
| return_value = self.log_softmax(logits) | |||||
| if self._phase == 'train' or self.phase_type == "teacher": | |||||
| return encoder_outputs, attention_outputs, logits, return_value | |||||
| return return_value | |||||