diff --git a/data_module.py b/data_module.py index 9a32542..76a841d 100644 --- a/data_module.py +++ b/data_module.py @@ -31,7 +31,10 @@ class DataModule(pl.LightningDataModule): if not os.path.exists(self.dataset_path + '/dataset_list.txt'): x = torch.randn(self.config['dataset_len'], self.config['dim_in']) noise = torch.randn(self.config['dataset_len']) - y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + noise + y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + torch.cos(torch.sin(x[:, 2] ** 3)) + torch.arctan( + x[:, 4]) + noise + assert (x[torch.isnan(x)].shape[0] == 0) + assert (y[torch.isnan(y)].shape[0] == 0) with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: for line in range(self.config['dataset_len']): f.write(' '.join([str(temp) for temp in x[line].tolist()]) + ' ' + str(y[line].item()) + '\n') diff --git a/main.py b/main.py index de437ca..7a24c35 100644 --- a/main.py +++ b/main.py @@ -54,11 +54,11 @@ def main(stage, # TODO 获得最优的batch size num_workers = cpu_count() # 获得非通用参数 - config = {'dim_in': 2, + config = {'dim_in': 5, 'dim': 10, 'res_coef': 0.5, 'dropout_p': 0.1, - 'n_layers': 2, + 'n_layers': 3, 'dataset_len': 100000} for kth_fold in range(kth_fold_start, k_fold): load_checkpoint_path = get_ckpt_path(version_nth, kth_fold) diff --git a/save_checkpoint.py b/save_checkpoint.py index 6baa33d..adb05de 100644 --- a/save_checkpoint.py +++ b/save_checkpoint.py @@ -66,6 +66,31 @@ class SaveCheckpoint(ModelCheckpoint): if self.check_monitor_top_k(trainer, current): self._update_best_and_save(current, trainer, monitor_candidates) + best_model_value = max([float(item) for item in list(self.best_k_models.values())]) + # 保存版本信息(准确率等)到txt中 + if not os.path.exists('./logs/default/version_info.txt'): + with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: + f.write(self.dirpath.split('\\')[1] + ' ' + str(best_model_value) + '\n') + else: + with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f: + info_list = f.readlines() + info_list = [item.strip('\n').split(' ') for item in info_list] + # 对list进行转置, 现在行为版本号和其数据, 列为不同的版本 + info_list = list(map(list, zip(*info_list))) + if self.dirpath.split('\\')[1] in info_list[0]: + for cou in range(len(info_list[0])): + if self.dirpath.split('\\')[1] == info_list[0][cou]: + info_list[1][cou] = str(best_model_value) + else: + info_list[0].append(self.dirpath.split('\\')[1]) + info_list[1].append(str(best_model_value)) + # 对list进行转置 + info_list = list(map(list, zip(*info_list))) + with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: + for line in info_list: + line = " ".join(line) + f.write(line + '\n') + # 每次更新ckpt文件后, 将其存放到另一个位置 if self.path_final_save is not None: zip_dir('./logs', './logs.zip') if os.path.exists(self.path_final_save + '/logs.zip'): diff --git a/utils.py b/utils.py index 8e43915..f3b2ffa 100644 --- a/utils.py +++ b/utils.py @@ -2,10 +2,8 @@ import glob import os import random -import string import zipfile import cv2 -import numpy import torch