You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_tedlium.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.audio.transforms as audio
  19. DATA_DIR_TEDLIUM_RELEASE12 = "../data/dataset/testTedliumData/TEDLIUM_release1"
  20. DATA_DIR_TEDLIUM_RELEASE3 = "../data/dataset/testTedliumData/TEDLIUM_release3"
  21. RELEASE1 = "release1"
  22. RELEASE2 = "release2"
  23. RELEASE3 = "release3"
  24. NO_SPH_DIR_TEDLIUM12 = "../data/dataset/testTedliumData/else"
  25. def test_tedlium_basic():
  26. """
  27. Feature: TedliumDataset
  28. Description: use different data to test the functions of different versions
  29. Expectation: num_samples
  30. set 1 2 4
  31. get 1 2 4
  32. num_parallel_workers
  33. set 1 2 4(num_samples=4)
  34. get 4 4 4
  35. num repeat
  36. set 3(num_samples=5)
  37. get 15
  38. """
  39. # case1 test num_samples
  40. data11 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=1)
  41. data12 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, num_samples=2)
  42. data13 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_samples=4)
  43. num_iter11 = 0
  44. num_iter12 = 0
  45. num_iter13 = 0
  46. for _ in data11.create_dict_iterator(num_epochs=1, output_numpy=True):
  47. num_iter11 += 1
  48. for _ in data12.create_dict_iterator(num_epochs=1, output_numpy=True):
  49. num_iter12 += 1
  50. for _ in data13.create_dict_iterator(num_epochs=1, output_numpy=True):
  51. num_iter13 += 1
  52. assert num_iter11 == 1
  53. assert num_iter12 == 2
  54. assert num_iter13 == 4
  55. # case2 test num_parallel_workers
  56. data21 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=4, num_parallel_workers=1)
  57. data22 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, num_samples=4, num_parallel_workers=2)
  58. data23 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_samples=4, num_parallel_workers=4)
  59. num_iter21 = 0
  60. num_iter22 = 0
  61. num_iter23 = 0
  62. for _ in data21.create_dict_iterator(num_epochs=1, output_numpy=True):
  63. num_iter21 += 1
  64. for _ in data22.create_dict_iterator(num_epochs=1, output_numpy=True):
  65. num_iter22 += 1
  66. for _ in data23.create_dict_iterator(num_epochs=1, output_numpy=True):
  67. num_iter23 += 1
  68. assert num_iter21 == 4
  69. assert num_iter22 == 4
  70. assert num_iter23 == 4
  71. # case3 test repeat
  72. data3 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=5)
  73. data3 = data3.repeat(3)
  74. num_iter3 = 0
  75. for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  76. num_iter3 += 1
  77. assert num_iter3 == 15
  78. def test_tedlium_content_check():
  79. """
  80. Feature: TedliumDataset
  81. Description: Check content of the first sample
  82. Expectation: correct content
  83. """
  84. data1 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=1, shuffle=False)
  85. data3 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_samples=1, shuffle=False)
  86. num_iter1 = 0
  87. num_iter3 = 0
  88. for d in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  89. waveform = d["waveform"]
  90. sample_rate = d["sample_rate"]
  91. transcript = d["transcript"]
  92. talk_id = d["talk_id"]
  93. speaker_id = d["speaker_id"]
  94. identifier = d["identifier"]
  95. assert waveform.dtype == np.float32
  96. assert waveform.shape == (1, 480)
  97. assert sample_rate == 16000
  98. assert sample_rate.dtype == np.int32
  99. assert talk_id.item().decode("utf8") == "test1"
  100. assert speaker_id.item().decode("utf8") == "test1"
  101. assert transcript.item().decode("utf8") == "this is record 1 of test1."
  102. assert identifier.item().decode("utf8") == "<o,f0,female>"
  103. num_iter1 += 1
  104. for d in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  105. waveform = d["waveform"]
  106. sample_rate = d["sample_rate"]
  107. transcript = d["transcript"]
  108. talk_id = d["talk_id"]
  109. speaker_id = d["speaker_id"]
  110. identifier = d["identifier"]
  111. assert waveform.dtype == np.float32
  112. assert waveform.shape == (1, 160)
  113. assert sample_rate == 16000
  114. assert sample_rate.dtype == np.int32
  115. assert talk_id.item().decode("utf8") == "test3"
  116. assert speaker_id.item().decode("utf8") == "test3"
  117. assert transcript.item().decode("utf8") == "this is record 1 of test3."
  118. assert identifier.item().decode("utf8") == "<o,f0,female>"
  119. num_iter3 += 1
  120. assert num_iter1 == 1
  121. assert num_iter3 == 1
  122. def test_tedlium_exceptions():
  123. """
  124. Feature: TedliumDataset
  125. Description: send error when error occur
  126. Expectation: send error
  127. """
  128. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  129. with pytest.raises(RuntimeError, match=error_msg_1):
  130. ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, shuffle=False, sampler=ds.PKSampler(3))
  131. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  132. with pytest.raises(RuntimeError, match=error_msg_2):
  133. ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  134. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  135. with pytest.raises(RuntimeError, match=error_msg_3):
  136. ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, num_shards=10)
  137. error_msg_4 = "shard_id is specified but num_shards is not"
  138. with pytest.raises(RuntimeError, match=error_msg_4):
  139. ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, shard_id=0)
  140. error_msg_5 = "Input shard_id is not within the required interval"
  141. with pytest.raises(ValueError, match=error_msg_5):
  142. ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_shards=2, shard_id=-1)
  143. with pytest.raises(ValueError, match=error_msg_5):
  144. ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, num_shards=2, shard_id=5)
  145. error_msg_6 = "num_parallel_workers exceeds"
  146. with pytest.raises(ValueError, match=error_msg_6):
  147. ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, shuffle=False, num_parallel_workers=0)
  148. with pytest.raises(ValueError, match=error_msg_6):
  149. ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, shuffle=False, num_parallel_workers=256)
  150. error_msg_7 = "Invalid data, no valid data matching the dataset API TedliumDataset"
  151. with pytest.raises(RuntimeError, match=error_msg_7):
  152. ds1 = ds.TedliumDataset(NO_SPH_DIR_TEDLIUM12, RELEASE1, "train")
  153. for _ in ds1.__iter__():
  154. pass
  155. def test_tedlium_exception_file_path():
  156. """
  157. Feature: TedliumDataset
  158. Description: error test
  159. Expectation: throw error
  160. """
  161. def exception_func(item):
  162. raise Exception("Error occur!")
  163. try:
  164. data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1)
  165. data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1)
  166. num_rows = 0
  167. for _ in data.create_dict_iterator():
  168. num_rows += 1
  169. assert False
  170. except RuntimeError as e:
  171. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  172. try:
  173. data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1)
  174. data = data.map(operations=exception_func, input_columns=["sample_rate"], num_parallel_workers=1)
  175. num_rows = 0
  176. for _ in data.create_dict_iterator():
  177. num_rows += 1
  178. assert False
  179. except RuntimeError as e:
  180. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  181. try:
  182. data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2)
  183. data = data.map(operations=exception_func, input_columns=["transcript"], num_parallel_workers=1)
  184. num_rows = 0
  185. for _ in data.create_dict_iterator():
  186. num_rows += 1
  187. assert False
  188. except RuntimeError as e:
  189. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  190. try:
  191. data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2)
  192. data = data.map(operations=exception_func, input_columns=["talk_id"], num_parallel_workers=1)
  193. num_rows = 0
  194. for _ in data.create_dict_iterator():
  195. num_rows += 1
  196. assert False
  197. except RuntimeError as e:
  198. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  199. try:
  200. data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3)
  201. data = data.map(operations=exception_func, input_columns=["speaker_id"], num_parallel_workers=1)
  202. num_rows = 0
  203. for _ in data.create_dict_iterator():
  204. num_rows += 1
  205. assert False
  206. except RuntimeError as e:
  207. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  208. try:
  209. data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3)
  210. data = data.map(operations=exception_func, input_columns=["identifier"], num_parallel_workers=1)
  211. num_rows = 0
  212. for _ in data.create_dict_iterator():
  213. num_rows += 1
  214. assert False
  215. except RuntimeError as e:
  216. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  217. def test_tedlium_extensions():
  218. """
  219. Feature: TedliumDataset
  220. Description: test extensions of tedlium
  221. Expectation: extensions
  222. set invalid data
  223. get throw error
  224. """
  225. try:
  226. data = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, "train", "invalid")
  227. for _ in data.create_dict_iterator(output_numpy=True):
  228. pass
  229. assert False
  230. except RuntimeError as e:
  231. assert "is not supported." in str(e)
  232. def test_tedlium_release():
  233. """
  234. Feature: TedliumDataset
  235. Description: test release of tedlium
  236. Expectation: release
  237. set invalid data
  238. get throw error
  239. """
  240. def test_config(release):
  241. try:
  242. ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, release)
  243. except (ValueError, TypeError, RuntimeError) as e:
  244. return str(e)
  245. return None
  246. # test the release
  247. assert "release is not within the valid set of ['release1', 'release2', 'release3']" in test_config("invalid")
  248. assert "Argument release with value None is not of type [<class 'str'>]" in test_config(None)
  249. assert "Argument release with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
  250. def test_tedlium_sequential_sampler():
  251. """
  252. Feature: TedliumDataset
  253. Description: test tedlium sequential sampler
  254. Expectation: correct data
  255. """
  256. num_samples = 3
  257. sampler = ds.SequentialSampler(num_samples=num_samples)
  258. data21 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, sampler=sampler)
  259. data22 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, shuffle=False, num_samples=num_samples)
  260. num_iter2 = 0
  261. for item1, item2 in zip(data21.create_dict_iterator(num_epochs=1, output_numpy=True),
  262. data22.create_dict_iterator(num_epochs=1, output_numpy=True)):
  263. np.testing.assert_equal(item1["waveform"], item2["waveform"])
  264. num_iter2 += 1
  265. assert num_iter2 == num_samples
  266. def test_tedlium_sampler_get_dataset_size():
  267. """
  268. Feature: TedliumDataset
  269. Description: test TedliumDataset with SequentialSampler and get_dataset_size
  270. Expectation: num_samples
  271. set 5
  272. get 5
  273. """
  274. sampler = ds.SequentialSampler(start_index=0, num_samples=5)
  275. data3 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE3, RELEASE3, sampler=sampler)
  276. num_iter3 = 0
  277. ds_sz3 = data3.get_dataset_size()
  278. for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  279. num_iter3 += 1
  280. assert ds_sz3 == num_iter3 == 5
  281. def test_tedlium_usage():
  282. """
  283. Feature: TedliumDataset
  284. Description: test usage of tedlium
  285. Expectation: usage
  286. set valid data invalid data
  287. get correct data throw error
  288. """
  289. def test_config_tedlium12(usage):
  290. try:
  291. data1 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, usage=usage)
  292. data2 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE2, usage=usage)
  293. num_rows = 0
  294. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  295. num_rows += 1
  296. for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
  297. num_rows += 1
  298. except (ValueError, TypeError, RuntimeError) as e:
  299. return str(e)
  300. return num_rows
  301. # test the usage of TEDLIUM
  302. assert test_config_tedlium12("dev") == 1 + 1
  303. assert test_config_tedlium12("test") == 2 + 2
  304. assert test_config_tedlium12("train") == 3 + 3
  305. assert test_config_tedlium12("all") == 1 + 1 + 2 + 2 + 3 + 3
  306. assert "usage is not within the valid set of ['train', 'test', 'dev', 'all']" in test_config_tedlium12("invalid")
  307. assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config_tedlium12(["list"])
  308. def test_tedlium_with_chained_sampler_get_dataset_size():
  309. """
  310. Feature: TedliumDataset
  311. Description: test TedliumDataset with RandomSampler chained with a SequentialSampler and get_dataset_size
  312. Expectation: num_samples
  313. set 2
  314. get 2
  315. """
  316. sampler = ds.SequentialSampler(start_index=0, num_samples=2)
  317. child_sampler = ds.RandomSampler()
  318. sampler.add_child(child_sampler)
  319. data1 = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, sampler=sampler)
  320. num_iter1 = 0
  321. ds_sz1 = data1.get_dataset_size()
  322. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  323. num_iter1 += 1
  324. assert ds_sz1 == num_iter1 == 2
  325. def test_tedlium_pipeline():
  326. """
  327. Feature: TedliumDataset
  328. Description: Read a sample
  329. Expectation: The amount of each function are equal
  330. """
  331. # Original waveform
  332. dataset = ds.TedliumDataset(DATA_DIR_TEDLIUM_RELEASE12, RELEASE1, num_samples=1)
  333. band_biquad_op = audio.BandBiquad(8000, 200.0)
  334. # Filtered waveform by bandbiquad
  335. dataset = dataset.map(input_columns=["waveform"], operations=band_biquad_op, num_parallel_workers=2)
  336. i = 0
  337. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  338. i += 1
  339. assert i == 1
  340. if __name__ == '__main__':
  341. test_tedlium_basic()
  342. test_tedlium_content_check()
  343. test_tedlium_exceptions()
  344. test_tedlium_exception_file_path()
  345. test_tedlium_extensions()
  346. test_tedlium_release()
  347. test_tedlium_sequential_sampler()
  348. test_tedlium_sampler_get_dataset_size()
  349. test_tedlium_usage()
  350. test_tedlium_with_chained_sampler_get_dataset_size()
  351. test_tedlium_pipeline()