| @@ -46,7 +46,7 @@ usage: run_standalone_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T | |||||
| options: | options: | ||||
| --distribute whether to run distributely: "true" | "false" | --distribute whether to run distributely: "true" | "false" | ||||
| --device_target target device to run, currently only support "Ascend" | |||||
| --device_target targeted device to run task: "Ascend" | "GPU" | |||||
| --epoch_size epoch size: N, default is 1 | --epoch_size epoch size: N, default is 1 | ||||
| --device_id device id: N, default is 0 | --device_id device id: N, default is 0 | ||||
| --enable_data_sink enable data sink: "true" | "false", default is "true" | --enable_data_sink enable data sink: "true" | "false", default is "true" | ||||
| @@ -64,7 +64,7 @@ usage: run_distribute_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T | |||||
| options: | options: | ||||
| --distribute whether to run distributely: "true" | "false" | --distribute whether to run distributely: "true" | "false" | ||||
| --device_target target device to run, currently only support "Ascend" | |||||
| --device_target targeted device to run task: "Ascend" | "GPU" | |||||
| --epoch_size epoch size: N, default is 1 | --epoch_size epoch size: N, default is 1 | ||||
| --device_id device id: N, default is 0 | --device_id device id: N, default is 0 | ||||
| --device_num device id to run task | --device_num device id to run task | ||||
| @@ -20,16 +20,20 @@ import argparse | |||||
| import datetime | import datetime | ||||
| import numpy | import numpy | ||||
| import mindspore.communication.management as D | import mindspore.communication.management as D | ||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.callback import TimeMonitor | from mindspore.train.callback import TimeMonitor | ||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| from mindspore.nn.optim import AdamWeightDecay | from mindspore.nn.optim import AdamWeightDecay | ||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | ||||
| from mindspore import log as logger | |||||
| from src.dataset import create_tinybert_dataset | from src.dataset import create_tinybert_dataset | ||||
| from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate | from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate | ||||
| from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg | from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg | ||||
| from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd | |||||
| from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell | |||||
| def run_general_distill(): | def run_general_distill(): | ||||
| """ | """ | ||||
| @@ -53,7 +57,6 @@ def run_general_distill(): | |||||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | ||||
| context.set_context(reserve_class_name_in_scope=False) | context.set_context(reserve_class_name_in_scope=False) | ||||
| context.set_context(variable_memory_max_size="30GB") | context.set_context(variable_memory_max_size="30GB") | ||||
| @@ -61,13 +64,17 @@ def run_general_distill(): | |||||
| save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, | save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, | ||||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | ||||
| if not os.path.exists(save_ckpt_dir): | |||||
| os.makedirs(save_ckpt_dir) | |||||
| if args_opt.distribute == "true": | if args_opt.distribute == "true": | ||||
| D.init('hccl') | |||||
| device_num = args_opt.device_num | |||||
| rank = args_opt.device_id % device_num | |||||
| if args_opt.device_target == 'Ascend': | |||||
| D.init('hccl') | |||||
| device_num = args_opt.device_num | |||||
| rank = args_opt.device_id % device_num | |||||
| else: | |||||
| D.init('nccl') | |||||
| device_num = D.get_group_size() | |||||
| rank = D.get_rank() | |||||
| save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank) | |||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | ||||
| device_num=device_num) | device_num=device_num) | ||||
| @@ -75,6 +82,21 @@ def run_general_distill(): | |||||
| rank = 0 | rank = 0 | ||||
| device_num = 1 | device_num = 1 | ||||
| if not os.path.exists(save_ckpt_dir): | |||||
| os.makedirs(save_ckpt_dir) | |||||
| enable_loss_scale = True | |||||
| if args_opt.device_target == "GPU": | |||||
| if bert_teacher_net_cfg.compute_type != mstype.float32: | |||||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||||
| bert_teacher_net_cfg.compute_type = mstype.float32 | |||||
| if bert_student_net_cfg.compute_type != mstype.float32: | |||||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||||
| bert_student_net_cfg.compute_type = mstype.float32 | |||||
| # Both the forward and backward of the network are calculated using fp32, | |||||
| # and the loss scale is not necessary | |||||
| enable_loss_scale = False | |||||
| netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, | netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, | ||||
| teacher_ckpt=args_opt.load_teacher_ckpt_path, | teacher_ckpt=args_opt.load_teacher_ckpt_path, | ||||
| student_config=bert_student_net_cfg, | student_config=bert_student_net_cfg, | ||||
| @@ -82,11 +104,11 @@ def run_general_distill(): | |||||
| dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, | dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, | ||||
| args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) | args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) | ||||
| dataset_size = dataset.get_dataset_size() | dataset_size = dataset.get_dataset_size() | ||||
| print('dataset size: ', dataset_size) | print('dataset size: ', dataset_size) | ||||
| print("dataset repeatcount: ", dataset.get_repeat_count()) | |||||
| if args_opt.enable_data_sink == "true": | if args_opt.enable_data_sink == "true": | ||||
| repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps | |||||
| repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps | |||||
| time_monitor_steps = args_opt.data_sink_steps | time_monitor_steps = args_opt.data_sink_steps | ||||
| else: | else: | ||||
| repeat_count = args_opt.epoch_size | repeat_count = args_opt.epoch_size | ||||
| @@ -110,12 +132,13 @@ def run_general_distill(): | |||||
| args_opt.save_ckpt_step, | args_opt.save_ckpt_step, | ||||
| args_opt.max_ckpt_num, | args_opt.max_ckpt_num, | ||||
| save_ckpt_dir)] | save_ckpt_dir)] | ||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value, | |||||
| scale_factor=common_cfg.scale_factor, | |||||
| scale_window=common_cfg.scale_window) | |||||
| netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| if enable_loss_scale: | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value, | |||||
| scale_factor=common_cfg.scale_factor, | |||||
| scale_window=common_cfg.scale_window) | |||||
| netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| else: | |||||
| netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer) | |||||
| model = Model(netwithgrads) | model = Model(netwithgrads) | ||||
| model.train(repeat_count, dataset, callbacks=callback, | model.train(repeat_count, dataset, callbacks=callback, | ||||
| dataset_sink_mode=(args_opt.enable_data_sink == "true"), | dataset_sink_mode=(args_opt.enable_data_sink == "true"), | ||||
| @@ -18,6 +18,7 @@ | |||||
| import os | import os | ||||
| import re | import re | ||||
| import argparse | import argparse | ||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| @@ -25,11 +26,12 @@ from mindspore.train.callback import TimeMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | ||||
| from mindspore.nn.optim import AdamWeightDecay | from mindspore.nn.optim import AdamWeightDecay | ||||
| from mindspore import log as logger | |||||
| from src.dataset import create_tinybert_dataset | from src.dataset import create_tinybert_dataset | ||||
| from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate | from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate | ||||
| from src.assessment_method import Accuracy | from src.assessment_method import Accuracy | ||||
| from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg | from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg | ||||
| from src.tinybert_for_gd_td import BertEvaluationCell, BertNetworkWithLoss_td | |||||
| from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell | |||||
| from src.tinybert_model import BertModelCLS | from src.tinybert_model import BertModelCLS | ||||
| _cur_dir = os.getcwd() | _cur_dir = os.getcwd() | ||||
| @@ -45,14 +47,14 @@ def parse_args(): | |||||
| parse args | parse args | ||||
| """ | """ | ||||
| parser = argparse.ArgumentParser(description='tinybert task distill') | parser = argparse.ArgumentParser(description='tinybert task distill') | ||||
| parser.add_argument("--device_target", type=str, default="Ascend", help="NPU device, default is Ascend.") | |||||
| parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'], | |||||
| help='device where the code will be implemented. (Default: Ascend)') | |||||
| parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.") | parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.") | ||||
| parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.") | parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.") | ||||
| parser.add_argument("--td_phase1_epoch_size", type=int, default=10, | parser.add_argument("--td_phase1_epoch_size", type=int, default=10, | ||||
| help="Epoch size for td phase 1, default is 10.") | help="Epoch size for td phase 1, default is 10.") | ||||
| parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.") | parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.") | ||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | ||||
| parser.add_argument("--num_labels", type=int, default=2, help="Classfication task, support SST2, QNLI, MNLI.") | |||||
| parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") | parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") | ||||
| parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") | parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") | ||||
| parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") | parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") | ||||
| @@ -64,11 +66,43 @@ def parse_args(): | |||||
| parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | ||||
| parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | ||||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | ||||
| parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], | |||||
| help="The name of the task to train.") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| return args | return args | ||||
| args_opt = parse_args() | args_opt = parse_args() | ||||
| DEFAULT_NUM_LABELS = 2 | |||||
| DEFAULT_SEQ_LENGTH = 128 | |||||
| task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, | |||||
| "QNLI": {"num_labels": 2, "seq_length": 128}, | |||||
| "MNLI": {"num_labels": 3, "seq_length": 128}} | |||||
| class Task: | |||||
| """ | |||||
| Encapsulation class of get the task parameter. | |||||
| """ | |||||
| def __init__(self, task_name): | |||||
| self.task_name = task_name | |||||
| @property | |||||
| def num_labels(self): | |||||
| if self.task_name in task_params and "num_labels" in task_params[self.task_name]: | |||||
| return task_params[self.task_name]["num_labels"] | |||||
| return DEFAULT_NUM_LABELS | |||||
| @property | |||||
| def seq_length(self): | |||||
| if self.task_name in task_params and "seq_length" in task_params[self.task_name]: | |||||
| return task_params[self.task_name]["seq_length"] | |||||
| return DEFAULT_SEQ_LENGTH | |||||
| task = Task(args_opt.task_name) | |||||
| def run_predistill(): | def run_predistill(): | ||||
| """ | """ | ||||
| run predistill | run predistill | ||||
| @@ -81,7 +115,7 @@ def run_predistill(): | |||||
| netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | ||||
| student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | ||||
| is_training=True, task_type='classification', | is_training=True, task_type='classification', | ||||
| num_labels=args_opt.num_labels, is_predistill=True) | |||||
| num_labels=task.num_labels, is_predistill=True) | |||||
| rank = 0 | rank = 0 | ||||
| device_num = 1 | device_num = 1 | ||||
| @@ -91,8 +125,9 @@ def run_predistill(): | |||||
| dataset_size = dataset.get_dataset_size() | dataset_size = dataset.get_dataset_size() | ||||
| print('td1 dataset size: ', dataset_size) | print('td1 dataset size: ', dataset_size) | ||||
| print('td1 dataset repeatcount: ', dataset.get_repeat_count()) | |||||
| if args_opt.enable_data_sink == 'true': | if args_opt.enable_data_sink == 'true': | ||||
| repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps | |||||
| repeat_count = args_opt.td_phase1_epoch_size * dataset_size // args_opt.data_sink_steps | |||||
| time_monitor_steps = args_opt.data_sink_steps | time_monitor_steps = args_opt.data_sink_steps | ||||
| else: | else: | ||||
| repeat_count = args_opt.td_phase1_epoch_size | repeat_count = args_opt.td_phase1_epoch_size | ||||
| @@ -117,10 +152,14 @@ def run_predistill(): | |||||
| args_opt.save_ckpt_step, | args_opt.save_ckpt_step, | ||||
| args_opt.max_ckpt_num, | args_opt.max_ckpt_num, | ||||
| td_phase1_save_ckpt_dir)] | td_phase1_save_ckpt_dir)] | ||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | |||||
| scale_factor=cfg.scale_factor, | |||||
| scale_window=cfg.scale_window) | |||||
| netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| if enable_loss_scale: | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | |||||
| scale_factor=cfg.scale_factor, | |||||
| scale_window=cfg.scale_window) | |||||
| netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| else: | |||||
| netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer) | |||||
| model = Model(netwithgrads) | model = Model(netwithgrads) | ||||
| model.train(repeat_count, dataset, callbacks=callback, | model.train(repeat_count, dataset, callbacks=callback, | ||||
| dataset_sink_mode=(args_opt.enable_data_sink == 'true'), | dataset_sink_mode=(args_opt.enable_data_sink == 'true'), | ||||
| @@ -139,7 +178,7 @@ def run_task_distill(ckpt_file): | |||||
| netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | ||||
| student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | ||||
| is_training=True, task_type='classification', | is_training=True, task_type='classification', | ||||
| num_labels=args_opt.num_labels, is_predistill=False) | |||||
| num_labels=task.num_labels, is_predistill=False) | |||||
| rank = 0 | rank = 0 | ||||
| device_num = 1 | device_num = 1 | ||||
| @@ -149,6 +188,7 @@ def run_task_distill(ckpt_file): | |||||
| dataset_size = train_dataset.get_dataset_size() | dataset_size = train_dataset.get_dataset_size() | ||||
| print('td2 train dataset size: ', dataset_size) | print('td2 train dataset size: ', dataset_size) | ||||
| print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count()) | |||||
| if args_opt.enable_data_sink == 'true': | if args_opt.enable_data_sink == 'true': | ||||
| repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps | repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps | ||||
| time_monitor_steps = args_opt.data_sink_steps | time_monitor_steps = args_opt.data_sink_steps | ||||
| @@ -175,6 +215,7 @@ def run_task_distill(ckpt_file): | |||||
| eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, | eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, | ||||
| device_num, rank, args_opt.do_shuffle, | device_num, rank, args_opt.do_shuffle, | ||||
| args_opt.eval_data_dir, args_opt.schema_dir) | args_opt.eval_data_dir, args_opt.schema_dir) | ||||
| print('td2 eval dataset size: ', eval_dataset.get_dataset_size()) | |||||
| if args_opt.do_eval.lower() == "true": | if args_opt.do_eval.lower() == "true": | ||||
| callback = [TimeMonitor(time_monitor_steps), LossCallBack(), | callback = [TimeMonitor(time_monitor_steps), LossCallBack(), | ||||
| @@ -185,11 +226,14 @@ def run_task_distill(ckpt_file): | |||||
| args_opt.save_ckpt_step, | args_opt.save_ckpt_step, | ||||
| args_opt.max_ckpt_num, | args_opt.max_ckpt_num, | ||||
| td_phase2_save_ckpt_dir)] | td_phase2_save_ckpt_dir)] | ||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | |||||
| scale_factor=cfg.scale_factor, | |||||
| scale_window=cfg.scale_window) | |||||
| if enable_loss_scale: | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | |||||
| scale_factor=cfg.scale_factor, | |||||
| scale_window=cfg.scale_window) | |||||
| netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| else: | |||||
| netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer) | |||||
| model = Model(netwithgrads) | model = Model(netwithgrads) | ||||
| model.train(repeat_count, train_dataset, callbacks=callback, | model.train(repeat_count, train_dataset, callbacks=callback, | ||||
| dataset_sink_mode=(args_opt.enable_data_sink == 'true'), | dataset_sink_mode=(args_opt.enable_data_sink == 'true'), | ||||
| @@ -203,7 +247,7 @@ def do_eval_standalone(): | |||||
| if ckpt_file == '': | if ckpt_file == '': | ||||
| raise ValueError("Student ckpt file should not be None") | raise ValueError("Student ckpt file should not be None") | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | ||||
| eval_model = BertModelCLS(td_student_net_cfg, False, args_opt.num_labels, 0.0, phase_type="student") | |||||
| eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||||
| param_dict = load_checkpoint(ckpt_file) | param_dict = load_checkpoint(ckpt_file) | ||||
| new_param_dict = {} | new_param_dict = {} | ||||
| for key, value in param_dict.items(): | for key, value in param_dict.items(): | ||||
| @@ -213,10 +257,13 @@ def do_eval_standalone(): | |||||
| load_param_into_net(eval_model, new_param_dict) | load_param_into_net(eval_model, new_param_dict) | ||||
| eval_model.set_train(False) | eval_model.set_train(False) | ||||
| eval_dataset = create_tinybert_dataset('td', batch_size=1, | |||||
| eval_dataset = create_tinybert_dataset('td', batch_size=td_student_net_cfg.batch_size, | |||||
| device_num=1, rank=0, do_shuffle="false", | device_num=1, rank=0, do_shuffle="false", | ||||
| data_dir=args_opt.eval_data_dir, | data_dir=args_opt.eval_data_dir, | ||||
| schema_dir=args_opt.schema_dir) | schema_dir=args_opt.schema_dir) | ||||
| print('eval dataset size: ', eval_dataset.get_dataset_size()) | |||||
| print('eval dataset batch size: ', eval_dataset.get_batch_size()) | |||||
| callback = Accuracy() | callback = Accuracy() | ||||
| columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | ||||
| for data in eval_dataset.create_dict_iterator(): | for data in eval_dataset.create_dict_iterator(): | ||||
| @@ -231,9 +278,26 @@ def do_eval_standalone(): | |||||
| print("============== acc is {}".format(acc)) | print("============== acc is {}".format(acc)) | ||||
| print("======================================") | print("======================================") | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true": | if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true": | ||||
| raise ValueError("do_train or do eval must have one be true, please confirm your config") | raise ValueError("do_train or do eval must have one be true, please confirm your config") | ||||
| enable_loss_scale = True | |||||
| if args_opt.device_target == "GPU": | |||||
| if td_teacher_net_cfg.compute_type != mstype.float32: | |||||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||||
| td_teacher_net_cfg.compute_type = mstype.float32 | |||||
| if td_student_net_cfg.compute_type != mstype.float32: | |||||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||||
| td_student_net_cfg.compute_type = mstype.float32 | |||||
| # Both the forward and backward of the network are calculated using fp32, | |||||
| # and the loss scale is not necessary | |||||
| enable_loss_scale = False | |||||
| td_teacher_net_cfg.seq_length = task.seq_length | |||||
| td_student_net_cfg.seq_length = task.seq_length | |||||
| if args_opt.do_train == "true": | if args_opt.do_train == "true": | ||||
| # run predistill | # run predistill | ||||
| run_predistill() | run_predistill() | ||||
| @@ -0,0 +1,40 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the scipt as: " | |||||
| echo "bash run_distribute_gd_for_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR TEACHER_CKPT_PATH" | |||||
| echo "for example: bash run_distribute_gd_for_gpu.sh 8 3 /path/data/ /path/datasetSchema.json /path/bert_base.ckpt" | |||||
| echo "It is better to use absolute path." | |||||
| echo "==============================================================================================================" | |||||
| RANK_SIZE=$1 | |||||
| EPOCH_SIZE=$2 | |||||
| DATA_DIR=$3 | |||||
| SCHEMA_DIR=$4 | |||||
| TEACHER_CKPT_PATH=$5 | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| mpirun --allow-run-as-root -n $RANK_SIZE \ | |||||
| python ${PROJECT_DIR}/../run_general_distill.py \ | |||||
| --distribute="true" \ | |||||
| --device_target="GPU" \ | |||||
| --epoch_size=$EPOCH_SIZE \ | |||||
| --save_ckpt_path="" \ | |||||
| --data_dir=$DATA_DIR \ | |||||
| --schema_dir=$SCHEMA_DIR \ | |||||
| --load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 & | |||||
| @@ -32,7 +32,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \ | |||||
| --do_eval="true" \ | --do_eval="true" \ | ||||
| --td_phase1_epoch_size=10 \ | --td_phase1_epoch_size=10 \ | ||||
| --td_phase2_epoch_size=3 \ | --td_phase2_epoch_size=3 \ | ||||
| --num_labels=2 \ | |||||
| --task_name="" \ | |||||
| --do_shuffle="true" \ | --do_shuffle="true" \ | ||||
| --enable_data_sink="true" \ | --enable_data_sink="true" \ | ||||
| --data_sink_steps=100 \ | --data_sink_steps=100 \ | ||||
| @@ -19,7 +19,6 @@ import os | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.dataset.engine.datasets as de | import mindspore.dataset.engine.datasets as de | ||||
| import mindspore.dataset.transforms.c_transforms as C | import mindspore.dataset.transforms.c_transforms as C | ||||
| from mindspore import log as logger | |||||
| def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, | def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, | ||||
| do_shuffle="true", data_dir=None, schema_dir=None): | do_shuffle="true", data_dir=None, schema_dir=None): | ||||
| @@ -45,7 +44,5 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, | |||||
| ds = ds.map(input_columns="label_ids", operations=type_cast_op) | ds = ds.map(input_columns="label_ids", operations=type_cast_op) | ||||
| # apply batch operations | # apply batch operations | ||||
| ds = ds.batch(batch_size, drop_remainder=True) | ds = ds.batch(batch_size, drop_remainder=True) | ||||
| logger.info("data size: {}".format(ds.get_dataset_size())) | |||||
| logger.info("repeatcount: {}".format(ds.get_repeat_count())) | |||||
| return ds | return ds | ||||
| @@ -292,6 +292,60 @@ class BertTrainWithLossScaleCell(nn.Cell): | |||||
| ret = (loss, cond, scaling_sens) | ret = (loss, cond, scaling_sens) | ||||
| return F.depend(ret, succ) | return F.depend(ret, succ) | ||||
| class BertTrainCell(nn.Cell): | |||||
| """ | |||||
| Encapsulation class of bert network training. | |||||
| Append an optimizer to the training network after that the construct | |||||
| function can be called to create the backward graph. | |||||
| Args: | |||||
| network (Cell): The training network. Note that loss function should have been added. | |||||
| optimizer (Optimizer): Optimizer for updating the weights. | |||||
| sens (Number): The adjust parameter. Default: 1.0. | |||||
| """ | |||||
| def __init__(self, network, optimizer, sens=1.0): | |||||
| super(BertTrainCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.sens = sens | |||||
| self.grad = C.GradOperation('grad', | |||||
| get_by_list=True, | |||||
| sens_param=True) | |||||
| self.reducer_flag = False | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = F.identity | |||||
| self.degree = 1 | |||||
| if self.reducer_flag: | |||||
| mean = context.get_auto_parallel_context("mirror_mean") | |||||
| self.degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree) | |||||
| self.cast = P.Cast() | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, | |||||
| input_ids, | |||||
| input_mask, | |||||
| token_type_id): | |||||
| """Defines the computation performed.""" | |||||
| weights = self.weights | |||||
| loss = self.network(input_ids, | |||||
| input_mask, | |||||
| token_type_id) | |||||
| grads = self.grad(self.network, weights)(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| self.cast(F.tuple_to_array((self.sens,)), | |||||
| mstype.float32)) | |||||
| # apply grad reducer on grads | |||||
| grads = self.grad_reducer(grads) | |||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||||
| succ = self.optimizer(grads) | |||||
| return F.depend(loss, succ) | |||||
| class BertNetworkWithLoss_td(nn.Cell): | class BertNetworkWithLoss_td(nn.Cell): | ||||
| """ | """ | ||||
| Provide bert pre-training loss through network. | Provide bert pre-training loss through network. | ||||
| @@ -411,12 +465,12 @@ class BertNetworkWithLoss_td(nn.Cell): | |||||
| total_loss += cls_loss | total_loss += cls_loss | ||||
| return self.cast(total_loss, mstype.float32) | return self.cast(total_loss, mstype.float32) | ||||
| class BertEvaluationCell(nn.Cell): | |||||
| class BertEvaluationWithLossScaleCell(nn.Cell): | |||||
| """ | """ | ||||
| Especifically defined for finetuning where only four inputs tensor are needed. | Especifically defined for finetuning where only four inputs tensor are needed. | ||||
| """ | """ | ||||
| def __init__(self, network, optimizer, scale_update_cell=None): | def __init__(self, network, optimizer, scale_update_cell=None): | ||||
| super(BertEvaluationCell, self).__init__(auto_prefix=False) | |||||
| super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | self.network = network | ||||
| self.weights = optimizer.parameters | self.weights = optimizer.parameters | ||||
| self.optimizer = optimizer | self.optimizer = optimizer | ||||
| @@ -496,3 +550,54 @@ class BertEvaluationCell(nn.Cell): | |||||
| succ = self.optimizer(grads) | succ = self.optimizer(grads) | ||||
| ret = (loss, cond, scaling_sens) | ret = (loss, cond, scaling_sens) | ||||
| return F.depend(ret, succ) | return F.depend(ret, succ) | ||||
| class BertEvaluationCell(nn.Cell): | |||||
| """ | |||||
| Especifically defined for finetuning where only four inputs tensor are needed. | |||||
| """ | |||||
| def __init__(self, network, optimizer, sens=1.0): | |||||
| super(BertEvaluationCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.sens = sens | |||||
| self.grad = C.GradOperation('grad', | |||||
| get_by_list=True, | |||||
| sens_param=True) | |||||
| self.reducer_flag = False | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = F.identity | |||||
| self.degree = 1 | |||||
| if self.reducer_flag: | |||||
| mean = context.get_auto_parallel_context("mirror_mean") | |||||
| self.degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree) | |||||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||||
| self.cast = P.Cast() | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, | |||||
| input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| label_ids): | |||||
| """Defines the computation performed.""" | |||||
| weights = self.weights | |||||
| loss = self.network(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| label_ids) | |||||
| grads = self.grad(self.network, weights)(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| label_ids, | |||||
| self.cast(F.tuple_to_array((self.sens,)), | |||||
| mstype.float32)) | |||||
| # apply grad reducer on grads | |||||
| grads = self.grad_reducer(grads) | |||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||||
| succ = self.optimizer(grads) | |||||
| return F.depend(loss, succ) | |||||
| @@ -110,7 +110,10 @@ class EvalCallBack(Callback): | |||||
| if acc > self.global_acc: | if acc > self.global_acc: | ||||
| self.global_acc = acc | self.global_acc = acc | ||||
| print("The best acc is {}".format(acc)) | print("The best acc is {}".format(acc)) | ||||
| _exec_save_checkpoint(self.network, "eval_model.ckpt") | |||||
| eval_model_ckpt_file = "eval_model.ckpt" | |||||
| if os.path.exists(eval_model_ckpt_file): | |||||
| os.remove(eval_model_ckpt_file) | |||||
| _exec_save_checkpoint(self.network, eval_model_ckpt_file) | |||||
| class BertLearningRate(LearningRateSchedule): | class BertLearningRate(LearningRateSchedule): | ||||
| """ | """ | ||||