|
|
@@ -12,6 +12,7 @@ |
|
|
# See the License for the specific language governing permissions and |
|
|
# See the License for the specific language governing permissions and |
|
|
# limitations under the License. |
|
|
# limitations under the License. |
|
|
# ============================================================================== |
|
|
# ============================================================================== |
|
|
|
|
|
import pytest |
|
|
import mindspore.dataset as ds |
|
|
import mindspore.dataset as ds |
|
|
from mindspore import log as logger |
|
|
from mindspore import log as logger |
|
|
|
|
|
|
|
|
@@ -382,6 +383,35 @@ def test_weighted_random_sampler(): |
|
|
logger.info("Number of data in data1: {}".format(num_iter)) |
|
|
logger.info("Number of data in data1: {}".format(num_iter)) |
|
|
assert num_iter == 11 |
|
|
assert num_iter == 11 |
|
|
|
|
|
|
|
|
|
|
|
def test_weighted_random_sampler_exception(): |
|
|
|
|
|
""" |
|
|
|
|
|
Test error cases for WeightedRandomSampler |
|
|
|
|
|
""" |
|
|
|
|
|
logger.info("Test error cases for WeightedRandomSampler") |
|
|
|
|
|
error_msg_1 = "type of weights element should be number" |
|
|
|
|
|
with pytest.raises(TypeError, match=error_msg_1): |
|
|
|
|
|
weights = "" |
|
|
|
|
|
ds.WeightedRandomSampler(weights) |
|
|
|
|
|
|
|
|
|
|
|
error_msg_2 = "type of weights element should be number" |
|
|
|
|
|
with pytest.raises(TypeError, match=error_msg_2): |
|
|
|
|
|
weights = (0.9, 0.8, 1.1) |
|
|
|
|
|
ds.WeightedRandomSampler(weights) |
|
|
|
|
|
|
|
|
|
|
|
error_msg_3 = "weights size should not be 0" |
|
|
|
|
|
with pytest.raises(ValueError, match=error_msg_3): |
|
|
|
|
|
weights = [] |
|
|
|
|
|
ds.WeightedRandomSampler(weights) |
|
|
|
|
|
|
|
|
|
|
|
error_msg_4 = "weights should not contain negative numbers" |
|
|
|
|
|
with pytest.raises(ValueError, match=error_msg_4): |
|
|
|
|
|
weights = [1.0, 0.1, 0.02, 0.3, -0.4] |
|
|
|
|
|
ds.WeightedRandomSampler(weights) |
|
|
|
|
|
|
|
|
|
|
|
error_msg_5 = "elements of weights should not be all zero" |
|
|
|
|
|
with pytest.raises(ValueError, match=error_msg_5): |
|
|
|
|
|
weights = [0, 0, 0, 0, 0] |
|
|
|
|
|
ds.WeightedRandomSampler(weights) |
|
|
|
|
|
|
|
|
def test_imagefolder_rename(): |
|
|
def test_imagefolder_rename(): |
|
|
logger.info("Test Case rename") |
|
|
logger.info("Test Case rename") |
|
|
@@ -465,6 +495,9 @@ if __name__ == '__main__': |
|
|
test_weighted_random_sampler() |
|
|
test_weighted_random_sampler() |
|
|
logger.info('test_weighted_random_sampler Ended.\n') |
|
|
logger.info('test_weighted_random_sampler Ended.\n') |
|
|
|
|
|
|
|
|
|
|
|
test_weighted_random_sampler_exception() |
|
|
|
|
|
logger.info('test_weighted_random_sampler_exception Ended.\n') |
|
|
|
|
|
|
|
|
test_imagefolder_numshards() |
|
|
test_imagefolder_numshards() |
|
|
logger.info('test_imagefolder_numshards Ended.\n') |
|
|
logger.info('test_imagefolder_numshards Ended.\n') |
|
|
|
|
|
|
|
|
|