You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_csv.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.dataset as ds
  18. DATA_FILE = '../data/dataset/testCSV/1.csv'
  19. def test_csv_dataset_basic():
  20. """
  21. Test CSV with repeat, skip and so on
  22. """
  23. TRAIN_FILE = '../data/dataset/testCSV/1.csv'
  24. buffer = []
  25. data = ds.CSVDataset(
  26. TRAIN_FILE,
  27. field_delim=',',
  28. column_defaults=["0", 0, 0.0, "0"],
  29. column_names=['1', '2', '3', '4'],
  30. shuffle=False)
  31. data = data.repeat(2)
  32. data = data.skip(2)
  33. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  34. buffer.append(d)
  35. assert len(buffer) == 4
  36. def test_csv_dataset_one_file():
  37. data = ds.CSVDataset(
  38. DATA_FILE,
  39. column_defaults=["1", "2", "3", "4"],
  40. column_names=['col1', 'col2', 'col3', 'col4'],
  41. shuffle=False)
  42. buffer = []
  43. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  44. buffer.append(d)
  45. assert len(buffer) == 3
  46. def test_csv_dataset_all_file():
  47. APPEND_FILE = '../data/dataset/testCSV/2.csv'
  48. data = ds.CSVDataset(
  49. [DATA_FILE, APPEND_FILE],
  50. column_defaults=["1", "2", "3", "4"],
  51. column_names=['col1', 'col2', 'col3', 'col4'],
  52. shuffle=False)
  53. buffer = []
  54. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  55. buffer.append(d)
  56. assert len(buffer) == 10
  57. def test_csv_dataset_num_samples():
  58. data = ds.CSVDataset(
  59. DATA_FILE,
  60. column_defaults=["1", "2", "3", "4"],
  61. column_names=['col1', 'col2', 'col3', 'col4'],
  62. shuffle=False, num_samples=2)
  63. count = 0
  64. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  65. count += 1
  66. assert count == 2
  67. def test_csv_dataset_distribution():
  68. TEST_FILE = '../data/dataset/testCSV/1.csv'
  69. data = ds.CSVDataset(
  70. TEST_FILE,
  71. column_defaults=["1", "2", "3", "4"],
  72. column_names=['col1', 'col2', 'col3', 'col4'],
  73. shuffle=False, num_shards=2, shard_id=0)
  74. count = 0
  75. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  76. count += 1
  77. assert count == 2
  78. def test_csv_dataset_quoted():
  79. TEST_FILE = '../data/dataset/testCSV/quoted.csv'
  80. data = ds.CSVDataset(
  81. TEST_FILE,
  82. column_defaults=["", "", "", ""],
  83. column_names=['col1', 'col2', 'col3', 'col4'],
  84. shuffle=False)
  85. buffer = []
  86. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  87. buffer.extend([d['col1'].item().decode("utf8"),
  88. d['col2'].item().decode("utf8"),
  89. d['col3'].item().decode("utf8"),
  90. d['col4'].item().decode("utf8")])
  91. assert buffer == ['a', 'b', 'c', 'd']
  92. def test_csv_dataset_separated():
  93. TEST_FILE = '../data/dataset/testCSV/separated.csv'
  94. data = ds.CSVDataset(
  95. TEST_FILE,
  96. field_delim='|',
  97. column_defaults=["", "", "", ""],
  98. column_names=['col1', 'col2', 'col3', 'col4'],
  99. shuffle=False)
  100. buffer = []
  101. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  102. buffer.extend([d['col1'].item().decode("utf8"),
  103. d['col2'].item().decode("utf8"),
  104. d['col3'].item().decode("utf8"),
  105. d['col4'].item().decode("utf8")])
  106. assert buffer == ['a', 'b', 'c', 'd']
  107. def test_csv_dataset_embedded():
  108. TEST_FILE = '../data/dataset/testCSV/embedded.csv'
  109. data = ds.CSVDataset(
  110. TEST_FILE,
  111. column_defaults=["", "", "", ""],
  112. column_names=['col1', 'col2', 'col3', 'col4'],
  113. shuffle=False)
  114. buffer = []
  115. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  116. buffer.extend([d['col1'].item().decode("utf8"),
  117. d['col2'].item().decode("utf8"),
  118. d['col3'].item().decode("utf8"),
  119. d['col4'].item().decode("utf8")])
  120. assert buffer == ['a,b', 'c"d', 'e\nf', ' g ']
  121. def test_csv_dataset_chinese():
  122. TEST_FILE = '../data/dataset/testCSV/chinese.csv'
  123. data = ds.CSVDataset(
  124. TEST_FILE,
  125. column_defaults=["", "", "", "", ""],
  126. column_names=['col1', 'col2', 'col3', 'col4', 'col5'],
  127. shuffle=False)
  128. buffer = []
  129. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  130. buffer.extend([d['col1'].item().decode("utf8"),
  131. d['col2'].item().decode("utf8"),
  132. d['col3'].item().decode("utf8"),
  133. d['col4'].item().decode("utf8"),
  134. d['col5'].item().decode("utf8")])
  135. assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好']
  136. def test_csv_dataset_header():
  137. TEST_FILE = '../data/dataset/testCSV/header.csv'
  138. data = ds.CSVDataset(
  139. TEST_FILE,
  140. column_defaults=["", "", "", ""],
  141. shuffle=False)
  142. buffer = []
  143. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  144. buffer.extend([d['col1'].item().decode("utf8"),
  145. d['col2'].item().decode("utf8"),
  146. d['col3'].item().decode("utf8"),
  147. d['col4'].item().decode("utf8")])
  148. assert buffer == ['a', 'b', 'c', 'd']
  149. def test_csv_dataset_number():
  150. TEST_FILE = '../data/dataset/testCSV/number.csv'
  151. data = ds.CSVDataset(
  152. TEST_FILE,
  153. column_defaults=[0.0, 0.0, 0, 0.0],
  154. column_names=['col1', 'col2', 'col3', 'col4'],
  155. shuffle=False)
  156. buffer = []
  157. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  158. buffer.extend([d['col1'].item(),
  159. d['col2'].item(),
  160. d['col3'].item(),
  161. d['col4'].item()])
  162. assert np.allclose(buffer, [3.0, 0.3, 4, 55.5])
  163. def test_csv_dataset_field_delim_none():
  164. """
  165. Test CSV with field_delim=None
  166. """
  167. TRAIN_FILE = '../data/dataset/testCSV/1.csv'
  168. buffer = []
  169. data = ds.CSVDataset(
  170. TRAIN_FILE,
  171. field_delim=None,
  172. column_defaults=["0", 0, 0.0, "0"],
  173. column_names=['1', '2', '3', '4'],
  174. shuffle=False)
  175. data = data.repeat(2)
  176. data = data.skip(2)
  177. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  178. buffer.append(d)
  179. assert len(buffer) == 4
  180. def test_csv_dataset_size():
  181. TEST_FILE = '../data/dataset/testCSV/size.csv'
  182. data = ds.CSVDataset(
  183. TEST_FILE,
  184. column_defaults=[0.0, 0.0, 0, 0.0],
  185. column_names=['col1', 'col2', 'col3', 'col4'],
  186. shuffle=False)
  187. assert data.get_dataset_size() == 5
  188. def test_csv_dataset_type_error():
  189. TEST_FILE = '../data/dataset/testCSV/exception.csv'
  190. data = ds.CSVDataset(
  191. TEST_FILE,
  192. column_defaults=["", 0, "", ""],
  193. column_names=['col1', 'col2', 'col3', 'col4'],
  194. shuffle=False)
  195. with pytest.raises(Exception) as err:
  196. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  197. pass
  198. assert "type does not match" in str(err.value)
  199. def test_csv_dataset_exception():
  200. TEST_FILE = '../data/dataset/testCSV/exception.csv'
  201. data = ds.CSVDataset(
  202. TEST_FILE,
  203. column_defaults=["", "", "", ""],
  204. column_names=['col1', 'col2', 'col3', 'col4'],
  205. shuffle=False)
  206. with pytest.raises(Exception) as err:
  207. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  208. pass
  209. assert "failed to parse file" in str(err.value)
  210. TEST_FILE1 = '../data/dataset/testCSV/quoted.csv'
  211. def exception_func(item):
  212. raise Exception("Error occur!")
  213. try:
  214. data = ds.CSVDataset(
  215. TEST_FILE1,
  216. column_defaults=["", "", "", ""],
  217. column_names=['col1', 'col2', 'col3', 'col4'],
  218. shuffle=False)
  219. data = data.map(operations=exception_func, input_columns=["col1"], num_parallel_workers=1)
  220. for _ in data.__iter__():
  221. pass
  222. assert False
  223. except RuntimeError as e:
  224. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  225. try:
  226. data = ds.CSVDataset(
  227. TEST_FILE1,
  228. column_defaults=["", "", "", ""],
  229. column_names=['col1', 'col2', 'col3', 'col4'],
  230. shuffle=False)
  231. data = data.map(operations=exception_func, input_columns=["col2"], num_parallel_workers=1)
  232. for _ in data.__iter__():
  233. pass
  234. assert False
  235. except RuntimeError as e:
  236. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  237. try:
  238. data = ds.CSVDataset(
  239. TEST_FILE1,
  240. column_defaults=["", "", "", ""],
  241. column_names=['col1', 'col2', 'col3', 'col4'],
  242. shuffle=False)
  243. data = data.map(operations=exception_func, input_columns=["col3"], num_parallel_workers=1)
  244. for _ in data.__iter__():
  245. pass
  246. assert False
  247. except RuntimeError as e:
  248. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  249. try:
  250. data = ds.CSVDataset(
  251. TEST_FILE1,
  252. column_defaults=["", "", "", ""],
  253. column_names=['col1', 'col2', 'col3', 'col4'],
  254. shuffle=False)
  255. data = data.map(operations=exception_func, input_columns=["col4"], num_parallel_workers=1)
  256. for _ in data.__iter__():
  257. pass
  258. assert False
  259. except RuntimeError as e:
  260. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  261. def test_csv_dataset_duplicate_columns():
  262. data = ds.CSVDataset(
  263. DATA_FILE,
  264. column_defaults=["1", "2", "3", "4"],
  265. column_names=['col1', 'col2', 'col3', 'col4', 'col1', 'col2', 'col3', 'col4'],
  266. shuffle=False)
  267. with pytest.raises(RuntimeError) as info:
  268. _ = data.create_dict_iterator(num_epochs=1, output_numpy=True)
  269. assert "Invalid parameter, duplicate column names are not allowed: col1" in str(info.value)
  270. assert "column_names" in str(info.value)
  271. if __name__ == "__main__":
  272. test_csv_dataset_basic()
  273. test_csv_dataset_one_file()
  274. test_csv_dataset_all_file()
  275. test_csv_dataset_num_samples()
  276. test_csv_dataset_distribution()
  277. test_csv_dataset_quoted()
  278. test_csv_dataset_separated()
  279. test_csv_dataset_embedded()
  280. test_csv_dataset_chinese()
  281. test_csv_dataset_header()
  282. test_csv_dataset_number()
  283. test_csv_dataset_field_delim_none()
  284. test_csv_dataset_size()
  285. test_csv_dataset_type_error()
  286. test_csv_dataset_exception()
  287. test_csv_dataset_duplicate_columns()