Browse Source

改用pin_memory=True;添加对gpu性能的注意说明;自动获取核心数并设置其为num_workers;简化每个epoch的随机数更换逻辑

master
shenyan 4 years ago
parent
commit
b114979ea5
3 changed files with 22 additions and 20 deletions
  1. +6
    -3
      data_module.py
  2. +9
    -8
      main.py
  3. +7
    -9
      save_checkpoint.py

+ 6
- 3
data_module.py View File

@@ -42,13 +42,16 @@ class DataModule(pl.LightningDataModule):
self.test_dataset = CustomDataset(self.x, self.y, self.config)

def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
pin_memory=True)

def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,
pin_memory=True)

def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers)
return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers,
pin_memory=True)


class CustomDataset(Dataset):


+ 9
- 8
main.py View File

@@ -3,10 +3,10 @@ from data_module import DataModule
from pytorch_lightning import loggers as pl_loggers
import pytorch_lightning as pl
from train_model import TrainModule
from multiprocessing import cpu_count


def main(stage,
num_workers,
max_epochs,
batch_size,
precision,
@@ -24,7 +24,6 @@ def main(stage,
该函数的参数为训练过程中需要经常改动的参数

:param stage: 表示处于训练阶段还是测试阶段, fit表示训练, test表示测试
:param num_workers:
:param max_epochs:
:param batch_size:
:param precision: 训练精度, 正常精度为32, 半精度为16, 也可以是64. 精度代表每个参数的类型所占的位数
@@ -50,7 +49,7 @@ def main(stage,
'res_coef': 0.5,
'dropout_p': 0.1,
'n_layers': 2,
'dataset_len': 10000,
'dataset_len': 100000,
'flag': True}
else:
config = {'dataset_path': dataset_path,
@@ -59,20 +58,21 @@ def main(stage,
'res_coef': 0.5,
'dropout_p': 0.1,
'n_layers': 20,
'dataset_len': 10000,
'dataset_len': 100000,
'flag': False}
# TODO 获得最优的batch size
# TODO 自动获取CPU核心数并设置num workers
num_workers = cpu_count()
precision = 32 if (gpus is None and tpu_cores is None) else precision
dm = DataModule(batch_size=batch_size, num_workers=num_workers, config=config)
logger = pl_loggers.TensorBoardLogger('logs/')
if stage == 'fit':
training_module = TrainModule(config=config)
# SaveCheckpoint的创建需要在TrainModule之前, 以保证网络参数初始化的确定性
save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs,
save_name=save_name, path_final_save=path_final_save,
every_n_epochs=every_n_epochs, verbose=True,
monitor='Validation loss', save_top_k=save_top_k,
mode='min')
training_module = TrainModule(config=config)
if load_checkpoint_path is None:
print('进行初始训练')
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
@@ -83,6 +83,7 @@ def main(stage,
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores,
resume_from_checkpoint='./logs/default' + load_checkpoint_path,
logger=logger, precision=precision, callbacks=[save_checkpoint])
print('训练过程中请注意gpu利用率等情况')
trainer.fit(training_module, datamodule=dm)
if stage == 'test':
if load_checkpoint_path is None:
@@ -98,7 +99,7 @@ def main(stage,


if __name__ == "__main__":
main('test', num_workers=8, max_epochs=5, batch_size=32, precision=16, seed=1234,
main('fit', max_epochs=5, batch_size=32, precision=16, seed=1234,
# gpus=1,
load_checkpoint_path='/version_4/checkpoints/epoch=4-step=7814.ckpt',
# load_checkpoint_path='/version_4/checkpoints/epoch=4-step=7814.ckpt',
)

+ 7
- 9
save_checkpoint.py View File

@@ -1,6 +1,7 @@
import os

import numpy.random
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning
import pytorch_lightning as pl
import shutil
import random
@@ -35,12 +36,9 @@ class SaveCheckpoint(ModelCheckpoint):
:param no_save_before_epoch:
"""
super().__init__(every_n_epochs=every_n_epochs, verbose=verbose, mode=mode)
random.seed(seed)
self.seeds = []
for i in range(max_epochs):
self.seeds.append(random.randint(0, 2000))
self.seeds.append(0)
pytorch_lightning.seed_everything(seed)
numpy.random.seed(seed)
self.seeds = numpy.random.randint(0, 2000, max_epochs)
pl.seed_everything(seed)
self.save_name = save_name
self.path_final_save = path_final_save
self.monitor = monitor
@@ -57,11 +55,11 @@ class SaveCheckpoint(ModelCheckpoint):
:param pl_module:
:return:
"""
# 第一个epoch使用原始输入seed作为种子, 后续的epoch使用seeds中的第epoch-1个作为种子
if self.flag_sanity_check == 0:
pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch])
self.flag_sanity_check = 1
else:
pytorch_lightning.seed_everything(self.seeds[trainer.current_epoch + 1])
pl.seed_everything(self.seeds[trainer.current_epoch])
super().on_validation_end(trainer, pl_module)

def _save_top_k_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates) -> None:


Loading…
Cancel
Save