| @@ -28,7 +28,7 @@ from mindspore.train.parallel_utils import ParallelMode | |||
| from mindspore.nn.optim import AdamWeightDecay | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore import log as logger | |||
| from src.dataset import create_tinybert_dataset | |||
| from src.dataset import create_tinybert_dataset, DataType | |||
| from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate | |||
| from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg | |||
| from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell | |||
| @@ -55,6 +55,7 @@ def run_general_distill(): | |||
| parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") | |||
| parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | |||
| parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord") | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | |||
| @@ -99,8 +100,15 @@ def run_general_distill(): | |||
| student_config=bert_student_net_cfg, | |||
| is_training=True, use_one_hot_embeddings=False) | |||
| if args_opt.dataset_type == "tfrecord": | |||
| dataset_type = DataType.TFRECORD | |||
| elif arg_opt.dataset_type == "mindrecord": | |||
| dataset_type = DataType.MINDRECORD | |||
| else: | |||
| raise Exception("dataset format is not supported yet") | |||
| dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, | |||
| args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) | |||
| args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir, | |||
| data_type=dataset_type) | |||
| dataset_size = dataset.get_dataset_size() | |||
| print('dataset size: ', dataset_size) | |||
| print("dataset repeatcount: ", dataset.get_repeat_count()) | |||
| @@ -27,7 +27,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore.nn.optim import AdamWeightDecay | |||
| from mindspore import log as logger | |||
| from src.dataset import create_tinybert_dataset | |||
| from src.dataset import create_tinybert_dataset, DataType | |||
| from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate | |||
| from src.assessment_method import Accuracy | |||
| from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg | |||
| @@ -68,7 +68,7 @@ def parse_args(): | |||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | |||
| parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], | |||
| help="The name of the task to train.") | |||
| parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord") | |||
| args = parser.parse_args() | |||
| return args | |||
| @@ -119,9 +119,17 @@ def run_predistill(): | |||
| rank = 0 | |||
| device_num = 1 | |||
| if arg_opt.dataset_type == "tfrecord": | |||
| dataset_type = DataType.TFRECORD | |||
| elif arg_opt.dataset_type == "mindrecord": | |||
| dataset_type = DataType.MINDRECORD | |||
| else: | |||
| raise Exception("dataset format is not supported yet") | |||
| dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, | |||
| device_num, rank, args_opt.do_shuffle, | |||
| args_opt.train_data_dir, args_opt.schema_dir) | |||
| args_opt.train_data_dir, args_opt.schema_dir, | |||
| data_tpye=dataset_type) | |||
| dataset_size = dataset.get_dataset_size() | |||
| print('td1 dataset size: ', dataset_size) | |||
| @@ -39,4 +39,5 @@ python ${PROJECT_DIR}/../run_general_distill.py \ | |||
| --save_ckpt_path="" \ | |||
| --load_teacher_ckpt_path="" \ | |||
| --data_dir="" \ | |||
| --schema_dir="" > log.txt 2>&1 & | |||
| --schema_dir="" \ | |||
| --dataset_type="tfrecord" > log.txt 2>&1 & | |||
| @@ -16,26 +16,38 @@ | |||
| """create tinybert dataset""" | |||
| import os | |||
| from enum import Enum | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset.engine.datasets as de | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| class DataType(Enum): | |||
| """Enumerate supported dataset format""" | |||
| TFRECORD = 1 | |||
| MINDRECORD = 2 | |||
| def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, | |||
| do_shuffle="true", data_dir=None, schema_dir=None): | |||
| do_shuffle="true", data_dir=None, schema_dir=None, | |||
| data_type=DataType.TFRECORD): | |||
| """create tinybert dataset""" | |||
| files = os.listdir(data_dir) | |||
| data_files = [] | |||
| for file_name in files: | |||
| if "record" in file_name: | |||
| if "record" in file_name and "db" not in file_name: | |||
| data_files.append(os.path.join(data_dir, file_name)) | |||
| if task == "td": | |||
| columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | |||
| else: | |||
| columns_list = ["input_ids", "input_mask", "segment_ids"] | |||
| ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, | |||
| shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, | |||
| shard_equal_rows=True) | |||
| if data_type == DataType.MINDRECORD: | |||
| ds = de.MindDataset(data_files, columns_list=columns_list, | |||
| shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank) | |||
| else: | |||
| ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, | |||
| shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, | |||
| shard_equal_rows=True) | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||