|
|
|
@@ -3,10 +3,10 @@ from data_module import DataModule |
|
|
|
from pytorch_lightning import loggers as pl_loggers |
|
|
|
import pytorch_lightning as pl |
|
|
|
from train_model import TrainModule |
|
|
|
from multiprocessing import cpu_count |
|
|
|
|
|
|
|
|
|
|
|
def main(stage, |
|
|
|
num_workers, |
|
|
|
max_epochs, |
|
|
|
batch_size, |
|
|
|
precision, |
|
|
|
@@ -24,7 +24,6 @@ def main(stage, |
|
|
|
该函数的参数为训练过程中需要经常改动的参数 |
|
|
|
|
|
|
|
:param stage: 表示处于训练阶段还是测试阶段, fit表示训练, test表示测试 |
|
|
|
:param num_workers: |
|
|
|
:param max_epochs: |
|
|
|
:param batch_size: |
|
|
|
:param precision: 训练精度, 正常精度为32, 半精度为16, 也可以是64. 精度代表每个参数的类型所占的位数 |
|
|
|
@@ -50,7 +49,7 @@ def main(stage, |
|
|
|
'res_coef': 0.5, |
|
|
|
'dropout_p': 0.1, |
|
|
|
'n_layers': 2, |
|
|
|
'dataset_len': 10000, |
|
|
|
'dataset_len': 100000, |
|
|
|
'flag': True} |
|
|
|
else: |
|
|
|
config = {'dataset_path': dataset_path, |
|
|
|
@@ -59,20 +58,21 @@ def main(stage, |
|
|
|
'res_coef': 0.5, |
|
|
|
'dropout_p': 0.1, |
|
|
|
'n_layers': 20, |
|
|
|
'dataset_len': 10000, |
|
|
|
'dataset_len': 100000, |
|
|
|
'flag': False} |
|
|
|
# TODO 获得最优的batch size |
|
|
|
# TODO 自动获取CPU核心数并设置num workers |
|
|
|
num_workers = cpu_count() |
|
|
|
precision = 32 if (gpus is None and tpu_cores is None) else precision |
|
|
|
dm = DataModule(batch_size=batch_size, num_workers=num_workers, config=config) |
|
|
|
logger = pl_loggers.TensorBoardLogger('logs/') |
|
|
|
if stage == 'fit': |
|
|
|
training_module = TrainModule(config=config) |
|
|
|
# SaveCheckpoint的创建需要在TrainModule之前, 以保证网络参数初始化的确定性 |
|
|
|
save_checkpoint = SaveCheckpoint(seed=seed, max_epochs=max_epochs, |
|
|
|
save_name=save_name, path_final_save=path_final_save, |
|
|
|
every_n_epochs=every_n_epochs, verbose=True, |
|
|
|
monitor='Validation loss', save_top_k=save_top_k, |
|
|
|
mode='min') |
|
|
|
training_module = TrainModule(config=config) |
|
|
|
if load_checkpoint_path is None: |
|
|
|
print('进行初始训练') |
|
|
|
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, |
|
|
|
@@ -83,6 +83,7 @@ def main(stage, |
|
|
|
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, tpu_cores=tpu_cores, |
|
|
|
resume_from_checkpoint='./logs/default' + load_checkpoint_path, |
|
|
|
logger=logger, precision=precision, callbacks=[save_checkpoint]) |
|
|
|
print('训练过程中请注意gpu利用率等情况') |
|
|
|
trainer.fit(training_module, datamodule=dm) |
|
|
|
if stage == 'test': |
|
|
|
if load_checkpoint_path is None: |
|
|
|
@@ -98,7 +99,7 @@ def main(stage, |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main('test', num_workers=8, max_epochs=5, batch_size=32, precision=16, seed=1234, |
|
|
|
main('fit', max_epochs=5, batch_size=32, precision=16, seed=1234, |
|
|
|
# gpus=1, |
|
|
|
load_checkpoint_path='/version_4/checkpoints/epoch=4-step=7814.ckpt', |
|
|
|
# load_checkpoint_path='/version_4/checkpoints/epoch=4-step=7814.ckpt', |
|
|
|
) |