| @@ -19,8 +19,8 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler. | |||||
| User can also define custom sampler by extending from Sampler class. | User can also define custom sampler by extending from Sampler class. | ||||
| """ | """ | ||||
| import mindspore._c_dataengine as cde | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore._c_dataengine as cde | |||||
| class Sampler: | class Sampler: | ||||
| @@ -137,6 +137,7 @@ class DistributedSampler(BuiltinSampler): | |||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| self.shuffle = shuffle | self.shuffle = shuffle | ||||
| self.seed = 0 | self.seed = 0 | ||||
| super().__init__() | |||||
| def create(self): | def create(self): | ||||
| # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle | # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle | ||||
| @@ -182,6 +183,7 @@ class PKSampler(BuiltinSampler): | |||||
| self.num_val = num_val | self.num_val = num_val | ||||
| self.shuffle = shuffle | self.shuffle = shuffle | ||||
| self.class_column = class_column # work for minddataset | self.class_column = class_column # work for minddataset | ||||
| super().__init__() | |||||
| def create(self): | def create(self): | ||||
| return cde.PKSampler(self.num_val, self.shuffle) | return cde.PKSampler(self.num_val, self.shuffle) | ||||
| @@ -192,6 +194,7 @@ class PKSampler(BuiltinSampler): | |||||
| but got class_column={}".format(class_column)) | but got class_column={}".format(class_column)) | ||||
| return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) | return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) | ||||
| class RandomSampler(BuiltinSampler): | class RandomSampler(BuiltinSampler): | ||||
| """ | """ | ||||
| Samples the elements randomly. | Samples the elements randomly. | ||||
| @@ -225,6 +228,7 @@ class RandomSampler(BuiltinSampler): | |||||
| self.replacement = replacement | self.replacement = replacement | ||||
| self.num_samples = num_samples | self.num_samples = num_samples | ||||
| super().__init__() | |||||
| def create(self): | def create(self): | ||||
| # If num_samples is not specified, then call constructor #2 | # If num_samples is not specified, then call constructor #2 | ||||
| @@ -275,6 +279,7 @@ class SubsetRandomSampler(BuiltinSampler): | |||||
| indices = [indices] | indices = [indices] | ||||
| self.indices = indices | self.indices = indices | ||||
| super().__init__() | |||||
| def create(self): | def create(self): | ||||
| return cde.SubsetRandomSampler(self.indices) | return cde.SubsetRandomSampler(self.indices) | ||||
| @@ -322,6 +327,7 @@ class WeightedRandomSampler(BuiltinSampler): | |||||
| self.weights = weights | self.weights = weights | ||||
| self.num_samples = num_samples | self.num_samples = num_samples | ||||
| self.replacement = replacement | self.replacement = replacement | ||||
| super().__init__() | |||||
| def create(self): | def create(self): | ||||
| return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement) | return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement) | ||||