You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_module.py 3.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. import numpy
  3. import torch
  4. from torch import Tensor
  5. from torch.utils.data import Dataset, DataLoader
  6. import pytorch_lightning as pl
  7. class DataModule(pl.LightningDataModule):
  8. def __init__(self, batch_size, num_workers, k_fold, kth_fold, dataset_path, config=None):
  9. super().__init__()
  10. self.batch_size = batch_size
  11. self.num_workers = num_workers
  12. self.config = config
  13. self.k_fold = k_fold
  14. self.kth_fold = kth_fold
  15. self.dataset_path = dataset_path
  16. def setup(self, stage=None) -> None:
  17. # 得到全部数据的list
  18. dataset_list = self.get_dataset_list()
  19. if stage == 'fit' or stage is None:
  20. dataset_train, dataset_val = self.get_dataset_lists(dataset_list)
  21. self.train_dataset = CustomDataset(dataset_train, self.config)
  22. self.val_dataset = CustomDataset(dataset_val, self.config)
  23. if stage == 'test' or stage is None:
  24. self.test_dataset = CustomDataset(dataset_list, self.config)
  25. def get_dataset_list(self):
  26. if not os.path.exists(self.dataset_path + '/dataset_list.txt'):
  27. # 针对数据拟合获得dataset
  28. dataset = torch.randn(self.config['dataset_len'], self.config['dim_in'] + 1)
  29. noise = torch.randn(self.config['dataset_len'])
  30. dataset[:, self.config['dim_in']] = torch.cos(1.5 * dataset[:, 0]) * (dataset[:, 1] ** 2.0) + torch.cos(
  31. torch.sin(dataset[:, 2] ** 3)) + torch.arctan(dataset[:, 4]) + noise
  32. assert (dataset[torch.isnan(dataset)].shape[0] == 0)
  33. with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
  34. for line in range(self.config['dataset_len']):
  35. f.write(' '.join([str(temp) for temp in dataset[line].tolist()]) + '\n')
  36. print('已生成新的数据list')
  37. else:
  38. dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines()
  39. # 针对数据拟合获得dataset
  40. dataset_list = [[float(temp) for temp in item.strip('\n').split(' ')] for item in dataset_list]
  41. dataset = torch.Tensor(dataset_list).float()
  42. return dataset
  43. def get_dataset_lists(self, dataset_list: Tensor):
  44. # 得到一个fold的数据量和不够组成一个fold的剩余数据的数据量
  45. num_1fold, remainder = divmod(self.config['dataset_len'], self.k_fold)
  46. # 分割全部数据, 得到训练集, 验证集, 测试集
  47. dataset_val = dataset_list[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder), :]
  48. temp = torch.ones(dataset_list.shape[0])
  49. temp[num_1fold * self.kth_fold:(num_1fold * (self.kth_fold + 1) + remainder)] = 0
  50. dataset_train = dataset_list[temp == 1]
  51. return dataset_train, dataset_val
  52. def train_dataloader(self):
  53. return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
  54. pin_memory=True)
  55. def val_dataloader(self):
  56. return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,
  57. pin_memory=True)
  58. def test_dataloader(self):
  59. return DataLoader(self.test_dataset, batch_size=1, shuffle=False, num_workers=self.num_workers,
  60. pin_memory=True)
  61. class CustomDataset(Dataset):
  62. def __init__(self, dataset, config):
  63. super().__init__()
  64. self.x = dataset[:, 0:config['dim_in']]
  65. self.y = dataset[:, config['dim_in']]
  66. def __getitem__(self, idx):
  67. return self.x[idx, :], self.y[idx]
  68. def __len__(self):
  69. return self.x.shape[0]

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)