You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Validators for TensorOps.
  17. """
  18. from functools import wraps
  19. from mindspore.dataset.core.validator_helpers import check_float32, check_float32_not_zero, check_int32,\
  20. check_int32_not_zero, check_list_same_size, check_non_negative_float32, check_non_negative_int32, \
  21. check_pos_float32, check_pos_int32, check_value, INT32_MAX, parse_user_args, type_check
  22. from .utils import BorderType, FadeShape, GainType, Interpolation, Modulation, ScaleType
  23. def check_amplitude_to_db(method):
  24. """Wrapper method to check the parameters of AmplitudeToDB."""
  25. @wraps(method)
  26. def new_method(self, *args, **kwargs):
  27. [stype, ref_value, amin, top_db], _ = parse_user_args(method, *args, **kwargs)
  28. type_check(stype, (ScaleType,), "stype")
  29. type_check(ref_value, (int, float), "ref_value")
  30. if ref_value is not None:
  31. check_pos_float32(ref_value, "ref_value")
  32. type_check(amin, (int, float), "amin")
  33. if amin is not None:
  34. check_pos_float32(amin, "amin")
  35. type_check(top_db, (int, float), "top_db")
  36. if top_db is not None:
  37. check_pos_float32(top_db, "top_db")
  38. return method(self, *args, **kwargs)
  39. return new_method
  40. def check_biquad_sample_rate(sample_rate):
  41. """Wrapper method to check the parameters of sample_rate."""
  42. type_check(sample_rate, (int,), "sample_rate")
  43. check_int32_not_zero(sample_rate, "sample_rate")
  44. def check_biquad_central_freq(central_freq):
  45. """Wrapper method to check the parameters of central_freq."""
  46. type_check(central_freq, (float, int), "central_freq")
  47. check_float32(central_freq, "central_freq")
  48. def check_biquad_Q(Q):
  49. """Wrapper method to check the parameters of Q."""
  50. type_check(Q, (float, int), "Q")
  51. check_value(Q, [0, 1], "Q", True)
  52. def check_biquad_noise(noise):
  53. """Wrapper method to check the parameters of noise."""
  54. type_check(noise, (bool,), "noise")
  55. def check_biquad_const_skirt_gain(const_skirt_gain):
  56. """Wrapper method to check the parameters of const_skirt_gain."""
  57. type_check(const_skirt_gain, (bool,), "const_skirt_gain")
  58. def check_biquad_gain(gain):
  59. """Wrapper method to check the parameters of gain."""
  60. type_check(gain, (float, int), "gain")
  61. check_float32(gain, "gain")
  62. def check_band_biquad(method):
  63. """Wrapper method to check the parameters of BandBiquad."""
  64. @wraps(method)
  65. def new_method(self, *args, **kwargs):
  66. [sample_rate, central_freq, Q, noise], _ = parse_user_args(
  67. method, *args, **kwargs)
  68. check_biquad_sample_rate(sample_rate)
  69. check_biquad_central_freq(central_freq)
  70. check_biquad_Q(Q)
  71. check_biquad_noise(noise)
  72. return method(self, *args, **kwargs)
  73. return new_method
  74. def check_biquad_cutoff_freq(cutoff_freq):
  75. """Wrapper method to check the parameters of cutoff_freq."""
  76. type_check(cutoff_freq, (float, int), "cutoff_freq")
  77. check_float32(cutoff_freq, "cutoff_freq")
  78. def check_highpass_biquad(method):
  79. """Wrapper method to check the parameters of HighpassBiquad."""
  80. @wraps(method)
  81. def new_method(self, *args, **kwargs):
  82. [sample_rate, cutoff_freq, Q], _ = parse_user_args(method, *args, **kwargs)
  83. check_biquad_sample_rate(sample_rate)
  84. check_biquad_cutoff_freq(cutoff_freq)
  85. check_biquad_Q(Q)
  86. return method(self, *args, **kwargs)
  87. return new_method
  88. def check_allpass_biquad(method):
  89. """Wrapper method to check the parameters of AllpassBiquad."""
  90. @wraps(method)
  91. def new_method(self, *args, **kwargs):
  92. [sample_rate, central_freq, Q], _ = parse_user_args(
  93. method, *args, **kwargs)
  94. check_biquad_sample_rate(sample_rate)
  95. check_biquad_central_freq(central_freq)
  96. check_biquad_Q(Q)
  97. return method(self, *args, **kwargs)
  98. return new_method
  99. def check_bandpass_biquad(method):
  100. """Wrapper method to check the parameters of BandpassBiquad."""
  101. @wraps(method)
  102. def new_method(self, *args, **kwargs):
  103. [sample_rate, central_freq, Q, const_skirt_gain], _ = parse_user_args(
  104. method, *args, **kwargs)
  105. check_biquad_sample_rate(sample_rate)
  106. check_biquad_central_freq(central_freq)
  107. check_biquad_Q(Q)
  108. check_biquad_const_skirt_gain(const_skirt_gain)
  109. return method(self, *args, **kwargs)
  110. return new_method
  111. def check_bandreject_biquad(method):
  112. """Wrapper method to check the parameters of BandrejectBiquad."""
  113. @wraps(method)
  114. def new_method(self, *args, **kwargs):
  115. [sample_rate, central_freq, Q], _ = parse_user_args(
  116. method, *args, **kwargs)
  117. check_biquad_sample_rate(sample_rate)
  118. check_biquad_central_freq(central_freq)
  119. check_biquad_Q(Q)
  120. return method(self, *args, **kwargs)
  121. return new_method
  122. def check_bass_biquad(method):
  123. """Wrapper method to check the parameters of BassBiquad."""
  124. @wraps(method)
  125. def new_method(self, *args, **kwargs):
  126. [sample_rate, gain, central_freq, Q], _ = parse_user_args(
  127. method, *args, **kwargs)
  128. check_biquad_sample_rate(sample_rate)
  129. check_biquad_gain(gain)
  130. check_biquad_central_freq(central_freq)
  131. check_biquad_Q(Q)
  132. return method(self, *args, **kwargs)
  133. return new_method
  134. def check_contrast(method):
  135. """Wrapper method to check the parameters of Contrast."""
  136. @wraps(method)
  137. def new_method(self, *args, **kwargs):
  138. [enhancement_amount], _ = parse_user_args(method, *args, **kwargs)
  139. type_check(enhancement_amount, (float, int), "enhancement_amount")
  140. check_value(enhancement_amount, [0, 100], "enhancement_amount")
  141. return method(self, *args, **kwargs)
  142. return new_method
  143. def check_db_to_amplitude(method):
  144. """Wrapper method to check the parameters of db_to_amplitude."""
  145. @wraps(method)
  146. def new_method(self, *args, **kwargs):
  147. [ref, power], _ = parse_user_args(method, *args, **kwargs)
  148. type_check(ref, (float, int), "ref")
  149. check_float32(ref, "ref")
  150. type_check(power, (float, int), "power")
  151. check_float32(power, "power")
  152. return method(self, *args, **kwargs)
  153. return new_method
  154. def check_dc_shift(method):
  155. """Wrapper method to check the parameters of DCShift."""
  156. @wraps(method)
  157. def new_method(self, *args, **kwargs):
  158. [shift, limiter_gain], _ = parse_user_args(method, *args, **kwargs)
  159. type_check(shift, (float, int), "shift")
  160. check_value(shift, [-2.0, 2.0], "shift")
  161. if limiter_gain is not None:
  162. type_check(limiter_gain, (float, int), "limiter_gain")
  163. return method(self, *args, **kwargs)
  164. return new_method
  165. def check_deemph_biquad(method):
  166. """Wrapper method to check the parameters of CutMixBatch."""
  167. @wraps(method)
  168. def new_method(self, *args, **kwargs):
  169. [sample_rate], _ = parse_user_args(method, *args, **kwargs)
  170. type_check(sample_rate, (int,), "sample_rate")
  171. if sample_rate not in (44100, 48000):
  172. raise ValueError("Input sample_rate should be 44100 or 48000, but got {0}.".format(sample_rate))
  173. return method(self, *args, **kwargs)
  174. return new_method
  175. def check_equalizer_biquad(method):
  176. """Wrapper method to check the parameters of EqualizerBiquad."""
  177. @wraps(method)
  178. def new_method(self, *args, **kwargs):
  179. [sample_rate, center_freq, gain, Q], _ = parse_user_args(method, *args, **kwargs)
  180. check_biquad_sample_rate(sample_rate)
  181. check_biquad_central_freq(center_freq)
  182. check_biquad_gain(gain)
  183. check_biquad_Q(Q)
  184. return method(self, *args, **kwargs)
  185. return new_method
  186. def check_lfilter(method):
  187. """Wrapper method to check the parameters of LFilter."""
  188. @wraps(method)
  189. def new_method(self, *args, **kwargs):
  190. [a_coeffs, b_coeffs, clamp], _ = parse_user_args(method, *args, **kwargs)
  191. type_check(a_coeffs, (list, tuple), "a_coeffs")
  192. type_check(b_coeffs, (list, tuple), "b_coeffs")
  193. for i, value in enumerate(a_coeffs):
  194. type_check(value, (float, int), "a_coeffs[{0}]".format(i))
  195. check_float32(value, "a_coeffs[{0}]".format(i))
  196. for i, value in enumerate(b_coeffs):
  197. type_check(value, (float, int), "b_coeffs[{0}]".format(i))
  198. check_float32(value, "b_coeffs[{0}]".format(i))
  199. check_list_same_size(a_coeffs, b_coeffs, "a_coeffs", "b_coeffs")
  200. type_check(clamp, (bool,), "clamp")
  201. return method(self, *args, **kwargs)
  202. return new_method
  203. def check_lowpass_biquad(method):
  204. """Wrapper method to check the parameters of LowpassBiquad."""
  205. @wraps(method)
  206. def new_method(self, *args, **kwargs):
  207. [sample_rate, cutoff_freq, Q], _ = parse_user_args(method, *args, **kwargs)
  208. check_biquad_sample_rate(sample_rate)
  209. check_biquad_cutoff_freq(cutoff_freq)
  210. check_biquad_Q(Q)
  211. return method(self, *args, **kwargs)
  212. return new_method
  213. def check_mu_law_coding(method):
  214. """Wrapper method to check the parameters of MuLawDecoding and MuLawEncoding"""
  215. @wraps(method)
  216. def new_method(self, *args, **kwargs):
  217. [quantization_channels], _ = parse_user_args(method, *args, **kwargs)
  218. check_pos_int32(quantization_channels, "quantization_channels")
  219. return method(self, *args, **kwargs)
  220. return new_method
  221. def check_overdrive(method):
  222. """Wrapper method to check the parameters of Overdrive."""
  223. @wraps(method)
  224. def new_method(self, *args, **kwargs):
  225. [gain, color], _ = parse_user_args(method, *args, **kwargs)
  226. type_check(gain, (float, int), "gain")
  227. check_value(gain, [0, 100], "gain")
  228. type_check(color, (float, int), "color")
  229. check_value(color, [0, 100], "color")
  230. return method(self, *args, **kwargs)
  231. return new_method
  232. def check_phaser(method):
  233. """Wrapper method to check the parameters of Phaser."""
  234. @wraps(method)
  235. def new_method(self, *args, **kwargs):
  236. [sample_rate, gain_in, gain_out, delay_ms, decay,
  237. mod_speed, sinusoidal], _ = parse_user_args(method, *args, **kwargs)
  238. type_check(sample_rate, (int,), "sample_rate")
  239. check_int32(sample_rate, "sample_rate")
  240. type_check(gain_in, (float, int), "gain_in")
  241. check_value(gain_in, [0, 1], "gain_in")
  242. type_check(gain_out, (float, int), "gain_out")
  243. check_value(gain_out, [0, 1e9], "gain_out")
  244. type_check(delay_ms, (float, int), "delay_ms")
  245. check_value(delay_ms, [0, 5.0], "delay_ms")
  246. type_check(decay, (float, int), "decay")
  247. check_value(decay, [0, 0.99], "decay")
  248. type_check(mod_speed, (float, int), "mod_speed")
  249. check_value(mod_speed, [0.1, 2], "mod_speed")
  250. type_check(sinusoidal, (bool,), "sinusoidal")
  251. return method(self, *args, **kwargs)
  252. return new_method
  253. def check_riaa_biquad(method):
  254. """Wrapper method to check the parameters of RiaaBiquad."""
  255. @wraps(method)
  256. def new_method(self, *args, **kwargs):
  257. [sample_rate], _ = parse_user_args(method, *args, **kwargs)
  258. type_check(sample_rate, (int,), "sample_rate")
  259. if sample_rate not in (44100, 48000, 88200, 96000):
  260. raise ValueError("sample_rate should be one of [44100, 48000, 88200, 96000], but got {0}.".format(
  261. sample_rate))
  262. return method(self, *args, **kwargs)
  263. return new_method
  264. def check_time_stretch(method):
  265. """Wrapper method to check the parameters of TimeStretch."""
  266. @wraps(method)
  267. def new_method(self, *args, **kwargs):
  268. [hop_length, n_freq, fixed_rate], _ = parse_user_args(method, *args, **kwargs)
  269. if hop_length is not None:
  270. type_check(hop_length, (int,), "hop_length")
  271. check_pos_int32(hop_length, "hop_length")
  272. type_check(n_freq, (int,), "n_freq")
  273. check_pos_int32(n_freq, "n_freq")
  274. if fixed_rate is not None:
  275. type_check(fixed_rate, (int, float), "fixed_rate")
  276. check_pos_float32(fixed_rate, "fixed_rate")
  277. return method(self, *args, **kwargs)
  278. return new_method
  279. def check_treble_biquad(method):
  280. """Wrapper method to check the parameters of TrebleBiquad."""
  281. @wraps(method)
  282. def new_method(self, *args, **kwargs):
  283. [sample_rate, gain, central_freq, Q], _ = parse_user_args(
  284. method, *args, **kwargs)
  285. check_biquad_sample_rate(sample_rate)
  286. check_biquad_gain(gain)
  287. check_biquad_central_freq(central_freq)
  288. check_biquad_Q(Q)
  289. return method(self, *args, **kwargs)
  290. return new_method
  291. def check_masking(method):
  292. """Wrapper method to check the parameters of TimeMasking and FrequencyMasking"""
  293. @wraps(method)
  294. def new_method(self, *args, **kwargs):
  295. [iid_masks, mask_param, mask_start, mask_value], _ = parse_user_args(
  296. method, *args, **kwargs)
  297. type_check(iid_masks, (bool,), "iid_masks")
  298. type_check(mask_param, (int,), "mask_param")
  299. check_non_negative_float32(mask_param, "mask_param")
  300. type_check(mask_start, (int,), "mask_start")
  301. check_non_negative_float32(mask_start, "mask_start")
  302. type_check(mask_value, (int, float), "mask_value")
  303. check_non_negative_float32(mask_value, "mask_value")
  304. return method(self, *args, **kwargs)
  305. return new_method
  306. def check_power(power):
  307. """Wrapper method to check the parameters of power."""
  308. type_check(power, (int, float), "power")
  309. check_non_negative_float32(power, "power")
  310. def check_complex_norm(method):
  311. """Wrapper method to check the parameters of ComplexNorm."""
  312. @wraps(method)
  313. def new_method(self, *args, **kwargs):
  314. [power], _ = parse_user_args(method, *args, **kwargs)
  315. check_power(power)
  316. return method(self, *args, **kwargs)
  317. return new_method
  318. def check_magphase(method):
  319. """Wrapper method to check the parameters of Magphase."""
  320. @wraps(method)
  321. def new_method(self, *args, **kwargs):
  322. [power], _ = parse_user_args(method, *args, **kwargs)
  323. check_power(power)
  324. return method(self, *args, **kwargs)
  325. return new_method
  326. def check_biquad_coeff(coeff, arg_name):
  327. """Wrapper method to check the parameters of coeff."""
  328. type_check(coeff, (float, int), arg_name)
  329. check_float32(coeff, arg_name)
  330. def check_biquad(method):
  331. """Wrapper method to check the parameters of Biquad."""
  332. @wraps(method)
  333. def new_method(self, *args, **kwargs):
  334. [b0, b1, b2, a0, a1, a2], _ = parse_user_args(
  335. method, *args, **kwargs)
  336. check_biquad_coeff(b0, "b0")
  337. check_biquad_coeff(b1, "b1")
  338. check_biquad_coeff(b2, "b2")
  339. type_check(a0, (float, int), "a0")
  340. check_float32_not_zero(a0, "a0")
  341. check_biquad_coeff(a1, "a1")
  342. check_biquad_coeff(a2, "a2")
  343. return method(self, *args, **kwargs)
  344. return new_method
  345. def check_fade(method):
  346. """Wrapper method to check the parameters of Fade."""
  347. @wraps(method)
  348. def new_method(self, *args, **kwargs):
  349. [fade_in_len, fade_out_len, fade_shape], _ = parse_user_args(method, *args, **kwargs)
  350. type_check(fade_in_len, (int,), "fade_in_len")
  351. check_non_negative_int32(fade_in_len, "fade_in_len")
  352. type_check(fade_out_len, (int,), "fade_out_len")
  353. check_non_negative_int32(fade_out_len, "fade_out_len")
  354. type_check(fade_shape, (FadeShape,), "fade_shape")
  355. return method(self, *args, **kwargs)
  356. return new_method
  357. def check_vol(method):
  358. """Wrapper method to check the parameters of Vol."""
  359. @wraps(method)
  360. def new_method(self, *args, **kwargs):
  361. [gain, gain_type], _ = parse_user_args(method, *args, **kwargs)
  362. type_check(gain, (int, float), "gain")
  363. type_check(gain_type, (GainType,), "gain_type")
  364. if gain_type == GainType.AMPLITUDE:
  365. check_non_negative_float32(gain, "gain")
  366. elif gain_type == GainType.POWER:
  367. check_pos_float32(gain, "gain")
  368. else:
  369. check_float32(gain, "gain")
  370. return method(self, *args, **kwargs)
  371. return new_method
  372. def check_detect_pitch_frequency(method):
  373. """Wrapper method to check the parameters of DetectPitchFrequency."""
  374. @wraps(method)
  375. def new_method(self, *args, **kwargs):
  376. [sample_rate, frame_time, win_length, freq_low, freq_high], _ = parse_user_args(
  377. method, *args, **kwargs)
  378. type_check(sample_rate, (int,), "sample_rate")
  379. check_int32_not_zero(sample_rate, "sample_rate")
  380. type_check(frame_time, (float, int), "frame_time")
  381. check_pos_float32(frame_time, "frame_time")
  382. type_check(win_length, (int,), "win_length")
  383. check_pos_int32(win_length, "win_length")
  384. type_check(freq_low, (int, float), "freq_low")
  385. check_pos_float32(freq_low, "freq_low")
  386. type_check(freq_high, (int, float), "freq_high")
  387. check_pos_float32(freq_high, "freq_high")
  388. return method(self, *args, **kwargs)
  389. return new_method
  390. def check_flanger(method):
  391. """Wrapper method to check the parameters of Flanger."""
  392. @wraps(method)
  393. def new_method(self, *args, **kwargs):
  394. [sample_rate, delay, depth, regen, width, speed, phase, modulation, interpolation], _ = parse_user_args(
  395. method, *args, **kwargs)
  396. type_check(sample_rate, (int,), "sample_rate")
  397. check_int32_not_zero(sample_rate, "sample_rate")
  398. type_check(delay, (float, int), "delay")
  399. check_value(delay, [0, 30], "delay")
  400. type_check(depth, (float, int), "depth")
  401. check_value(depth, [0, 10], "depth")
  402. type_check(regen, (float, int), "regen")
  403. check_value(regen, [-95, 95], "regen")
  404. type_check(width, (float, int), "width")
  405. check_value(width, [0, 100], "width")
  406. type_check(speed, (float, int), "speed")
  407. check_value(speed, [0.1, 10], "speed")
  408. type_check(phase, (float, int), "phase")
  409. check_value(phase, [0, 100], "phase")
  410. type_check(modulation, (Modulation,), "modulation")
  411. type_check(interpolation, (Interpolation,), "interpolation")
  412. return method(self, *args, **kwargs)
  413. return new_method
  414. def check_sliding_window_cmn(method):
  415. """Wrapper method to check the parameters of SlidingWidowCmn."""
  416. @wraps(method)
  417. def new_method(self, *args, **kwargs):
  418. [cmn_window, min_cmn_window, center, norm_vars], _ = parse_user_args(method, *args, **kwargs)
  419. type_check(cmn_window, (int,), "cmn_window")
  420. check_non_negative_int32(cmn_window, "cmn_window")
  421. type_check(min_cmn_window, (int,), "min_cmn_window")
  422. check_non_negative_int32(min_cmn_window, "min_cmn_window")
  423. type_check(center, (bool,), "center")
  424. type_check(norm_vars, (bool,), "norm_vars")
  425. return method(self, *args, **kwargs)
  426. return new_method
  427. def check_compute_deltas(method):
  428. """Wrapper method to check the parameter of ComputeDeltas."""
  429. @wraps(method)
  430. def new_method(self, *args, **kwargs):
  431. [win_length, pad_mode], _ = parse_user_args(method, *args, **kwargs)
  432. type_check(pad_mode, (BorderType,), "pad_mode")
  433. type_check(win_length, (int,), "win_length")
  434. check_value(win_length, (3, INT32_MAX), "win_length")
  435. return method(self, *args, **kwargs)
  436. return new_method