| @@ -46,7 +46,7 @@ usage: run_standalone_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T | |||
| options: | |||
| --distribute whether to run distributely: "true" | "false" | |||
| --device_target target device to run, currently only support "Ascend" | |||
| --device_target targeted device to run task: "Ascend" | "GPU" | |||
| --epoch_size epoch size: N, default is 1 | |||
| --device_id device id: N, default is 0 | |||
| --enable_data_sink enable data sink: "true" | "false", default is "true" | |||
| @@ -64,7 +64,7 @@ usage: run_distribute_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T | |||
| options: | |||
| --distribute whether to run distributely: "true" | "false" | |||
| --device_target target device to run, currently only support "Ascend" | |||
| --device_target targeted device to run task: "Ascend" | "GPU" | |||
| --epoch_size epoch size: N, default is 1 | |||
| --device_id device id: N, default is 0 | |||
| --device_num device id to run task | |||
| @@ -20,16 +20,20 @@ import argparse | |||
| import datetime | |||
| import numpy | |||
| import mindspore.communication.management as D | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import context | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.callback import TimeMonitor | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| from mindspore.nn.optim import AdamWeightDecay | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore import log as logger | |||
| from src.dataset import create_tinybert_dataset | |||
| from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate | |||
| from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg | |||
| from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd | |||
| from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell | |||
| def run_general_distill(): | |||
| """ | |||
| @@ -53,7 +57,6 @@ def run_general_distill(): | |||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||
| context.set_context(reserve_class_name_in_scope=False) | |||
| context.set_context(variable_memory_max_size="30GB") | |||
| @@ -61,13 +64,17 @@ def run_general_distill(): | |||
| save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, | |||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
| if not os.path.exists(save_ckpt_dir): | |||
| os.makedirs(save_ckpt_dir) | |||
| if args_opt.distribute == "true": | |||
| D.init('hccl') | |||
| device_num = args_opt.device_num | |||
| rank = args_opt.device_id % device_num | |||
| if args_opt.device_target == 'Ascend': | |||
| D.init('hccl') | |||
| device_num = args_opt.device_num | |||
| rank = args_opt.device_id % device_num | |||
| else: | |||
| D.init('nccl') | |||
| device_num = D.get_group_size() | |||
| rank = D.get_rank() | |||
| save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | |||
| device_num=device_num) | |||
| @@ -75,6 +82,21 @@ def run_general_distill(): | |||
| rank = 0 | |||
| device_num = 1 | |||
| if not os.path.exists(save_ckpt_dir): | |||
| os.makedirs(save_ckpt_dir) | |||
| enable_loss_scale = True | |||
| if args_opt.device_target == "GPU": | |||
| if bert_teacher_net_cfg.compute_type != mstype.float32: | |||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||
| bert_teacher_net_cfg.compute_type = mstype.float32 | |||
| if bert_student_net_cfg.compute_type != mstype.float32: | |||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||
| bert_student_net_cfg.compute_type = mstype.float32 | |||
| # Both the forward and backward of the network are calculated using fp32, | |||
| # and the loss scale is not necessary | |||
| enable_loss_scale = False | |||
| netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, | |||
| teacher_ckpt=args_opt.load_teacher_ckpt_path, | |||
| student_config=bert_student_net_cfg, | |||
| @@ -82,11 +104,11 @@ def run_general_distill(): | |||
| dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, | |||
| args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) | |||
| dataset_size = dataset.get_dataset_size() | |||
| print('dataset size: ', dataset_size) | |||
| print("dataset repeatcount: ", dataset.get_repeat_count()) | |||
| if args_opt.enable_data_sink == "true": | |||
| repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps | |||
| repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps | |||
| time_monitor_steps = args_opt.data_sink_steps | |||
| else: | |||
| repeat_count = args_opt.epoch_size | |||
| @@ -110,12 +132,13 @@ def run_general_distill(): | |||
| args_opt.save_ckpt_step, | |||
| args_opt.max_ckpt_num, | |||
| save_ckpt_dir)] | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value, | |||
| scale_factor=common_cfg.scale_factor, | |||
| scale_window=common_cfg.scale_window) | |||
| netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||
| if enable_loss_scale: | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value, | |||
| scale_factor=common_cfg.scale_factor, | |||
| scale_window=common_cfg.scale_window) | |||
| netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||
| else: | |||
| netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer) | |||
| model = Model(netwithgrads) | |||
| model.train(repeat_count, dataset, callbacks=callback, | |||
| dataset_sink_mode=(args_opt.enable_data_sink == "true"), | |||
| @@ -18,6 +18,7 @@ | |||
| import os | |||
| import re | |||
| import argparse | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.train.model import Model | |||
| @@ -25,11 +26,12 @@ from mindspore.train.callback import TimeMonitor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore.nn.optim import AdamWeightDecay | |||
| from mindspore import log as logger | |||
| from src.dataset import create_tinybert_dataset | |||
| from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate | |||
| from src.assessment_method import Accuracy | |||
| from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg | |||
| from src.tinybert_for_gd_td import BertEvaluationCell, BertNetworkWithLoss_td | |||
| from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell | |||
| from src.tinybert_model import BertModelCLS | |||
| _cur_dir = os.getcwd() | |||
| @@ -45,14 +47,14 @@ def parse_args(): | |||
| parse args | |||
| """ | |||
| parser = argparse.ArgumentParser(description='tinybert task distill') | |||
| parser.add_argument("--device_target", type=str, default="Ascend", help="NPU device, default is Ascend.") | |||
| parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented. (Default: Ascend)') | |||
| parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.") | |||
| parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.") | |||
| parser.add_argument("--td_phase1_epoch_size", type=int, default=10, | |||
| help="Epoch size for td phase 1, default is 10.") | |||
| parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||
| parser.add_argument("--num_labels", type=int, default=2, help="Classfication task, support SST2, QNLI, MNLI.") | |||
| parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") | |||
| parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") | |||
| parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") | |||
| @@ -64,11 +66,43 @@ def parse_args(): | |||
| parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | |||
| parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], | |||
| help="The name of the task to train.") | |||
| args = parser.parse_args() | |||
| return args | |||
| args_opt = parse_args() | |||
| DEFAULT_NUM_LABELS = 2 | |||
| DEFAULT_SEQ_LENGTH = 128 | |||
| task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, | |||
| "QNLI": {"num_labels": 2, "seq_length": 128}, | |||
| "MNLI": {"num_labels": 3, "seq_length": 128}} | |||
| class Task: | |||
| """ | |||
| Encapsulation class of get the task parameter. | |||
| """ | |||
| def __init__(self, task_name): | |||
| self.task_name = task_name | |||
| @property | |||
| def num_labels(self): | |||
| if self.task_name in task_params and "num_labels" in task_params[self.task_name]: | |||
| return task_params[self.task_name]["num_labels"] | |||
| return DEFAULT_NUM_LABELS | |||
| @property | |||
| def seq_length(self): | |||
| if self.task_name in task_params and "seq_length" in task_params[self.task_name]: | |||
| return task_params[self.task_name]["seq_length"] | |||
| return DEFAULT_SEQ_LENGTH | |||
| task = Task(args_opt.task_name) | |||
| def run_predistill(): | |||
| """ | |||
| run predistill | |||
| @@ -81,7 +115,7 @@ def run_predistill(): | |||
| netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | |||
| student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | |||
| is_training=True, task_type='classification', | |||
| num_labels=args_opt.num_labels, is_predistill=True) | |||
| num_labels=task.num_labels, is_predistill=True) | |||
| rank = 0 | |||
| device_num = 1 | |||
| @@ -91,8 +125,9 @@ def run_predistill(): | |||
| dataset_size = dataset.get_dataset_size() | |||
| print('td1 dataset size: ', dataset_size) | |||
| print('td1 dataset repeatcount: ', dataset.get_repeat_count()) | |||
| if args_opt.enable_data_sink == 'true': | |||
| repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps | |||
| repeat_count = args_opt.td_phase1_epoch_size * dataset_size // args_opt.data_sink_steps | |||
| time_monitor_steps = args_opt.data_sink_steps | |||
| else: | |||
| repeat_count = args_opt.td_phase1_epoch_size | |||
| @@ -117,10 +152,14 @@ def run_predistill(): | |||
| args_opt.save_ckpt_step, | |||
| args_opt.max_ckpt_num, | |||
| td_phase1_save_ckpt_dir)] | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | |||
| scale_factor=cfg.scale_factor, | |||
| scale_window=cfg.scale_window) | |||
| netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||
| if enable_loss_scale: | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | |||
| scale_factor=cfg.scale_factor, | |||
| scale_window=cfg.scale_window) | |||
| netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||
| else: | |||
| netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer) | |||
| model = Model(netwithgrads) | |||
| model.train(repeat_count, dataset, callbacks=callback, | |||
| dataset_sink_mode=(args_opt.enable_data_sink == 'true'), | |||
| @@ -139,7 +178,7 @@ def run_task_distill(ckpt_file): | |||
| netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, | |||
| student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, | |||
| is_training=True, task_type='classification', | |||
| num_labels=args_opt.num_labels, is_predistill=False) | |||
| num_labels=task.num_labels, is_predistill=False) | |||
| rank = 0 | |||
| device_num = 1 | |||
| @@ -149,6 +188,7 @@ def run_task_distill(ckpt_file): | |||
| dataset_size = train_dataset.get_dataset_size() | |||
| print('td2 train dataset size: ', dataset_size) | |||
| print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count()) | |||
| if args_opt.enable_data_sink == 'true': | |||
| repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps | |||
| time_monitor_steps = args_opt.data_sink_steps | |||
| @@ -175,6 +215,7 @@ def run_task_distill(ckpt_file): | |||
| eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, | |||
| device_num, rank, args_opt.do_shuffle, | |||
| args_opt.eval_data_dir, args_opt.schema_dir) | |||
| print('td2 eval dataset size: ', eval_dataset.get_dataset_size()) | |||
| if args_opt.do_eval.lower() == "true": | |||
| callback = [TimeMonitor(time_monitor_steps), LossCallBack(), | |||
| @@ -185,11 +226,14 @@ def run_task_distill(ckpt_file): | |||
| args_opt.save_ckpt_step, | |||
| args_opt.max_ckpt_num, | |||
| td_phase2_save_ckpt_dir)] | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | |||
| scale_factor=cfg.scale_factor, | |||
| scale_window=cfg.scale_window) | |||
| if enable_loss_scale: | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value, | |||
| scale_factor=cfg.scale_factor, | |||
| scale_window=cfg.scale_window) | |||
| netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||
| netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||
| else: | |||
| netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer) | |||
| model = Model(netwithgrads) | |||
| model.train(repeat_count, train_dataset, callbacks=callback, | |||
| dataset_sink_mode=(args_opt.enable_data_sink == 'true'), | |||
| @@ -203,7 +247,7 @@ def do_eval_standalone(): | |||
| if ckpt_file == '': | |||
| raise ValueError("Student ckpt file should not be None") | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||
| eval_model = BertModelCLS(td_student_net_cfg, False, args_opt.num_labels, 0.0, phase_type="student") | |||
| eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") | |||
| param_dict = load_checkpoint(ckpt_file) | |||
| new_param_dict = {} | |||
| for key, value in param_dict.items(): | |||
| @@ -213,10 +257,13 @@ def do_eval_standalone(): | |||
| load_param_into_net(eval_model, new_param_dict) | |||
| eval_model.set_train(False) | |||
| eval_dataset = create_tinybert_dataset('td', batch_size=1, | |||
| eval_dataset = create_tinybert_dataset('td', batch_size=td_student_net_cfg.batch_size, | |||
| device_num=1, rank=0, do_shuffle="false", | |||
| data_dir=args_opt.eval_data_dir, | |||
| schema_dir=args_opt.schema_dir) | |||
| print('eval dataset size: ', eval_dataset.get_dataset_size()) | |||
| print('eval dataset batch size: ', eval_dataset.get_batch_size()) | |||
| callback = Accuracy() | |||
| columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | |||
| for data in eval_dataset.create_dict_iterator(): | |||
| @@ -231,9 +278,26 @@ def do_eval_standalone(): | |||
| print("============== acc is {}".format(acc)) | |||
| print("======================================") | |||
| if __name__ == '__main__': | |||
| if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true": | |||
| raise ValueError("do_train or do eval must have one be true, please confirm your config") | |||
| enable_loss_scale = True | |||
| if args_opt.device_target == "GPU": | |||
| if td_teacher_net_cfg.compute_type != mstype.float32: | |||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||
| td_teacher_net_cfg.compute_type = mstype.float32 | |||
| if td_student_net_cfg.compute_type != mstype.float32: | |||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||
| td_student_net_cfg.compute_type = mstype.float32 | |||
| # Both the forward and backward of the network are calculated using fp32, | |||
| # and the loss scale is not necessary | |||
| enable_loss_scale = False | |||
| td_teacher_net_cfg.seq_length = task.seq_length | |||
| td_student_net_cfg.seq_length = task.seq_length | |||
| if args_opt.do_train == "true": | |||
| # run predistill | |||
| run_predistill() | |||
| @@ -0,0 +1,40 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the scipt as: " | |||
| echo "bash run_distribute_gd_for_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR TEACHER_CKPT_PATH" | |||
| echo "for example: bash run_distribute_gd_for_gpu.sh 8 3 /path/data/ /path/datasetSchema.json /path/bert_base.ckpt" | |||
| echo "It is better to use absolute path." | |||
| echo "==============================================================================================================" | |||
| RANK_SIZE=$1 | |||
| EPOCH_SIZE=$2 | |||
| DATA_DIR=$3 | |||
| SCHEMA_DIR=$4 | |||
| TEACHER_CKPT_PATH=$5 | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| mpirun --allow-run-as-root -n $RANK_SIZE \ | |||
| python ${PROJECT_DIR}/../run_general_distill.py \ | |||
| --distribute="true" \ | |||
| --device_target="GPU" \ | |||
| --epoch_size=$EPOCH_SIZE \ | |||
| --save_ckpt_path="" \ | |||
| --data_dir=$DATA_DIR \ | |||
| --schema_dir=$SCHEMA_DIR \ | |||
| --load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 & | |||
| @@ -32,7 +32,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \ | |||
| --do_eval="true" \ | |||
| --td_phase1_epoch_size=10 \ | |||
| --td_phase2_epoch_size=3 \ | |||
| --num_labels=2 \ | |||
| --task_name="" \ | |||
| --do_shuffle="true" \ | |||
| --enable_data_sink="true" \ | |||
| --data_sink_steps=100 \ | |||
| @@ -19,7 +19,6 @@ import os | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset.engine.datasets as de | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| from mindspore import log as logger | |||
| def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, | |||
| do_shuffle="true", data_dir=None, schema_dir=None): | |||
| @@ -45,7 +44,5 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, | |||
| ds = ds.map(input_columns="label_ids", operations=type_cast_op) | |||
| # apply batch operations | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| logger.info("data size: {}".format(ds.get_dataset_size())) | |||
| logger.info("repeatcount: {}".format(ds.get_repeat_count())) | |||
| return ds | |||
| @@ -292,6 +292,60 @@ class BertTrainWithLossScaleCell(nn.Cell): | |||
| ret = (loss, cond, scaling_sens) | |||
| return F.depend(ret, succ) | |||
| class BertTrainCell(nn.Cell): | |||
| """ | |||
| Encapsulation class of bert network training. | |||
| Append an optimizer to the training network after that the construct | |||
| function can be called to create the backward graph. | |||
| Args: | |||
| network (Cell): The training network. Note that loss function should have been added. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| sens (Number): The adjust parameter. Default: 1.0. | |||
| """ | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(BertTrainCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.sens = sens | |||
| self.grad = C.GradOperation('grad', | |||
| get_by_list=True, | |||
| sens_param=True) | |||
| self.reducer_flag = False | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| self.grad_reducer = F.identity | |||
| self.degree = 1 | |||
| if self.reducer_flag: | |||
| mean = context.get_auto_parallel_context("mirror_mean") | |||
| self.degree = get_group_size() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree) | |||
| self.cast = P.Cast() | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, | |||
| input_ids, | |||
| input_mask, | |||
| token_type_id): | |||
| """Defines the computation performed.""" | |||
| weights = self.weights | |||
| loss = self.network(input_ids, | |||
| input_mask, | |||
| token_type_id) | |||
| grads = self.grad(self.network, weights)(input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| self.cast(F.tuple_to_array((self.sens,)), | |||
| mstype.float32)) | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| succ = self.optimizer(grads) | |||
| return F.depend(loss, succ) | |||
| class BertNetworkWithLoss_td(nn.Cell): | |||
| """ | |||
| Provide bert pre-training loss through network. | |||
| @@ -411,12 +465,12 @@ class BertNetworkWithLoss_td(nn.Cell): | |||
| total_loss += cls_loss | |||
| return self.cast(total_loss, mstype.float32) | |||
| class BertEvaluationCell(nn.Cell): | |||
| class BertEvaluationWithLossScaleCell(nn.Cell): | |||
| """ | |||
| Especifically defined for finetuning where only four inputs tensor are needed. | |||
| """ | |||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||
| super(BertEvaluationCell, self).__init__(auto_prefix=False) | |||
| super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| @@ -496,3 +550,54 @@ class BertEvaluationCell(nn.Cell): | |||
| succ = self.optimizer(grads) | |||
| ret = (loss, cond, scaling_sens) | |||
| return F.depend(ret, succ) | |||
| class BertEvaluationCell(nn.Cell): | |||
| """ | |||
| Especifically defined for finetuning where only four inputs tensor are needed. | |||
| """ | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(BertEvaluationCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.sens = sens | |||
| self.grad = C.GradOperation('grad', | |||
| get_by_list=True, | |||
| sens_param=True) | |||
| self.reducer_flag = False | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| self.grad_reducer = F.identity | |||
| self.degree = 1 | |||
| if self.reducer_flag: | |||
| mean = context.get_auto_parallel_context("mirror_mean") | |||
| self.degree = get_group_size() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree) | |||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||
| self.cast = P.Cast() | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, | |||
| input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| label_ids): | |||
| """Defines the computation performed.""" | |||
| weights = self.weights | |||
| loss = self.network(input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| label_ids) | |||
| grads = self.grad(self.network, weights)(input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| label_ids, | |||
| self.cast(F.tuple_to_array((self.sens,)), | |||
| mstype.float32)) | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| succ = self.optimizer(grads) | |||
| return F.depend(loss, succ) | |||
| @@ -110,7 +110,10 @@ class EvalCallBack(Callback): | |||
| if acc > self.global_acc: | |||
| self.global_acc = acc | |||
| print("The best acc is {}".format(acc)) | |||
| _exec_save_checkpoint(self.network, "eval_model.ckpt") | |||
| eval_model_ckpt_file = "eval_model.ckpt" | |||
| if os.path.exists(eval_model_ckpt_file): | |||
| os.remove(eval_model_ckpt_file) | |||
| _exec_save_checkpoint(self.network, eval_model_ckpt_file) | |||
| class BertLearningRate(LearningRateSchedule): | |||
| """ | |||