diff --git a/data_module.py b/data_module.py index bab3e3f..e0f0327 100644 --- a/data_module.py +++ b/data_module.py @@ -42,13 +42,16 @@ class DataModule(pl.LightningDataModule): self.test_dataset = CustomDataset(self.x, self.y, self.config) def train_dataloader(self): - return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, + pin_memory=True) def val_dataloader(self): - return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) + return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, + pin_memory=True) def test_dataloader(self): - return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers) + return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, + pin_memory=True) class CustomDataset(Dataset): diff --git a/main.py b/main.py index 9968b36..580003a 100644 --- a/main.py +++ b/main.py @@ -3,10 +3,10 @@ from data_module import DataModule from pytorch_lightning import loggers as pl_loggers import pytorch_lightning as pl from train_model import TrainModule +from multiprocessing import cpu_count def main(stage, - num_workers, max_epochs, batch_size, precision, @@ -24,7 +24,6 @@ def main(stage, 该函数的参数为训练过程中需要经常改动的参数 :param stage: 表示处于训练阶段还是测试阶段, fit表示训练, test表示测试 - :param num_workers: :param max_epochs: :param batch_size: :param precision: 训练精度, 正常精度为32, 半精度为16, 也可以是64. 精度代表每个参数的类型所占的位数 @@ -50,7 +49,7 @@ def main(stage, 'res_coef': 0.5, 'dropout_p': 0.1, 'n_layers': 2, - 'dataset_len': 10000, + 'dataset_len': 100000, 'flag': True} else: config = {'dataset_path': dataset_path, @@ -59,20 +58,21 @@ def main(stage, 'res_coef': 0.5, 'dropout_p': 0.1, 'n_layers': 20, - 'dataset_len': 10000, + 'dataset_len': 100000, 'flag': False} # TODO 获得最优的batch size - # TODO 自动获取CPU核心数并设置num workers + num_workers = cpu_count() precision = 32 if (gpus is None and tpu_cores is None) else precision dm = DataModule(batch_size=batch_size, num_workers=num_workers, config=config) logger = pl_loggers.TensorBoardLogger('logs/') if stage == 'fit': - training_module = TrainModule(config=config) + # SaveCheckpoint的创建需要在TrainModule之前, 以保证网络参数初始化的确定性 save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs, save_name=save_name, path_final_save=path_final_save, every_n_epochs=every_n_epochs, verbose=True, monitor='Validation loss', save_top_k=save_top_k, mode='min') + training_module = TrainModule(config=config) if load_checkpoint_path is None: print('进行初始训练') trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, @@ -83,6 +83,7 @@ def main(stage, trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, resume_from_checkpoint='./logs/default' + load_checkpoint_path, logger=logger, precision=precision, callbacks=[save_checkpoint]) + print('训练过程中请注意gpu利用率等情况') trainer.fit(training_module, datamodule=dm) if stage == 'test': if load_checkpoint_path is None: @@ -98,7 +99,7 @@ def main(stage, if __name__ == "__main__": - main('test', num_workers=8, max_epochs=5, batch_size=32, precision=16, seed=1234, + main('fit', max_epochs=5, batch_size=32, precision=16, seed=1234, # gpus=1, - load_checkpoint_path='/version_4/checkpoints/epoch=4-step=7814.ckpt', + # load_checkpoint_path='/version_4/checkpoints/epoch=4-step=7814.ckpt', ) diff --git a/save_checkpoint.py b/save_checkpoint.py index ba28633..31ff475 100644 --- a/save_checkpoint.py +++ b/save_checkpoint.py @@ -1,6 +1,7 @@ import os + +import numpy.random from pytorch_lightning.callbacks import ModelCheckpoint -import pytorch_lightning import pytorch_lightning as pl import shutil import random @@ -35,12 +36,9 @@ class SaveCheckpoint(ModelCheckpoint): :param no_save_before_epoch: """ super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode) - random.seed(seed) - self.seeds = [] - for i in range(max_epochs): - self.seeds.append(random.randint(0, 2000)) - self.seeds.append(0) - pytorch_lightning.seed_everything(seed) + numpy.random.seed(seed) + self.seeds = numpy.random.randint(0, 2000, max_epochs) + pl.seed_everything(seed) self.save_name = save_name self.path_final_save = path_final_save self.monitor = monitor @@ -57,11 +55,11 @@ class SaveCheckpoint(ModelCheckpoint): :param pl_module: :return: """ + # 第一个epoch使用原始输入seed作为种子, 后续的epoch使用seeds中的第epoch-1个作为种子 if self.flag_sanity_check == 0: - pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch]) self.flag_sanity_check = 1 else: - pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch + 1]) + pl.seed_everything(self.seeds[trainer.current_epoch]) super().on_validation_end(trainer, pl_module) def _save_top_k_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates) -> None: