You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pyfunc.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.dataset as ds
  18. from mindspore import log as logger
  19. DATA_DIR = ["../data/dataset/testPyfuncMap/data.data"]
  20. SCHEMA_DIR = "../data/dataset/testPyfuncMap/schema.json"
  21. COLUMNS = ["col0", "col1", "col2"]
  22. GENERATE_GOLDEN = False
  23. def test_case_0():
  24. """
  25. Test PyFunc
  26. """
  27. logger.info("Test 1-1 PyFunc : lambda x : x + x")
  28. # apply dataset operations
  29. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  30. data1 = data1.map(input_columns="col0", output_columns="out", operations=(lambda x: x + x))
  31. i = 0
  32. for item in data1.create_dict_iterator(): # each data is a dictionary
  33. # In this test, the dataset is 2x2 sequential tensors
  34. golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
  35. assert np.array_equal(item["out"], golden)
  36. i = i + 4
  37. def test_case_1():
  38. """
  39. Test PyFunc
  40. """
  41. logger.info("Test 1-n PyFunc : lambda x : (x , x + x) ")
  42. col = "col0"
  43. # apply dataset operations
  44. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  45. data1 = data1.map(input_columns=col, output_columns=["out0", "out1"], operations=(lambda x: (x, x + x)),
  46. columns_order=["out0", "out1"])
  47. i = 0
  48. for item in data1.create_dict_iterator(): # each data is a dictionary
  49. # In this test, the dataset is 2x2 sequential tensors
  50. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  51. assert np.array_equal(item["out0"], golden)
  52. golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
  53. assert np.array_equal(item["out1"], golden)
  54. i = i + 4
  55. def test_case_2():
  56. """
  57. Test PyFunc
  58. """
  59. logger.info("Test n-1 PyFunc : lambda x, y : x + y ")
  60. col = ["col0", "col1"]
  61. # apply dataset operations
  62. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  63. data1 = data1.map(input_columns=col, output_columns="out", operations=(lambda x, y: x + y),
  64. columns_order=["out"])
  65. i = 0
  66. for item in data1.create_dict_iterator(): # each data is a dictionary
  67. # In this test, the dataset is 2x2 sequential tensors
  68. golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
  69. assert np.array_equal(item["out"], golden)
  70. i = i + 4
  71. def test_case_3():
  72. """
  73. Test PyFunc
  74. """
  75. logger.info("Test n-m PyFunc : lambda x, y : (x , x + 1, x + y)")
  76. col = ["col0", "col1"]
  77. # apply dataset operations
  78. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  79. data1 = data1.map(input_columns=col, output_columns=["out0", "out1", "out2"],
  80. operations=(lambda x, y: (x, x + y, x + y + 1)), columns_order=["out0", "out1", "out2"])
  81. i = 0
  82. for item in data1.create_dict_iterator(): # each data is a dictionary
  83. # In this test, the dataset is 2x2 sequential tensors
  84. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  85. assert np.array_equal(item["out0"], golden)
  86. golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
  87. assert np.array_equal(item["out1"], golden)
  88. golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]])
  89. assert np.array_equal(item["out2"], golden)
  90. i = i + 4
  91. def test_case_4():
  92. """
  93. Test PyFunc
  94. """
  95. logger.info("Test Parallel n-m PyFunc : lambda x, y : (x , x + 1, x + y)")
  96. col = ["col0", "col1"]
  97. # apply dataset operations
  98. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  99. data1 = data1.map(input_columns=col, output_columns=["out0", "out1", "out2"], num_parallel_workers=4,
  100. operations=(lambda x, y: (x, x + y, x + y + 1)), columns_order=["out0", "out1", "out2"])
  101. i = 0
  102. for item in data1.create_dict_iterator(): # each data is a dictionary
  103. # In this test, the dataset is 2x2 sequential tensors
  104. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  105. assert np.array_equal(item["out0"], golden)
  106. golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
  107. assert np.array_equal(item["out1"], golden)
  108. golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]])
  109. assert np.array_equal(item["out2"], golden)
  110. i = i + 4
  111. # The execution of this function will acquire GIL
  112. def func_5(x):
  113. return np.ones(x.shape, dtype=x.dtype)
  114. def test_case_5():
  115. """
  116. Test PyFunc
  117. """
  118. logger.info("Test 1-1 PyFunc : lambda x: np.ones(x.shape)")
  119. # apply dataset operations
  120. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  121. data1 = data1.map(input_columns="col0", output_columns="out", operations=func_5)
  122. for item in data1.create_dict_iterator(): # each data is a dictionary
  123. # In this test, the dataset is 2x2 sequential tensors
  124. golden = np.array([[1, 1], [1, 1]])
  125. assert np.array_equal(item["out"], golden)
  126. def test_case_6():
  127. """
  128. Test PyFunc
  129. """
  130. logger.info("Test PyFunc ComposeOp : (lambda x : x + x), (lambda x : x + x)")
  131. # apply dataset operations
  132. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  133. data1 = data1.map(input_columns="col0", output_columns="out",
  134. operations=[(lambda x: x + x), (lambda x: x + x)])
  135. i = 0
  136. for item in data1.create_dict_iterator(): # each data is a dictionary
  137. # In this test, the dataset is 2x2 sequential tensors
  138. golden = np.array([[i * 4, (i + 1) * 4], [(i + 2) * 4, (i + 3) * 4]])
  139. assert np.array_equal(item["out"], golden)
  140. i = i + 4
  141. def test_case_7():
  142. """
  143. Test PyFunc
  144. """
  145. logger.info("Test 1-1 PyFunc Multiprocess: lambda x : x + x")
  146. # apply dataset operations
  147. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  148. data1 = data1.map(input_columns="col0", output_columns="out", operations=(lambda x: x + x),
  149. num_parallel_workers=4, python_multiprocessing = True)
  150. i = 0
  151. for item in data1.create_dict_iterator(): # each data is a dictionary
  152. # In this test, the dataset is 2x2 sequential tensors
  153. golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
  154. assert np.array_equal(item["out"], golden)
  155. i = i + 4
  156. def test_case_8():
  157. """
  158. Test PyFunc
  159. """
  160. logger.info("Test Multiprocess n-m PyFunc : lambda x, y : (x , x + 1, x + y)")
  161. col = ["col0", "col1"]
  162. # apply dataset operations
  163. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  164. data1 = data1.map(input_columns=col, output_columns=["out0", "out1", "out2"], num_parallel_workers=4,
  165. operations=(lambda x, y: (x, x + y, x + y + 1)), columns_order=["out0", "out1", "out2"],
  166. python_multiprocessing=True)
  167. i = 0
  168. for item in data1.create_dict_iterator(): # each data is a dictionary
  169. # In this test, the dataset is 2x2 sequential tensors
  170. golden = np.array([[i, i + 1], [i + 2, i + 3]])
  171. assert np.array_equal(item["out0"], golden)
  172. golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
  173. assert np.array_equal(item["out1"], golden)
  174. golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]])
  175. assert np.array_equal(item["out2"], golden)
  176. i = i + 4
  177. def test_case_9():
  178. """
  179. Test PyFunc
  180. """
  181. logger.info("Test multiple 1-1 PyFunc Multiprocess: lambda x : x + x")
  182. # apply dataset operations
  183. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  184. data1 = data1.map(input_columns="col0", output_columns="out", operations=[(lambda x: x + x), (lambda x: x + 1),
  185. (lambda x: x + 2)],
  186. num_parallel_workers=4, python_multiprocessing=True)
  187. i = 0
  188. for item in data1.create_dict_iterator(): # each data is a dictionary
  189. # In this test, the dataset is 2x2 sequential tensors
  190. golden = np.array([[i * 2 + 3, (i + 1) * 2 + 3], [(i + 2) * 2 + 3, (i + 3) * 2 + 3]])
  191. assert np.array_equal(item["out"], golden)
  192. i = i + 4
  193. def test_pyfunc_execption():
  194. logger.info("Test PyFunc Execption Throw: lambda x : raise Execption()")
  195. def pyfunc(x):
  196. raise Exception("Pyfunc Throw")
  197. with pytest.raises(RuntimeError) as info:
  198. # apply dataset operations
  199. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  200. data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc,
  201. num_parallel_workers=4)
  202. for _ in data1:
  203. pass
  204. assert "Pyfunc Throw" in str(info.value)
  205. def test_pyfunc_execption_multiprocess():
  206. logger.info("Test Multiprocess PyFunc Execption Throw: lambda x : raise Execption()")
  207. def pyfunc(x):
  208. raise Exception("MP Pyfunc Throw")
  209. with pytest.raises(RuntimeError) as info:
  210. # apply dataset operations
  211. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  212. data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc,
  213. num_parallel_workers=4, python_multiprocessing = True)
  214. for _ in data1:
  215. pass
  216. assert "MP Pyfunc Throw" in str(info.value)
  217. if __name__ == "__main__":
  218. test_case_0()
  219. test_case_1()
  220. test_case_2()
  221. test_case_3()
  222. test_case_4()
  223. test_case_5()
  224. test_case_6()
  225. test_case_7()
  226. test_case_8()
  227. test_case_9()
  228. test_pyfunc_execption()
  229. test_pyfunc_execption_multiprocess()