Browse Source

Fix CI warning of samplers.py

tags/v0.3.0-alpha
Junhan Hu 5 years ago
parent
commit
1e904ddcee
1 changed files with 7 additions and 1 deletions
  1. +7
    -1
      mindspore/dataset/engine/samplers.py

+ 7
- 1
mindspore/dataset/engine/samplers.py View File

@@ -19,8 +19,8 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
User can also define custom sampler by extending from Sampler class. User can also define custom sampler by extending from Sampler class.
""" """


import mindspore._c_dataengine as cde
import numpy as np import numpy as np
import mindspore._c_dataengine as cde




class Sampler: class Sampler:
@@ -137,6 +137,7 @@ class DistributedSampler(BuiltinSampler):
self.shard_id = shard_id self.shard_id = shard_id
self.shuffle = shuffle self.shuffle = shuffle
self.seed = 0 self.seed = 0
super().__init__()


def create(self): def create(self):
# each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
@@ -182,6 +183,7 @@ class PKSampler(BuiltinSampler):
self.num_val = num_val self.num_val = num_val
self.shuffle = shuffle self.shuffle = shuffle
self.class_column = class_column # work for minddataset self.class_column = class_column # work for minddataset
super().__init__()


def create(self): def create(self):
return cde.PKSampler(self.num_val, self.shuffle) return cde.PKSampler(self.num_val, self.shuffle)
@@ -192,6 +194,7 @@ class PKSampler(BuiltinSampler):
but got class_column={}".format(class_column)) but got class_column={}".format(class_column))
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)



class RandomSampler(BuiltinSampler): class RandomSampler(BuiltinSampler):
""" """
Samples the elements randomly. Samples the elements randomly.
@@ -225,6 +228,7 @@ class RandomSampler(BuiltinSampler):


self.replacement = replacement self.replacement = replacement
self.num_samples = num_samples self.num_samples = num_samples
super().__init__()


def create(self): def create(self):
# If num_samples is not specified, then call constructor #2 # If num_samples is not specified, then call constructor #2
@@ -275,6 +279,7 @@ class SubsetRandomSampler(BuiltinSampler):
indices = [indices] indices = [indices]


self.indices = indices self.indices = indices
super().__init__()


def create(self): def create(self):
return cde.SubsetRandomSampler(self.indices) return cde.SubsetRandomSampler(self.indices)
@@ -322,6 +327,7 @@ class WeightedRandomSampler(BuiltinSampler):
self.weights = weights self.weights = weights
self.num_samples = num_samples self.num_samples = num_samples
self.replacement = replacement self.replacement = replacement
super().__init__()


def create(self): def create(self):
return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement) return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)

Loading…
Cancel
Save