| @@ -1,8 +0,0 @@ | |||||
| { | |||||
| "deviceNum":4, | |||||
| "deviceId": 2, | |||||
| "shardConfig":"ALL", | |||||
| "shuffle":"ON", | |||||
| "seed": 0, | |||||
| "epoch": 2 | |||||
| } | |||||
| @@ -1,8 +0,0 @@ | |||||
| { | |||||
| "deviceNum":4, | |||||
| "deviceId": 2, | |||||
| "shardConfig":"RANDOM", | |||||
| "shuffle":"ON", | |||||
| "seed": 0, | |||||
| "epoch": 1 | |||||
| } | |||||
| @@ -1,8 +0,0 @@ | |||||
| { | |||||
| "deviceNum":4, | |||||
| "deviceId": 2, | |||||
| "shardConfig":"UNIQUE", | |||||
| "shuffle":"ON", | |||||
| "seed": 0, | |||||
| "epoch": 3 | |||||
| } | |||||
| @@ -1,7 +0,0 @@ | |||||
| { | |||||
| "deviceNum":1, | |||||
| "deviceId": 0, | |||||
| "shardConfig":"RANDOM", | |||||
| "shuffle":"OFF", | |||||
| "seed": 0 | |||||
| } | |||||
| @@ -12,15 +12,12 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| from util import save_and_check | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import save_and_check_dict | |||||
| DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] | DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] | ||||
| SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json" | SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json" | ||||
| COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", | |||||
| "col_sint16", "col_sint32", "col_sint64"] | |||||
| GENERATE_GOLDEN = False | GENERATE_GOLDEN = False | ||||
| @@ -33,9 +30,6 @@ def test_2ops_repeat_shuffle(): | |||||
| repeat_count = 2 | repeat_count = 2 | ||||
| buffer_size = 5 | buffer_size = 5 | ||||
| seed = 0 | seed = 0 | ||||
| parameters = {"params": {'repeat_count': repeat_count, | |||||
| 'buffer_size': buffer_size, | |||||
| 'seed': seed}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | ||||
| @@ -44,7 +38,7 @@ def test_2ops_repeat_shuffle(): | |||||
| data1 = data1.shuffle(buffer_size=buffer_size) | data1 = data1.shuffle(buffer_size=buffer_size) | ||||
| filename = "test_2ops_repeat_shuffle.npz" | filename = "test_2ops_repeat_shuffle.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_2ops_shuffle_repeat(): | def test_2ops_shuffle_repeat(): | ||||
| @@ -56,10 +50,6 @@ def test_2ops_shuffle_repeat(): | |||||
| repeat_count = 2 | repeat_count = 2 | ||||
| buffer_size = 5 | buffer_size = 5 | ||||
| seed = 0 | seed = 0 | ||||
| parameters = {"params": {'repeat_count': repeat_count, | |||||
| 'buffer_size': buffer_size, | |||||
| 'reshuffle_each_iteration': False, | |||||
| 'seed': seed}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | ||||
| @@ -68,7 +58,7 @@ def test_2ops_shuffle_repeat(): | |||||
| data1 = data1.repeat(repeat_count) | data1 = data1.repeat(repeat_count) | ||||
| filename = "test_2ops_shuffle_repeat.npz" | filename = "test_2ops_shuffle_repeat.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_2ops_repeat_batch(): | def test_2ops_repeat_batch(): | ||||
| @@ -79,8 +69,6 @@ def test_2ops_repeat_batch(): | |||||
| # define parameters | # define parameters | ||||
| repeat_count = 2 | repeat_count = 2 | ||||
| batch_size = 5 | batch_size = 5 | ||||
| parameters = {"params": {'repeat_count': repeat_count, | |||||
| 'batch_size': batch_size}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | ||||
| @@ -88,7 +76,7 @@ def test_2ops_repeat_batch(): | |||||
| data1 = data1.batch(batch_size, drop_remainder=True) | data1 = data1.batch(batch_size, drop_remainder=True) | ||||
| filename = "test_2ops_repeat_batch.npz" | filename = "test_2ops_repeat_batch.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_2ops_batch_repeat(): | def test_2ops_batch_repeat(): | ||||
| @@ -99,8 +87,6 @@ def test_2ops_batch_repeat(): | |||||
| # define parameters | # define parameters | ||||
| repeat_count = 2 | repeat_count = 2 | ||||
| batch_size = 5 | batch_size = 5 | ||||
| parameters = {"params": {'repeat_count': repeat_count, | |||||
| 'batch_size': batch_size}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | ||||
| @@ -108,7 +94,7 @@ def test_2ops_batch_repeat(): | |||||
| data1 = data1.repeat(repeat_count) | data1 = data1.repeat(repeat_count) | ||||
| filename = "test_2ops_batch_repeat.npz" | filename = "test_2ops_batch_repeat.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_2ops_batch_shuffle(): | def test_2ops_batch_shuffle(): | ||||
| @@ -120,9 +106,6 @@ def test_2ops_batch_shuffle(): | |||||
| buffer_size = 5 | buffer_size = 5 | ||||
| seed = 0 | seed = 0 | ||||
| batch_size = 2 | batch_size = 2 | ||||
| parameters = {"params": {'buffer_size': buffer_size, | |||||
| 'seed': seed, | |||||
| 'batch_size': batch_size}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | ||||
| @@ -131,7 +114,7 @@ def test_2ops_batch_shuffle(): | |||||
| data1 = data1.shuffle(buffer_size=buffer_size) | data1 = data1.shuffle(buffer_size=buffer_size) | ||||
| filename = "test_2ops_batch_shuffle.npz" | filename = "test_2ops_batch_shuffle.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_2ops_shuffle_batch(): | def test_2ops_shuffle_batch(): | ||||
| @@ -143,9 +126,6 @@ def test_2ops_shuffle_batch(): | |||||
| buffer_size = 5 | buffer_size = 5 | ||||
| seed = 0 | seed = 0 | ||||
| batch_size = 2 | batch_size = 2 | ||||
| parameters = {"params": {'buffer_size': buffer_size, | |||||
| 'seed': seed, | |||||
| 'batch_size': batch_size}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | ||||
| @@ -154,7 +134,7 @@ def test_2ops_shuffle_batch(): | |||||
| data1 = data1.batch(batch_size, drop_remainder=True) | data1 = data1.batch(batch_size, drop_remainder=True) | ||||
| filename = "test_2ops_shuffle_batch.npz" | filename = "test_2ops_shuffle_batch.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import save_and_check | |||||
| from util import save_and_check_dict | |||||
| # Note: Number of rows in test.data dataset: 12 | # Note: Number of rows in test.data dataset: 12 | ||||
| DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] | DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] | ||||
| @@ -29,8 +29,6 @@ def test_batch_01(): | |||||
| # define parameters | # define parameters | ||||
| batch_size = 2 | batch_size = 2 | ||||
| drop_remainder = True | drop_remainder = True | ||||
| parameters = {"params": {'batch_size': batch_size, | |||||
| 'drop_remainder': drop_remainder}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -38,7 +36,7 @@ def test_batch_01(): | |||||
| assert sum([1 for _ in data1]) == 6 | assert sum([1 for _ in data1]) == 6 | ||||
| filename = "batch_01_result.npz" | filename = "batch_01_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_02(): | def test_batch_02(): | ||||
| @@ -49,8 +47,6 @@ def test_batch_02(): | |||||
| # define parameters | # define parameters | ||||
| batch_size = 5 | batch_size = 5 | ||||
| drop_remainder = True | drop_remainder = True | ||||
| parameters = {"params": {'batch_size': batch_size, | |||||
| 'drop_remainder': drop_remainder}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -58,7 +54,7 @@ def test_batch_02(): | |||||
| assert sum([1 for _ in data1]) == 2 | assert sum([1 for _ in data1]) == 2 | ||||
| filename = "batch_02_result.npz" | filename = "batch_02_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_03(): | def test_batch_03(): | ||||
| @@ -69,8 +65,6 @@ def test_batch_03(): | |||||
| # define parameters | # define parameters | ||||
| batch_size = 3 | batch_size = 3 | ||||
| drop_remainder = False | drop_remainder = False | ||||
| parameters = {"params": {'batch_size': batch_size, | |||||
| 'drop_remainder': drop_remainder}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -78,7 +72,7 @@ def test_batch_03(): | |||||
| assert sum([1 for _ in data1]) == 4 | assert sum([1 for _ in data1]) == 4 | ||||
| filename = "batch_03_result.npz" | filename = "batch_03_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_04(): | def test_batch_04(): | ||||
| @@ -89,8 +83,6 @@ def test_batch_04(): | |||||
| # define parameters | # define parameters | ||||
| batch_size = 7 | batch_size = 7 | ||||
| drop_remainder = False | drop_remainder = False | ||||
| parameters = {"params": {'batch_size': batch_size, | |||||
| 'drop_remainder': drop_remainder}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -98,7 +90,7 @@ def test_batch_04(): | |||||
| assert sum([1 for _ in data1]) == 2 | assert sum([1 for _ in data1]) == 2 | ||||
| filename = "batch_04_result.npz" | filename = "batch_04_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_05(): | def test_batch_05(): | ||||
| @@ -108,7 +100,6 @@ def test_batch_05(): | |||||
| logger.info("test_batch_05") | logger.info("test_batch_05") | ||||
| # define parameters | # define parameters | ||||
| batch_size = 1 | batch_size = 1 | ||||
| parameters = {"params": {'batch_size': batch_size}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -116,7 +107,7 @@ def test_batch_05(): | |||||
| assert sum([1 for _ in data1]) == 12 | assert sum([1 for _ in data1]) == 12 | ||||
| filename = "batch_05_result.npz" | filename = "batch_05_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_06(): | def test_batch_06(): | ||||
| @@ -127,8 +118,6 @@ def test_batch_06(): | |||||
| # define parameters | # define parameters | ||||
| batch_size = 12 | batch_size = 12 | ||||
| drop_remainder = False | drop_remainder = False | ||||
| parameters = {"params": {'batch_size': batch_size, | |||||
| 'drop_remainder': drop_remainder}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -136,7 +125,7 @@ def test_batch_06(): | |||||
| assert sum([1 for _ in data1]) == 1 | assert sum([1 for _ in data1]) == 1 | ||||
| filename = "batch_06_result.npz" | filename = "batch_06_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_07(): | def test_batch_07(): | ||||
| @@ -148,9 +137,6 @@ def test_batch_07(): | |||||
| batch_size = 4 | batch_size = 4 | ||||
| drop_remainder = False | drop_remainder = False | ||||
| num_parallel_workers = 2 | num_parallel_workers = 2 | ||||
| parameters = {"params": {'batch_size': batch_size, | |||||
| 'drop_remainder': drop_remainder, | |||||
| 'num_parallel_workers': num_parallel_workers}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -159,7 +145,7 @@ def test_batch_07(): | |||||
| assert sum([1 for _ in data1]) == 3 | assert sum([1 for _ in data1]) == 3 | ||||
| filename = "batch_07_result.npz" | filename = "batch_07_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_08(): | def test_batch_08(): | ||||
| @@ -170,8 +156,6 @@ def test_batch_08(): | |||||
| # define parameters | # define parameters | ||||
| batch_size = 6 | batch_size = 6 | ||||
| num_parallel_workers = 1 | num_parallel_workers = 1 | ||||
| parameters = {"params": {'batch_size': batch_size, | |||||
| 'num_parallel_workers': num_parallel_workers}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -179,7 +163,7 @@ def test_batch_08(): | |||||
| assert sum([1 for _ in data1]) == 2 | assert sum([1 for _ in data1]) == 2 | ||||
| filename = "batch_08_result.npz" | filename = "batch_08_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_09(): | def test_batch_09(): | ||||
| @@ -190,8 +174,6 @@ def test_batch_09(): | |||||
| # define parameters | # define parameters | ||||
| batch_size = 13 | batch_size = 13 | ||||
| drop_remainder = False | drop_remainder = False | ||||
| parameters = {"params": {'batch_size': batch_size, | |||||
| 'drop_remainder': drop_remainder}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -199,7 +181,7 @@ def test_batch_09(): | |||||
| assert sum([1 for _ in data1]) == 1 | assert sum([1 for _ in data1]) == 1 | ||||
| filename = "batch_09_result.npz" | filename = "batch_09_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_10(): | def test_batch_10(): | ||||
| @@ -210,8 +192,6 @@ def test_batch_10(): | |||||
| # define parameters | # define parameters | ||||
| batch_size = 99 | batch_size = 99 | ||||
| drop_remainder = True | drop_remainder = True | ||||
| parameters = {"params": {'batch_size': batch_size, | |||||
| 'drop_remainder': drop_remainder}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -219,7 +199,7 @@ def test_batch_10(): | |||||
| assert sum([1 for _ in data1]) == 0 | assert sum([1 for _ in data1]) == 0 | ||||
| filename = "batch_10_result.npz" | filename = "batch_10_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_11(): | def test_batch_11(): | ||||
| @@ -229,7 +209,6 @@ def test_batch_11(): | |||||
| logger.info("test_batch_11") | logger.info("test_batch_11") | ||||
| # define parameters | # define parameters | ||||
| batch_size = 1 | batch_size = 1 | ||||
| parameters = {"params": {'batch_size': batch_size}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| # Use schema file with 1 row | # Use schema file with 1 row | ||||
| @@ -239,7 +218,7 @@ def test_batch_11(): | |||||
| assert sum([1 for _ in data1]) == 1 | assert sum([1 for _ in data1]) == 1 | ||||
| filename = "batch_11_result.npz" | filename = "batch_11_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_12(): | def test_batch_12(): | ||||
| @@ -249,7 +228,6 @@ def test_batch_12(): | |||||
| logger.info("test_batch_12") | logger.info("test_batch_12") | ||||
| # define parameters | # define parameters | ||||
| batch_size = True | batch_size = True | ||||
| parameters = {"params": {'batch_size': batch_size}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -257,7 +235,7 @@ def test_batch_12(): | |||||
| assert sum([1 for _ in data1]) == 12 | assert sum([1 for _ in data1]) == 12 | ||||
| filename = "batch_12_result.npz" | filename = "batch_12_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_batch_exception_01(): | def test_batch_exception_01(): | ||||
| @@ -356,9 +356,13 @@ def test_clue_to_device(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_clue() | test_clue() | ||||
| test_clue_num_shards() | |||||
| test_clue_num_samples() | |||||
| test_textline_dataset_get_datasetsize() | |||||
| test_clue_afqmc() | test_clue_afqmc() | ||||
| test_clue_cmnli() | test_clue_cmnli() | ||||
| test_clue_csl() | test_clue_csl() | ||||
| test_clue_iflytek() | test_clue_iflytek() | ||||
| test_clue_tnews() | test_clue_tnews() | ||||
| test_clue_wsc() | test_clue_wsc() | ||||
| test_clue_to_device() | |||||
| @@ -26,7 +26,7 @@ def generator_1d(): | |||||
| yield (np.array([i]),) | yield (np.array([i]),) | ||||
| def test_case_0(): | |||||
| def test_generator_0(): | |||||
| """ | """ | ||||
| Test 1D Generator | Test 1D Generator | ||||
| """ | """ | ||||
| @@ -48,7 +48,7 @@ def generator_md(): | |||||
| yield (np.array([[i, i + 1], [i + 2, i + 3]]),) | yield (np.array([[i, i + 1], [i + 2, i + 3]]),) | ||||
| def test_case_1(): | |||||
| def test_generator_1(): | |||||
| """ | """ | ||||
| Test MD Generator | Test MD Generator | ||||
| """ | """ | ||||
| @@ -70,7 +70,7 @@ def generator_mc(maxid=64): | |||||
| yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) | yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) | ||||
| def test_case_2(): | |||||
| def test_generator_2(): | |||||
| """ | """ | ||||
| Test multi column generator | Test multi column generator | ||||
| """ | """ | ||||
| @@ -88,7 +88,7 @@ def test_case_2(): | |||||
| i = i + 1 | i = i + 1 | ||||
| def test_case_3(): | |||||
| def test_generator_3(): | |||||
| """ | """ | ||||
| Test 1D Generator + repeat(4) | Test 1D Generator + repeat(4) | ||||
| """ | """ | ||||
| @@ -108,7 +108,7 @@ def test_case_3(): | |||||
| i = 0 | i = 0 | ||||
| def test_case_4(): | |||||
| def test_generator_4(): | |||||
| """ | """ | ||||
| Test fixed size 1D Generator + batch | Test fixed size 1D Generator + batch | ||||
| """ | """ | ||||
| @@ -146,7 +146,7 @@ def type_tester(t): | |||||
| i = i + 4 | i = i + 4 | ||||
| def test_case_5(): | |||||
| def test_generator_5(): | |||||
| """ | """ | ||||
| Test 1D Generator on different data type | Test 1D Generator on different data type | ||||
| """ | """ | ||||
| @@ -173,7 +173,7 @@ def type_tester_with_type_check(t, c): | |||||
| i = i + 4 | i = i + 4 | ||||
| def test_case_6(): | |||||
| def test_generator_6(): | |||||
| """ | """ | ||||
| Test 1D Generator on different data type with type check | Test 1D Generator on different data type with type check | ||||
| """ | """ | ||||
| @@ -208,7 +208,7 @@ def type_tester_with_type_check_2c(t, c): | |||||
| i = i + 4 | i = i + 4 | ||||
| def test_case_7(): | |||||
| def test_generator_7(): | |||||
| """ | """ | ||||
| Test 2 column Generator on different data type with type check | Test 2 column Generator on different data type with type check | ||||
| """ | """ | ||||
| @@ -223,7 +223,7 @@ def test_case_7(): | |||||
| type_tester_with_type_check_2c(np_types[i], [None, de_types[i]]) | type_tester_with_type_check_2c(np_types[i], [None, de_types[i]]) | ||||
| def test_case_8(): | |||||
| def test_generator_8(): | |||||
| """ | """ | ||||
| Test multi column generator with few mapops | Test multi column generator with few mapops | ||||
| """ | """ | ||||
| @@ -249,7 +249,7 @@ def test_case_8(): | |||||
| i = i + 1 | i = i + 1 | ||||
| def test_case_9(): | |||||
| def test_generator_9(): | |||||
| """ | """ | ||||
| Test map column order when len(input_columns) == len(output_columns). | Test map column order when len(input_columns) == len(output_columns). | ||||
| """ | """ | ||||
| @@ -280,7 +280,7 @@ def test_case_9(): | |||||
| i = i + 1 | i = i + 1 | ||||
| def test_case_10(): | |||||
| def test_generator_10(): | |||||
| """ | """ | ||||
| Test map column order when len(input_columns) != len(output_columns). | Test map column order when len(input_columns) != len(output_columns). | ||||
| """ | """ | ||||
| @@ -303,7 +303,7 @@ def test_case_10(): | |||||
| i = i + 1 | i = i + 1 | ||||
| def test_case_11(): | |||||
| def test_generator_11(): | |||||
| """ | """ | ||||
| Test map column order when len(input_columns) != len(output_columns). | Test map column order when len(input_columns) != len(output_columns). | ||||
| """ | """ | ||||
| @@ -327,7 +327,7 @@ def test_case_11(): | |||||
| i = i + 1 | i = i + 1 | ||||
| def test_case_12(): | |||||
| def test_generator_12(): | |||||
| """ | """ | ||||
| Test map column order when input_columns and output_columns are None. | Test map column order when input_columns and output_columns are None. | ||||
| """ | """ | ||||
| @@ -361,7 +361,7 @@ def test_case_12(): | |||||
| i = i + 1 | i = i + 1 | ||||
| def test_case_13(): | |||||
| def test_generator_13(): | |||||
| """ | """ | ||||
| Test map column order when input_columns is None. | Test map column order when input_columns is None. | ||||
| """ | """ | ||||
| @@ -391,7 +391,7 @@ def test_case_13(): | |||||
| i = i + 1 | i = i + 1 | ||||
| def test_case_14(): | |||||
| def test_generator_14(): | |||||
| """ | """ | ||||
| Test 1D Generator MP + CPP sampler | Test 1D Generator MP + CPP sampler | ||||
| """ | """ | ||||
| @@ -408,7 +408,7 @@ def test_case_14(): | |||||
| i = 0 | i = 0 | ||||
| def test_case_15(): | |||||
| def test_generator_15(): | |||||
| """ | """ | ||||
| Test 1D Generator MP + Python sampler | Test 1D Generator MP + Python sampler | ||||
| """ | """ | ||||
| @@ -426,7 +426,7 @@ def test_case_15(): | |||||
| i = 0 | i = 0 | ||||
| def test_case_16(): | |||||
| def test_generator_16(): | |||||
| """ | """ | ||||
| Test multi column generator Mp + CPP sampler | Test multi column generator Mp + CPP sampler | ||||
| """ | """ | ||||
| @@ -445,7 +445,7 @@ def test_case_16(): | |||||
| i = i + 1 | i = i + 1 | ||||
| def test_case_17(): | |||||
| def test_generator_17(): | |||||
| """ | """ | ||||
| Test multi column generator Mp + Python sampler | Test multi column generator Mp + Python sampler | ||||
| """ | """ | ||||
| @@ -465,7 +465,7 @@ def test_case_17(): | |||||
| i = i + 1 | i = i + 1 | ||||
| def test_case_error_1(): | |||||
| def test_generator_error_1(): | |||||
| def generator_np(): | def generator_np(): | ||||
| for i in range(64): | for i in range(64): | ||||
| yield (np.array([{i}]),) | yield (np.array([{i}]),) | ||||
| @@ -477,7 +477,7 @@ def test_case_error_1(): | |||||
| assert "Invalid data type" in str(info.value) | assert "Invalid data type" in str(info.value) | ||||
| def test_case_error_2(): | |||||
| def test_generator_error_2(): | |||||
| def generator_np(): | def generator_np(): | ||||
| for i in range(64): | for i in range(64): | ||||
| yield ({i},) | yield ({i},) | ||||
| @@ -489,7 +489,7 @@ def test_case_error_2(): | |||||
| assert "Generator should return a tuple of numpy arrays" in str(info.value) | assert "Generator should return a tuple of numpy arrays" in str(info.value) | ||||
| def test_case_error_3(): | |||||
| def test_generator_error_3(): | |||||
| with pytest.raises(ValueError) as info: | with pytest.raises(ValueError) as info: | ||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"]) | data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"]) | ||||
| @@ -501,7 +501,7 @@ def test_case_error_3(): | |||||
| assert "When (len(input_columns) != len(output_columns)), columns_order must be specified." in str(info.value) | assert "When (len(input_columns) != len(output_columns)), columns_order must be specified." in str(info.value) | ||||
| def test_case_error_4(): | |||||
| def test_generator_error_4(): | |||||
| with pytest.raises(RuntimeError) as info: | with pytest.raises(RuntimeError) as info: | ||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"]) | data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"]) | ||||
| @@ -513,7 +513,7 @@ def test_case_error_4(): | |||||
| assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value) | assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value) | ||||
| def test_sequential_sampler(): | |||||
| def test_generator_sequential_sampler(): | |||||
| source = [(np.array([x]),) for x in range(64)] | source = [(np.array([x]),) for x in range(64)] | ||||
| ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler()) | ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler()) | ||||
| i = 0 | i = 0 | ||||
| @@ -523,14 +523,14 @@ def test_sequential_sampler(): | |||||
| i = i + 1 | i = i + 1 | ||||
| def test_random_sampler(): | |||||
| def test_generator_random_sampler(): | |||||
| source = [(np.array([x]),) for x in range(64)] | source = [(np.array([x]),) for x in range(64)] | ||||
| ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True) | ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True) | ||||
| for _ in ds1.create_dict_iterator(): # each data is a dictionary | for _ in ds1.create_dict_iterator(): # each data is a dictionary | ||||
| pass | pass | ||||
| def test_distributed_sampler(): | |||||
| def test_generator_distributed_sampler(): | |||||
| source = [(np.array([x]),) for x in range(64)] | source = [(np.array([x]),) for x in range(64)] | ||||
| for sid in range(8): | for sid in range(8): | ||||
| ds1 = ds.GeneratorDataset(source, ["data"], shuffle=False, num_shards=8, shard_id=sid) | ds1 = ds.GeneratorDataset(source, ["data"], shuffle=False, num_shards=8, shard_id=sid) | ||||
| @@ -541,7 +541,7 @@ def test_distributed_sampler(): | |||||
| i = i + 8 | i = i + 8 | ||||
| def test_num_samples(): | |||||
| def test_generator_num_samples(): | |||||
| source = [(np.array([x]),) for x in range(64)] | source = [(np.array([x]),) for x in range(64)] | ||||
| num_samples = 32 | num_samples = 32 | ||||
| ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples)) | ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples)) | ||||
| @@ -564,7 +564,7 @@ def test_num_samples(): | |||||
| assert count == num_samples | assert count == num_samples | ||||
| def test_num_samples_underflow(): | |||||
| def test_generator_num_samples_underflow(): | |||||
| source = [(np.array([x]),) for x in range(64)] | source = [(np.array([x]),) for x in range(64)] | ||||
| num_samples = 256 | num_samples = 256 | ||||
| ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples=num_samples) | ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples=num_samples) | ||||
| @@ -600,7 +600,7 @@ def type_tester_with_type_check_2c_schema(t, c): | |||||
| i = i + 4 | i = i + 4 | ||||
| def test_schema(): | |||||
| def test_generator_schema(): | |||||
| """ | """ | ||||
| Test 2 column Generator on different data type with type check with schema input | Test 2 column Generator on different data type with type check with schema input | ||||
| """ | """ | ||||
| @@ -615,9 +615,9 @@ def test_schema(): | |||||
| type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]]) | type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]]) | ||||
| def manual_test_keyborad_interrupt(): | |||||
| def manual_test_generator_keyboard_interrupt(): | |||||
| """ | """ | ||||
| Test keyborad_interrupt | |||||
| Test keyboard_interrupt | |||||
| """ | """ | ||||
| logger.info("Test 1D Generator MP : 0 - 63") | logger.info("Test 1D Generator MP : 0 - 63") | ||||
| @@ -635,31 +635,31 @@ def manual_test_keyborad_interrupt(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_case_0() | |||||
| test_case_1() | |||||
| test_case_2() | |||||
| test_case_3() | |||||
| test_case_4() | |||||
| test_case_5() | |||||
| test_case_6() | |||||
| test_case_7() | |||||
| test_case_8() | |||||
| test_case_9() | |||||
| test_case_10() | |||||
| test_case_11() | |||||
| test_case_12() | |||||
| test_case_13() | |||||
| test_case_14() | |||||
| test_case_15() | |||||
| test_case_16() | |||||
| test_case_17() | |||||
| test_case_error_1() | |||||
| test_case_error_2() | |||||
| test_case_error_3() | |||||
| test_case_error_4() | |||||
| test_sequential_sampler() | |||||
| test_distributed_sampler() | |||||
| test_random_sampler() | |||||
| test_num_samples() | |||||
| test_num_samples_underflow() | |||||
| test_schema() | |||||
| test_generator_0() | |||||
| test_generator_1() | |||||
| test_generator_2() | |||||
| test_generator_3() | |||||
| test_generator_4() | |||||
| test_generator_5() | |||||
| test_generator_6() | |||||
| test_generator_7() | |||||
| test_generator_8() | |||||
| test_generator_9() | |||||
| test_generator_10() | |||||
| test_generator_11() | |||||
| test_generator_12() | |||||
| test_generator_13() | |||||
| test_generator_14() | |||||
| test_generator_15() | |||||
| test_generator_16() | |||||
| test_generator_17() | |||||
| test_generator_error_1() | |||||
| test_generator_error_2() | |||||
| test_generator_error_3() | |||||
| test_generator_error_4() | |||||
| test_generator_sequential_sampler() | |||||
| test_generator_distributed_sampler() | |||||
| test_generator_random_sampler() | |||||
| test_generator_num_samples() | |||||
| test_generator_num_samples_underflow() | |||||
| test_generator_schema() | |||||
| @@ -33,7 +33,7 @@ def check(project_columns): | |||||
| assert all([np.array_equal(d1, d2) for d1, d2 in zip(data_actual, data_expected)]) | assert all([np.array_equal(d1, d2) for d1, d2 in zip(data_actual, data_expected)]) | ||||
| def test_case_iterator(): | |||||
| def test_iterator_create_tuple(): | |||||
| """ | """ | ||||
| Test creating tuple iterator | Test creating tuple iterator | ||||
| """ | """ | ||||
| @@ -95,7 +95,9 @@ class MyDict(dict): | |||||
| def test_tree_copy(): | def test_tree_copy(): | ||||
| # Testing copying the tree with a pyfunc that cannot be pickled | |||||
| """ | |||||
| Testing copying the tree with a pyfunc that cannot be pickled | |||||
| """ | |||||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS) | data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS) | ||||
| data1 = data.map(operations=[MyDict()]) | data1 = data.map(operations=[MyDict()]) | ||||
| @@ -110,4 +112,6 @@ def test_tree_copy(): | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_iterator_create_tuple() | |||||
| test_iterator_weak_ref() | |||||
| test_tree_copy() | test_tree_copy() | ||||
| @@ -13,10 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| import numpy as np | import numpy as np | ||||
| from util import save_and_check | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import save_and_check_dict | |||||
| # Note: Number of rows in test.data dataset: 12 | # Note: Number of rows in test.data dataset: 12 | ||||
| DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] | DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] | ||||
| @@ -31,7 +30,6 @@ def test_shuffle_01(): | |||||
| # define parameters | # define parameters | ||||
| buffer_size = 5 | buffer_size = 5 | ||||
| seed = 1 | seed = 1 | ||||
| parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -39,7 +37,7 @@ def test_shuffle_01(): | |||||
| data1 = data1.shuffle(buffer_size=buffer_size) | data1 = data1.shuffle(buffer_size=buffer_size) | ||||
| filename = "shuffle_01_result.npz" | filename = "shuffle_01_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_shuffle_02(): | def test_shuffle_02(): | ||||
| @@ -50,7 +48,6 @@ def test_shuffle_02(): | |||||
| # define parameters | # define parameters | ||||
| buffer_size = 12 | buffer_size = 12 | ||||
| seed = 1 | seed = 1 | ||||
| parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -58,7 +55,7 @@ def test_shuffle_02(): | |||||
| data1 = data1.shuffle(buffer_size=buffer_size) | data1 = data1.shuffle(buffer_size=buffer_size) | ||||
| filename = "shuffle_02_result.npz" | filename = "shuffle_02_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_shuffle_03(): | def test_shuffle_03(): | ||||
| @@ -69,7 +66,6 @@ def test_shuffle_03(): | |||||
| # define parameters | # define parameters | ||||
| buffer_size = 2 | buffer_size = 2 | ||||
| seed = 1 | seed = 1 | ||||
| parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -77,7 +73,7 @@ def test_shuffle_03(): | |||||
| data1 = data1.shuffle(buffer_size) | data1 = data1.shuffle(buffer_size) | ||||
| filename = "shuffle_03_result.npz" | filename = "shuffle_03_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_shuffle_04(): | def test_shuffle_04(): | ||||
| @@ -88,7 +84,6 @@ def test_shuffle_04(): | |||||
| # define parameters | # define parameters | ||||
| buffer_size = 2 | buffer_size = 2 | ||||
| seed = 1 | seed = 1 | ||||
| parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, num_samples=2) | data1 = ds.TFRecordDataset(DATA_DIR, num_samples=2) | ||||
| @@ -96,7 +91,7 @@ def test_shuffle_04(): | |||||
| data1 = data1.shuffle(buffer_size=buffer_size) | data1 = data1.shuffle(buffer_size=buffer_size) | ||||
| filename = "shuffle_04_result.npz" | filename = "shuffle_04_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_shuffle_05(): | def test_shuffle_05(): | ||||
| @@ -107,7 +102,6 @@ def test_shuffle_05(): | |||||
| # define parameters | # define parameters | ||||
| buffer_size = 13 | buffer_size = 13 | ||||
| seed = 1 | seed = 1 | ||||
| parameters = {"params": {'buffer_size': buffer_size, "seed": seed}} | |||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) | ||||
| @@ -115,7 +109,7 @@ def test_shuffle_05(): | |||||
| data1 = data1.shuffle(buffer_size=buffer_size) | data1 = data1.shuffle(buffer_size=buffer_size) | ||||
| filename = "shuffle_05_result.npz" | filename = "shuffle_05_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_shuffle_06(): | def test_shuffle_06(): | ||||
| @@ -24,9 +24,6 @@ import numpy as np | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| # These are the column names defined in the testTFTestAllTypes dataset | |||||
| COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", | |||||
| "col_sint16", "col_sint32", "col_sint64"] | |||||
| # These are list of plot title in different visualize modes | # These are list of plot title in different visualize modes | ||||
| PLOT_TITLE_DICT = { | PLOT_TITLE_DICT = { | ||||
| 1: ["Original image", "Transformed image"], | 1: ["Original image", "Transformed image"], | ||||
| @@ -82,39 +79,6 @@ def _save_json(filename, parameters, result_dict): | |||||
| fout.write(jsbeautifier.beautify(json.dumps(out_dict), options)) | fout.write(jsbeautifier.beautify(json.dumps(out_dict), options)) | ||||
| def save_and_check(data, parameters, filename, generate_golden=False): | |||||
| """ | |||||
| Save the dataset dictionary and compare (as numpy array) with golden file. | |||||
| Use create_dict_iterator to access the dataset. | |||||
| Note: save_and_check() is deprecated; use save_and_check_dict(). | |||||
| """ | |||||
| num_iter = 0 | |||||
| result_dict = {} | |||||
| for column_name in COLUMNS: | |||||
| result_dict[column_name] = [] | |||||
| for item in data.create_dict_iterator(): # each data is a dictionary | |||||
| for data_key in list(item.keys()): | |||||
| if data_key not in result_dict: | |||||
| result_dict[data_key] = [] | |||||
| result_dict[data_key].append(item[data_key].tolist()) | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||||
| cur_dir = os.path.dirname(os.path.realpath(__file__)) | |||||
| golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename) | |||||
| if generate_golden: | |||||
| # Save as the golden result | |||||
| _save_golden(cur_dir, golden_ref_dir, result_dict) | |||||
| _compare_to_golden(golden_ref_dir, result_dict) | |||||
| if SAVE_JSON: | |||||
| # Save result to a json file for inspection | |||||
| _save_json(filename, parameters, result_dict) | |||||
| def save_and_check_dict(data, filename, generate_golden=False): | def save_and_check_dict(data, filename, generate_golden=False): | ||||
| """ | """ | ||||
| Save the dataset dictionary and compare (as dictionary) with golden file. | Save the dataset dictionary and compare (as dictionary) with golden file. | ||||
| @@ -203,6 +167,29 @@ def save_and_check_tuple(data, parameters, filename, generate_golden=False): | |||||
| _save_json(filename, parameters, result_dict) | _save_json(filename, parameters, result_dict) | ||||
| def config_get_set_seed(seed_new): | |||||
| """ | |||||
| Get and return the original configuration seed value. | |||||
| Set the new configuration seed value. | |||||
| """ | |||||
| seed_original = ds.config.get_seed() | |||||
| ds.config.set_seed(seed_new) | |||||
| logger.info("seed: original = {} new = {} ".format(seed_original, seed_new)) | |||||
| return seed_original | |||||
| def config_get_set_num_parallel_workers(num_parallel_workers_new): | |||||
| """ | |||||
| Get and return the original configuration num_parallel_workers value. | |||||
| Set the new configuration num_parallel_workers value. | |||||
| """ | |||||
| num_parallel_workers_original = ds.config.get_num_parallel_workers() | |||||
| ds.config.set_num_parallel_workers(num_parallel_workers_new) | |||||
| logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original, | |||||
| num_parallel_workers_new)) | |||||
| return num_parallel_workers_original | |||||
| def diff_mse(in1, in2): | def diff_mse(in1, in2): | ||||
| mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean() | mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean() | ||||
| return mse * 100 | return mse * 100 | ||||
| @@ -265,29 +252,6 @@ def visualize_image(image_original, image_de, mse=None, image_lib=None): | |||||
| plt.show() | plt.show() | ||||
| def config_get_set_seed(seed_new): | |||||
| """ | |||||
| Get and return the original configuration seed value. | |||||
| Set the new configuration seed value. | |||||
| """ | |||||
| seed_original = ds.config.get_seed() | |||||
| ds.config.set_seed(seed_new) | |||||
| logger.info("seed: original = {} new = {} ".format(seed_original, seed_new)) | |||||
| return seed_original | |||||
| def config_get_set_num_parallel_workers(num_parallel_workers_new): | |||||
| """ | |||||
| Get and return the original configuration num_parallel_workers value. | |||||
| Set the new configuration num_parallel_workers value. | |||||
| """ | |||||
| num_parallel_workers_original = ds.config.get_num_parallel_workers() | |||||
| ds.config.set_num_parallel_workers(num_parallel_workers_new) | |||||
| logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original, | |||||
| num_parallel_workers_new)) | |||||
| return num_parallel_workers_original | |||||
| def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=3): | def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=3): | ||||
| """ | """ | ||||
| Take a list of un-augmented and augmented images with "annotation" bounding boxes | Take a list of un-augmented and augmented images with "annotation" bounding boxes | ||||