| @@ -1,6 +1,21 @@ | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import numpy as np | import numpy as np | ||||
| import torch.utils.data.sampler as sampler | |||||
| class InfiniteSampler(sampler.Sampler): | |||||
| def __init__(self, num_samples): | |||||
| self.num_samples = num_samples | |||||
| def __iter__(self): | |||||
| while True: | |||||
| order = np.random.permutation(self.num_samples) | |||||
| for i in range(self.num_samples): | |||||
| yield order[i] | |||||
| def __len__(self): | |||||
| return None | |||||
| def gen_mappings(chars, symbs): | def gen_mappings(chars, symbs): | ||||