You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 15 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. validators for text ops
  17. """
  18. from functools import wraps
  19. import mindspore._c_dataengine as cde
  20. import mindspore.common.dtype as mstype
  21. from mindspore._c_expression import typing
  22. from ..transforms.validators import check_uint32, check_pos_int64
  23. def check_unique_list_of_words(words, arg_name):
  24. """Check that words is a list and each element is a str without any duplication"""
  25. if not isinstance(words, list):
  26. raise ValueError(arg_name + " needs to be a list of words of type string.")
  27. words_set = set()
  28. for word in words:
  29. if not isinstance(word, str):
  30. raise ValueError("each word in " + arg_name + " needs to be type str.")
  31. if word in words_set:
  32. raise ValueError(arg_name + " contains duplicate word: " + word + ".")
  33. words_set.add(word)
  34. return words_set
  35. def check_lookup(method):
  36. """A wrapper that wrap a parameter checker to the original function."""
  37. @wraps(method)
  38. def new_method(self, *args, **kwargs):
  39. vocab, unknown = (list(args) + 2 * [None])[:2]
  40. if "vocab" in kwargs:
  41. vocab = kwargs.get("vocab")
  42. if "unknown" in kwargs:
  43. unknown = kwargs.get("unknown")
  44. if unknown is not None:
  45. if not (isinstance(unknown, int) and unknown >= 0):
  46. raise ValueError("unknown needs to be a non-negative integer.")
  47. if not isinstance(vocab, cde.Vocab):
  48. raise ValueError("vocab is not an instance of cde.Vocab.")
  49. kwargs["vocab"] = vocab
  50. kwargs["unknown"] = unknown
  51. return method(self, **kwargs)
  52. return new_method
  53. def check_from_file(method):
  54. """A wrapper that wrap a parameter checker to the original function."""
  55. @wraps(method)
  56. def new_method(self, *args, **kwargs):
  57. file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5]
  58. if "file_path" in kwargs:
  59. file_path = kwargs.get("file_path")
  60. if "delimiter" in kwargs:
  61. delimiter = kwargs.get("delimiter")
  62. if "vocab_size" in kwargs:
  63. vocab_size = kwargs.get("vocab_size")
  64. if "special_tokens" in kwargs:
  65. special_tokens = kwargs.get("special_tokens")
  66. if "special_first" in kwargs:
  67. special_first = kwargs.get("special_first")
  68. if not isinstance(file_path, str):
  69. raise ValueError("file_path needs to be str.")
  70. if delimiter is not None:
  71. if not isinstance(delimiter, str):
  72. raise ValueError("delimiter needs to be str.")
  73. else:
  74. delimiter = ""
  75. if vocab_size is not None:
  76. if not (isinstance(vocab_size, int) and vocab_size > 0):
  77. raise ValueError("vocab size needs to be a positive integer.")
  78. else:
  79. vocab_size = -1
  80. if special_first is None:
  81. special_first = True
  82. if not isinstance(special_first, bool):
  83. raise ValueError("special_first needs to be a boolean value")
  84. if special_tokens is None:
  85. special_tokens = []
  86. check_unique_list_of_words(special_tokens, "special_tokens")
  87. kwargs["file_path"] = file_path
  88. kwargs["delimiter"] = delimiter
  89. kwargs["vocab_size"] = vocab_size
  90. kwargs["special_tokens"] = special_tokens
  91. kwargs["special_first"] = special_first
  92. return method(self, **kwargs)
  93. return new_method
  94. def check_from_list(method):
  95. """A wrapper that wrap a parameter checker to the original function."""
  96. @wraps(method)
  97. def new_method(self, *args, **kwargs):
  98. word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3]
  99. if "word_list" in kwargs:
  100. word_list = kwargs.get("word_list")
  101. if "special_tokens" in kwargs:
  102. special_tokens = kwargs.get("special_tokens")
  103. if "special_first" in kwargs:
  104. special_first = kwargs.get("special_first")
  105. if special_tokens is None:
  106. special_tokens = []
  107. word_set = check_unique_list_of_words(word_list, "word_list")
  108. token_set = check_unique_list_of_words(special_tokens, "special_tokens")
  109. intersect = word_set.intersection(token_set)
  110. if intersect != set():
  111. raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
  112. if special_first is None:
  113. special_first = True
  114. if not isinstance(special_first, bool):
  115. raise ValueError("special_first needs to be a boolean value.")
  116. kwargs["word_list"] = word_list
  117. kwargs["special_tokens"] = special_tokens
  118. kwargs["special_first"] = special_first
  119. return method(self, **kwargs)
  120. return new_method
  121. def check_from_dict(method):
  122. """A wrapper that wrap a parameter checker to the original function."""
  123. @wraps(method)
  124. def new_method(self, *args, **kwargs):
  125. word_dict, = (list(args) + [None])[:1]
  126. if "word_dict" in kwargs:
  127. word_dict = kwargs.get("word_dict")
  128. if not isinstance(word_dict, dict):
  129. raise ValueError("word_dict needs to be a list of word,id pairs.")
  130. for word, word_id in word_dict.items():
  131. if not isinstance(word, str):
  132. raise ValueError("Each word in word_dict needs to be type string.")
  133. if not (isinstance(word_id, int) and word_id >= 0):
  134. raise ValueError("Each word id needs to be positive integer.")
  135. kwargs["word_dict"] = word_dict
  136. return method(self, **kwargs)
  137. return new_method
  138. def check_jieba_init(method):
  139. """Wrapper method to check the parameters of jieba add word."""
  140. @wraps(method)
  141. def new_method(self, *args, **kwargs):
  142. hmm_path, mp_path, model = (list(args) + 3 * [None])[:3]
  143. if "hmm_path" in kwargs:
  144. hmm_path = kwargs.get("hmm_path")
  145. if "mp_path" in kwargs:
  146. mp_path = kwargs.get("mp_path")
  147. if hmm_path is None:
  148. raise ValueError(
  149. "The dict of HMMSegment in cppjieba is not provided.")
  150. kwargs["hmm_path"] = hmm_path
  151. if mp_path is None:
  152. raise ValueError(
  153. "The dict of MPSegment in cppjieba is not provided.")
  154. kwargs["mp_path"] = mp_path
  155. if model is not None:
  156. kwargs["model"] = model
  157. return method(self, **kwargs)
  158. return new_method
  159. def check_jieba_add_word(method):
  160. """Wrapper method to check the parameters of jieba add word."""
  161. @wraps(method)
  162. def new_method(self, *args, **kwargs):
  163. word, freq = (list(args) + 2 * [None])[:2]
  164. if "word" in kwargs:
  165. word = kwargs.get("word")
  166. if "freq" in kwargs:
  167. freq = kwargs.get("freq")
  168. if word is None:
  169. raise ValueError("word is not provided.")
  170. kwargs["word"] = word
  171. if freq is not None:
  172. check_uint32(freq)
  173. kwargs["freq"] = freq
  174. return method(self, **kwargs)
  175. return new_method
  176. def check_jieba_add_dict(method):
  177. """Wrapper method to check the parameters of add dict."""
  178. @wraps(method)
  179. def new_method(self, *args, **kwargs):
  180. user_dict = (list(args) + [None])[0]
  181. if "user_dict" in kwargs:
  182. user_dict = kwargs.get("user_dict")
  183. if user_dict is None:
  184. raise ValueError("user_dict is not provided.")
  185. kwargs["user_dict"] = user_dict
  186. return method(self, **kwargs)
  187. return new_method
  188. def check_from_dataset(method):
  189. """A wrapper that wrap a parameter checker to the original function."""
  190. @wraps(method)
  191. def new_method(self, *args, **kwargs):
  192. dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6]
  193. if "dataset" in kwargs:
  194. dataset = kwargs.get("dataset")
  195. if "columns" in kwargs:
  196. columns = kwargs.get("columns")
  197. if "freq_range" in kwargs:
  198. freq_range = kwargs.get("freq_range")
  199. if "top_k" in kwargs:
  200. top_k = kwargs.get("top_k")
  201. if "special_tokens" in kwargs:
  202. special_tokens = kwargs.get("special_tokens")
  203. if "special_first" in kwargs:
  204. special_first = kwargs.get("special_first")
  205. if columns is None:
  206. columns = []
  207. if not isinstance(columns, list):
  208. columns = [columns]
  209. for column in columns:
  210. if not isinstance(column, str):
  211. raise ValueError("columns need to be a list of strings.")
  212. if freq_range is None:
  213. freq_range = (None, None)
  214. if not isinstance(freq_range, tuple) or len(freq_range) != 2:
  215. raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
  216. for num in freq_range:
  217. if num is not None and (not isinstance(num, int)):
  218. raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
  219. if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
  220. if freq_range[0] > freq_range[1] or freq_range[0] < 0:
  221. raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
  222. if top_k is not None and (not isinstance(top_k, int)):
  223. raise ValueError("top_k needs to be a positive integer.")
  224. if isinstance(top_k, int) and top_k <= 0:
  225. raise ValueError("top_k needs to be a positive integer.")
  226. if special_first is None:
  227. special_first = True
  228. if special_tokens is None:
  229. special_tokens = []
  230. if not isinstance(special_first, bool):
  231. raise ValueError("special_first needs to be a boolean value.")
  232. check_unique_list_of_words(special_tokens, "special_tokens")
  233. kwargs["dataset"] = dataset
  234. kwargs["columns"] = columns
  235. kwargs["freq_range"] = freq_range
  236. kwargs["top_k"] = top_k
  237. kwargs["special_tokens"] = special_tokens
  238. kwargs["special_first"] = special_first
  239. return method(self, **kwargs)
  240. return new_method
  241. def check_ngram(method):
  242. """A wrapper that wrap a parameter checker to the original function."""
  243. @wraps(method)
  244. def new_method(self, *args, **kwargs):
  245. n, left_pad, right_pad, separator = (list(args) + 4 * [None])[:4]
  246. if "n" in kwargs:
  247. n = kwargs.get("n")
  248. if "left_pad" in kwargs:
  249. left_pad = kwargs.get("left_pad")
  250. if "right_pad" in kwargs:
  251. right_pad = kwargs.get("right_pad")
  252. if "separator" in kwargs:
  253. separator = kwargs.get("separator")
  254. if isinstance(n, int):
  255. n = [n]
  256. if not (isinstance(n, list) and n != []):
  257. raise ValueError("n needs to be a non-empty list of positive integers.")
  258. for gram in n:
  259. if not (isinstance(gram, int) and gram > 0):
  260. raise ValueError("n in ngram needs to be a positive number.")
  261. if left_pad is None:
  262. left_pad = ("", 0)
  263. if right_pad is None:
  264. right_pad = ("", 0)
  265. if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
  266. left_pad[1], int)):
  267. raise ValueError("left_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.")
  268. if not (isinstance(right_pad, tuple) and len(right_pad) == 2 and isinstance(right_pad[0], str) and isinstance(
  269. right_pad[1], int)):
  270. raise ValueError("right_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.")
  271. if not (left_pad[1] >= 0 and right_pad[1] >= 0):
  272. raise ValueError("padding width need to be positive numbers.")
  273. if separator is None:
  274. separator = " "
  275. if not isinstance(separator, str):
  276. raise ValueError("separator needs to be a string.")
  277. kwargs["n"] = n
  278. kwargs["left_pad"] = left_pad
  279. kwargs["right_pad"] = right_pad
  280. kwargs["separator"] = separator
  281. return method(self, **kwargs)
  282. return new_method
  283. def check_pair_truncate(method):
  284. """Wrapper method to check the parameters of number of pair truncate."""
  285. @wraps(method)
  286. def new_method(self, *args, **kwargs):
  287. max_length = (list(args) + [None])[0]
  288. if "max_length" in kwargs:
  289. max_length = kwargs.get("max_length")
  290. if max_length is None:
  291. raise ValueError("max_length is not provided.")
  292. check_pos_int64(max_length)
  293. kwargs["max_length"] = max_length
  294. return method(self, **kwargs)
  295. return new_method
  296. def check_to_number(method):
  297. """A wrapper that wraps a parameter check to the original function (ToNumber)."""
  298. @wraps(method)
  299. def new_method(self, *args, **kwargs):
  300. data_type = (list(args) + [None])[0]
  301. if "data_type" in kwargs:
  302. data_type = kwargs.get("data_type")
  303. if data_type is None:
  304. raise ValueError("data_type is a mandatory parameter but was not provided.")
  305. if not isinstance(data_type, typing.Type):
  306. raise TypeError("data_type is not a MindSpore data type.")
  307. if data_type not in mstype.number_type:
  308. raise TypeError("data_type is not numeric data type.")
  309. kwargs["data_type"] = data_type
  310. return method(self, **kwargs)
  311. return new_method
  312. def check_python_tokenizer(method):
  313. """A wrapper that wraps a parameter check to the original function (PythonTokenizer)."""
  314. @wraps(method)
  315. def new_method(self, *args, **kwargs):
  316. tokenizer = (list(args) + [None])[0]
  317. if "tokenizer" in kwargs:
  318. tokenizer = kwargs.get("tokenizer")
  319. if tokenizer is None:
  320. raise ValueError("tokenizer is a mandatory parameter.")
  321. if not callable(tokenizer):
  322. raise TypeError("tokenizer is not a callable python function")
  323. kwargs["tokenizer"] = tokenizer
  324. return method(self, **kwargs)
  325. return new_method