| @@ -98,6 +98,25 @@ def test_shuffle_04(): | |||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | ||||
| def test_shuffle_05(): | |||||
| """ | |||||
| Test shuffle: buffer_size > number-of-rows-in-dataset | |||||
| """ | |||||
| logger.info("test_shuffle_05") | |||||
| # define parameters | |||||
| buffer_size = 13 | |||||
| seed = 1 | |||||
| parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} | |||||
| # apply dataset operations | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | |||||
| ds.config.set_seed(seed) | |||||
| data1 = data1.shuffle(buffer_size=buffer_size) | |||||
| filename = "shuffle_05_result.npz" | |||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_shuffle_exception_01(): | def test_shuffle_exception_01(): | ||||
| """ | """ | ||||
| Test shuffle exception: buffer_size<0 | Test shuffle exception: buffer_size<0 | ||||
| @@ -152,24 +171,6 @@ def test_shuffle_exception_03(): | |||||
| assert "buffer_size" in str(e) | assert "buffer_size" in str(e) | ||||
| def test_shuffle_exception_04(): | |||||
| """ | |||||
| Test shuffle exception: buffer_size > number-of-rows-in-dataset | |||||
| """ | |||||
| logger.info("test_shuffle_exception_04") | |||||
| # apply dataset operations | |||||
| data1 = ds.TFRecordDataset(DATA_DIR) | |||||
| ds.config.set_seed(1) | |||||
| try: | |||||
| data1 = data1.shuffle(buffer_size=13) | |||||
| sum([1 for _ in data1]) | |||||
| except BaseException as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| assert "buffer_size" in str(e) | |||||
| def test_shuffle_exception_05(): | def test_shuffle_exception_05(): | ||||
| """ | """ | ||||
| Test shuffle exception: Missing mandatory buffer_size input parameter | Test shuffle exception: Missing mandatory buffer_size input parameter | ||||
| @@ -229,10 +230,10 @@ if __name__ == '__main__': | |||||
| test_shuffle_02() | test_shuffle_02() | ||||
| test_shuffle_03() | test_shuffle_03() | ||||
| test_shuffle_04() | test_shuffle_04() | ||||
| test_shuffle_05() | |||||
| test_shuffle_exception_01() | test_shuffle_exception_01() | ||||
| test_shuffle_exception_02() | test_shuffle_exception_02() | ||||
| test_shuffle_exception_03() | test_shuffle_exception_03() | ||||
| test_shuffle_exception_04() | |||||
| test_shuffle_exception_05() | test_shuffle_exception_05() | ||||
| test_shuffle_exception_06() | test_shuffle_exception_06() | ||||
| test_shuffle_exception_07() | test_shuffle_exception_07() | ||||