| @@ -28,7 +28,7 @@ from mindspore.train.parallel_utils import ParallelMode | |||||
| from mindspore.nn.optim import AdamWeightDecay | from mindspore.nn.optim import AdamWeightDecay | ||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from src.dataset import create_tinybert_dataset | |||||
| from src.dataset import create_tinybert_dataset, DataType | |||||
| from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate | from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate | ||||
| from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg | from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg | ||||
| from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell | from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell | ||||
| @@ -55,6 +55,7 @@ def run_general_distill(): | |||||
| parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") | parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path") | ||||
| parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") | parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") | ||||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | ||||
| parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord") | |||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) | ||||
| @@ -99,8 +100,15 @@ def run_general_distill(): | |||||
| student_config=bert_student_net_cfg, | student_config=bert_student_net_cfg, | ||||
| is_training=True, use_one_hot_embeddings=False) | is_training=True, use_one_hot_embeddings=False) | ||||
| if args_opt.dataset_type == "tfrecord": | |||||
| dataset_type = DataType.TFRECORD | |||||
| elif arg_opt.dataset_type == "mindrecord": | |||||
| dataset_type = DataType.MINDRECORD | |||||
| else: | |||||
| raise Exception("dataset format is not supported yet") | |||||
| dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, | dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank, | ||||
| args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir) | |||||
| args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir, | |||||
| data_type=dataset_type) | |||||
| dataset_size = dataset.get_dataset_size() | dataset_size = dataset.get_dataset_size() | ||||
| print('dataset size: ', dataset_size) | print('dataset size: ', dataset_size) | ||||
| print("dataset repeatcount: ", dataset.get_repeat_count()) | print("dataset repeatcount: ", dataset.get_repeat_count()) | ||||
| @@ -27,7 +27,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | ||||
| from mindspore.nn.optim import AdamWeightDecay | from mindspore.nn.optim import AdamWeightDecay | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from src.dataset import create_tinybert_dataset | |||||
| from src.dataset import create_tinybert_dataset, DataType | |||||
| from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate | from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate | ||||
| from src.assessment_method import Accuracy | from src.assessment_method import Accuracy | ||||
| from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg | from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg | ||||
| @@ -68,7 +68,7 @@ def parse_args(): | |||||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | ||||
| parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], | parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"], | ||||
| help="The name of the task to train.") | help="The name of the task to train.") | ||||
| parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| return args | return args | ||||
| @@ -119,9 +119,17 @@ def run_predistill(): | |||||
| rank = 0 | rank = 0 | ||||
| device_num = 1 | device_num = 1 | ||||
| if arg_opt.dataset_type == "tfrecord": | |||||
| dataset_type = DataType.TFRECORD | |||||
| elif arg_opt.dataset_type == "mindrecord": | |||||
| dataset_type = DataType.MINDRECORD | |||||
| else: | |||||
| raise Exception("dataset format is not supported yet") | |||||
| dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, | dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size, | ||||
| device_num, rank, args_opt.do_shuffle, | device_num, rank, args_opt.do_shuffle, | ||||
| args_opt.train_data_dir, args_opt.schema_dir) | |||||
| args_opt.train_data_dir, args_opt.schema_dir, | |||||
| data_tpye=dataset_type) | |||||
| dataset_size = dataset.get_dataset_size() | dataset_size = dataset.get_dataset_size() | ||||
| print('td1 dataset size: ', dataset_size) | print('td1 dataset size: ', dataset_size) | ||||
| @@ -39,4 +39,5 @@ python ${PROJECT_DIR}/../run_general_distill.py \ | |||||
| --save_ckpt_path="" \ | --save_ckpt_path="" \ | ||||
| --load_teacher_ckpt_path="" \ | --load_teacher_ckpt_path="" \ | ||||
| --data_dir="" \ | --data_dir="" \ | ||||
| --schema_dir="" > log.txt 2>&1 & | |||||
| --schema_dir="" \ | |||||
| --dataset_type="tfrecord" > log.txt 2>&1 & | |||||
| @@ -16,26 +16,38 @@ | |||||
| """create tinybert dataset""" | """create tinybert dataset""" | ||||
| import os | import os | ||||
| from enum import Enum | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.dataset.engine.datasets as de | import mindspore.dataset.engine.datasets as de | ||||
| import mindspore.dataset.transforms.c_transforms as C | import mindspore.dataset.transforms.c_transforms as C | ||||
| class DataType(Enum): | |||||
| """Enumerate supported dataset format""" | |||||
| TFRECORD = 1 | |||||
| MINDRECORD = 2 | |||||
| def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, | def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, | ||||
| do_shuffle="true", data_dir=None, schema_dir=None): | |||||
| do_shuffle="true", data_dir=None, schema_dir=None, | |||||
| data_type=DataType.TFRECORD): | |||||
| """create tinybert dataset""" | """create tinybert dataset""" | ||||
| files = os.listdir(data_dir) | files = os.listdir(data_dir) | ||||
| data_files = [] | data_files = [] | ||||
| for file_name in files: | for file_name in files: | ||||
| if "record" in file_name: | |||||
| if "record" in file_name and "db" not in file_name: | |||||
| data_files.append(os.path.join(data_dir, file_name)) | data_files.append(os.path.join(data_dir, file_name)) | ||||
| if task == "td": | if task == "td": | ||||
| columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | ||||
| else: | else: | ||||
| columns_list = ["input_ids", "input_mask", "segment_ids"] | columns_list = ["input_ids", "input_mask", "segment_ids"] | ||||
| ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, | |||||
| shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, | |||||
| shard_equal_rows=True) | |||||
| if data_type == DataType.MINDRECORD: | |||||
| ds = de.MindDataset(data_files, columns_list=columns_list, | |||||
| shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank) | |||||
| else: | |||||
| ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, | |||||
| shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, | |||||
| shard_equal_rows=True) | |||||
| type_cast_op = C.TypeCast(mstype.int32) | type_cast_op = C.TypeCast(mstype.int32) | ||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | ||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | ds = ds.map(input_columns="input_mask", operations=type_cast_op) | ||||