|
|
|
@@ -19,8 +19,8 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler. |
|
|
|
User can also define custom sampler by extending from Sampler class. |
|
|
|
""" |
|
|
|
|
|
|
|
import mindspore._c_dataengine as cde |
|
|
|
import numpy as np |
|
|
|
import mindspore._c_dataengine as cde |
|
|
|
|
|
|
|
|
|
|
|
class Sampler: |
|
|
|
@@ -137,6 +137,7 @@ class DistributedSampler(BuiltinSampler): |
|
|
|
self.shard_id = shard_id |
|
|
|
self.shuffle = shuffle |
|
|
|
self.seed = 0 |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
def create(self): |
|
|
|
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle |
|
|
|
@@ -182,6 +183,7 @@ class PKSampler(BuiltinSampler): |
|
|
|
self.num_val = num_val |
|
|
|
self.shuffle = shuffle |
|
|
|
self.class_column = class_column # work for minddataset |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
def create(self): |
|
|
|
return cde.PKSampler(self.num_val, self.shuffle) |
|
|
|
@@ -192,6 +194,7 @@ class PKSampler(BuiltinSampler): |
|
|
|
but got class_column={}".format(class_column)) |
|
|
|
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) |
|
|
|
|
|
|
|
|
|
|
|
class RandomSampler(BuiltinSampler): |
|
|
|
""" |
|
|
|
Samples the elements randomly. |
|
|
|
@@ -225,6 +228,7 @@ class RandomSampler(BuiltinSampler): |
|
|
|
|
|
|
|
self.replacement = replacement |
|
|
|
self.num_samples = num_samples |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
def create(self): |
|
|
|
# If num_samples is not specified, then call constructor #2 |
|
|
|
@@ -275,6 +279,7 @@ class SubsetRandomSampler(BuiltinSampler): |
|
|
|
indices = [indices] |
|
|
|
|
|
|
|
self.indices = indices |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
def create(self): |
|
|
|
return cde.SubsetRandomSampler(self.indices) |
|
|
|
@@ -322,6 +327,7 @@ class WeightedRandomSampler(BuiltinSampler): |
|
|
|
self.weights = weights |
|
|
|
self.num_samples = num_samples |
|
|
|
self.replacement = replacement |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
def create(self): |
|
|
|
return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement) |