You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 21 kB

5 years ago
5 years ago
5 years ago
added python api based on cpp api 1st draft of python iterator Added Cifar10 and Cifar100 pybind port Change pybind to use IR for Skip and Manifest Signed-off-by: alex-yuyue <yue.yu1@huawei.com> DatasetNode as a base for all IR nodes namespace change Fix the namespace issue and make ut tests work Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Add VOCDataset !63 Added RandomDataset * Added RandomDataset add imagefolder ir Pybind switch: CelebA and UT !61 CLUE example with class definition * Merge branch 'python-api' of gitee.com:ezphlow/mindspore into clue_class_pybind * Passing testcases * Added CLUE, not working add ManifestDataset IR Signed-off-by: alex-yuyue <yue.yu1@huawei.com> Update Coco & VOC & TFReader, Update clang-format, Reorder datasets_binding !69 Add Generator and move c_dataset.Iterator to dataset.Iterator * Add GeneratorDataset to c_dataset * Add GeneratorDataset to c_dataset !67 Moving c_datasets and adding sampler wrapper * Need to add create() method in datasets.py * migration from c_dataset to dataset part 1 !71 Fix indent error * Fix indentation error !72 Fix c_api tests cases * Fix c_api tests cases !73 Added CSV Dataset * Added CSVDataset pybind switch: Take and CelebA fixes !75 move c_dataset functionality to datasets * Fixed existing testcases * Added working clue and imagefolder * Added sampler conversion from pybind * Added sampler creation !77 Add Python API tree * Python API tree add minddataset TextFileDataset pybind Rename to skip test_concat.py and test_minddataset_exception.py !80 Add batch IR to python-api branch, most test cases work * staging III * staging, add pybind Enable more c_api take and CelebA tests; delete util_c_api !84 Schema changes in datasets.py * Schema changes !85 Remove input_indexes from sub-classes * remove input_index from each subclass !83 Remove C datasets * Removed c_dataset package * Remove c_datasets !82 pybind switch: shuffle * pybind switch: shuffle !86 Add build_vocab * Add build_vocab Rebase with upstream/master _shuffle conflict BatchNode error !88 Fix rebase problem * fix rebase problem Enable more unit tests; code typo/nit fixes !91 Fix python vocag hang * Fix python vocab hang !89 Added BucketBatchByLength Pybind switch * Added BucketBatchByLength Update and enable more tet_c_api_*.py tests !95 Add BuildSentencePeiceVocab * - Add BuildSentencePeiceVocab !96 Fix more tests * - Fix some tests - Enable more test_c_api_* - Add syncwait !99 pybind switch for device op * pybind switch for device op !93 Add getters to python API * Add getters to python API !101 Validate tree, error if graph * - Add sync wait !103 TFrecord/Random Datasets schema problem * - TfRecord/Random schem aproblem !102 Added filter pybind switch * Added Filter pybind switch !104 Fix num_samples * - TfRecord/Random schem aproblem !105 Fix to_device hang * Fix to_device hang !94 Adds Cache support for CLUE dataset * Added cache for all dataset ops * format change * Added CLUE cache support * Added Cache conversion Add save pybind fix compile err init modify concat_node !107 Fix some tests cases * Fix tests cases Enable and fix more tests !109 pybind switch for get dataset size * pybind_get_dataset_size some check-code fixes for pylint, cpplint and clang-format !113 Add callback * revert * dataset_sz 1 line * fix typo * get callback to work !114 Make Android compile clean * Make Android Compile Clean Fix build issues due to rebase !115 Fix more tests * Fix tests cases * !93 Add getters to python API fix test_profiling.py !116 fix get dataset size * fix get dataset size !117 GetColumnNames pybind switch * Added GetColumnNames pybind switch code-check fixes: clangformat, cppcheck, cpplint, pylint Delete duplicate test_c_api_*.py files; more lint fixes !121 Fix cpp tests * Remove extra call to getNext in cpp tests !122 Fix Schema with Generator * Fix Schema with Generator fix some cases of csv & mindrecord !124 fix tfrecord get_dataset_size and add some UTs * fix tfrecord get dataset size and add some ut for get_dataset_size !125 getter separation * Getter separation !126 Fix sampler.GetNumSamples * Fix sampler.GetNumSampler !127 Assign runtime getter to each get function * Assign runtime getter to each get function Fix compile issues !128 Match master code * Match master code !129 Cleanup DeviceOp/save code * Cleanup ToDevice/Save code !130 Add cache fix * Added cache fix for map and image folder !132 Fix testing team issues * Pass queue_name from python to C++ * Add Schema.from_json !131 Fix Cache op issues and delete de_pipeline * Roll back C++ change * Removed de_pipeline and passing all cache tests. * fixed cache tests !134 Cleanup datasets.py part1 * Cleanup dataset.py part1 !133 Updated validation for SentencePieceVocab.from_dataset * Added type_check for column names in SentencePieceVocab.from_dataset Rebase on master 181120 10:20 fix profiling temporary solution of catching stauts from Node.Build() !141 ToDevice Termination * ToDevice termination pylint fixes !137 Fix test team issues and add some corresponding tests * Fix test team issues and add some corresponding tests !138 TreeGetter changes to use OptPass * Getter changes to use OptPass (Zirui) Rebase fix !143 Fix cpplint issue * Fix cpplint issue pylint fixes in updated testcases !145 Reset exceptions testcase * reset exception test to master !146 Fix Check_Pylint Error * Fix Check_Pylint Error !147 fix android * fix android !148 ToDevice changes * Add ToDevice to the iterator List for cleanup at exit !149 Pylint issue * Add ToDevice to the iterator List for cleanup at exit !150 Pylint 2 * Add ToDevice to the iterator List for cleanup at exit !152 ExecutionTree error * ET destructor error !153 in getter_pass, only remove callback, without deleting map op * getter pass no longer removes map !156 early __del__ of iterator/to_device * early __del__ of iterator !155 Address review comments Eric 1 * Added one liner fix to validators.py * roll back signature fix * lint fix * Eric Address comments 2 * C++ lint fix * Address comments Eric 1 !158 Review rework for dataset bindings - part 1 * Reorder nodes repeat and rename * Review rework for dataset bindings - part 1 !154 Fixing minor problems in the comments (datasets.py, python_tree_consumer.cc, iterators_bindings.cc, and iterators.py) * Fixing minor problems in the comments (datasets.py, python_tree_consum… !157 add replace none * Add replace_none to datasets.py, address comments in tests Trying to resolve copy Override the deepcopy method of deviceop Create_ir_tree method Create_ir_tree method 2 Create_ir_tree method 2 del to_device if already exists del to_device if already exists cache getters shapes and types Added yolov3 relaxation, to be rolled back Get shapes and types together bypass yolo NumWorkers for MapOp revert Yolo revert Thor Print more info Debug code: Update LOG INFO to LOG ERROR do not remove epochctrl for getter pass Remove repeat(1) pritn batch size add log to tree_consumer and device_queue op Revert PR 8744 Signed-off-by: alex-yuyue <yue.yu1@huawei.com> __del__ toDEvice __del__ toDevice2 !165 add ifndef ENABLE_ANDROID to device queue print * Add ifndef ENABLE_ANDROID to device queue print revert some changes !166 getter: get_data_info * getter: get_data_info !168 add back tree print * revert info to warnning in one log * add back the missed print tree log Release GIL in GetDataInfo
5 years ago
5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. validators for text ops
  17. """
  18. from functools import wraps
  19. import mindspore._c_dataengine as cde
  20. import mindspore.common.dtype as mstype
  21. from mindspore._c_expression import typing
  22. from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
  23. INT32_MAX, check_value, check_positive, check_pos_int32, check_filename, check_non_negative_int32
  24. def check_unique_list_of_words(words, arg_name):
  25. """Check that words is a list and each element is a str without any duplication"""
  26. type_check(words, (list,), arg_name)
  27. words_set = set()
  28. for word in words:
  29. type_check(word, (str,), arg_name)
  30. if word in words_set:
  31. raise ValueError(arg_name + " contains duplicate word: " + word + ".")
  32. words_set.add(word)
  33. return words_set
  34. def check_lookup(method):
  35. """A wrapper that wraps a parameter checker to the original function."""
  36. @wraps(method)
  37. def new_method(self, *args, **kwargs):
  38. [vocab, unknown_token, data_type], _ = parse_user_args(method, *args, **kwargs)
  39. if unknown_token is not None:
  40. type_check(unknown_token, (str,), "unknown_token")
  41. type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
  42. type_check(data_type, (typing.Type,), "data_type")
  43. return method(self, *args, **kwargs)
  44. return new_method
  45. def check_from_file(method):
  46. """A wrapper that wraps a parameter checker to the original function."""
  47. @wraps(method)
  48. def new_method(self, *args, **kwargs):
  49. [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args,
  50. **kwargs)
  51. if special_tokens is not None:
  52. check_unique_list_of_words(special_tokens, "special_tokens")
  53. type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"])
  54. if vocab_size is not None:
  55. check_positive(vocab_size, "vocab_size")
  56. type_check(special_first, (bool,), special_first)
  57. return method(self, *args, **kwargs)
  58. return new_method
  59. def check_from_list(method):
  60. """A wrapper that wraps a parameter checker to the original function."""
  61. @wraps(method)
  62. def new_method(self, *args, **kwargs):
  63. [word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs)
  64. word_set = check_unique_list_of_words(word_list, "word_list")
  65. if special_tokens is not None:
  66. token_set = check_unique_list_of_words(special_tokens, "special_tokens")
  67. intersect = word_set.intersection(token_set)
  68. if intersect != set():
  69. raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
  70. type_check(special_first, (bool,), "special_first")
  71. return method(self, *args, **kwargs)
  72. return new_method
  73. def check_from_dict(method):
  74. """A wrapper that wraps a parameter checker to the original function."""
  75. @wraps(method)
  76. def new_method(self, *args, **kwargs):
  77. [word_dict], _ = parse_user_args(method, *args, **kwargs)
  78. type_check(word_dict, (dict,), "word_dict")
  79. for word, word_id in word_dict.items():
  80. type_check(word, (str,), "word")
  81. type_check(word_id, (int,), "word_id")
  82. check_value(word_id, (0, INT32_MAX), "word_id")
  83. return method(self, *args, **kwargs)
  84. return new_method
  85. def check_jieba_init(method):
  86. """Wrapper method to check the parameters of jieba init."""
  87. @wraps(method)
  88. def new_method(self, *args, **kwargs):
  89. [hmm_path, mp_path, _, with_offsets], _ = parse_user_args(method, *args, **kwargs)
  90. if hmm_path is None:
  91. raise ValueError("The dict of HMMSegment in cppjieba is not provided.")
  92. if not isinstance(hmm_path, str):
  93. raise TypeError("Wrong input type for hmm_path, should be string.")
  94. if mp_path is None:
  95. raise ValueError("The dict of MPSegment in cppjieba is not provided.")
  96. if not isinstance(mp_path, str):
  97. raise TypeError("Wrong input type for mp_path, should be string.")
  98. if not isinstance(with_offsets, bool):
  99. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  100. return method(self, *args, **kwargs)
  101. return new_method
  102. def check_jieba_add_word(method):
  103. """Wrapper method to check the parameters of jieba add word."""
  104. @wraps(method)
  105. def new_method(self, *args, **kwargs):
  106. [word, freq], _ = parse_user_args(method, *args, **kwargs)
  107. if word is None:
  108. raise ValueError("word is not provided.")
  109. if freq is not None:
  110. check_uint32(freq)
  111. return method(self, *args, **kwargs)
  112. return new_method
  113. def check_jieba_add_dict(method):
  114. """Wrapper method to check the parameters of add dict."""
  115. @wraps(method)
  116. def new_method(self, *args, **kwargs):
  117. parse_user_args(method, *args, **kwargs)
  118. return method(self, *args, **kwargs)
  119. return new_method
  120. def check_with_offsets(method):
  121. """Wrapper method to check if with_offsets is the only one parameter."""
  122. @wraps(method)
  123. def new_method(self, *args, **kwargs):
  124. [with_offsets], _ = parse_user_args(method, *args, **kwargs)
  125. if not isinstance(with_offsets, bool):
  126. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  127. return method(self, *args, **kwargs)
  128. return new_method
  129. def check_unicode_script_tokenizer(method):
  130. """Wrapper method to check the parameter of UnicodeScriptTokenizer."""
  131. @wraps(method)
  132. def new_method(self, *args, **kwargs):
  133. [keep_whitespace, with_offsets], _ = parse_user_args(method, *args, **kwargs)
  134. if not isinstance(keep_whitespace, bool):
  135. raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
  136. if not isinstance(with_offsets, bool):
  137. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  138. return method(self, *args, **kwargs)
  139. return new_method
  140. def check_wordpiece_tokenizer(method):
  141. """Wrapper method to check the parameter of WordpieceTokenizer."""
  142. @wraps(method)
  143. def new_method(self, *args, **kwargs):
  144. [vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ = \
  145. parse_user_args(method, *args, **kwargs)
  146. if vocab is None:
  147. raise ValueError("vocab is not provided.")
  148. if not isinstance(vocab, cde.Vocab):
  149. raise TypeError("Wrong input type for vocab, should be Vocab object.")
  150. if not isinstance(suffix_indicator, str):
  151. raise TypeError("Wrong input type for suffix_indicator, should be string.")
  152. if not isinstance(unknown_token, str):
  153. raise TypeError("Wrong input type for unknown_token, should be string.")
  154. if not isinstance(with_offsets, bool):
  155. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  156. check_uint32(max_bytes_per_token)
  157. return method(self, *args, **kwargs)
  158. return new_method
  159. def check_regex_replace(method):
  160. """Wrapper method to check the parameter of RegexReplace."""
  161. @wraps(method)
  162. def new_method(self, *args, **kwargs):
  163. [pattern, replace, replace_all], _ = parse_user_args(method, *args, **kwargs)
  164. type_check(pattern, (str,), "pattern")
  165. type_check(replace, (str,), "replace")
  166. type_check(replace_all, (bool,), "replace_all")
  167. return method(self, *args, **kwargs)
  168. return new_method
  169. def check_regex_tokenizer(method):
  170. """Wrapper method to check the parameter of RegexTokenizer."""
  171. @wraps(method)
  172. def new_method(self, *args, **kwargs):
  173. [delim_pattern, keep_delim_pattern, with_offsets], _ = parse_user_args(method, *args, **kwargs)
  174. if delim_pattern is None:
  175. raise ValueError("delim_pattern is not provided.")
  176. if not isinstance(delim_pattern, str):
  177. raise TypeError("Wrong input type for delim_pattern, should be string.")
  178. if not isinstance(keep_delim_pattern, str):
  179. raise TypeError("Wrong input type for keep_delim_pattern, should be string.")
  180. if not isinstance(with_offsets, bool):
  181. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  182. return method(self, *args, **kwargs)
  183. return new_method
  184. def check_basic_tokenizer(method):
  185. """Wrapper method to check the parameter of RegexTokenizer."""
  186. @wraps(method)
  187. def new_method(self, *args, **kwargs):
  188. [lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ = \
  189. parse_user_args(method, *args, **kwargs)
  190. if not isinstance(lower_case, bool):
  191. raise TypeError("Wrong input type for lower_case, should be boolean.")
  192. if not isinstance(keep_whitespace, bool):
  193. raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
  194. if not isinstance(preserve_unused, bool):
  195. raise TypeError("Wrong input type for preserve_unused_token, should be boolean.")
  196. if not isinstance(with_offsets, bool):
  197. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  198. return method(self, *args, **kwargs)
  199. return new_method
  200. def check_bert_tokenizer(method):
  201. """Wrapper method to check the parameter of BertTokenizer."""
  202. @wraps(method)
  203. def new_method(self, *args, **kwargs):
  204. [vocab, suffix_indicator, max_bytes_per_token, unknown_token, lower_case, keep_whitespace, _,
  205. preserve_unused_token, with_offsets], _ = parse_user_args(method, *args, **kwargs)
  206. if vocab is None:
  207. raise ValueError("vacab is not provided.")
  208. if not isinstance(vocab, cde.Vocab):
  209. raise TypeError("Wrong input type for vocab, should be Vocab object.")
  210. if not isinstance(suffix_indicator, str):
  211. raise TypeError("Wrong input type for suffix_indicator, should be string.")
  212. if not isinstance(max_bytes_per_token, int):
  213. raise TypeError("Wrong input type for max_bytes_per_token, should be int.")
  214. check_uint32(max_bytes_per_token)
  215. if not isinstance(unknown_token, str):
  216. raise TypeError("Wrong input type for unknown_token, should be string.")
  217. if not isinstance(lower_case, bool):
  218. raise TypeError("Wrong input type for lower_case, should be boolean.")
  219. if not isinstance(keep_whitespace, bool):
  220. raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
  221. if not isinstance(preserve_unused_token, bool):
  222. raise TypeError("Wrong input type for preserve_unused_token, should be boolean.")
  223. if not isinstance(with_offsets, bool):
  224. raise TypeError("Wrong input type for with_offsets, should be boolean.")
  225. return method(self, *args, **kwargs)
  226. return new_method
  227. def check_from_dataset(method):
  228. """A wrapper that wraps a parameter checker to the original function."""
  229. @wraps(method)
  230. def new_method(self, *args, **kwargs):
  231. [_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args,
  232. **kwargs)
  233. if columns is not None:
  234. if not isinstance(columns, list):
  235. columns = [columns]
  236. type_check_list(columns, (str,), "col")
  237. if freq_range is not None:
  238. type_check(freq_range, (tuple,), "freq_range")
  239. if len(freq_range) != 2:
  240. raise ValueError("freq_range needs to be a tuple of 2 element.")
  241. for num in freq_range:
  242. if num is not None and (not isinstance(num, int)):
  243. raise ValueError(
  244. "freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
  245. if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
  246. if freq_range[0] > freq_range[1] or freq_range[0] < 0:
  247. raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
  248. type_check(top_k, (int, type(None)), "top_k")
  249. if isinstance(top_k, int):
  250. check_positive(top_k, "top_k")
  251. type_check(special_first, (bool,), "special_first")
  252. if special_tokens is not None:
  253. check_unique_list_of_words(special_tokens, "special_tokens")
  254. return method(self, *args, **kwargs)
  255. return new_method
  256. def check_slidingwindow(method):
  257. """A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
  258. @wraps(method)
  259. def new_method(self, *args, **kwargs):
  260. [width, axis], _ = parse_user_args(method, *args, **kwargs)
  261. check_pos_int32(width, "width")
  262. type_check(axis, (int,), "axis")
  263. return method(self, *args, **kwargs)
  264. return new_method
  265. def check_ngram(method):
  266. """A wrapper that wraps a parameter checker to the original function."""
  267. @wraps(method)
  268. def new_method(self, *args, **kwargs):
  269. [n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs)
  270. if isinstance(n, int):
  271. n = [n]
  272. if not (isinstance(n, list) and n != []):
  273. raise ValueError("n needs to be a non-empty list of positive integers.")
  274. for i, gram in enumerate(n):
  275. type_check(gram, (int,), "gram[{0}]".format(i))
  276. check_positive(gram, "gram_{}".format(i))
  277. if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
  278. left_pad[1], int)):
  279. raise ValueError("left_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.")
  280. if not (isinstance(right_pad, tuple) and len(right_pad) == 2 and isinstance(right_pad[0], str) and isinstance(
  281. right_pad[1], int)):
  282. raise ValueError("right_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.")
  283. if not (left_pad[1] >= 0 and right_pad[1] >= 0):
  284. raise ValueError("padding width need to be positive numbers.")
  285. type_check(separator, (str,), "separator")
  286. kwargs["n"] = n
  287. kwargs["left_pad"] = left_pad
  288. kwargs["right_pad"] = right_pad
  289. kwargs["separator"] = separator
  290. return method(self, **kwargs)
  291. return new_method
  292. def check_pair_truncate(method):
  293. """Wrapper method to check the parameters of number of pair truncate."""
  294. @wraps(method)
  295. def new_method(self, *args, **kwargs):
  296. parse_user_args(method, *args, **kwargs)
  297. return method(self, *args, **kwargs)
  298. return new_method
  299. def check_to_number(method):
  300. """A wrapper that wraps a parameter check to the original function (ToNumber)."""
  301. @wraps(method)
  302. def new_method(self, *args, **kwargs):
  303. [data_type], _ = parse_user_args(method, *args, **kwargs)
  304. type_check(data_type, (typing.Type,), "data_type")
  305. if data_type not in mstype.number_type:
  306. raise TypeError("data_type: " + str(data_type) + " is not numeric data type.")
  307. return method(self, *args, **kwargs)
  308. return new_method
  309. def check_python_tokenizer(method):
  310. """A wrapper that wraps a parameter check to the original function (PythonTokenizer)."""
  311. @wraps(method)
  312. def new_method(self, *args, **kwargs):
  313. [tokenizer], _ = parse_user_args(method, *args, **kwargs)
  314. if not callable(tokenizer):
  315. raise TypeError("tokenizer is not a callable Python function.")
  316. return method(self, *args, **kwargs)
  317. return new_method
  318. def check_from_dataset_sentencepiece(method):
  319. """A wrapper that wraps a parameter checker to the original function (from_dataset)."""
  320. @wraps(method)
  321. def new_method(self, *args, **kwargs):
  322. [_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
  323. if col_names is not None:
  324. type_check_list(col_names, (str,), "col_names")
  325. if vocab_size is not None:
  326. check_uint32(vocab_size, "vocab_size")
  327. else:
  328. raise TypeError("vocab_size must be provided.")
  329. if character_coverage is not None:
  330. type_check(character_coverage, (float,), "character_coverage")
  331. if model_type is not None:
  332. from .utils import SentencePieceModel
  333. type_check(model_type, (str, SentencePieceModel), "model_type")
  334. if params is not None:
  335. type_check(params, (dict,), "params")
  336. return method(self, *args, **kwargs)
  337. return new_method
  338. def check_from_file_sentencepiece(method):
  339. """A wrapper that wraps a parameter checker to the original function (from_file)."""
  340. @wraps(method)
  341. def new_method(self, *args, **kwargs):
  342. [file_path, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
  343. if file_path is not None:
  344. type_check(file_path, (list,), "file_path")
  345. if vocab_size is not None:
  346. check_uint32(vocab_size, "vocab_size")
  347. if character_coverage is not None:
  348. type_check(character_coverage, (float,), "character_coverage")
  349. if model_type is not None:
  350. from .utils import SentencePieceModel
  351. type_check(model_type, (str, SentencePieceModel), "model_type")
  352. if params is not None:
  353. type_check(params, (dict,), "params")
  354. return method(self, *args, **kwargs)
  355. return new_method
  356. def check_save_model(method):
  357. """A wrapper that wraps a parameter checker to the original function (save_model)."""
  358. @wraps(method)
  359. def new_method(self, *args, **kwargs):
  360. [vocab, path, filename], _ = parse_user_args(method, *args, **kwargs)
  361. if vocab is not None:
  362. type_check(vocab, (cde.SentencePieceVocab,), "vocab")
  363. if path is not None:
  364. type_check(path, (str,), "path")
  365. if filename is not None:
  366. type_check(filename, (str,), "filename")
  367. return method(self, *args, **kwargs)
  368. return new_method
  369. def check_sentence_piece_tokenizer(method):
  370. """A wrapper that wraps a parameter checker to the original function."""
  371. from .utils import SPieceTokenizerOutType
  372. @wraps(method)
  373. def new_method(self, *args, **kwargs):
  374. [mode, out_type], _ = parse_user_args(method, *args, **kwargs)
  375. type_check(mode, (str, cde.SentencePieceVocab), "mode is not an instance of str or cde.SentencePieceVocab.")
  376. type_check(out_type, (SPieceTokenizerOutType,), "out_type is not an instance of SPieceTokenizerOutType")
  377. return method(self, *args, **kwargs)
  378. return new_method
  379. def check_from_file_vectors(method):
  380. """A wrapper that wraps a parameter checker to from_file of class Vectors."""
  381. @wraps(method)
  382. def new_method(self, *args, **kwargs):
  383. [file_path, max_vectors], _ = parse_user_args(method, *args, **kwargs)
  384. type_check(file_path, (str,), "file_path")
  385. check_filename(file_path)
  386. if max_vectors is not None:
  387. type_check(max_vectors, (int,), "max_vectors")
  388. check_non_negative_int32(max_vectors, "max_vectors")
  389. return method(self, *args, **kwargs)
  390. return new_method
  391. def check_to_vectors(method):
  392. """A wrapper that wraps a parameter checker to ToVectors."""
  393. @wraps(method)
  394. def new_method(self, *args, **kwargs):
  395. [vectors, unk_init, lower_case_backup], _ = parse_user_args(method, *args, **kwargs)
  396. type_check(vectors, (cde.Vectors,), "vectors")
  397. if unk_init is not None:
  398. type_check(unk_init, (list, tuple), "unk_init")
  399. for i, value in enumerate(unk_init):
  400. type_check(value, (int, float), "unk_init[{0}]".format(i))
  401. type_check(lower_case_backup, (bool,), "lower_case_backup")
  402. return method(self, *args, **kwargs)
  403. return new_method