diff --git a/model_zoo/official/nlp/tinybert/run_general_distill.py b/model_zoo/official/nlp/tinybert/run_general_distill.py index 8fdc86b8bc..dd4b36cd09 100644 --- a/model_zoo/official/nlp/tinybert/run_general_distill.py +++ b/model_zoo/official/nlp/tinybert/run_general_distill.py @@ -28,7 +28,7 @@ from mindspore.train.parallel_utils import ParallelMode from mindspore.nn.optim import AdamWeightDecay from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore import log as logger -from src.dataset import create_tinybert_dataset +from src.dataset import create_tinybert_dataset, DataType from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell @@ -55,6 +55,7 @@ def run_general_distill(): parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") + parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) @@ -99,8 +100,15 @@ def run_general_distill(): student_config=bert_student_net_cfg, is_training=True, use_one_hot_embeddings=False) + if args_opt.dataset_type == "tfrecord": + dataset_type = DataType.TFRECORD + elif arg_opt.dataset_type == "mindrecord": + dataset_type = DataType.MINDRECORD + else: + raise Exception("dataset format is not supported yet") dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, - args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) + args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir, + data_type=dataset_type) dataset_size = dataset.get_dataset_size() print('dataset size: ', dataset_size) print("dataset repeatcount: ", dataset.get_repeat_count()) diff --git a/model_zoo/official/nlp/tinybert/run_task_distill.py b/model_zoo/official/nlp/tinybert/run_task_distill.py index 517f652d8a..bc3266d265 100644 --- a/model_zoo/official/nlp/tinybert/run_task_distill.py +++ b/model_zoo/official/nlp/tinybert/run_task_distill.py @@ -27,7 +27,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.optim import AdamWeightDecay from mindspore import log as logger -from src.dataset import create_tinybert_dataset +from src.dataset import create_tinybert_dataset, DataType from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate from src.assessment_method import Accuracy from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg @@ -68,7 +68,7 @@ def parse_args(): parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], help="The name of the task to train.") - + parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord") args = parser.parse_args() return args @@ -119,9 +119,17 @@ def run_predistill(): rank = 0 device_num = 1 + + if arg_opt.dataset_type == "tfrecord": + dataset_type = DataType.TFRECORD + elif arg_opt.dataset_type == "mindrecord": + dataset_type = DataType.MINDRECORD + else: + raise Exception("dataset format is not supported yet") dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, device_num, rank, args_opt.do_shuffle, - args_opt.train_data_dir, args_opt.schema_dir) + args_opt.train_data_dir, args_opt.schema_dir, + data_tpye=dataset_type) dataset_size = dataset.get_dataset_size() print('td1 dataset size: ', dataset_size) diff --git a/model_zoo/official/nlp/tinybert/scripts/run_standalone_gd_ascend.sh b/model_zoo/official/nlp/tinybert/scripts/run_standalone_gd_ascend.sh index 343d1ed7ca..7d4b241865 100644 --- a/model_zoo/official/nlp/tinybert/scripts/run_standalone_gd_ascend.sh +++ b/model_zoo/official/nlp/tinybert/scripts/run_standalone_gd_ascend.sh @@ -39,4 +39,5 @@ python ${PROJECT_DIR}/../run_general_distill.py \ --save_ckpt_path="" \ --load_teacher_ckpt_path="" \ --data_dir="" \ - --schema_dir="" > log.txt 2>&1 & + --schema_dir="" \ + --dataset_type="tfrecord" > log.txt 2>&1 & diff --git a/model_zoo/official/nlp/tinybert/src/dataset.py b/model_zoo/official/nlp/tinybert/src/dataset.py index d4af8ed603..e632e02fe2 100644 --- a/model_zoo/official/nlp/tinybert/src/dataset.py +++ b/model_zoo/official/nlp/tinybert/src/dataset.py @@ -16,26 +16,38 @@ """create tinybert dataset""" import os +from enum import Enum import mindspore.common.dtype as mstype import mindspore.dataset.engine.datasets as de import mindspore.dataset.transforms.c_transforms as C +class DataType(Enum): + """Enumerate supported dataset format""" + TFRECORD = 1 + MINDRECORD = 2 + def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, - do_shuffle="true", data_dir=None, schema_dir=None): + do_shuffle="true", data_dir=None, schema_dir=None, + data_type=DataType.TFRECORD): """create tinybert dataset""" files = os.listdir(data_dir) data_files = [] for file_name in files: - if "record" in file_name: + if "record" in file_name and "db" not in file_name: data_files.append(os.path.join(data_dir, file_name)) if task == "td": columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] else: columns_list = ["input_ids", "input_mask", "segment_ids"] - ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, - shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, - shard_equal_rows=True) + if data_type == DataType.MINDRECORD: + ds = de.MindDataset(data_files, columns_list=columns_list, + shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank) + else: + ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, + shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, + shard_equal_rows=True) + type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="segment_ids", operations=type_cast_op) ds = ds.map(input_columns="input_mask", operations=type_cast_op)