Browse Source

!8072 Revert changes to weighted_random_sampler in PR7866

Merge pull request !8072 from luoyang/pylint
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
7e43eaf5e8
2 changed files with 48 additions and 0 deletions
  1. +15
    -0
      mindspore/dataset/engine/samplers.py
  2. +33
    -0
      tests/ut/python/dataset/test_datasets_imagefolder.py

+ 15
- 0
mindspore/dataset/engine/samplers.py View File

@@ -19,6 +19,7 @@ SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler.
Users can also define a custom sampler by extending from the Sampler class. Users can also define a custom sampler by extending from the Sampler class.
""" """


import numbers
import numpy as np import numpy as np
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
import mindspore.dataset as ds import mindspore.dataset as ds
@@ -591,6 +592,20 @@ class WeightedRandomSampler(BuiltinSampler):
if not isinstance(weights, list): if not isinstance(weights, list):
weights = [weights] weights = [weights]


for ind, w in enumerate(weights):
if not isinstance(w, numbers.Number):
raise TypeError("type of weights element should be number, "
"but got w[{}]={}, type={}".format(ind, w, type(w)))

if weights == []:
raise ValueError("weights size should not be 0")

if list(filter(lambda x: x < 0, weights)) != []:
raise ValueError("weights should not contain negative numbers")

if list(filter(lambda x: x == 0, weights)) == weights:
raise ValueError("elements of weights should not be all zero")

if num_samples is not None: if num_samples is not None:
if num_samples <= 0: if num_samples <= 0:
raise ValueError("num_samples should be a positive integer " raise ValueError("num_samples should be a positive integer "


+ 33
- 0
tests/ut/python/dataset/test_datasets_imagefolder.py View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger


@@ -382,6 +383,35 @@ def test_weighted_random_sampler():
logger.info("Number of data in data1: {}".format(num_iter)) logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 11 assert num_iter == 11


def test_weighted_random_sampler_exception():
"""
Test error cases for WeightedRandomSampler
"""
logger.info("Test error cases for WeightedRandomSampler")
error_msg_1 = "type of weights element should be number"
with pytest.raises(TypeError, match=error_msg_1):
weights = ""
ds.WeightedRandomSampler(weights)

error_msg_2 = "type of weights element should be number"
with pytest.raises(TypeError, match=error_msg_2):
weights = (0.9, 0.8, 1.1)
ds.WeightedRandomSampler(weights)

error_msg_3 = "weights size should not be 0"
with pytest.raises(ValueError, match=error_msg_3):
weights = []
ds.WeightedRandomSampler(weights)

error_msg_4 = "weights should not contain negative numbers"
with pytest.raises(ValueError, match=error_msg_4):
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
ds.WeightedRandomSampler(weights)

error_msg_5 = "elements of weights should not be all zero"
with pytest.raises(ValueError, match=error_msg_5):
weights = [0, 0, 0, 0, 0]
ds.WeightedRandomSampler(weights)


def test_imagefolder_rename(): def test_imagefolder_rename():
logger.info("Test Case rename") logger.info("Test Case rename")
@@ -465,6 +495,9 @@ if __name__ == '__main__':
test_weighted_random_sampler() test_weighted_random_sampler()
logger.info('test_weighted_random_sampler Ended.\n') logger.info('test_weighted_random_sampler Ended.\n')


test_weighted_random_sampler_exception()
logger.info('test_weighted_random_sampler_exception Ended.\n')

test_imagefolder_numshards() test_imagefolder_numshards()
logger.info('test_imagefolder_numshards Ended.\n') logger.info('test_imagefolder_numshards Ended.\n')




Loading…
Cancel
Save