|
|
|
@@ -218,6 +218,11 @@ class DistributedSampler(BuiltinSampler): |
|
|
|
if not isinstance(shuffle, bool): |
|
|
|
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) |
|
|
|
|
|
|
|
if num_samples is not None: |
|
|
|
if num_samples <= 0: |
|
|
|
raise ValueError("num_samples should be a positive integer " |
|
|
|
"value, but got num_samples={}".format(num_samples)) |
|
|
|
|
|
|
|
self.num_shards = num_shards |
|
|
|
self.shard_id = shard_id |
|
|
|
self.shuffle = shuffle |
|
|
|
@@ -282,6 +287,11 @@ class PKSampler(BuiltinSampler): |
|
|
|
if not isinstance(shuffle, bool): |
|
|
|
raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle)) |
|
|
|
|
|
|
|
if num_samples is not None: |
|
|
|
if num_samples <= 0: |
|
|
|
raise ValueError("num_samples should be a positive integer " |
|
|
|
"value, but got num_samples={}".format(num_samples)) |
|
|
|
|
|
|
|
self.num_val = num_val |
|
|
|
self.shuffle = shuffle |
|
|
|
self.class_column = class_column # work for minddataset |
|
|
|
@@ -385,6 +395,16 @@ class SequentialSampler(BuiltinSampler): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, start_index=None, num_samples=None): |
|
|
|
if num_samples is not None: |
|
|
|
if num_samples <= 0: |
|
|
|
raise ValueError("num_samples should be a positive integer " |
|
|
|
"value, but got num_samples={}".format(num_samples)) |
|
|
|
|
|
|
|
if start_index is not None: |
|
|
|
if start_index < 0: |
|
|
|
raise ValueError("start_index should be a positive integer " |
|
|
|
"value or 0, but got start_index={}".format(start_index)) |
|
|
|
|
|
|
|
self.start_index = start_index |
|
|
|
super().__init__(num_samples) |
|
|
|
|
|
|
|
@@ -430,6 +450,11 @@ class SubsetRandomSampler(BuiltinSampler): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, indices, num_samples=None): |
|
|
|
if num_samples is not None: |
|
|
|
if num_samples <= 0: |
|
|
|
raise ValueError("num_samples should be a positive integer " |
|
|
|
"value, but got num_samples={}".format(num_samples)) |
|
|
|
|
|
|
|
if not isinstance(indices, list): |
|
|
|
indices = [indices] |
|
|
|
|
|
|
|
|