|
|
|
@@ -147,6 +147,21 @@ def test_concatenate_op_wrong_axis(): |
|
|
|
assert "only 1D concatenation supported." in repr(error_info.value) |
|
|
|
|
|
|
|
|
|
|
|
def test_concatenate_op_negative_axis(): |
|
|
|
def gen(): |
|
|
|
yield (np.array([5., 6., 7., 8.], dtype=np.float),) |
|
|
|
|
|
|
|
prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float) |
|
|
|
append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float) |
|
|
|
data = ds.GeneratorDataset(gen, column_names=["col"]) |
|
|
|
concatenate_op = data_trans.Concatenate(-1, prepend_tensor, append_tensor) |
|
|
|
data = data.map(input_columns=["col"], operations=concatenate_op) |
|
|
|
expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3, |
|
|
|
11., 12.]) |
|
|
|
for data_row in data: |
|
|
|
np.testing.assert_array_equal(data_row[0], expected) |
|
|
|
|
|
|
|
|
|
|
|
def test_concatenate_op_incorrect_input_dim(): |
|
|
|
def gen(): |
|
|
|
yield (np.array(["ss", "ad"], dtype='S'),) |
|
|
|
@@ -166,10 +181,11 @@ if __name__ == "__main__": |
|
|
|
test_concatenate_op_all() |
|
|
|
test_concatenate_op_none() |
|
|
|
test_concatenate_op_string() |
|
|
|
test_concatenate_op_multi_input_string() |
|
|
|
test_concatenate_op_multi_input_numeric() |
|
|
|
test_concatenate_op_type_mismatch() |
|
|
|
test_concatenate_op_type_mismatch2() |
|
|
|
test_concatenate_op_incorrect_dim() |
|
|
|
test_concatenate_op_incorrect_input_dim() |
|
|
|
test_concatenate_op_multi_input_numeric() |
|
|
|
test_concatenate_op_multi_input_string() |
|
|
|
test_concatenate_op_negative_axis() |
|
|
|
test_concatenate_op_wrong_axis() |
|
|
|
test_concatenate_op_incorrect_input_dim() |