| @@ -18,45 +18,44 @@ class DataModule(pl.LightningDataModule): | |||||
| def setup(self, stage=None) -> None: | def setup(self, stage=None) -> None: | ||||
| # 得到全部数据的list | # 得到全部数据的list | ||||
| # dataset_list = get_dataset_list(dataset_path) | |||||
| x, y = self.get_fit_dataset_list() | |||||
| dataset_list = self.get_dataset_list() | |||||
| if stage == 'fit' or stage is None: | if stage == 'fit' or stage is None: | ||||
| x_train, y_train, x_val, y_val = self.get_dataset_lists(x, y) | |||||
| self.train_dataset = CustomDataset(x_train, y_train, self.config) | |||||
| self.val_dataset = CustomDataset(x_val, y_val, self.config) | |||||
| dataset_train, dataset_val = self.get_dataset_lists(dataset_list) | |||||
| self.train_dataset = CustomDataset(dataset_train, self.config) | |||||
| self.val_dataset = CustomDataset(dataset_val, self.config) | |||||
| if stage == 'test' or stage is None: | if stage == 'test' or stage is None: | ||||
| self.test_dataset = CustomDataset(x, y, self.config) | |||||
| self.test_dataset = CustomDataset(dataset_list, self.config) | |||||
| def get_fit_dataset_list(self): | |||||
| def get_dataset_list(self): | |||||
| if not os.path.exists(self.dataset_path + '/dataset_list.txt'): | if not os.path.exists(self.dataset_path + '/dataset_list.txt'): | ||||
| x = torch.randn(self.config['dataset_len'], self.config['dim_in']) | |||||
| # 针对数据拟合获得dataset | |||||
| dataset = torch.randn(self.config['dataset_len'], self.config['dim_in'] + 1) | |||||
| noise = torch.randn(self.config['dataset_len']) | noise = torch.randn(self.config['dataset_len']) | ||||
| y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + torch.cos(torch.sin(x[:, 2] ** 3)) + torch.arctan( | |||||
| x[:, 4]) + noise | |||||
| assert (x[torch.isnan(x)].shape[0] == 0) | |||||
| assert (y[torch.isnan(y)].shape[0] == 0) | |||||
| dataset[:, self.config['dim_in']] = torch.cos(1.5 * dataset[:, 0]) * (dataset[:, 1] ** 2.0) + torch.cos( | |||||
| torch.sin(dataset[:, 2] ** 3)) + torch.arctan(dataset[:, 4]) + noise | |||||
| assert (dataset[torch.isnan(dataset)].shape[0] == 0) | |||||
| with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: | with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: | ||||
| for line in range(self.config['dataset_len']): | for line in range(self.config['dataset_len']): | ||||
| f.write(' '.join([str(temp) for temp in x[line].tolist()]) + ' ' + str(y[line].item()) + '\n') | |||||
| f.write(' '.join([str(temp) for temp in dataset[line].tolist()]) + '\n') | |||||
| print('已生成新的数据list') | print('已生成新的数据list') | ||||
| else: | else: | ||||
| dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines() | dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines() | ||||
| # 针对数据拟合获得dataset | |||||
| dataset_list = [[float(temp) for temp in item.strip('\n').split(' ')] for item in dataset_list] | dataset_list = [[float(temp) for temp in item.strip('\n').split(' ')] for item in dataset_list] | ||||
| x = torch.from_numpy(numpy.array(dataset_list)[:, 0:self.config['dim_in']]).float() | |||||
| y = torch.from_numpy(numpy.array(dataset_list)[:, self.config['dim_in']]).float() | |||||
| return x, y | |||||
| dataset = torch.Tensor(dataset_list).float() | |||||
| return dataset | |||||
| def get_dataset_lists(self, x: Tensor, y): | |||||
| def get_dataset_lists(self, dataset_list: Tensor): | |||||
| # 得到一个fold的数据量和不够组成一个fold的剩余数据的数据量 | # 得到一个fold的数据量和不够组成一个fold的剩余数据的数据量 | ||||
| num_1fold, remainder = divmod(self.config['dataset_len'], self.k_fold) | num_1fold, remainder = divmod(self.config['dataset_len'], self.k_fold) | ||||
| # 分割全部数据, 得到训练集, 验证集, 测试集 | # 分割全部数据, 得到训练集, 验证集, 测试集 | ||||
| x_val = x[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] | |||||
| y_val = y[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] | |||||
| temp = torch.ones(x.shape[0]) | |||||
| dataset_val = dataset_list[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder), :] | |||||
| temp = torch.ones(dataset_list.shape[0]) | |||||
| temp[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] = 0 | temp[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] = 0 | ||||
| x_train = x[temp == 1] | |||||
| y_train = y[temp == 1] | |||||
| return x_train, y_train, x_val, y_val | |||||
| dataset_train = dataset_list[temp == 1] | |||||
| return dataset_train, dataset_val | |||||
| def train_dataloader(self): | def train_dataloader(self): | ||||
| return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, | return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, | ||||
| @@ -72,11 +71,10 @@ class DataModule(pl.LightningDataModule): | |||||
| class CustomDataset(Dataset): | class CustomDataset(Dataset): | ||||
| def __init__(self, x, y, config): | |||||
| def __init__(self, dataset, config): | |||||
| super().__init__() | super().__init__() | ||||
| self.x = x | |||||
| self.y = y | |||||
| self.config = config | |||||
| self.x = dataset[:, 0:config['dim_in']] | |||||
| self.y = dataset[:, config['dim_in']] | |||||
| def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
| return self.x[idx, :], self.y[idx] | return self.x[idx, :], self.y[idx] | ||||