From f23c6ce8541add580dd96ccdd2d4b41186c45120 Mon Sep 17 00:00:00 2001 From: shenyan <23357320@qq.com> Date: Fri, 15 Oct 2021 19:24:56 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=BE=97=E5=88=B0=E6=95=B0?= =?UTF-8?q?=E6=8D=AElist=E7=9A=84=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_module.py | 5 +++-- utils.py | 7 +++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/data_module.py b/data_module.py index 38ba325..831a695 100644 --- a/data_module.py +++ b/data_module.py @@ -34,10 +34,11 @@ class DataModule(pl.LightningDataModule): dataset[:, self.config['dim_in']] = torch.cos(1.5 * dataset[:, 0]) * (dataset[:, 1] ** 2.0) + torch.cos( torch.sin(dataset[:, 2] ** 3)) + torch.arctan(dataset[:, 4]) + noise assert (dataset[torch.isnan(dataset)].shape[0] == 0) + written = [' '.join([str(temp) for temp in dataset[cou, :].tolist()]) for cou in range(dataset.shape[0])] with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: - for line in range(self.config['dataset_len']): - f.write(' '.join([str(temp) for temp in dataset[line].tolist()]) + '\n') + for line in written: + f.write(line + '\n') print('已生成新的数据list') else: dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines() diff --git a/utils.py b/utils.py index f3b2ffa..41e8499 100644 --- a/utils.py +++ b/utils.py @@ -11,9 +11,12 @@ def get_dataset_list(dataset_path): if not os.path.exists(dataset_path + '/dataset_list.txt'): all_list = glob.glob(dataset_path + '/labels' + '/*.png') random.shuffle(all_list) + all_list = [os.path.basename(item.replace('\\', '/')) for item in all_list] + written = all_list + with open(dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: - for line in all_list: - f.write(os.path.basename(line.replace('\\', '/')) + '\n') + for line in written: + f.write(line + '\n') print('已生成新的数据list') return all_list else: