| @@ -31,7 +31,10 @@ class DataModule(pl.LightningDataModule): | |||||
| if not os.path.exists(self.dataset_path + '/dataset_list.txt'): | if not os.path.exists(self.dataset_path + '/dataset_list.txt'): | ||||
| x = torch.randn(self.config['dataset_len'], self.config['dim_in']) | x = torch.randn(self.config['dataset_len'], self.config['dim_in']) | ||||
| noise = torch.randn(self.config['dataset_len']) | noise = torch.randn(self.config['dataset_len']) | ||||
| y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + noise | |||||
| y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + torch.cos(torch.sin(x[:, 2] ** 3)) + torch.arctan( | |||||
| x[:, 4]) + noise | |||||
| assert (x[torch.isnan(x)].shape[0] == 0) | |||||
| assert (y[torch.isnan(y)].shape[0] == 0) | |||||
| with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: | with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: | ||||
| for line in range(self.config['dataset_len']): | for line in range(self.config['dataset_len']): | ||||
| f.write(' '.join([str(temp) for temp in x[line].tolist()]) + ' ' + str(y[line].item()) + '\n') | f.write(' '.join([str(temp) for temp in x[line].tolist()]) + ' ' + str(y[line].item()) + '\n') | ||||
| @@ -54,11 +54,11 @@ def main(stage, | |||||
| # TODO 获得最优的batch size | # TODO 获得最优的batch size | ||||
| num_workers = cpu_count() | num_workers = cpu_count() | ||||
| # 获得非通用参数 | # 获得非通用参数 | ||||
| config = {'dim_in': 2, | |||||
| config = {'dim_in': 5, | |||||
| 'dim': 10, | 'dim': 10, | ||||
| 'res_coef': 0.5, | 'res_coef': 0.5, | ||||
| 'dropout_p': 0.1, | 'dropout_p': 0.1, | ||||
| 'n_layers': 2, | |||||
| 'n_layers': 3, | |||||
| 'dataset_len': 100000} | 'dataset_len': 100000} | ||||
| for kth_fold in range(kth_fold_start, k_fold): | for kth_fold in range(kth_fold_start, k_fold): | ||||
| load_checkpoint_path = get_ckpt_path(version_nth, kth_fold) | load_checkpoint_path = get_ckpt_path(version_nth, kth_fold) | ||||
| @@ -66,6 +66,31 @@ class SaveCheckpoint(ModelCheckpoint): | |||||
| if self.check_monitor_top_k(trainer, current): | if self.check_monitor_top_k(trainer, current): | ||||
| self._update_best_and_save(current, trainer, monitor_candidates) | self._update_best_and_save(current, trainer, monitor_candidates) | ||||
| best_model_value = max([float(item) for item in list(self.best_k_models.values())]) | |||||
| # 保存版本信息(准确率等)到txt中 | |||||
| if not os.path.exists('./logs/default/version_info.txt'): | |||||
| with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: | |||||
| f.write(self.dirpath.split('\\')[1] + ' ' + str(best_model_value) + '\n') | |||||
| else: | |||||
| with open('./logs/default/version_info.txt', 'r', encoding='utf-8') as f: | |||||
| info_list = f.readlines() | |||||
| info_list = [item.strip('\n').split(' ') for item in info_list] | |||||
| # 对list进行转置, 现在行为版本号和其数据, 列为不同的版本 | |||||
| info_list = list(map(list, zip(*info_list))) | |||||
| if self.dirpath.split('\\')[1] in info_list[0]: | |||||
| for cou in range(len(info_list[0])): | |||||
| if self.dirpath.split('\\')[1] == info_list[0][cou]: | |||||
| info_list[1][cou] = str(best_model_value) | |||||
| else: | |||||
| info_list[0].append(self.dirpath.split('\\')[1]) | |||||
| info_list[1].append(str(best_model_value)) | |||||
| # 对list进行转置 | |||||
| info_list = list(map(list, zip(*info_list))) | |||||
| with open('./logs/default/version_info.txt', 'w', encoding='utf-8') as f: | |||||
| for line in info_list: | |||||
| line = " ".join(line) | |||||
| f.write(line + '\n') | |||||
| # 每次更新ckpt文件后, 将其存放到另一个位置 | |||||
| if self.path_final_save is not None: | if self.path_final_save is not None: | ||||
| zip_dir('./logs', './logs.zip') | zip_dir('./logs', './logs.zip') | ||||
| if os.path.exists(self.path_final_save + '/logs.zip'): | if os.path.exists(self.path_final_save + '/logs.zip'): | ||||
| @@ -2,10 +2,8 @@ | |||||
| import glob | import glob | ||||
| import os | import os | ||||
| import random | import random | ||||
| import string | |||||
| import zipfile | import zipfile | ||||
| import cv2 | import cv2 | ||||
| import numpy | |||||
| import torch | import torch | ||||