From b114979ea5df8a5213901c4476a5c2a15f555966 Mon Sep 17 00:00:00 2001 From: shenyan <23357320@qq.com> Date: Thu, 14 Oct 2021 15:32:17 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E7=94=A8pin=5Fmemory=3DTrue;=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E5=AF=B9gpu=E6=80=A7=E8=83=BD=E7=9A=84=E6=B3=A8?= =?UTF-8?q?=E6=84=8F=E8=AF=B4=E6=98=8E;=E8=87=AA=E5=8A=A8=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E6=A0=B8=E5=BF=83=E6=95=B0=E5=B9=B6=E8=AE=BE=E7=BD=AE?= =?UTF-8?q?=E5=85=B6=E4=B8=BAnum=5Fworkers;=E7=AE=80=E5=8C=96=E6=AF=8F?= =?UTF-8?q?=E4=B8=AAepoch=E7=9A=84=E9=9A=8F=E6=9C=BA=E6=95=B0=E6=9B=B4?= =?UTF-8?q?=E6=8D=A2=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_module.py | 9 ++++++--- main.py | 17 +++++++++-------- save_checkpoint.py | 16 +++++++--------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/data_module.py b/data_module.py index bab3e3f..e0f0327 100644 --- a/data_module.py +++ b/data_module.py @@ -42,13 +42,16 @@ class DataModule(pl.LightningDataModule): self.test_dataset = CustomDataset(self.x, self.y, self.config) def train_dataloader(self): - return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, + pin_memory=True) def val_dataloader(self): - return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) + return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, + pin_memory=True) def test_dataloader(self): - return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers) + return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, + pin_memory=True) class CustomDataset(Dataset): diff --git a/main.py b/main.py index 9968b36..580003a 100644 --- a/main.py +++ b/main.py @@ -3,10 +3,10 @@ from data_module import DataModule from pytorch_lightning import loggers as pl_loggers import pytorch_lightning as pl from train_model import TrainModule +from multiprocessing import cpu_count def main(stage, - num_workers, max_epochs, batch_size, precision, @@ -24,7 +24,6 @@ def main(stage, 该函数的参数为训练过程中需要经常改动的参数 :param stage: 表示处于训练阶段还是测试阶段, fit表示训练, test表示测试 - :param num_workers: :param max_epochs: :param batch_size: :param precision: 训练精度, 正常精度为32, 半精度为16, 也可以是64. 精度代表每个参数的类型所占的位数 @@ -50,7 +49,7 @@ def main(stage, 'res_coef': 0.5, 'dropout_p': 0.1, 'n_layers': 2, - 'dataset_len': 10000, + 'dataset_len': 100000, 'flag': True} else: config = {'dataset_path': dataset_path, @@ -59,20 +58,21 @@ def main(stage, 'res_coef': 0.5, 'dropout_p': 0.1, 'n_layers': 20, - 'dataset_len': 10000, + 'dataset_len': 100000, 'flag': False} # TODO 获得最优的batch size - # TODO 自动获取CPU核心数并设置num workers + num_workers = cpu_count() precision = 32 if (gpus is None and tpu_cores is None) else precision dm = DataModule(batch_size=batch_size, num_workers=num_workers, config=config) logger = pl_loggers.TensorBoardLogger('logs/') if stage == 'fit': - training_module = TrainModule(config=config) + # SaveCheckpoint的创建需要在TrainModule之前, 以保证网络参数初始化的确定性 save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs, save_name=save_name, path_final_save=path_final_save, every_n_epochs=every_n_epochs, verbose=True, monitor='Validation loss', save_top_k=save_top_k, mode='min') + training_module = TrainModule(config=config) if load_checkpoint_path is None: print('进行初始训练') trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, @@ -83,6 +83,7 @@ def main(stage, trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, resume_from_checkpoint='./logs/default' + load_checkpoint_path, logger=logger, precision=precision, callbacks=[save_checkpoint]) + print('训练过程中请注意gpu利用率等情况') trainer.fit(training_module, datamodule=dm) if stage == 'test': if load_checkpoint_path is None: @@ -98,7 +99,7 @@ def main(stage, if __name__ == "__main__": - main('test', num_workers=8, max_epochs=5, batch_size=32, precision=16, seed=1234, + main('fit', max_epochs=5, batch_size=32, precision=16, seed=1234, # gpus=1, - load_checkpoint_path='/version_4/checkpoints/epoch=4-step=7814.ckpt', + # load_checkpoint_path='/version_4/checkpoints/epoch=4-step=7814.ckpt', ) diff --git a/save_checkpoint.py b/save_checkpoint.py index ba28633..31ff475 100644 --- a/save_checkpoint.py +++ b/save_checkpoint.py @@ -1,6 +1,7 @@ import os + +import numpy.random from pytorch_lightning.callbacks import ModelCheckpoint -import pytorch_lightning import pytorch_lightning as pl import shutil import random @@ -35,12 +36,9 @@ class SaveCheckpoint(ModelCheckpoint): :param no_save_before_epoch: """ super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode) - random.seed(seed) - self.seeds = [] - for i in range(max_epochs): - self.seeds.append(random.randint(0, 2000)) - self.seeds.append(0) - pytorch_lightning.seed_everything(seed) + numpy.random.seed(seed) + self.seeds = numpy.random.randint(0, 2000, max_epochs) + pl.seed_everything(seed) self.save_name = save_name self.path_final_save = path_final_save self.monitor = monitor @@ -57,11 +55,11 @@ class SaveCheckpoint(ModelCheckpoint): :param pl_module: :return: """ + # 第一个epoch使用原始输入seed作为种子, 后续的epoch使用seeds中的第epoch-1个作为种子 if self.flag_sanity_check == 0: - pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch]) self.flag_sanity_check = 1 else: - pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch + 1]) + pl.seed_everything(self.seeds[trainer.current_epoch]) super().on_validation_end(trainer, pl_module) def _save_top_k_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates) -> None: