|
|
|
@@ -1,6 +1,21 @@ |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import numpy as np |
|
|
|
import torch.utils.data.sampler as sampler |
|
|
|
|
|
|
|
|
|
|
|
class InfiniteSampler(sampler.Sampler): |
|
|
|
def __init__(self, num_samples): |
|
|
|
self.num_samples = num_samples |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
while True: |
|
|
|
order = np.random.permutation(self.num_samples) |
|
|
|
for i in range(self.num_samples): |
|
|
|
yield order[i] |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def gen_mappings(chars, symbs): |
|
|
|
|