You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 19 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. validators for text ops
  17. """
  18. from functools import wraps
  19. import mindspore.common.dtype as mstype
  20. import mindspore._c_dataengine as cde
  21. from mindspore._c_expression import typing
  22. from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
  23. INT32_MAX, check_value, check_positive, check_pos_int32
  24. def check_unique_list_of_words(words, arg_name):
  25. """Check that words is a list and each element is a str without any duplication"""
  26. type_check(words, (list,), arg_name)
  27. words_set = set()
  28. for word in words:
  29. type_check(word, (str,), arg_name)
  30. if word in words_set:
  31. raise ValueError(arg_name + " contains duplicate word: " + word + ".")
  32. words_set.add(word)
  33. return words_set
  34. def check_lookup(method):
  35. """A wrapper that wraps a parameter checker to the original function."""
  36. @wraps(method)
  37. def new_method(self, *args, **kwargs):
  38. [vocab, unknown_token, data_type], _ = parse_user_args(method, *args, **kwargs)
  39. if unknown_token is not None:
  40. type_check(unknown_token, (str,), "unknown_token")
  41. type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
  42. type_check(data_type, (typing.Type,), "data_type")
  43. return method(self, *args, **kwargs)
  44. return new_method
  45. def check_from_file(method):
  46. """A wrapper that wraps a parameter checker to the original function."""
  47. @wraps(method)
  48. def new_method(self, *args, **kwargs):
  49. [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args,
  50. **kwargs)
  51. if special_tokens is not None:
  52. check_unique_list_of_words(special_tokens, "special_tokens")
  53. type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"])
  54. if vocab_size is not None:
  55. check_positive(vocab_size, "vocab_size")
  56. type_check(special_first, (bool,), special_first)
  57. return method(self, *args, **kwargs)
  58. return new_method
  59. def check_from_list(method):
  60. """A wrapper that wraps a parameter checker to the original function."""
  61. @wraps(method)
  62. def new_method(self, *args, **kwargs):
  63. [word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs)
  64. word_set = check_unique_list_of_words(word_list, "word_list")
  65. if special_tokens is not None:
  66. token_set = check_unique_list_of_words(special_tokens, "special_tokens")
  67. intersect = word_set.intersection(token_set)
  68. if intersect != set():
  69. raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
  70. type_check(special_first, (bool,), "special_first")
  71. return method(self, *args, **kwargs)
  72. return new_method
  73. def check_from_dict(method):
  74. """A wrapper that wraps a parameter checker to the original function."""
  75. @wraps(method)
  76. def new_method(self, *args, **kwargs):
  77. [word_dict], _ = parse_user_args(method, *args, **kwargs)
  78. type_check(word_dict, (dict,), "word_dict")
  79. for word, word_id in word_dict.items():
  80. type_check(word, (str,), "word")
  81. type_check(word_id, (int,), "word_id")
  82. check_value(word_id, (0, INT32_MAX), "word_id")
  83. return method(self, *args, **kwargs)
  84. return new_method
  85. def check_jieba_init(method):
  86. """Wrapper method to check the parameters of jieba init."""
  87. @wraps(method)
  88. def new_method(self, *args, **kwargs):
  89. [hmm_path, mp_path, _, with_offsets], _ = parse_user_args(method, *args, **kwargs)
  90. if hmm_path is None:
  91. raise ValueError("The dict of HMMSegment in cppjieba is not provided.")
  92. if not isinstance(hmm_path, str):
  93. raise TypeError("Wrong input type for hmm_path, should be string.")
  94. if mp_path is None:
  95. raise ValueError("The dict of MPSegment in cppjieba is not provided.")
  96. if not isinstance(mp_path, str):
  97. raise TypeError("Wrong input type for mp_path, should be string.")
  98. if not isinstance(with_offsets, bool):
  99. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  100. return method(self, *args, **kwargs)
  101. return new_method
  102. def check_jieba_add_word(method):
  103. """Wrapper method to check the parameters of jieba add word."""
  104. @wraps(method)
  105. def new_method(self, *args, **kwargs):
  106. [word, freq], _ = parse_user_args(method, *args, **kwargs)
  107. if word is None:
  108. raise ValueError("word is not provided.")
  109. if freq is not None:
  110. check_uint32(freq)
  111. return method(self, *args, **kwargs)
  112. return new_method
  113. def check_jieba_add_dict(method):
  114. """Wrapper method to check the parameters of add dict."""
  115. @wraps(method)
  116. def new_method(self, *args, **kwargs):
  117. parse_user_args(method, *args, **kwargs)
  118. return method(self, *args, **kwargs)
  119. return new_method
  120. def check_with_offsets(method):
  121. """Wrapper method to check if with_offsets is the only one parameter."""
  122. @wraps(method)
  123. def new_method(self, *args, **kwargs):
  124. [with_offsets], _ = parse_user_args(method, *args, **kwargs)
  125. if not isinstance(with_offsets, bool):
  126. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  127. return method(self, *args, **kwargs)
  128. return new_method
  129. def check_unicode_script_tokenizer(method):
  130. """Wrapper method to check the parameter of UnicodeScriptTokenizer."""
  131. @wraps(method)
  132. def new_method(self, *args, **kwargs):
  133. [keep_whitespace, with_offsets], _ = parse_user_args(method, *args, **kwargs)
  134. if not isinstance(keep_whitespace, bool):
  135. raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
  136. if not isinstance(with_offsets, bool):
  137. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  138. return method(self, *args, **kwargs)
  139. return new_method
  140. def check_wordpiece_tokenizer(method):
  141. """Wrapper method to check the parameter of WordpieceTokenizer."""
  142. @wraps(method)
  143. def new_method(self, *args, **kwargs):
  144. [vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ = \
  145. parse_user_args(method, *args, **kwargs)
  146. if vocab is None:
  147. raise ValueError("vocab is not provided.")
  148. if not isinstance(vocab, cde.Vocab):
  149. raise TypeError("Wrong input type for vocab, should be Vocab object.")
  150. if not isinstance(suffix_indicator, str):
  151. raise TypeError("Wrong input type for suffix_indicator, should be string.")
  152. if not isinstance(unknown_token, str):
  153. raise TypeError("Wrong input type for unknown_token, should be string.")
  154. if not isinstance(with_offsets, bool):
  155. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  156. check_uint32(max_bytes_per_token)
  157. return method(self, *args, **kwargs)
  158. return new_method
  159. def check_regex_tokenizer(method):
  160. """Wrapper method to check the parameter of RegexTokenizer."""
  161. @wraps(method)
  162. def new_method(self, *args, **kwargs):
  163. [delim_pattern, keep_delim_pattern, with_offsets], _ = parse_user_args(method, *args, **kwargs)
  164. if delim_pattern is None:
  165. raise ValueError("delim_pattern is not provided.")
  166. if not isinstance(delim_pattern, str):
  167. raise TypeError("Wrong input type for delim_pattern, should be string.")
  168. if not isinstance(keep_delim_pattern, str):
  169. raise TypeError("Wrong input type for keep_delim_pattern, should be string.")
  170. if not isinstance(with_offsets, bool):
  171. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  172. return method(self, *args, **kwargs)
  173. return new_method
  174. def check_basic_tokenizer(method):
  175. """Wrapper method to check the parameter of RegexTokenizer."""
  176. @wraps(method)
  177. def new_method(self, *args, **kwargs):
  178. [lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ = \
  179. parse_user_args(method, *args, **kwargs)
  180. if not isinstance(lower_case, bool):
  181. raise TypeError("Wrong input type for lower_case, should be boolean.")
  182. if not isinstance(keep_whitespace, bool):
  183. raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
  184. if not isinstance(preserve_unused, bool):
  185. raise TypeError("Wrong input type for preserve_unused_token, should be boolean.")
  186. if not isinstance(with_offsets, bool):
  187. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  188. return method(self, *args, **kwargs)
  189. return new_method
  190. def check_bert_tokenizer(method):
  191. """Wrapper method to check the parameter of BertTokenizer."""
  192. @wraps(method)
  193. def new_method(self, *args, **kwargs):
  194. [vocab, suffix_indicator, max_bytes_per_token, unknown_token, lower_case, keep_whitespace, _,
  195. preserve_unused_token, with_offsets], _ = parse_user_args(method, *args, **kwargs)
  196. if vocab is None:
  197. raise ValueError("vacab is not provided.")
  198. if not isinstance(vocab, cde.Vocab):
  199. raise TypeError("Wrong input type for vocab, should be Vocab object.")
  200. if not isinstance(suffix_indicator, str):
  201. raise TypeError("Wrong input type for suffix_indicator, should be string.")
  202. if not isinstance(max_bytes_per_token, int):
  203. raise TypeError("Wrong input type for max_bytes_per_token, should be int.")
  204. check_uint32(max_bytes_per_token)
  205. if not isinstance(unknown_token, str):
  206. raise TypeError("Wrong input type for unknown_token, should be string.")
  207. if not isinstance(lower_case, bool):
  208. raise TypeError("Wrong input type for lower_case, should be boolean.")
  209. if not isinstance(keep_whitespace, bool):
  210. raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
  211. if not isinstance(preserve_unused_token, bool):
  212. raise TypeError("Wrong input type for preserve_unused_token, should be boolean.")
  213. if not isinstance(with_offsets, bool):
  214. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  215. return method(self, *args, **kwargs)
  216. return new_method
  217. def check_from_dataset(method):
  218. """A wrapper that wraps a parameter checker to the original function."""
  219. @wraps(method)
  220. def new_method(self, *args, **kwargs):
  221. [_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args,
  222. **kwargs)
  223. if columns is not None:
  224. if not isinstance(columns, list):
  225. columns = [columns]
  226. type_check_list(columns, (str,), "col")
  227. if freq_range is not None:
  228. type_check(freq_range, (tuple,), "freq_range")
  229. if len(freq_range) != 2:
  230. raise ValueError("freq_range needs to be a tuple of 2 integers or an int and a None.")
  231. for num in freq_range:
  232. if num is not None and (not isinstance(num, int)):
  233. raise ValueError(
  234. "freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
  235. if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
  236. if freq_range[0] > freq_range[1] or freq_range[0] < 0:
  237. raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
  238. type_check(top_k, (int, type(None)), "top_k")
  239. if isinstance(top_k, int):
  240. check_positive(top_k, "top_k")
  241. type_check(special_first, (bool,), "special_first")
  242. if special_tokens is not None:
  243. check_unique_list_of_words(special_tokens, "special_tokens")
  244. return method(self, *args, **kwargs)
  245. return new_method
  246. def check_slidingwindow(method):
  247. """A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
  248. @wraps(method)
  249. def new_method(self, *args, **kwargs):
  250. [width, axis], _ = parse_user_args(method, *args, **kwargs)
  251. check_pos_int32(width, "width")
  252. type_check(axis, (int,), "axis")
  253. return method(self, *args, **kwargs)
  254. return new_method
  255. def check_ngram(method):
  256. """A wrapper that wraps a parameter checker to the original function."""
  257. @wraps(method)
  258. def new_method(self, *args, **kwargs):
  259. [n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs)
  260. if isinstance(n, int):
  261. n = [n]
  262. if not (isinstance(n, list) and n != []):
  263. raise ValueError("n needs to be a non-empty list of positive integers.")
  264. for i, gram in enumerate(n):
  265. type_check(gram, (int,), "gram[{0}]".format(i))
  266. check_positive(gram, "gram_{}".format(i))
  267. if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
  268. left_pad[1], int)):
  269. raise ValueError("left_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.")
  270. if not (isinstance(right_pad, tuple) and len(right_pad) == 2 and isinstance(right_pad[0], str) and isinstance(
  271. right_pad[1], int)):
  272. raise ValueError("right_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.")
  273. if not (left_pad[1] >= 0 and right_pad[1] >= 0):
  274. raise ValueError("padding width need to be positive numbers.")
  275. type_check(separator, (str,), "separator")
  276. kwargs["n"] = n
  277. kwargs["left_pad"] = left_pad
  278. kwargs["right_pad"] = right_pad
  279. kwargs["separator"] = separator
  280. return method(self, **kwargs)
  281. return new_method
  282. def check_pair_truncate(method):
  283. """Wrapper method to check the parameters of number of pair truncate."""
  284. @wraps(method)
  285. def new_method(self, *args, **kwargs):
  286. parse_user_args(method, *args, **kwargs)
  287. return method(self, *args, **kwargs)
  288. return new_method
  289. def check_to_number(method):
  290. """A wrapper that wraps a parameter check to the original function (ToNumber)."""
  291. @wraps(method)
  292. def new_method(self, *args, **kwargs):
  293. [data_type], _ = parse_user_args(method, *args, **kwargs)
  294. type_check(data_type, (typing.Type,), "data_type")
  295. if data_type not in mstype.number_type:
  296. raise TypeError("data_type is not numeric data type.")
  297. return method(self, *args, **kwargs)
  298. return new_method
  299. def check_python_tokenizer(method):
  300. """A wrapper that wraps a parameter check to the original function (PythonTokenizer)."""
  301. @wraps(method)
  302. def new_method(self, *args, **kwargs):
  303. [tokenizer], _ = parse_user_args(method, *args, **kwargs)
  304. if not callable(tokenizer):
  305. raise TypeError("tokenizer is not a callable Python function")
  306. return method(self, *args, **kwargs)
  307. return new_method
  308. def check_from_dataset_sentencepiece(method):
  309. """A wrapper that wraps a parameter checker to the original function (from_dataset)."""
  310. @wraps(method)
  311. def new_method(self, *args, **kwargs):
  312. [_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
  313. if col_names is not None:
  314. type_check(col_names, (list,), "col_names")
  315. if vocab_size is not None:
  316. check_uint32(vocab_size, "vocab_size")
  317. if character_coverage is not None:
  318. type_check(character_coverage, (float,), "character_coverage")
  319. if model_type is not None:
  320. from .utils import SentencePieceModel
  321. type_check(model_type, (str, SentencePieceModel), "model_type")
  322. if params is not None:
  323. type_check(params, (dict,), "params")
  324. return method(self, *args, **kwargs)
  325. return new_method
  326. def check_from_file_sentencepiece(method):
  327. """A wrapper that wraps a parameter checker to the original function (from_file)."""
  328. @wraps(method)
  329. def new_method(self, *args, **kwargs):
  330. [file_path, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
  331. if file_path is not None:
  332. type_check(file_path, (list,), "file_path")
  333. if vocab_size is not None:
  334. check_uint32(vocab_size, "vocab_size")
  335. if character_coverage is not None:
  336. type_check(character_coverage, (float,), "character_coverage")
  337. if model_type is not None:
  338. from .utils import SentencePieceModel
  339. type_check(model_type, (str, SentencePieceModel), "model_type")
  340. if params is not None:
  341. type_check(params, (dict,), "params")
  342. return method(self, *args, **kwargs)
  343. return new_method
  344. def check_save_model(method):
  345. """A wrapper that wraps a parameter checker to the original function (save_model)."""
  346. @wraps(method)
  347. def new_method(self, *args, **kwargs):
  348. [vocab, path, filename], _ = parse_user_args(method, *args, **kwargs)
  349. if vocab is not None:
  350. type_check(vocab, (cde.SentencePieceVocab,), "vocab")
  351. if path is not None:
  352. type_check(path, (str,), "path")
  353. if filename is not None:
  354. type_check(filename, (str,), "filename")
  355. return method(self, *args, **kwargs)
  356. return new_method