Browse Source

tinybert script suit for gpu

tags/v0.7.0-beta
hanhuifeng2020 5 years ago
parent
commit
1756d084ab
8 changed files with 272 additions and 40 deletions
  1. +2
    -2
      model_zoo/official/nlp/tinybert/README.md
  2. +38
    -15
      model_zoo/official/nlp/tinybert/run_general_distill.py
  3. +80
    -16
      model_zoo/official/nlp/tinybert/run_task_distill.py
  4. +40
    -0
      model_zoo/official/nlp/tinybert/scripts/run_distribute_gd_for_gpu.sh
  5. +1
    -1
      model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh
  6. +0
    -3
      model_zoo/official/nlp/tinybert/src/dataset.py
  7. +107
    -2
      model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py
  8. +4
    -1
      model_zoo/official/nlp/tinybert/src/utils.py

+ 2
- 2
model_zoo/official/nlp/tinybert/README.md View File

@@ -46,7 +46,7 @@ usage: run_standalone_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T


options: options:
--distribute whether to run distributely: "true" | "false" --distribute whether to run distributely: "true" | "false"
--device_target target device to run, currently only support "Ascend"
--device_target targeted device to run task: "Ascend" | "GPU"
--epoch_size epoch size: N, default is 1 --epoch_size epoch size: N, default is 1
--device_id device id: N, default is 0 --device_id device id: N, default is 0
--enable_data_sink enable data sink: "true" | "false", default is "true" --enable_data_sink enable data sink: "true" | "false", default is "true"
@@ -64,7 +64,7 @@ usage: run_distribute_gd.py [--distribute DISTRIBUTE] [--device_target DEVICE_T


options: options:
--distribute whether to run distributely: "true" | "false" --distribute whether to run distributely: "true" | "false"
--device_target target device to run, currently only support "Ascend"
--device_target targeted device to run task: "Ascend" | "GPU"
--epoch_size epoch size: N, default is 1 --epoch_size epoch size: N, default is 1
--device_id device id: N, default is 0 --device_id device id: N, default is 0
--device_num device id to run task --device_num device id to run task


+ 38
- 15
model_zoo/official/nlp/tinybert/run_general_distill.py View File

@@ -20,16 +20,20 @@ import argparse
import datetime import datetime
import numpy import numpy
import mindspore.communication.management as D import mindspore.communication.management as D
import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import TimeMonitor from mindspore.train.callback import TimeMonitor
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.optim import AdamWeightDecay from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore import log as logger
from src.dataset import create_tinybert_dataset from src.dataset import create_tinybert_dataset
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell




def run_general_distill(): def run_general_distill():
""" """
@@ -53,7 +57,6 @@ def run_general_distill():
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
args_opt = parser.parse_args() args_opt = parser.parse_args()


context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
context.set_context(reserve_class_name_in_scope=False) context.set_context(reserve_class_name_in_scope=False)
context.set_context(variable_memory_max_size="30GB") context.set_context(variable_memory_max_size="30GB")
@@ -61,13 +64,17 @@ def run_general_distill():
save_ckpt_dir = os.path.join(args_opt.save_ckpt_path, save_ckpt_dir = os.path.join(args_opt.save_ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))


if not os.path.exists(save_ckpt_dir):
os.makedirs(save_ckpt_dir)


if args_opt.distribute == "true": if args_opt.distribute == "true":
D.init('hccl')
device_num = args_opt.device_num
rank = args_opt.device_id % device_num
if args_opt.device_target == 'Ascend':
D.init('hccl')
device_num = args_opt.device_num
rank = args_opt.device_id % device_num
else:
D.init('nccl')
device_num = D.get_group_size()
rank = D.get_rank()
save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num) device_num=device_num)
@@ -75,6 +82,21 @@ def run_general_distill():
rank = 0 rank = 0
device_num = 1 device_num = 1


if not os.path.exists(save_ckpt_dir):
os.makedirs(save_ckpt_dir)

enable_loss_scale = True
if args_opt.device_target == "GPU":
if bert_teacher_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
bert_teacher_net_cfg.compute_type = mstype.float32
if bert_student_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
bert_student_net_cfg.compute_type = mstype.float32
# Both the forward and backward of the network are calculated using fp32,
# and the loss scale is not necessary
enable_loss_scale = False

netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg, netwithloss = BertNetworkWithLoss_gd(teacher_config=bert_teacher_net_cfg,
teacher_ckpt=args_opt.load_teacher_ckpt_path, teacher_ckpt=args_opt.load_teacher_ckpt_path,
student_config=bert_student_net_cfg, student_config=bert_student_net_cfg,
@@ -82,11 +104,11 @@ def run_general_distill():


dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank,
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)

dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()
print('dataset size: ', dataset_size) print('dataset size: ', dataset_size)
print("dataset repeatcount: ", dataset.get_repeat_count())
if args_opt.enable_data_sink == "true": if args_opt.enable_data_sink == "true":
repeat_count = args_opt.epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps
else: else:
repeat_count = args_opt.epoch_size repeat_count = args_opt.epoch_size
@@ -110,12 +132,13 @@ def run_general_distill():
args_opt.save_ckpt_step, args_opt.save_ckpt_step,
args_opt.max_ckpt_num, args_opt.max_ckpt_num,
save_ckpt_dir)] save_ckpt_dir)]

update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value,
scale_factor=common_cfg.scale_factor,
scale_window=common_cfg.scale_window)

netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
if enable_loss_scale:
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=common_cfg.loss_scale_value,
scale_factor=common_cfg.scale_factor,
scale_window=common_cfg.scale_window)
netwithgrads = BertTrainWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
else:
netwithgrads = BertTrainCell(netwithloss, optimizer=optimizer)
model = Model(netwithgrads) model = Model(netwithgrads)
model.train(repeat_count, dataset, callbacks=callback, model.train(repeat_count, dataset, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == "true"), dataset_sink_mode=(args_opt.enable_data_sink == "true"),


+ 80
- 16
model_zoo/official/nlp/tinybert/run_task_distill.py View File

@@ -18,6 +18,7 @@
import os import os
import re import re
import argparse import argparse
import mindspore.common.dtype as mstype
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.train.model import Model from mindspore.train.model import Model
@@ -25,11 +26,12 @@ from mindspore.train.callback import TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecay from mindspore.nn.optim import AdamWeightDecay
from mindspore import log as logger
from src.dataset import create_tinybert_dataset from src.dataset import create_tinybert_dataset
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
from src.assessment_method import Accuracy from src.assessment_method import Accuracy
from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg
from src.tinybert_for_gd_td import BertEvaluationCell, BertNetworkWithLoss_td
from src.tinybert_for_gd_td import BertEvaluationWithLossScaleCell, BertNetworkWithLoss_td, BertEvaluationCell
from src.tinybert_model import BertModelCLS from src.tinybert_model import BertModelCLS


_cur_dir = os.getcwd() _cur_dir = os.getcwd()
@@ -45,14 +47,14 @@ def parse_args():
parse args parse args
""" """
parser = argparse.ArgumentParser(description='tinybert task distill') parser = argparse.ArgumentParser(description='tinybert task distill')
parser.add_argument("--device_target", type=str, default="Ascend", help="NPU device, default is Ascend.")
parser.add_argument("--device_target", type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.") parser.add_argument("--do_train", type=str, default="true", help="Do train task, default is true.")
parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.") parser.add_argument("--do_eval", type=str, default="true", help="Do eval task, default is true.")
parser.add_argument("--td_phase1_epoch_size", type=int, default=10, parser.add_argument("--td_phase1_epoch_size", type=int, default=10,
help="Epoch size for td phase 1, default is 10.") help="Epoch size for td phase 1, default is 10.")
parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.") parser.add_argument("--td_phase2_epoch_size", type=int, default=3, help="Epoch size for td phase 2, default is 3.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--num_labels", type=int, default=2, help="Classfication task, support SST2, QNLI, MNLI.")
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.") parser.add_argument("--save_ckpt_step", type=int, default=100, help="Enable data sink, default is true.")
@@ -64,11 +66,43 @@ def parse_args():
parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--train_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--eval_data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
help="The name of the task to train.")


args = parser.parse_args() args = parser.parse_args()
return args return args



args_opt = parse_args() args_opt = parse_args()

DEFAULT_NUM_LABELS = 2
DEFAULT_SEQ_LENGTH = 128
task_params = {"SST-2": {"num_labels": 2, "seq_length": 64},
"QNLI": {"num_labels": 2, "seq_length": 128},
"MNLI": {"num_labels": 3, "seq_length": 128}}


class Task:
"""
Encapsulation class of get the task parameter.
"""
def __init__(self, task_name):
self.task_name = task_name

@property
def num_labels(self):
if self.task_name in task_params and "num_labels" in task_params[self.task_name]:
return task_params[self.task_name]["num_labels"]
return DEFAULT_NUM_LABELS

@property
def seq_length(self):
if self.task_name in task_params and "seq_length" in task_params[self.task_name]:
return task_params[self.task_name]["seq_length"]
return DEFAULT_SEQ_LENGTH
task = Task(args_opt.task_name)


def run_predistill(): def run_predistill():
""" """
run predistill run predistill
@@ -81,7 +115,7 @@ def run_predistill():
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
is_training=True, task_type='classification', is_training=True, task_type='classification',
num_labels=args_opt.num_labels, is_predistill=True)
num_labels=task.num_labels, is_predistill=True)


rank = 0 rank = 0
device_num = 1 device_num = 1
@@ -91,8 +125,9 @@ def run_predistill():


dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()
print('td1 dataset size: ', dataset_size) print('td1 dataset size: ', dataset_size)
print('td1 dataset repeatcount: ', dataset.get_repeat_count())
if args_opt.enable_data_sink == 'true': if args_opt.enable_data_sink == 'true':
repeat_count = args_opt.td_phase1_epoch_size * dataset.get_dataset_size() // args_opt.data_sink_steps
repeat_count = args_opt.td_phase1_epoch_size * dataset_size // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps
else: else:
repeat_count = args_opt.td_phase1_epoch_size repeat_count = args_opt.td_phase1_epoch_size
@@ -117,10 +152,14 @@ def run_predistill():
args_opt.save_ckpt_step, args_opt.save_ckpt_step,
args_opt.max_ckpt_num, args_opt.max_ckpt_num,
td_phase1_save_ckpt_dir)] td_phase1_save_ckpt_dir)]
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
if enable_loss_scale:
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
else:
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)

model = Model(netwithgrads) model = Model(netwithgrads)
model.train(repeat_count, dataset, callbacks=callback, model.train(repeat_count, dataset, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == 'true'), dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
@@ -139,7 +178,7 @@ def run_task_distill(ckpt_file):
netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path, netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,
student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path, student_config=td_student_net_cfg, student_ckpt=load_student_checkpoint_path,
is_training=True, task_type='classification', is_training=True, task_type='classification',
num_labels=args_opt.num_labels, is_predistill=False)
num_labels=task.num_labels, is_predistill=False)


rank = 0 rank = 0
device_num = 1 device_num = 1
@@ -149,6 +188,7 @@ def run_task_distill(ckpt_file):


dataset_size = train_dataset.get_dataset_size() dataset_size = train_dataset.get_dataset_size()
print('td2 train dataset size: ', dataset_size) print('td2 train dataset size: ', dataset_size)
print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count())
if args_opt.enable_data_sink == 'true': if args_opt.enable_data_sink == 'true':
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps time_monitor_steps = args_opt.data_sink_steps
@@ -175,6 +215,7 @@ def run_task_distill(ckpt_file):
eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, eval_dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
device_num, rank, args_opt.do_shuffle, device_num, rank, args_opt.do_shuffle,
args_opt.eval_data_dir, args_opt.schema_dir) args_opt.eval_data_dir, args_opt.schema_dir)
print('td2 eval dataset size: ', eval_dataset.get_dataset_size())


if args_opt.do_eval.lower() == "true": if args_opt.do_eval.lower() == "true":
callback = [TimeMonitor(time_monitor_steps), LossCallBack(), callback = [TimeMonitor(time_monitor_steps), LossCallBack(),
@@ -185,11 +226,14 @@ def run_task_distill(ckpt_file):
args_opt.save_ckpt_step, args_opt.save_ckpt_step,
args_opt.max_ckpt_num, args_opt.max_ckpt_num,
td_phase2_save_ckpt_dir)] td_phase2_save_ckpt_dir)]
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)
if enable_loss_scale:
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
scale_factor=cfg.scale_factor,
scale_window=cfg.scale_window)


netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
netwithgrads = BertEvaluationWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
else:
netwithgrads = BertEvaluationCell(netwithloss, optimizer=optimizer)
model = Model(netwithgrads) model = Model(netwithgrads)
model.train(repeat_count, train_dataset, callbacks=callback, model.train(repeat_count, train_dataset, callbacks=callback,
dataset_sink_mode=(args_opt.enable_data_sink == 'true'), dataset_sink_mode=(args_opt.enable_data_sink == 'true'),
@@ -203,7 +247,7 @@ def do_eval_standalone():
if ckpt_file == '': if ckpt_file == '':
raise ValueError("Student ckpt file should not be None") raise ValueError("Student ckpt file should not be None")
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
eval_model = BertModelCLS(td_student_net_cfg, False, args_opt.num_labels, 0.0, phase_type="student")
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
param_dict = load_checkpoint(ckpt_file) param_dict = load_checkpoint(ckpt_file)
new_param_dict = {} new_param_dict = {}
for key, value in param_dict.items(): for key, value in param_dict.items():
@@ -213,10 +257,13 @@ def do_eval_standalone():
load_param_into_net(eval_model, new_param_dict) load_param_into_net(eval_model, new_param_dict)
eval_model.set_train(False) eval_model.set_train(False)


eval_dataset = create_tinybert_dataset('td', batch_size=1,
eval_dataset = create_tinybert_dataset('td', batch_size=td_student_net_cfg.batch_size,
device_num=1, rank=0, do_shuffle="false", device_num=1, rank=0, do_shuffle="false",
data_dir=args_opt.eval_data_dir, data_dir=args_opt.eval_data_dir,
schema_dir=args_opt.schema_dir) schema_dir=args_opt.schema_dir)
print('eval dataset size: ', eval_dataset.get_dataset_size())
print('eval dataset batch size: ', eval_dataset.get_batch_size())

callback = Accuracy() callback = Accuracy()
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
for data in eval_dataset.create_dict_iterator(): for data in eval_dataset.create_dict_iterator():
@@ -231,9 +278,26 @@ def do_eval_standalone():
print("============== acc is {}".format(acc)) print("============== acc is {}".format(acc))
print("======================================") print("======================================")



if __name__ == '__main__': if __name__ == '__main__':
if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true": if args_opt.do_train.lower() != "true" and args_opt.do_eval.lower() != "true":
raise ValueError("do_train or do eval must have one be true, please confirm your config") raise ValueError("do_train or do eval must have one be true, please confirm your config")

enable_loss_scale = True
if args_opt.device_target == "GPU":
if td_teacher_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
td_teacher_net_cfg.compute_type = mstype.float32
if td_student_net_cfg.compute_type != mstype.float32:
logger.warning('GPU only support fp32 temporarily, run with fp32.')
td_student_net_cfg.compute_type = mstype.float32
# Both the forward and backward of the network are calculated using fp32,
# and the loss scale is not necessary
enable_loss_scale = False

td_teacher_net_cfg.seq_length = task.seq_length
td_student_net_cfg.seq_length = task.seq_length

if args_opt.do_train == "true": if args_opt.do_train == "true":
# run predistill # run predistill
run_predistill() run_predistill()


+ 40
- 0
model_zoo/official/nlp/tinybert/scripts/run_distribute_gd_for_gpu.sh View File

@@ -0,0 +1,40 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_distribute_gd_for_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR TEACHER_CKPT_PATH"
echo "for example: bash run_distribute_gd_for_gpu.sh 8 3 /path/data/ /path/datasetSchema.json /path/bert_base.ckpt"
echo "It is better to use absolute path."
echo "=============================================================================================================="

RANK_SIZE=$1
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
TEACHER_CKPT_PATH=$5

PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)

mpirun --allow-run-as-root -n $RANK_SIZE \
python ${PROJECT_DIR}/../run_general_distill.py \
--distribute="true" \
--device_target="GPU" \
--epoch_size=$EPOCH_SIZE \
--save_ckpt_path="" \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR \
--load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 &

+ 1
- 1
model_zoo/official/nlp/tinybert/scripts/run_standalone_td.sh View File

@@ -32,7 +32,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \
--do_eval="true" \ --do_eval="true" \
--td_phase1_epoch_size=10 \ --td_phase1_epoch_size=10 \
--td_phase2_epoch_size=3 \ --td_phase2_epoch_size=3 \
--num_labels=2 \
--task_name="" \
--do_shuffle="true" \ --do_shuffle="true" \
--enable_data_sink="true" \ --enable_data_sink="true" \
--data_sink_steps=100 \ --data_sink_steps=100 \


+ 0
- 3
model_zoo/official/nlp/tinybert/src/dataset.py View File

@@ -19,7 +19,6 @@ import os
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
from mindspore import log as logger


def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
do_shuffle="true", data_dir=None, schema_dir=None): do_shuffle="true", data_dir=None, schema_dir=None):
@@ -45,7 +44,5 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
ds = ds.map(input_columns="label_ids", operations=type_cast_op) ds = ds.map(input_columns="label_ids", operations=type_cast_op)
# apply batch operations # apply batch operations
ds = ds.batch(batch_size, drop_remainder=True) ds = ds.batch(batch_size, drop_remainder=True)
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))


return ds return ds

+ 107
- 2
model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py View File

@@ -292,6 +292,60 @@ class BertTrainWithLossScaleCell(nn.Cell):
ret = (loss, cond, scaling_sens) ret = (loss, cond, scaling_sens)
return F.depend(ret, succ) return F.depend(ret, succ)


class BertTrainCell(nn.Cell):
"""
Encapsulation class of bert network training.

Append an optimizer to the training network after that the construct
function can be called to create the backward graph.

Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
"""
def __init__(self, network, optimizer, sens=1.0):
super(BertTrainCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = optimizer.parameters
self.optimizer = optimizer
self.sens = sens
self.grad = C.GradOperation('grad',
get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree)
self.cast = P.Cast()
self.hyper_map = C.HyperMap()

def construct(self,
input_ids,
input_mask,
token_type_id):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
succ = self.optimizer(grads)
return F.depend(loss, succ)

class BertNetworkWithLoss_td(nn.Cell): class BertNetworkWithLoss_td(nn.Cell):
""" """
Provide bert pre-training loss through network. Provide bert pre-training loss through network.
@@ -411,12 +465,12 @@ class BertNetworkWithLoss_td(nn.Cell):
total_loss += cls_loss total_loss += cls_loss
return self.cast(total_loss, mstype.float32) return self.cast(total_loss, mstype.float32)


class BertEvaluationCell(nn.Cell):
class BertEvaluationWithLossScaleCell(nn.Cell):
""" """
Especifically defined for finetuning where only four inputs tensor are needed. Especifically defined for finetuning where only four inputs tensor are needed.
""" """
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertEvaluationCell, self).__init__(auto_prefix=False)
super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
@@ -496,3 +550,54 @@ class BertEvaluationCell(nn.Cell):
succ = self.optimizer(grads) succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens) ret = (loss, cond, scaling_sens)
return F.depend(ret, succ) return F.depend(ret, succ)


class BertEvaluationCell(nn.Cell):
"""
Especifically defined for finetuning where only four inputs tensor are needed.
"""
def __init__(self, network, optimizer, sens=1.0):
super(BertEvaluationCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = optimizer.parameters
self.optimizer = optimizer
self.sens = sens
self.grad = C.GradOperation('grad',
get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
self.hyper_map = C.HyperMap()

def construct(self,
input_ids,
input_mask,
token_type_id,
label_ids):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids,
input_mask,
token_type_id,
label_ids)
grads = self.grad(self.network, weights)(input_ids,
input_mask,
token_type_id,
label_ids,
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
succ = self.optimizer(grads)
return F.depend(loss, succ)

+ 4
- 1
model_zoo/official/nlp/tinybert/src/utils.py View File

@@ -110,7 +110,10 @@ class EvalCallBack(Callback):
if acc > self.global_acc: if acc > self.global_acc:
self.global_acc = acc self.global_acc = acc
print("The best acc is {}".format(acc)) print("The best acc is {}".format(acc))
_exec_save_checkpoint(self.network, "eval_model.ckpt")
eval_model_ckpt_file = "eval_model.ckpt"
if os.path.exists(eval_model_ckpt_file):
os.remove(eval_model_ckpt_file)
_exec_save_checkpoint(self.network, eval_model_ckpt_file)


class BertLearningRate(LearningRateSchedule): class BertLearningRate(LearningRateSchedule):
""" """


Loading…
Cancel
Save