Browse Source

使用较为复杂的拟合网络; 将训练结果输出至文件;

master
shenyan 4 years ago
parent
commit
436625c6ff
4 changed files with 31 additions and 5 deletions
  1. +4
    -1
      data_module.py
  2. +2
    -2
      main.py
  3. +25
    -0
      save_checkpoint.py
  4. +0
    -2
      utils.py

+ 4
- 1
data_module.py View File

@@ -31,7 +31,10 @@ class DataModule(pl.LightningDataModule):
if not os.path.exists(self.dataset_path + '/dataset_list.txt'): if not os.path.exists(self.dataset_path + '/dataset_list.txt'):
x = torch.randn(self.config['dataset_len'], self.config['dim_in']) x = torch.randn(self.config['dataset_len'], self.config['dim_in'])
noise = torch.randn(self.config['dataset_len']) noise = torch.randn(self.config['dataset_len'])
y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + noise
y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + torch.cos(torch.sin(x[:, 2] ** 3)) + torch.arctan(
x[:, 4]) + noise
assert (x[torch.isnan(x)].shape[0] == 0)
assert (y[torch.isnan(y)].shape[0] == 0)
with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
for line in range(self.config['dataset_len']): for line in range(self.config['dataset_len']):
f.write(' '.join([str(temp) for temp in x[line].tolist()]) + ' ' + str(y[line].item()) + '\n') f.write(' '.join([str(temp) for temp in x[line].tolist()]) + ' ' + str(y[line].item()) + '\n')


+ 2
- 2
main.py View File

@@ -54,11 +54,11 @@ def main(stage,
# TODO 获得最优的batch size # TODO 获得最优的batch size
num_workers = cpu_count() num_workers = cpu_count()
# 获得非通用参数 # 获得非通用参数
config = {'dim_in': 2,
config = {'dim_in': 5,
'dim': 10, 'dim': 10,
'res_coef': 0.5, 'res_coef': 0.5,
'dropout_p': 0.1, 'dropout_p': 0.1,
'n_layers': 2,
'n_layers': 3,
'dataset_len': 100000} 'dataset_len': 100000}
for kth_fold in range(kth_fold_start, k_fold): for kth_fold in range(kth_fold_start, k_fold):
load_checkpoint_path = get_ckpt_path(version_nth, kth_fold) load_checkpoint_path = get_ckpt_path(version_nth, kth_fold)


+ 25
- 0
save_checkpoint.py View File

@@ -66,6 +66,31 @@ class SaveCheckpoint(ModelCheckpoint):


if self.check_monitor_top_k(trainer, current): if self.check_monitor_top_k(trainer, current):
self._update_best_and_save(current, trainer, monitor_candidates) self._update_best_and_save(current, trainer, monitor_candidates)
best_model_value = max([float(item) for item in list(self.best_k_models.values())])
# 保存版本信息(准确率等)到txt中
if not os.path.exists('./logs/default/version_info.txt'):
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f:
f.write(self.dirpath.split('\\')[1] + ' ' + str(best_model_value) + '\n')
else:
with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f:
info_list = f.readlines()
info_list = [item.strip('\n').split(' ') for item in info_list]
# 对list进行转置, 现在行为版本号和其数据, 列为不同的版本
info_list = list(map(list, zip(*info_list)))
if self.dirpath.split('\\')[1] in info_list[0]:
for cou in range(len(info_list[0])):
if self.dirpath.split('\\')[1] == info_list[0][cou]:
info_list[1][cou] = str(best_model_value)
else:
info_list[0].append(self.dirpath.split('\\')[1])
info_list[1].append(str(best_model_value))
# 对list进行转置
info_list = list(map(list, zip(*info_list)))
with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f:
for line in info_list:
line = " ".join(line)
f.write(line + '\n')
# 每次更新ckpt文件后, 将其存放到另一个位置
if self.path_final_save is not None: if self.path_final_save is not None:
zip_dir('./logs', './logs.zip') zip_dir('./logs', './logs.zip')
if os.path.exists(self.path_final_save + '/logs.zip'): if os.path.exists(self.path_final_save + '/logs.zip'):


+ 0
- 2
utils.py View File

@@ -2,10 +2,8 @@
import glob import glob
import os import os
import random import random
import string
import zipfile import zipfile
import cv2 import cv2
import numpy
import torch import torch






Loading…
Cancel
Save