You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_concatenate_op.py 7.1 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing concatenate op
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.transforms.c_transforms as data_trans
  22. def test_concatenate_op_all():
  23. def gen():
  24. yield (np.array([5., 6., 7., 8.], dtype=np.float),)
  25. prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float)
  26. append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float)
  27. data = ds.GeneratorDataset(gen, column_names=["col"])
  28. concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor)
  29. data = data.map(operations=concatenate_op, input_columns=["col"])
  30. expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3,
  31. 11., 12.])
  32. for data_row in data.create_tuple_iterator(output_numpy=True):
  33. np.testing.assert_array_equal(data_row[0], expected)
  34. def test_concatenate_op_none():
  35. def gen():
  36. yield (np.array([5., 6., 7., 8.], dtype=np.float),)
  37. data = ds.GeneratorDataset(gen, column_names=["col"])
  38. concatenate_op = data_trans.Concatenate()
  39. data = data.map(operations=concatenate_op, input_columns=["col"])
  40. for data_row in data.create_tuple_iterator(output_numpy=True):
  41. np.testing.assert_array_equal(data_row[0], np.array([5., 6., 7., 8.], dtype=np.float))
  42. def test_concatenate_op_string():
  43. def gen():
  44. yield (np.array(["ss", "ad"], dtype='S'),)
  45. prepend_tensor = np.array(["dw", "df"], dtype='S')
  46. append_tensor = np.array(["dwsdf", "df"], dtype='S')
  47. data = ds.GeneratorDataset(gen, column_names=["col"])
  48. concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor)
  49. data = data.map(operations=concatenate_op, input_columns=["col"])
  50. expected = np.array(["dw", "df", "ss", "ad", "dwsdf", "df"], dtype='S')
  51. for data_row in data.create_tuple_iterator(output_numpy=True):
  52. np.testing.assert_array_equal(data_row[0], expected)
  53. def test_concatenate_op_multi_input_string():
  54. prepend_tensor = np.array(["dw", "df"], dtype='S')
  55. append_tensor = np.array(["dwsdf", "df"], dtype='S')
  56. data = ([["1", "2", "d"]], [["3", "4", "e"]])
  57. data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"])
  58. concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor, append=append_tensor)
  59. data = data.map(operations=concatenate_op, input_columns=["col1", "col2"], column_order=["out1"],
  60. output_columns=["out1"])
  61. expected = np.array(["dw", "df", "1", "2", "d", "3", "4", "e", "dwsdf", "df"], dtype='S')
  62. for data_row in data.create_tuple_iterator(output_numpy=True):
  63. np.testing.assert_array_equal(data_row[0], expected)
  64. def test_concatenate_op_multi_input_numeric():
  65. prepend_tensor = np.array([3, 5])
  66. data = ([[1, 2]], [[3, 4]])
  67. data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"])
  68. concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor)
  69. data = data.map(operations=concatenate_op, input_columns=["col1", "col2"], column_order=["out1"],
  70. output_columns=["out1"])
  71. expected = np.array([3, 5, 1, 2, 3, 4])
  72. for data_row in data.create_tuple_iterator(output_numpy=True):
  73. np.testing.assert_array_equal(data_row[0], expected)
  74. def test_concatenate_op_type_mismatch():
  75. def gen():
  76. yield (np.array([3, 4], dtype=np.float),)
  77. prepend_tensor = np.array(["ss", "ad"], dtype='S')
  78. data = ds.GeneratorDataset(gen, column_names=["col"])
  79. concatenate_op = data_trans.Concatenate(0, prepend_tensor)
  80. data = data.map(operations=concatenate_op, input_columns=["col"])
  81. with pytest.raises(RuntimeError) as error_info:
  82. for _ in data:
  83. pass
  84. assert "Tensor types do not match" in str(error_info.value)
  85. def test_concatenate_op_type_mismatch2():
  86. def gen():
  87. yield (np.array(["ss", "ad"], dtype='S'),)
  88. prepend_tensor = np.array([3, 5], dtype=np.float)
  89. data = ds.GeneratorDataset(gen, column_names=["col"])
  90. concatenate_op = data_trans.Concatenate(0, prepend_tensor)
  91. data = data.map(operations=concatenate_op, input_columns=["col"])
  92. with pytest.raises(RuntimeError) as error_info:
  93. for _ in data:
  94. pass
  95. assert "Tensor types do not match" in str(error_info.value)
  96. def test_concatenate_op_incorrect_dim():
  97. def gen():
  98. yield (np.array([["ss", "ad"], ["ss", "ad"]], dtype='S'),)
  99. prepend_tensor = np.array(["ss", "ss"], dtype='S')
  100. concatenate_op = data_trans.Concatenate(0, prepend_tensor)
  101. data = ds.GeneratorDataset(gen, column_names=["col"])
  102. data = data.map(operations=concatenate_op, input_columns=["col"])
  103. with pytest.raises(RuntimeError) as error_info:
  104. for _ in data:
  105. pass
  106. assert "Only 1D tensors supported" in str(error_info.value)
  107. def test_concatenate_op_wrong_axis():
  108. with pytest.raises(ValueError) as error_info:
  109. data_trans.Concatenate(2)
  110. assert "only 1D concatenation supported." in str(error_info.value)
  111. def test_concatenate_op_negative_axis():
  112. def gen():
  113. yield (np.array([5., 6., 7., 8.], dtype=np.float),)
  114. prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float)
  115. append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float)
  116. data = ds.GeneratorDataset(gen, column_names=["col"])
  117. concatenate_op = data_trans.Concatenate(-1, prepend_tensor, append_tensor)
  118. data = data.map(operations=concatenate_op, input_columns=["col"])
  119. expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3,
  120. 11., 12.])
  121. for data_row in data.create_tuple_iterator(output_numpy=True):
  122. np.testing.assert_array_equal(data_row[0], expected)
  123. def test_concatenate_op_incorrect_input_dim():
  124. prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S')
  125. with pytest.raises(ValueError) as error_info:
  126. data_trans.Concatenate(0, prepend_tensor)
  127. assert "can only prepend 1D arrays." in str(error_info.value)
  128. if __name__ == "__main__":
  129. test_concatenate_op_all()
  130. test_concatenate_op_none()
  131. test_concatenate_op_string()
  132. test_concatenate_op_multi_input_string()
  133. test_concatenate_op_multi_input_numeric()
  134. test_concatenate_op_type_mismatch()
  135. test_concatenate_op_type_mismatch2()
  136. test_concatenate_op_incorrect_dim()
  137. test_concatenate_op_negative_axis()
  138. test_concatenate_op_wrong_axis()
  139. test_concatenate_op_incorrect_input_dim()