Browse Source

support minddataset for tinybert

tags/v1.0.0
dengyutao 5 years ago
parent
commit
4e0447b05e
4 changed files with 40 additions and 11 deletions
  1. +10
    -2
      model_zoo/official/nlp/tinybert/run_general_distill.py
  2. +11
    -3
      model_zoo/official/nlp/tinybert/run_task_distill.py
  3. +2
    -1
      model_zoo/official/nlp/tinybert/scripts/run_standalone_gd_ascend.sh
  4. +17
    -5
      model_zoo/official/nlp/tinybert/src/dataset.py

+ 10
- 2
model_zoo/official/nlp/tinybert/run_general_distill.py View File

@@ -28,7 +28,7 @@ from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.optim import AdamWeightDecay from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore import log as logger from mindspore import log as logger
from src.dataset import create_tinybert_dataset
from src.dataset import create_tinybert_dataset, DataType
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell
@@ -55,6 +55,7 @@ def run_general_distill():
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord")
args_opt = parser.parse_args() args_opt = parser.parse_args()


context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
@@ -99,8 +100,15 @@ def run_general_distill():
student_config=bert_student_net_cfg, student_config=bert_student_net_cfg,
is_training=True, use_one_hot_embeddings=False) is_training=True, use_one_hot_embeddings=False)


if args_opt.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif arg_opt.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
raise Exception("dataset format is not supported yet")
dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank,
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir,
data_type=dataset_type)
dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()
print('dataset size: ', dataset_size) print('dataset size: ', dataset_size)
print("dataset repeatcount: ", dataset.get_repeat_count()) print("dataset repeatcount: ", dataset.get_repeat_count())


+ 11
- 3
model_zoo/official/nlp/tinybert/run_task_distill.py View File

@@ -27,7 +27,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecay from mindspore.nn.optim import AdamWeightDecay
from mindspore import log as logger from mindspore import log as logger
from src.dataset import create_tinybert_dataset
from src.dataset import create_tinybert_dataset, DataType
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
from src.assessment_method import Accuracy from src.assessment_method import Accuracy
from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg
@@ -68,7 +68,7 @@ def parse_args():
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
help="The name of the task to train.") help="The name of the task to train.")
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord")
args = parser.parse_args() args = parser.parse_args()
return args return args


@@ -119,9 +119,17 @@ def run_predistill():


rank = 0 rank = 0
device_num = 1 device_num = 1

if arg_opt.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif arg_opt.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
raise Exception("dataset format is not supported yet")
dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
device_num, rank, args_opt.do_shuffle, device_num, rank, args_opt.do_shuffle,
args_opt.train_data_dir, args_opt.schema_dir)
args_opt.train_data_dir, args_opt.schema_dir,
data_tpye=dataset_type)


dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()
print('td1 dataset size: ', dataset_size) print('td1 dataset size: ', dataset_size)


+ 2
- 1
model_zoo/official/nlp/tinybert/scripts/run_standalone_gd_ascend.sh View File

@@ -39,4 +39,5 @@ python ${PROJECT_DIR}/../run_general_distill.py \
--save_ckpt_path="" \ --save_ckpt_path="" \
--load_teacher_ckpt_path="" \ --load_teacher_ckpt_path="" \
--data_dir="" \ --data_dir="" \
--schema_dir="" > log.txt 2>&1 &
--schema_dir="" \
--dataset_type="tfrecord" > log.txt 2>&1 &

+ 17
- 5
model_zoo/official/nlp/tinybert/src/dataset.py View File

@@ -16,26 +16,38 @@
"""create tinybert dataset""" """create tinybert dataset"""


import os import os
from enum import Enum
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C


class DataType(Enum):
"""Enumerate supported dataset format"""
TFRECORD = 1
MINDRECORD = 2

def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
do_shuffle="true", data_dir=None, schema_dir=None):
do_shuffle="true", data_dir=None, schema_dir=None,
data_type=DataType.TFRECORD):
"""create tinybert dataset""" """create tinybert dataset"""
files = os.listdir(data_dir) files = os.listdir(data_dir)
data_files = [] data_files = []
for file_name in files: for file_name in files:
if "record" in file_name:
if "record" in file_name and "db" not in file_name:
data_files.append(os.path.join(data_dir, file_name)) data_files.append(os.path.join(data_dir, file_name))
if task == "td": if task == "td":
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
else: else:
columns_list = ["input_ids", "input_mask", "segment_ids"] columns_list = ["input_ids", "input_mask", "segment_ids"]


ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
if data_type == DataType.MINDRECORD:
ds = de.MindDataset(data_files, columns_list=columns_list,
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank)
else:
ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)

type_cast_op = C.TypeCast(mstype.int32) type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op) ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op) ds = ds.map(input_columns="input_mask", operations=type_cast_op)


Loading…
Cancel
Save