| @@ -17,7 +17,7 @@ d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] | |||
| d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100}) | |||
| def test_pad_val(tensor, val=0): | |||
| def _test_pad_val(tensor, val=0): | |||
| if isinstance(tensor, torch.Tensor): | |||
| tensor = tensor.tolist() | |||
| for item in tensor: | |||
| @@ -45,7 +45,7 @@ class TestMixDataLoader: | |||
| if idx > 1: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| # collate_fn = Callable | |||
| def collate_batch(batch): | |||
| @@ -74,13 +74,13 @@ class TestMixDataLoader: | |||
| dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True) | |||
| for idx, batch in enumerate(dl2): | |||
| if idx == 0: | |||
| assert test_pad_val(batch['x'], val=-1) | |||
| assert _test_pad_val(batch['x'], val=-1) | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| if idx == 1: | |||
| assert test_pad_val(batch['x'], val=-2) | |||
| assert _test_pad_val(batch['x'], val=-2) | |||
| assert batch['x'].shape == torch.Size([16, 3]) | |||
| if idx > 1: | |||
| assert test_pad_val(batch['x'], val=-3) | |||
| assert _test_pad_val(batch['x'], val=-3) | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| # sampler 为 str | |||
| @@ -101,7 +101,7 @@ class TestMixDataLoader: | |||
| if idx > 1: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| for idx, batch in enumerate(dl4): | |||
| if idx == 0: | |||
| @@ -118,7 +118,7 @@ class TestMixDataLoader: | |||
| if idx > 1: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| # sampler 为 Dict | |||
| samplers = {'d1': SequentialSampler(d1), | |||
| @@ -137,7 +137,7 @@ class TestMixDataLoader: | |||
| if idx > 1: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| # ds_ratio 为 'truncate_to_least' | |||
| dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True) | |||
| @@ -154,7 +154,7 @@ class TestMixDataLoader: | |||
| # d3 | |||
| assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| if idx > 2: | |||
| raise ValueError(f"ds_ratio: 'truncate_to_least' error") | |||
| @@ -170,7 +170,7 @@ class TestMixDataLoader: | |||
| if 36 <= idx < 54: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| if idx >= 54: | |||
| raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||
| @@ -187,7 +187,7 @@ class TestMixDataLoader: | |||
| if 4 <= idx < 41: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| if idx >= 41: | |||
| raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||
| @@ -201,7 +201,7 @@ class TestMixDataLoader: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| if idx >= 19: | |||
| raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||
| @@ -209,7 +209,7 @@ class TestMixDataLoader: | |||
| datasets = {'d1': d1, 'd2': d2, 'd3': d3} | |||
| dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True) | |||
| for idx, batch in enumerate(dl): | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| if idx >= 22: | |||
| raise ValueError(f"out of range") | |||
| @@ -224,7 +224,7 @@ class TestMixDataLoader: | |||
| dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True) | |||
| for idx, batch in enumerate(dl1): | |||
| assert isinstance(batch['x'], list) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| if idx >= 22: | |||
| raise ValueError(f"out of range") | |||
| @@ -237,12 +237,12 @@ class TestMixDataLoader: | |||
| # sampler 为 str | |||
| dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True) | |||
| for idx, batch in enumerate(dl3): | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| if idx >= 22: | |||
| raise ValueError(f"out of range") | |||
| dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True) | |||
| for idx, batch in enumerate(dl4): | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| if idx >= 22: | |||
| raise ValueError(f"out of range") | |||
| # sampler 为 Dict | |||
| @@ -251,7 +251,7 @@ class TestMixDataLoader: | |||
| 'd3': RandomSampler(d3)} | |||
| dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True) | |||
| for idx, batch in enumerate(dl5): | |||
| assert test_pad_val(batch['x'], val=0) | |||
| assert _test_pad_val(batch['x'], val=0) | |||
| if idx >= 22: | |||
| raise ValueError(f"out of range") | |||
| # ds_ratio 为 'truncate_to_least' | |||
| @@ -333,7 +333,7 @@ class TestMixDataLoader: | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 20: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| _test_pad_val(batch['x'], val=0) | |||
| # collate_fn = Callable | |||
| def collate_batch(batch): | |||
| @@ -361,16 +361,16 @@ class TestMixDataLoader: | |||
| dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18) | |||
| for idx, batch in enumerate(dl1): | |||
| if idx == 0 or idx == 3: | |||
| assert test_pad_val(batch['x'], val=-1) | |||
| assert _test_pad_val(batch['x'], val=-1) | |||
| assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]] | |||
| assert batch['x'].shape[1] == 4 | |||
| elif idx == 1 or idx == 4: | |||
| # d2 | |||
| assert test_pad_val(batch['x'], val=-2) | |||
| assert _test_pad_val(batch['x'], val=-2) | |||
| assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]] | |||
| assert batch['x'].shape[1] == 3 | |||
| elif idx == 2 or 4 < idx <= 20: | |||
| assert test_pad_val(batch['x'], val=-3) | |||
| assert _test_pad_val(batch['x'], val=-3) | |||
| assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]] | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 20: | |||
| @@ -392,7 +392,7 @@ class TestMixDataLoader: | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 20: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| _test_pad_val(batch['x'], val=0) | |||
| for idx, batch in enumerate(dl3): | |||
| if idx == 0 or idx == 3: | |||
| assert batch['x'].shape[1] == 4 | |||
| @@ -403,7 +403,7 @@ class TestMixDataLoader: | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 20: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| _test_pad_val(batch['x'], val=0) | |||
| # sampler 为 Dict | |||
| samplers = {'d1': SequentialSampler(d1), | |||
| 'd2': SequentialSampler(d2), | |||
| @@ -421,7 +421,7 @@ class TestMixDataLoader: | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 20: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| _test_pad_val(batch['x'], val=0) | |||
| # ds_ratio 为 'truncate_to_least' | |||
| dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18) | |||
| @@ -438,7 +438,7 @@ class TestMixDataLoader: | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 5: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| _test_pad_val(batch['x'], val=0) | |||
| # ds_ratio 为 'pad_to_most' | |||
| dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18) | |||
| @@ -457,7 +457,7 @@ class TestMixDataLoader: | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx >= 51: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| _test_pad_val(batch['x'], val=0) | |||
| # ds_ratio 为 Dict[str, float] | |||
| ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} | |||
| @@ -475,7 +475,7 @@ class TestMixDataLoader: | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 39: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| _test_pad_val(batch['x'], val=0) | |||
| ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} | |||
| dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) | |||
| @@ -493,4 +493,4 @@ class TestMixDataLoader: | |||
| if idx > 18: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| _test_pad_val(batch['x'], val=0) | |||