Browse Source

改变产生拟合函数数据的方式;

master
shenyan 4 years ago
parent
commit
6e1eecfc30
2 changed files with 25 additions and 27 deletions
  1. +25
    -27
      data_module.py
  2. BIN
      requirements.txt

+ 25
- 27
data_module.py View File

@@ -18,45 +18,44 @@ class DataModule(pl.LightningDataModule):

def setup(self, stage=None) -> None:
# 得到全部数据的list
# dataset_list = get_dataset_list(dataset_path)
x, y = self.get_fit_dataset_list()
dataset_list = self.get_dataset_list()
if stage == 'fit' or stage is None:
x_train, y_train, x_val, y_val = self.get_dataset_lists(x, y)
self.train_dataset = CustomDataset(x_train, y_train, self.config)
self.val_dataset = CustomDataset(x_val, y_val, self.config)
dataset_train, dataset_val = self.get_dataset_lists(dataset_list)
self.train_dataset = CustomDataset(dataset_train, self.config)
self.val_dataset = CustomDataset(dataset_val, self.config)
if stage == 'test' or stage is None:
self.test_dataset = CustomDataset(x, y, self.config)
self.test_dataset = CustomDataset(dataset_list, self.config)

def get_fit_dataset_list(self):
def get_dataset_list(self):
if not os.path.exists(self.dataset_path + '/dataset_list.txt'):
x = torch.randn(self.config['dataset_len'], self.config['dim_in'])
# 针对数据拟合获得dataset
dataset = torch.randn(self.config['dataset_len'], self.config['dim_in'] + 1)
noise = torch.randn(self.config['dataset_len'])
y = torch.cos(1.5 * x[:, 0]) * (x[:, 1] ** 2.0) + torch.cos(torch.sin(x[:, 2] ** 3)) + torch.arctan(
x[:, 4]) + noise
assert (x[torch.isnan(x)].shape[0] == 0)
assert (y[torch.isnan(y)].shape[0] == 0)
dataset[:, self.config['dim_in']] = torch.cos(1.5 * dataset[:, 0]) * (dataset[:, 1] ** 2.0) + torch.cos(
torch.sin(dataset[:, 2] ** 3)) + torch.arctan(dataset[:, 4]) + noise
assert (dataset[torch.isnan(dataset)].shape[0] == 0)
with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
for line in range(self.config['dataset_len']):
f.write(' '.join([str(temp) for temp in x[line].tolist()]) + ' ' + str(y[line].item()) + '\n')
f.write(' '.join([str(temp) for temp in dataset[line].tolist()]) + '\n')
print('已生成新的数据list')
else:
dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines()
# 针对数据拟合获得dataset
dataset_list = [[float(temp) for temp in item.strip('\n').split(' ')] for item in dataset_list]
x = torch.from_numpy(numpy.array(dataset_list)[:, 0:self.config['dim_in']]).float()
y = torch.from_numpy(numpy.array(dataset_list)[:, self.config['dim_in']]).float()
return x, y
dataset = torch.Tensor(dataset_list).float()
return dataset

def get_dataset_lists(self, x: Tensor, y):
def get_dataset_lists(self, dataset_list: Tensor):
# 得到一个fold的数据量和不够组成一个fold的剩余数据的数据量
num_1fold, remainder = divmod(self.config['dataset_len'], self.k_fold)
# 分割全部数据, 得到训练集, 验证集, 测试集
x_val = x[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)]
y_val = y[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)]
temp = torch.ones(x.shape[0])
dataset_val = dataset_list[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder), :]
temp = torch.ones(dataset_list.shape[0])
temp[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] = 0
x_train = x[temp == 1]
y_train = y[temp == 1]
return x_train, y_train, x_val, y_val
dataset_train = dataset_list[temp == 1]
return dataset_train, dataset_val

def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
@@ -72,11 +71,10 @@ class DataModule(pl.LightningDataModule):


class CustomDataset(Dataset):
def __init__(self, x, y, config):
def __init__(self, dataset, config):
super().__init__()
self.x = x
self.y = y
self.config = config
self.x = dataset[:, 0:config['dim_in']]
self.y = dataset[:, config['dim_in']]

def __getitem__(self, idx):
return self.x[idx, :], self.y[idx]


BIN
requirements.txt View File


Loading…
Cancel
Save