|
|
|
@@ -501,15 +501,15 @@ def check_batch(method): |
|
|
|
for k, v in param_dict.get('pad_info').items(): |
|
|
|
check_pad_info(k, v) |
|
|
|
|
|
|
|
if (per_batch_map is None) != (input_columns is None): |
|
|
|
# These two parameters appear together. |
|
|
|
raise ValueError("per_batch_map and input_columns need to be passed in together.") |
|
|
|
|
|
|
|
if input_columns is not None: |
|
|
|
check_columns(input_columns, "input_columns") |
|
|
|
if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1): |
|
|
|
raise ValueError("the signature of per_batch_map should match with input columns") |
|
|
|
|
|
|
|
if (per_batch_map is None) != (input_columns is None): |
|
|
|
# These two parameters appear together. |
|
|
|
raise ValueError("per_batch_map and input_columns need to be passed in together.") |
|
|
|
|
|
|
|
if output_columns is not None: |
|
|
|
raise ValueError("output_columns is currently not implemented.") |
|
|
|
|
|
|
|
|