Browse Source

使用较为复杂的拟合网络; 将训练结果输出至文件;

master
shenyan 4 years ago
parent
commit
436625c6ff
4 changed files with 31 additions and 5 deletions
  1. +4
    -1
      data_module.py
  2. +2
    -2
      main.py
  3. +25
    -0
      save_checkpoint.py
  4. +0
    -2
      utils.py

+ 4
- 1
data_module.py View File

@@ -31,7 +31,10 @@ class DataModule(pl.LightningDataModule):
if not os.path.exists(self.dataset_path + '/dataset_list.txt'):
x = torch.randn(self.config['dataset_len'], self.config['dim_in'])
noise = torch.randn(self.config['dataset_len'])
y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + noise
y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + torch.cos(torch.sin(x[:, 2] ** 3)) + torch.arctan(
x[:, 4]) + noise
assert (x[torch.isnan(x)].shape[0] == 0)
assert (y[torch.isnan(y)].shape[0] == 0)
with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
for line in range(self.config['dataset_len']):
f.write(' '.join([str(temp) for temp in x[line].tolist()]) + ' ' + str(y[line].item()) + '\n')


+ 2
- 2
main.py View File

@@ -54,11 +54,11 @@ def main(stage,
# TODO 获得最优的batch size
num_workers = cpu_count()
# 获得非通用参数
config = {'dim_in': 2,
config = {'dim_in': 5,
'dim': 10,
'res_coef': 0.5,
'dropout_p': 0.1,
'n_layers': 2,
'n_layers': 3,
'dataset_len': 100000}
for kth_fold in range(kth_fold_start, k_fold):
load_checkpoint_path = get_ckpt_path(version_nth, kth_fold)


+ 25
- 0
save_checkpoint.py View File

@@ -66,6 +66,31 @@ class SaveCheckpoint(ModelCheckpoint):

if self.check_monitor_top_k(trainer, current):
self._update_best_and_save(current, trainer, monitor_candidates)
best_model_value = max([float(item) for item in list(self.best_k_models.values())])
# 保存版本信息(准确率等)到txt中
if not os.path.exists('./logs/default/version_info.txt'):
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f:
f.write(self.dirpath.split('\\')[1] + ' ' + str(best_model_value) + '\n')
else:
with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f:
info_list = f.readlines()
info_list = [item.strip('\n').split(' ') for item in info_list]
# 对list进行转置, 现在行为版本号和其数据, 列为不同的版本
info_list = list(map(list, zip(*info_list)))
if self.dirpath.split('\\')[1] in info_list[0]:
for cou in range(len(info_list[0])):
if self.dirpath.split('\\')[1] == info_list[0][cou]:
info_list[1][cou] = str(best_model_value)
else:
info_list[0].append(self.dirpath.split('\\')[1])
info_list[1].append(str(best_model_value))
# 对list进行转置
info_list = list(map(list, zip(*info_list)))
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f:
for line in info_list:
line = " ".join(line)
f.write(line + '\n')
# 每次更新ckpt文件后, 将其存放到另一个位置
if self.path_final_save is not None:
zip_dir('./logs', './logs.zip')
if os.path.exists(self.path_final_save + '/logs.zip'):


+ 0
- 2
utils.py View File

@@ -2,10 +2,8 @@
import glob
import os
import random
import string
import zipfile
import cv2
import numpy
import torch




Loading…
Cancel
Save