|
|
|
@@ -374,6 +374,28 @@ def test_multi_col_map(): |
|
|
|
assert "col-1 doesn't exist" in batch_map_config(2, 2, split_col, ["col-1"], ["col_x", "col_y"]) |
|
|
|
|
|
|
|
|
|
|
|
def test_exceptions_2(): |
|
|
|
def gen(num): |
|
|
|
for i in range(num): |
|
|
|
yield (np.array([i]),) |
|
|
|
|
|
|
|
def simple_copy(colList, batchInfo): |
|
|
|
return ([np.copy(arr) for arr in colList],) |
|
|
|
|
|
|
|
def test_wrong_col_name(gen_num, batch_size): |
|
|
|
data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=["num1"], |
|
|
|
per_batch_map=simple_copy) |
|
|
|
try: |
|
|
|
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): |
|
|
|
pass |
|
|
|
return "success" |
|
|
|
except RuntimeError as e: |
|
|
|
return str(e) |
|
|
|
|
|
|
|
# test exception where column name is incorrect |
|
|
|
assert "error. col:num1 doesn't exist" in test_wrong_col_name(4, 2) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
logger.info("Running test_var_batch_map.py test_batch_corner_cases() function") |
|
|
|
test_batch_corner_cases() |
|
|
|
@@ -398,3 +420,6 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
logger.info("Running test_var_batch_map.py test_multi_col_map() function") |
|
|
|
test_multi_col_map() |
|
|
|
|
|
|
|
logger.info("Running test_var_batch_map.py test_exceptions_2() function") |
|
|
|
test_exceptions_2() |