|
|
|
@@ -152,6 +152,7 @@ class PKSampler(BuiltinSampler): |
|
|
|
num_val (int): Number of elements to sample for each class. |
|
|
|
num_class (int, optional): Number of classes to sample (default=None, all classes). |
|
|
|
shuffle (bool, optional): If true, the class IDs are shuffled (default=False). |
|
|
|
class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> import mindspore.dataset as ds |
|
|
|
@@ -168,7 +169,7 @@ class PKSampler(BuiltinSampler): |
|
|
|
ValueError: If shuffle is not boolean. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, num_val, num_class=None, shuffle=False): |
|
|
|
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'): |
|
|
|
if num_val <= 0: |
|
|
|
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) |
|
|
|
|
|
|
|
@@ -180,12 +181,16 @@ class PKSampler(BuiltinSampler): |
|
|
|
|
|
|
|
self.num_val = num_val |
|
|
|
self.shuffle = shuffle |
|
|
|
self.class_column = class_column # work for minddataset |
|
|
|
|
|
|
|
def create(self): |
|
|
|
return cde.PKSampler(self.num_val, self.shuffle) |
|
|
|
|
|
|
|
def _create_for_minddataset(self): |
|
|
|
return cde.MindrecordPkSampler(self.num_val, self.shuffle) |
|
|
|
if not self.class_column or not isinstance(self.class_column, str): |
|
|
|
raise ValueError("class_column should be a not empty string value, \ |
|
|
|
but got class_column={}".format(class_column)) |
|
|
|
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) |
|
|
|
|
|
|
|
class RandomSampler(BuiltinSampler): |
|
|
|
""" |
|
|
|
|