diff --git a/examples/hed/utils.py b/examples/hed/utils.py index cf35eaf..42b7316 100644 --- a/examples/hed/utils.py +++ b/examples/hed/utils.py @@ -1,6 +1,21 @@ import torch import torch.nn as nn import numpy as np +import torch.utils.data.sampler as sampler + + +class InfiniteSampler(sampler.Sampler): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __iter__(self): + while True: + order = np.random.permutation(self.num_samples) + for i in range(self.num_samples): + yield order[i] + + def __len__(self): + return None def gen_mappings(chars, symbs):