You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_bindings.cc 61 kB

6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <exception>
  17. #include "minddata/dataset/api/de_pipeline.h"
  18. #include "minddata/dataset/engine/cache/cache_client.h"
  19. #include "minddata/dataset/engine/datasetops/source/cifar_op.h"
  20. #include "minddata/dataset/engine/datasetops/source/clue_op.h"
  21. #include "minddata/dataset/engine/datasetops/source/csv_op.h"
  22. #include "minddata/dataset/engine/datasetops/source/coco_op.h"
  23. #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
  24. #include "minddata/dataset/engine/datasetops/source/io_block.h"
  25. #include "minddata/dataset/engine/datasetops/source/manifest_op.h"
  26. #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
  27. #include "minddata/dataset/engine/datasetops/source/mnist_op.h"
  28. #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
  29. #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
  30. #include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
  31. #include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
  32. #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
  33. #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
  34. #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
  35. #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
  36. #include "minddata/dataset/engine/datasetops/source/text_file_op.h"
  37. #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
  38. #include "minddata/dataset/engine/datasetops/source/voc_op.h"
  39. #include "minddata/dataset/engine/gnn/graph.h"
  40. #include "minddata/dataset/engine/jagged_connector.h"
  41. #include "minddata/dataset/kernels/compose_op.h"
  42. #include "minddata/dataset/kernels/data/concatenate_op.h"
  43. #include "minddata/dataset/kernels/data/duplicate_op.h"
  44. #include "minddata/dataset/kernels/data/fill_op.h"
  45. #include "minddata/dataset/kernels/data/mask_op.h"
  46. #include "minddata/dataset/kernels/data/one_hot_op.h"
  47. #include "minddata/dataset/kernels/data/pad_end_op.h"
  48. #include "minddata/dataset/kernels/data/slice_op.h"
  49. #include "minddata/dataset/kernels/data/to_float16_op.h"
  50. #include "minddata/dataset/kernels/data/type_cast_op.h"
  51. #include "minddata/dataset/kernels/image/auto_contrast_op.h"
  52. #include "minddata/dataset/kernels/image/bounding_box_augment_op.h"
  53. #include "minddata/dataset/kernels/image/center_crop_op.h"
  54. #include "minddata/dataset/kernels/image/cut_out_op.h"
  55. #include "minddata/dataset/kernels/image/decode_op.h"
  56. #include "minddata/dataset/kernels/image/equalize_op.h"
  57. #include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
  58. #include "minddata/dataset/kernels/image/image_utils.h"
  59. #include "minddata/dataset/kernels/image/invert_op.h"
  60. #include "minddata/dataset/kernels/image/normalize_op.h"
  61. #include "minddata/dataset/kernels/image/pad_op.h"
  62. #include "minddata/dataset/kernels/image/random_color_adjust_op.h"
  63. #include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
  64. #include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h"
  65. #include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h"
  66. #include "minddata/dataset/kernels/image/random_crop_op.h"
  67. #include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h"
  68. #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
  69. #include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h"
  70. #include "minddata/dataset/kernels/image/random_resize_op.h"
  71. #include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h"
  72. #include "minddata/dataset/kernels/image/random_rotation_op.h"
  73. #include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
  74. #include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
  75. #include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h"
  76. #include "minddata/dataset/kernels/image/rescale_op.h"
  77. #include "minddata/dataset/kernels/image/resize_bilinear_op.h"
  78. #include "minddata/dataset/kernels/image/resize_op.h"
  79. #include "minddata/dataset/kernels/image/resize_with_bbox_op.h"
  80. #include "minddata/dataset/kernels/image/uniform_aug_op.h"
  81. #include "minddata/dataset/kernels/no_op.h"
  82. #include "minddata/dataset/kernels/py_func_op.h"
  83. #include "minddata/dataset/kernels/random_apply_op.h"
  84. #include "minddata/dataset/kernels/random_choice_op.h"
  85. #include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
  86. #include "minddata/dataset/text/kernels/lookup_op.h"
  87. #include "minddata/dataset/text/kernels/ngram_op.h"
  88. #include "minddata/dataset/text/kernels/sliding_window_op.h"
  89. #include "minddata/dataset/text/kernels/to_number_op.h"
  90. #include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
  91. #include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
  92. #include "minddata/dataset/text/kernels/sentence_piece_tokenizer_op.h"
  93. #include "minddata/dataset/text/vocab.h"
  94. #include "minddata/dataset/text/sentence_piece_vocab.h"
  95. #include "minddata/dataset/util/random.h"
  96. #include "minddata/mindrecord/include/shard_distributed_sample.h"
  97. #include "minddata/mindrecord/include/shard_operator.h"
  98. #include "minddata/mindrecord/include/shard_pk_sample.h"
  99. #include "minddata/mindrecord/include/shard_sample.h"
  100. #include "minddata/mindrecord/include/shard_sequential_sample.h"
  101. #include "mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
  102. #include "pybind11/pybind11.h"
  103. #include "pybind11/stl.h"
  104. #include "pybind11/stl_bind.h"
  105. #ifdef ENABLE_ICU4C
  106. #include "minddata/dataset/text/kernels/basic_tokenizer_op.h"
  107. #include "minddata/dataset/text/kernels/bert_tokenizer_op.h"
  108. #include "minddata/dataset/text/kernels/case_fold_op.h"
  109. #include "minddata/dataset/text/kernels/normalize_utf8_op.h"
  110. #include "minddata/dataset/text/kernels/regex_replace_op.h"
  111. #include "minddata/dataset/text/kernels/regex_tokenizer_op.h"
  112. #include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h"
  113. #include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h"
  114. #endif
  115. namespace py = pybind11;
  116. namespace mindspore {
  117. namespace dataset {
  118. #define THROW_IF_ERROR(s) \
  119. do { \
  120. Status rc = std::move(s); \
  121. if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
  122. } while (false)
  123. Status PyListToTensorOps(const py::list &py_ops, std::vector<std::shared_ptr<TensorOp>> *ops) {
  124. RETURN_UNEXPECTED_IF_NULL(ops);
  125. for (auto op : py_ops) {
  126. if (py::isinstance<TensorOp>(op)) {
  127. ops->emplace_back(op.cast<std::shared_ptr<TensorOp>>());
  128. } else if (py::isinstance<py::function>(op)) {
  129. ops->emplace_back(std::make_shared<PyFuncOp>(op.cast<py::function>()));
  130. } else {
  131. RETURN_STATUS_UNEXPECTED("element is neither a TensorOp nor a pyfunc.");
  132. }
  133. }
  134. CHECK_FAIL_RETURN_UNEXPECTED(!ops->empty(), "TensorOp list is empty.");
  135. for (auto const &op : *ops) {
  136. RETURN_UNEXPECTED_IF_NULL(op);
  137. }
  138. return Status::OK();
  139. }
  140. void bindDEPipeline(py::module *m) {
  141. (void)py::class_<DEPipeline>(*m, "DEPipeline")
  142. .def(py::init<>())
  143. .def(
  144. "AddNodeToTree",
  145. [](DEPipeline &de, const OpName &op_name, const py::dict &args) {
  146. py::dict out;
  147. THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out));
  148. return out;
  149. },
  150. py::return_value_policy::reference)
  151. .def_static("AddChildToParentNode",
  152. [](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
  153. THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
  154. })
  155. .def("AssignRootNode",
  156. [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
  157. .def("SetBatchParameters",
  158. [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
  159. .def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); })
  160. .def("GetNextAsMap",
  161. [](DEPipeline &de) {
  162. py::dict out;
  163. THROW_IF_ERROR(de.GetNextAsMap(&out));
  164. return out;
  165. })
  166. .def("GetNextAsList",
  167. [](DEPipeline &de) {
  168. py::list out;
  169. THROW_IF_ERROR(de.GetNextAsList(&out));
  170. return out;
  171. })
  172. .def("GetOutputShapes",
  173. [](DEPipeline &de) {
  174. py::list out;
  175. THROW_IF_ERROR(de.GetOutputShapes(&out));
  176. return out;
  177. })
  178. .def("GetOutputTypes",
  179. [](DEPipeline &de) {
  180. py::list out;
  181. THROW_IF_ERROR(de.GetOutputTypes(&out));
  182. return out;
  183. })
  184. .def("GetDatasetSize", &DEPipeline::GetDatasetSize)
  185. .def("GetBatchSize", &DEPipeline::GetBatchSize)
  186. .def("GetNumClasses", &DEPipeline::GetNumClasses)
  187. .def("GetRepeatCount", &DEPipeline::GetRepeatCount)
  188. .def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
  189. .def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
  190. THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
  191. return true;
  192. });
  193. }
  194. void bindDatasetOps(py::module *m) {
  195. (void)py::class_<TFReaderOp, DatasetOp, std::shared_ptr<TFReaderOp>>(*m, "TFReaderOp")
  196. .def_static("get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) {
  197. int64_t count = 0;
  198. std::vector<std::string> filenames;
  199. for (auto l : files) {
  200. !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back("");
  201. }
  202. THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate));
  203. return count;
  204. });
  205. (void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
  206. .def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
  207. int64_t count = 0;
  208. THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
  209. return count;
  210. });
  211. (void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
  212. .def_static("get_num_rows_and_classes", [](const std::string &path) {
  213. int64_t count = 0, num_classes = 0;
  214. THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set<std::string>{}, &count, &num_classes));
  215. return py::make_tuple(count, num_classes);
  216. });
  217. (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
  218. .def_static("get_num_rows", [](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler,
  219. const int64_t num_padded) {
  220. int64_t count = 0;
  221. std::shared_ptr<mindrecord::ShardOperator> op;
  222. if (py::hasattr(sampler, "create_for_minddataset")) {
  223. auto create = sampler.attr("create_for_minddataset");
  224. op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
  225. }
  226. THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded));
  227. return count;
  228. });
  229. (void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
  230. .def_static("get_num_rows_and_classes",
  231. [](const std::string &file, const py::dict &dict, const std::string &usage) {
  232. int64_t count = 0, num_classes = 0;
  233. THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
  234. return py::make_tuple(count, num_classes);
  235. })
  236. .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) {
  237. std::map<std::string, int32_t> output_class_indexing;
  238. THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
  239. return output_class_indexing;
  240. });
  241. (void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
  242. .def_static("get_num_rows", [](const std::string &dir) {
  243. int64_t count = 0;
  244. THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
  245. return count;
  246. });
  247. (void)py::class_<TextFileOp, DatasetOp, std::shared_ptr<TextFileOp>>(*m, "TextFileOp")
  248. .def_static("get_num_rows", [](const py::list &files) {
  249. int64_t count = 0;
  250. std::vector<std::string> filenames;
  251. for (auto file : files) {
  252. !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back("");
  253. }
  254. THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
  255. return count;
  256. });
  257. (void)py::class_<ClueOp, DatasetOp, std::shared_ptr<ClueOp>>(*m, "ClueOp")
  258. .def_static("get_num_rows", [](const py::list &files) {
  259. int64_t count = 0;
  260. std::vector<std::string> filenames;
  261. for (auto file : files) {
  262. file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
  263. }
  264. THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count));
  265. return count;
  266. });
  267. (void)py::class_<CsvOp, DatasetOp, std::shared_ptr<CsvOp>>(*m, "CsvOp")
  268. .def_static("get_num_rows", [](const py::list &files, bool csv_header) {
  269. int64_t count = 0;
  270. std::vector<std::string> filenames;
  271. for (auto file : files) {
  272. file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
  273. }
  274. THROW_IF_ERROR(CsvOp::CountAllFileRows(filenames, csv_header, &count));
  275. return count;
  276. });
  277. (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
  278. .def_static("get_num_rows",
  279. [](const std::string &dir, const std::string &task_type, const std::string &task_mode,
  280. const py::dict &dict, int64_t numSamples) {
  281. int64_t count = 0;
  282. THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
  283. return count;
  284. })
  285. .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
  286. const std::string &task_mode, const py::dict &dict) {
  287. std::map<std::string, int32_t> output_class_indexing;
  288. THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing));
  289. return output_class_indexing;
  290. });
  291. (void)py::class_<CocoOp, DatasetOp, std::shared_ptr<CocoOp>>(*m, "CocoOp")
  292. .def_static("get_class_indexing",
  293. [](const std::string &dir, const std::string &file, const std::string &task) {
  294. std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
  295. THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing));
  296. return output_class_indexing;
  297. })
  298. .def_static("get_num_rows", [](const std::string &dir, const std::string &file, const std::string &task) {
  299. int64_t count = 0;
  300. THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count));
  301. return count;
  302. });
  303. }
  304. void bindTensor(py::module *m) {
  305. (void)py::class_<GlobalContext>(*m, "GlobalContext")
  306. .def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference);
  307. (void)py::class_<ConfigManager, std::shared_ptr<ConfigManager>>(*m, "ConfigManager")
  308. .def("__str__", &ConfigManager::ToString)
  309. .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer)
  310. .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers)
  311. .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
  312. .def("set_op_connector_size", &ConfigManager::set_op_connector_size)
  313. .def("set_seed", &ConfigManager::set_seed)
  314. .def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval)
  315. .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
  316. .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers)
  317. .def("get_worker_connector_size", &ConfigManager::worker_connector_size)
  318. .def("get_op_connector_size", &ConfigManager::op_connector_size)
  319. .def("get_seed", &ConfigManager::seed)
  320. .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
  321. .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
  322. (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
  323. .def(py::init([](py::array arr) {
  324. std::shared_ptr<Tensor> out;
  325. THROW_IF_ERROR(Tensor::CreateFromNpArray(arr, &out));
  326. return out;
  327. }))
  328. .def_buffer([](Tensor &tensor) {
  329. py::buffer_info info;
  330. THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info));
  331. return info;
  332. })
  333. .def("__str__", &Tensor::ToString)
  334. .def("shape", &Tensor::shape)
  335. .def("type", &Tensor::type)
  336. .def("as_array", [](py::object &t) {
  337. auto &tensor = py::cast<Tensor &>(t);
  338. if (tensor.type() == DataType::DE_STRING) {
  339. py::array res;
  340. tensor.GetDataAsNumpyStrings(&res);
  341. return res;
  342. }
  343. py::buffer_info info;
  344. THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info));
  345. return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t);
  346. });
  347. (void)py::class_<TensorShape>(*m, "TensorShape")
  348. .def(py::init([](const py::list &list) {
  349. std::vector<dsize_t> list_c;
  350. for (auto &i : list) {
  351. if (!i.is_none()) {
  352. list_c.push_back(i.cast<int>());
  353. } else {
  354. list_c.push_back(TensorShape::kDimUnknown);
  355. }
  356. }
  357. TensorShape out(list_c);
  358. return out;
  359. }))
  360. .def("__str__", &TensorShape::ToString)
  361. .def("as_list", &TensorShape::AsPyList)
  362. .def("is_known", &TensorShape::known);
  363. (void)py::class_<DataType>(*m, "DataType")
  364. .def(py::init<std::string>())
  365. .def(py::self == py::self)
  366. .def("__str__", &DataType::ToString)
  367. .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
  368. }
  369. void bindTensorOps1(py::module *m) {
  370. (void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
  371. .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
  372. (void)py::class_<AutoContrastOp, TensorOp, std::shared_ptr<AutoContrastOp>>(
  373. *m, "AutoContrastOp", "Tensor operation to apply autocontrast on an image.")
  374. .def(py::init<float, std::vector<uint32_t>>(), py::arg("cutoff") = AutoContrastOp::kCutOff,
  375. py::arg("ignore") = AutoContrastOp::kIgnore);
  376. (void)py::class_<NormalizeOp, TensorOp, std::shared_ptr<NormalizeOp>>(
  377. *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.")
  378. .def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"),
  379. py::arg("stdR"), py::arg("stdG"), py::arg("stdB"));
  380. (void)py::class_<EqualizeOp, TensorOp, std::shared_ptr<EqualizeOp>>(
  381. *m, "EqualizeOp", "Tensor operation to apply histogram equalization on images.")
  382. .def(py::init<>());
  383. (void)py::class_<InvertOp, TensorOp, std::shared_ptr<InvertOp>>(*m, "InvertOp",
  384. "Tensor operation to apply invert on RGB images.")
  385. .def(py::init<>());
  386. (void)py::class_<RescaleOp, TensorOp, std::shared_ptr<RescaleOp>>(
  387. *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.")
  388. .def(py::init<float, float>(), py::arg("rescale"), py::arg("shift"));
  389. (void)py::class_<CenterCropOp, TensorOp, std::shared_ptr<CenterCropOp>>(
  390. *m, "CenterCropOp", "Tensor operation to crop and image in the middle. Takes height and width (optional)")
  391. .def(py::init<int32_t, int32_t>(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth);
  392. (void)py::class_<ResizeOp, TensorOp, std::shared_ptr<ResizeOp>>(
  393. *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode")
  394. .def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"),
  395. py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation);
  396. (void)py::class_<ResizeWithBBoxOp, TensorOp, std::shared_ptr<ResizeWithBBoxOp>>(
  397. *m, "ResizeWithBBoxOp", "Tensor operation to resize an image. Takes height, width and mode.")
  398. .def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"),
  399. py::arg("targetWidth") = ResizeWithBBoxOp::kDefWidth,
  400. py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation);
  401. (void)py::class_<RandomResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomResizeWithBBoxOp>>(
  402. *m, "RandomResizeWithBBoxOp",
  403. "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.")
  404. .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"),
  405. py::arg("targetWidth") = RandomResizeWithBBoxOp::kDefTargetWidth);
  406. (void)py::class_<UniformAugOp, TensorOp, std::shared_ptr<UniformAugOp>>(
  407. *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).")
  408. .def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("transforms"),
  409. py::arg("NumOps") = UniformAugOp::kDefNumOps);
  410. (void)py::class_<BoundingBoxAugmentOp, TensorOp, std::shared_ptr<BoundingBoxAugmentOp>>(
  411. *m, "BoundingBoxAugmentOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.")
  412. .def(py::init<std::shared_ptr<TensorOp>, float>(), py::arg("transform"),
  413. py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio);
  414. (void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
  415. *m, "ResizeBilinearOp",
  416. "Tensor operation to resize an image using "
  417. "Bilinear mode. Takes height and width.")
  418. .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeBilinearOp::kDefWidth);
  419. (void)py::class_<DecodeOp, TensorOp, std::shared_ptr<DecodeOp>>(*m, "DecodeOp",
  420. "Tensor operation to decode a jpg image")
  421. .def(py::init<>())
  422. .def(py::init<bool>(), py::arg("rgb_format") = DecodeOp::kDefRgbFormat);
  423. (void)py::class_<RandomHorizontalFlipOp, TensorOp, std::shared_ptr<RandomHorizontalFlipOp>>(
  424. *m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.")
  425. .def(py::init<float>(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability);
  426. (void)py::class_<RandomHorizontalFlipWithBBoxOp, TensorOp, std::shared_ptr<RandomHorizontalFlipWithBBoxOp>>(
  427. *m, "RandomHorizontalFlipWithBBoxOp",
  428. "Tensor operation to randomly flip an image horizontally, while flipping bounding boxes.")
  429. .def(py::init<float>(), py::arg("probability") = RandomHorizontalFlipWithBBoxOp::kDefProbability);
  430. }
  431. void bindTensorOps2(py::module *m) {
  432. (void)py::class_<RandomVerticalFlipOp, TensorOp, std::shared_ptr<RandomVerticalFlipOp>>(
  433. *m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.")
  434. .def(py::init<float>(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability);
  435. (void)py::class_<RandomVerticalFlipWithBBoxOp, TensorOp, std::shared_ptr<RandomVerticalFlipWithBBoxOp>>(
  436. *m, "RandomVerticalFlipWithBBoxOp",
  437. "Tensor operation to randomly flip an image vertically"
  438. " and adjust bounding boxes.")
  439. .def(py::init<float>(), py::arg("probability") = RandomVerticalFlipWithBBoxOp::kDefProbability);
  440. (void)py::class_<RandomCropOp, TensorOp, std::shared_ptr<RandomCropOp>>(*m, "RandomCropOp",
  441. "Gives random crop of specified size "
  442. "Takes crop size")
  443. .def(py::init<int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, BorderType, bool, uint8_t, uint8_t, uint8_t>(),
  444. py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropOp::kDefPadTop,
  445. py::arg("padBottom") = RandomCropOp::kDefPadBottom, py::arg("padLeft") = RandomCropOp::kDefPadLeft,
  446. py::arg("padRight") = RandomCropOp::kDefPadRight, py::arg("borderType") = RandomCropOp::kDefBorderType,
  447. py::arg("padIfNeeded") = RandomCropOp::kDefPadIfNeeded, py::arg("fillR") = RandomCropOp::kDefFillR,
  448. py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB);
  449. (void)py::class_<HwcToChwOp, TensorOp, std::shared_ptr<HwcToChwOp>>(*m, "ChannelSwapOp").def(py::init<>());
  450. (void)py::class_<RandomCropWithBBoxOp, TensorOp, std::shared_ptr<RandomCropWithBBoxOp>>(*m, "RandomCropWithBBoxOp",
  451. "Gives random crop of given "
  452. "size + adjusts bboxes "
  453. "Takes crop size")
  454. .def(py::init<int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, BorderType, bool, uint8_t, uint8_t, uint8_t>(),
  455. py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropWithBBoxOp::kDefPadTop,
  456. py::arg("padBottom") = RandomCropWithBBoxOp::kDefPadBottom,
  457. py::arg("padLeft") = RandomCropWithBBoxOp::kDefPadLeft,
  458. py::arg("padRight") = RandomCropWithBBoxOp::kDefPadRight,
  459. py::arg("borderType") = RandomCropWithBBoxOp::kDefBorderType,
  460. py::arg("padIfNeeded") = RandomCropWithBBoxOp::kDefPadIfNeeded,
  461. py::arg("fillR") = RandomCropWithBBoxOp::kDefFillR, py::arg("fillG") = RandomCropWithBBoxOp::kDefFillG,
  462. py::arg("fillB") = RandomCropWithBBoxOp::kDefFillB);
  463. (void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(
  464. *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.")
  465. .def(py::init<int32_t>());
  466. (void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(
  467. *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.")
  468. .def(py::init<std::shared_ptr<Tensor>>());
  469. (void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "Tensor slice operation.")
  470. .def(py::init<bool>())
  471. .def(py::init([](const py::list &py_list) {
  472. std::vector<dsize_t> c_list;
  473. for (auto l : py_list) {
  474. if (!l.is_none()) {
  475. c_list.push_back(py::reinterpret_borrow<py::int_>(l));
  476. }
  477. }
  478. return std::make_shared<SliceOp>(c_list);
  479. }))
  480. .def(py::init([](const py::tuple &py_slice) {
  481. if (py_slice.size() != 3) {
  482. THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
  483. }
  484. Slice c_slice;
  485. if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) {
  486. c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1]),
  487. py::reinterpret_borrow<py::int_>(py_slice[2]));
  488. } else if (py_slice[0].is_none() && py_slice[2].is_none()) {
  489. c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[1]));
  490. } else if (!py_slice[0].is_none() && !py_slice[1].is_none()) {
  491. c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1]));
  492. }
  493. if (!c_slice.valid()) {
  494. THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
  495. }
  496. return std::make_shared<SliceOp>(c_slice);
  497. }));
  498. (void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic())
  499. .value("EQ", RelationalOp::kEqual)
  500. .value("NE", RelationalOp::kNotEqual)
  501. .value("LT", RelationalOp::kLess)
  502. .value("LE", RelationalOp::kLessEqual)
  503. .value("GT", RelationalOp::kGreater)
  504. .value("GE", RelationalOp::kGreaterEqual)
  505. .export_values();
  506. (void)py::class_<MaskOp, TensorOp, std::shared_ptr<MaskOp>>(*m, "MaskOp",
  507. "Tensor mask operation using relational comparator")
  508. .def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>());
  509. (void)py::class_<DuplicateOp, TensorOp, std::shared_ptr<DuplicateOp>>(*m, "DuplicateOp", "Duplicate tensor.")
  510. .def(py::init<>());
  511. (void)py::class_<TruncateSequencePairOp, TensorOp, std::shared_ptr<TruncateSequencePairOp>>(
  512. *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length")
  513. .def(py::init<int64_t>());
  514. (void)py::class_<ConcatenateOp, TensorOp, std::shared_ptr<ConcatenateOp>>(*m, "ConcatenateOp",
  515. "Tensor operation concatenate tensors.")
  516. .def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>(), py::arg("axis"),
  517. py::arg("prepend").none(true), py::arg("append").none(true));
  518. (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
  519. *m, "RandomRotationOp",
  520. "Tensor operation to apply RandomRotation."
  521. "Takes a range for degrees and "
  522. "optional parameters for rotation center and image expand")
  523. .def(py::init<float, float, float, float, InterpolationMode, bool, uint8_t, uint8_t, uint8_t>(),
  524. py::arg("startDegree"), py::arg("endDegree"), py::arg("centerX") = RandomRotationOp::kDefCenterX,
  525. py::arg("centerY") = RandomRotationOp::kDefCenterY,
  526. py::arg("interpolation") = RandomRotationOp::kDefInterpolation,
  527. py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR,
  528. py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
  529. (void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(
  530. *m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.")
  531. .def(py::init<TensorShape, std::shared_ptr<Tensor>>());
  532. }
  533. void bindTensorOps3(py::module *m) {
  534. (void)py::class_<RandomCropAndResizeOp, TensorOp, std::shared_ptr<RandomCropAndResizeOp>>(
  535. *m, "RandomCropAndResizeOp",
  536. "Tensor operation to randomly crop an image and resize to a given size."
  537. "Takes output height and width and"
  538. "optional parameters for lower and upper bound for aspect ratio (h/w) and scale,"
  539. "interpolation mode, and max attempts to crop")
  540. .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
  541. py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeOp::kDefScaleLb,
  542. py::arg("scaleUb") = RandomCropAndResizeOp::kDefScaleUb,
  543. py::arg("aspectLb") = RandomCropAndResizeOp::kDefAspectLb,
  544. py::arg("aspectUb") = RandomCropAndResizeOp::kDefAspectUb,
  545. py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation,
  546. py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter);
  547. (void)py::class_<RandomCropAndResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomCropAndResizeWithBBoxOp>>(
  548. *m, "RandomCropAndResizeWithBBoxOp",
  549. "Tensor operation to randomly crop an image (with BBoxes) and resize to a given size."
  550. "Takes output height and width and"
  551. "optional parameters for lower and upper bound for aspect ratio (h/w) and scale,"
  552. "interpolation mode, and max attempts to crop")
  553. .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
  554. py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeWithBBoxOp::kDefScaleLb,
  555. py::arg("scaleUb") = RandomCropAndResizeWithBBoxOp::kDefScaleUb,
  556. py::arg("aspectLb") = RandomCropAndResizeWithBBoxOp::kDefAspectLb,
  557. py::arg("aspectUb") = RandomCropAndResizeWithBBoxOp::kDefAspectUb,
  558. py::arg("interpolation") = RandomCropAndResizeWithBBoxOp::kDefInterpolation,
  559. py::arg("maxIter") = RandomCropAndResizeWithBBoxOp::kDefMaxIter);
  560. (void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>(
  561. *m, "RandomColorAdjustOp",
  562. "Tensor operation to adjust an image's color randomly."
  563. "Takes range for brightness, contrast, saturation, hue and")
  564. .def(py::init<float, float, float, float, float, float, float, float>(), py::arg("bright_factor_start"),
  565. py::arg("bright_factor_end"), py::arg("contrast_factor_start"), py::arg("contrast_factor_end"),
  566. py::arg("saturation_factor_start"), py::arg("saturation_factor_end"), py::arg("hue_factor_start"),
  567. py::arg("hue_factor_end"));
  568. (void)py::class_<RandomResizeOp, TensorOp, std::shared_ptr<RandomResizeOp>>(
  569. *m, "RandomResizeOp",
  570. "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.")
  571. .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"),
  572. py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth);
  573. (void)py::class_<CutOutOp, TensorOp, std::shared_ptr<CutOutOp>>(
  574. *m, "CutOutOp", "Tensor operation to randomly erase a portion of the image. Takes height and width.")
  575. .def(py::init<int32_t, int32_t, int32_t, bool, uint8_t, uint8_t, uint8_t>(), py::arg("boxHeight"),
  576. py::arg("boxWidth"), py::arg("numPatches"), py::arg("randomColor") = CutOutOp::kDefRandomColor,
  577. py::arg("fillR") = CutOutOp::kDefFillR, py::arg("fillG") = CutOutOp::kDefFillG,
  578. py::arg("fillB") = CutOutOp::kDefFillB);
  579. }
  580. void bindTensorOps4(py::module *m) {
  581. (void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>(
  582. *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.")
  583. .def(py::init<DataType>(), py::arg("data_type"))
  584. .def(py::init<std::string>(), py::arg("data_type"));
  585. (void)py::class_<NoOp, TensorOp, std::shared_ptr<NoOp>>(*m, "NoOp",
  586. "TensorOp that does nothing, for testing purposes only.")
  587. .def(py::init<>());
  588. (void)py::class_<ToFloat16Op, TensorOp, std::shared_ptr<ToFloat16Op>>(
  589. *m, "ToFloat16Op", py::dynamic_attr(), "Tensor operator to type cast float32 data to a float16 type.")
  590. .def(py::init<>());
  591. (void)py::class_<RandomCropDecodeResizeOp, TensorOp, std::shared_ptr<RandomCropDecodeResizeOp>>(
  592. *m, "RandomCropDecodeResizeOp", "equivalent to RandomCropAndResize but crops before decoding")
  593. .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
  594. py::arg("targetWidth"), py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb,
  595. py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb,
  596. py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb,
  597. py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb,
  598. py::arg("interpolation") = RandomCropDecodeResizeOp::kDefInterpolation,
  599. py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter);
  600. (void)py::class_<PadOp, TensorOp, std::shared_ptr<PadOp>>(
  601. *m, "PadOp",
  602. "Pads image with specified color, default black, "
  603. "Takes amount to pad for top, bottom, left, right of image, boarder type and color")
  604. .def(py::init<int32_t, int32_t, int32_t, int32_t, BorderType, uint8_t, uint8_t, uint8_t>(), py::arg("padTop"),
  605. py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType,
  606. py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
  607. (void)py::class_<ToNumberOp, TensorOp, std::shared_ptr<ToNumberOp>>(*m, "ToNumberOp",
  608. "TensorOp to convert strings to numbers.")
  609. .def(py::init<DataType>(), py::arg("data_type"))
  610. .def(py::init<std::string>(), py::arg("data_type"));
  611. }
  612. void bindTokenizerOps(py::module *m) {
  613. (void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(*m, "JiebaTokenizerOp", "")
  614. .def(py::init<const std::string &, const std::string &, const JiebaMode &, const bool &>(), py::arg("hmm_path"),
  615. py::arg("mp_path"), py::arg("mode") = JiebaMode::kMix,
  616. py::arg("with_offsets") = JiebaTokenizerOp::kDefWithOffsets)
  617. .def("add_word",
  618. [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); });
  619. (void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>(
  620. *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.")
  621. .def(py::init<const bool &>(), py::arg("with_offsets") = UnicodeCharTokenizerOp::kDefWithOffsets);
  622. (void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp",
  623. "Tensor operation to LookUp each word.")
  624. .def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word) {
  625. if (vocab == nullptr) {
  626. THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null."));
  627. }
  628. if (py_word.is_none()) {
  629. return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists);
  630. }
  631. std::string word = py::reinterpret_borrow<py::str>(py_word);
  632. WordIdType default_id = vocab->Lookup(word);
  633. if (default_id == Vocab::kNoTokenExists) {
  634. THROW_IF_ERROR(
  635. Status(StatusCode::kUnexpectedError, "default unknown token: " + word + " doesn't exist in vocab."));
  636. }
  637. return std::make_shared<LookupOp>(vocab, default_id);
  638. }));
  639. (void)py::class_<NgramOp, TensorOp, std::shared_ptr<NgramOp>>(*m, "NgramOp", "TensorOp performs ngram mapping.")
  640. .def(py::init<const std::vector<int32_t> &, int32_t, int32_t, const std::string &, const std::string &,
  641. const std::string &>(),
  642. py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"),
  643. py::arg("separator"));
  644. (void)py::class_<WordpieceTokenizerOp, TensorOp, std::shared_ptr<WordpieceTokenizerOp>>(
  645. *m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.")
  646. .def(
  647. py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &, const bool &>(),
  648. py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
  649. py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
  650. py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
  651. py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
  652. (void)py::class_<SlidingWindowOp, TensorOp, std::shared_ptr<SlidingWindowOp>>(
  653. *m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.")
  654. .def(py::init<uint32_t, int32_t>(), py::arg("width"), py::arg("axis"));
  655. (void)py::class_<SentencePieceTokenizerOp, TensorOp, std::shared_ptr<SentencePieceTokenizerOp>>(
  656. *m, "SentencePieceTokenizerOp", "Tokenize scalar token or 1-D tokens to tokens by sentence piece.")
  657. .def(py::init<std::shared_ptr<SentencePieceVocab> &, const SPieceTokenizerLoadType, const SPieceTokenizerOutType>(),
  658. py::arg("vocab"), py::arg("load_type") = SPieceTokenizerLoadType::kModel,
  659. py::arg("out_type") = SPieceTokenizerOutType::kString)
  660. .def(
  661. py::init<const std::string &, const std::string &, const SPieceTokenizerLoadType, const SPieceTokenizerOutType>(),
  662. py::arg("model_path"), py::arg("model_filename"), py::arg("load_type") = SPieceTokenizerLoadType::kFile,
  663. py::arg("out_type") = SPieceTokenizerOutType::kString);
  664. }
  665. void bindDependIcuTokenizerOps(py::module *m) {
  666. #ifdef ENABLE_ICU4C
  667. (void)py::class_<WhitespaceTokenizerOp, TensorOp, std::shared_ptr<WhitespaceTokenizerOp>>(
  668. *m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.")
  669. .def(py::init<const bool &>(), py::arg("with_offsets") = WhitespaceTokenizerOp::kDefWithOffsets);
  670. (void)py::class_<UnicodeScriptTokenizerOp, TensorOp, std::shared_ptr<UnicodeScriptTokenizerOp>>(
  671. *m, "UnicodeScriptTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.")
  672. .def(py::init<>())
  673. .def(py::init<const bool &, const bool &>(),
  674. py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace,
  675. py::arg("with_offsets") = UnicodeScriptTokenizerOp::kDefWithOffsets);
  676. (void)py::class_<CaseFoldOp, TensorOp, std::shared_ptr<CaseFoldOp>>(
  677. *m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor")
  678. .def(py::init<>());
  679. (void)py::class_<NormalizeUTF8Op, TensorOp, std::shared_ptr<NormalizeUTF8Op>>(
  680. *m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.")
  681. .def(py::init<>())
  682. .def(py::init<NormalizeForm>(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm);
  683. (void)py::class_<RegexReplaceOp, TensorOp, std::shared_ptr<RegexReplaceOp>>(
  684. *m, "RegexReplaceOp", "Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.")
  685. .def(py::init<const std::string &, const std::string &, bool>(), py::arg("pattern"), py::arg("replace"),
  686. py::arg("replace_all"));
  687. (void)py::class_<RegexTokenizerOp, TensorOp, std::shared_ptr<RegexTokenizerOp>>(
  688. *m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.")
  689. .def(py::init<const std::string &, const std::string &, const bool &>(), py::arg("delim_pattern"),
  690. py::arg("keep_delim_pattern"), py::arg("with_offsets") = RegexTokenizerOp::kDefWithOffsets);
  691. (void)py::class_<BasicTokenizerOp, TensorOp, std::shared_ptr<BasicTokenizerOp>>(
  692. *m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.")
  693. .def(py::init<const bool &, const bool &, const NormalizeForm &, const bool &, const bool &>(),
  694. py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
  695. py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
  696. py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
  697. py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken,
  698. py::arg("with_offsets") = BasicTokenizerOp::kDefWithOffsets);
  699. (void)py::class_<BertTokenizerOp, TensorOp, std::shared_ptr<BertTokenizerOp>>(*m, "BertTokenizerOp",
  700. "Tokenizer used for Bert text process.")
  701. .def(py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &, const bool &,
  702. const bool &, const NormalizeForm &, const bool &, const bool &>(),
  703. py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
  704. py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
  705. py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
  706. py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
  707. py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
  708. py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
  709. py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken,
  710. py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
  711. #endif
  712. }
  713. void bindSamplerOps(py::module *m) {
  714. (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
  715. .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
  716. .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
  717. .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
  718. .def("get_indices",
  719. [](Sampler &self) {
  720. py::array ret;
  721. THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
  722. return ret;
  723. })
  724. .def("add_child",
  725. [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
  726. (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator")
  727. .def("add_child", [](std::shared_ptr<mindrecord::ShardOperator> self,
  728. std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
  729. (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
  730. .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
  731. (void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
  732. .def(py::init<int64_t, int64_t, bool>());
  733. (void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
  734. .def(py::init<int64_t, bool, bool>());
  735. (void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
  736. .def(py::init<int64_t, int64_t>());
  737. (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
  738. .def(py::init<int64_t, std::vector<int64_t>>());
  739. (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
  740. *m, "MindrecordSubsetRandomSampler")
  741. .def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
  742. (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
  743. *m, "MindrecordPkSampler")
  744. .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
  745. if (shuffle == true) {
  746. return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
  747. GetSeed());
  748. } else {
  749. return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
  750. }
  751. }));
  752. (void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
  753. std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler")
  754. .def(py::init<int64_t, int64_t, bool, uint32_t, int64_t>());
  755. (void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
  756. *m, "MindrecordRandomSampler")
  757. .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
  758. return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
  759. }));
  760. (void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
  761. std::shared_ptr<mindrecord::ShardSequentialSample>>(*m, "MindrecordSequentialSampler")
  762. .def(py::init([](int num_samples, int start_index) {
  763. return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
  764. }));
  765. (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
  766. .def(py::init<int64_t, std::vector<double>, bool>());
  767. (void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
  768. .def(py::init<int64_t, py::object>());
  769. }
  770. void bindInfoObjects(py::module *m) {
  771. (void)py::class_<BatchOp::CBatchInfo>(*m, "CBatchInfo")
  772. .def(py::init<int64_t, int64_t, int64_t>())
  773. .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num)
  774. .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
  775. }
  776. void bindCacheClient(py::module *m) {
  777. (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
  778. .def(py::init<uint32_t, uint64_t, bool>());
  779. }
  780. void bindVocabObjects(py::module *m) {
  781. (void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
  782. .def(py::init<>())
  783. .def_static("from_list",
  784. [](const py::list &words, const py::list &special_tokens, bool special_first) {
  785. std::shared_ptr<Vocab> v;
  786. THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v));
  787. return v;
  788. })
  789. .def_static("from_file",
  790. [](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens,
  791. bool special_first) {
  792. std::shared_ptr<Vocab> v;
  793. THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v));
  794. return v;
  795. })
  796. .def_static("from_dict", [](const py::dict &words) {
  797. std::shared_ptr<Vocab> v;
  798. THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v));
  799. return v;
  800. });
  801. (void)py::class_<SentencePieceVocab, std::shared_ptr<SentencePieceVocab>>(*m, "SentencePieceVocab")
  802. .def(py::init<>())
  803. .def_static("from_file",
  804. [](const py::list &paths, const int vocab_size, const float character_coverage,
  805. const SentencePieceModel model_type, const py::dict &params) {
  806. std::shared_ptr<SentencePieceVocab> v;
  807. std::vector<std::string> path_list;
  808. for (auto path : paths) {
  809. path_list.emplace_back(py::str(path));
  810. }
  811. std::unordered_map<std::string, std::string> param_map;
  812. for (auto param : params) {
  813. std::string key = py::reinterpret_borrow<py::str>(param.first);
  814. if (key == "input" || key == "vocab_size" || key == "model_prefix" || key == "character_coverage" ||
  815. key == "model_type") {
  816. continue;
  817. }
  818. param_map[key] = py::reinterpret_borrow<py::str>(param.second);
  819. }
  820. THROW_IF_ERROR(SentencePieceVocab::BuildFromFile(path_list, vocab_size, character_coverage,
  821. model_type, param_map, &v));
  822. return v;
  823. })
  824. .def_static("save_model",
  825. [](const std::shared_ptr<SentencePieceVocab> *vocab, std::string path, std::string filename) {
  826. THROW_IF_ERROR(SentencePieceVocab::SaveModel(vocab, path, filename));
  827. });
  828. }
  829. void bindGraphData(py::module *m) {
  830. (void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph")
  831. .def(py::init([](std::string dataset_file, int32_t num_workers) {
  832. std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers);
  833. THROW_IF_ERROR(g_out->Init());
  834. return g_out;
  835. }))
  836. .def("get_all_nodes",
  837. [](gnn::Graph &g, gnn::NodeType node_type) {
  838. std::shared_ptr<Tensor> out;
  839. THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
  840. return out;
  841. })
  842. .def("get_all_edges",
  843. [](gnn::Graph &g, gnn::EdgeType edge_type) {
  844. std::shared_ptr<Tensor> out;
  845. THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
  846. return out;
  847. })
  848. .def("get_nodes_from_edges",
  849. [](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) {
  850. std::shared_ptr<Tensor> out;
  851. THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
  852. return out;
  853. })
  854. .def("get_all_neighbors",
  855. [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
  856. std::shared_ptr<Tensor> out;
  857. THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
  858. return out;
  859. })
  860. .def("get_sampled_neighbors",
  861. [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
  862. std::vector<gnn::NodeType> neighbor_types) {
  863. std::shared_ptr<Tensor> out;
  864. THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
  865. return out;
  866. })
  867. .def("get_neg_sampled_neighbors",
  868. [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
  869. gnn::NodeType neg_neighbor_type) {
  870. std::shared_ptr<Tensor> out;
  871. THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
  872. return out;
  873. })
  874. .def("get_node_feature",
  875. [](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
  876. TensorRow out;
  877. THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
  878. return out.getRow();
  879. })
  880. .def("get_edge_feature",
  881. [](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
  882. TensorRow out;
  883. THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
  884. return out.getRow();
  885. })
  886. .def("graph_info",
  887. [](gnn::Graph &g) {
  888. py::dict out;
  889. THROW_IF_ERROR(g.GraphInfo(&out));
  890. return out;
  891. })
  892. .def("random_walk", [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
  893. float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
  894. std::shared_ptr<Tensor> out;
  895. THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
  896. return out;
  897. });
  898. }
  899. void bindRandomTransformTensorOps(py::module *m) {
  900. (void)py::class_<ComposeOp, TensorOp, std::shared_ptr<ComposeOp>>(*m, "ComposeOp")
  901. .def(py::init([](const py::list &ops) {
  902. std::vector<std::shared_ptr<TensorOp>> t_ops;
  903. THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
  904. return std::make_shared<ComposeOp>(t_ops);
  905. }));
  906. (void)py::class_<RandomChoiceOp, TensorOp, std::shared_ptr<RandomChoiceOp>>(*m, "RandomChoiceOp")
  907. .def(py::init([](const py::list &ops) {
  908. std::vector<std::shared_ptr<TensorOp>> t_ops;
  909. THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
  910. return std::make_shared<RandomChoiceOp>(t_ops);
  911. }));
  912. (void)py::class_<RandomApplyOp, TensorOp, std::shared_ptr<RandomApplyOp>>(*m, "RandomApplyOp")
  913. .def(py::init([](double prob, const py::list &ops) {
  914. std::vector<std::shared_ptr<TensorOp>> t_ops;
  915. THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
  916. if (prob < 0 || prob > 1) {
  917. THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be within [0,1]."));
  918. }
  919. return std::make_shared<RandomApplyOp>(prob, t_ops);
  920. }));
  921. (void)py::class_<RandomSelectSubpolicyOp, TensorOp, std::shared_ptr<RandomSelectSubpolicyOp>>(
  922. *m, "RandomSelectSubpolicyOp")
  923. .def(py::init([](const py::list &py_policy) {
  924. std::vector<Subpolicy> cpp_policy;
  925. for (auto &py_sub : py_policy) {
  926. cpp_policy.push_back({});
  927. for (auto handle : py_sub.cast<py::list>()) {
  928. py::tuple tp = handle.cast<py::tuple>();
  929. if (tp.is_none() || tp.size() != 2) {
  930. THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "Each tuple in subpolicy should be (op, prob)."));
  931. }
  932. std::shared_ptr<TensorOp> t_op;
  933. if (py::isinstance<TensorOp>(tp[0])) {
  934. t_op = (tp[0]).cast<std::shared_ptr<TensorOp>>();
  935. } else if (py::isinstance<py::function>(tp[0])) {
  936. t_op = std::make_shared<PyFuncOp>((tp[0]).cast<py::function>());
  937. } else {
  938. THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "op is neither a tensorOp nor a pyfunc."));
  939. }
  940. double prob = (tp[1]).cast<py::float_>();
  941. if (prob < 0 || prob > 1) {
  942. THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be with [0,1]."));
  943. }
  944. cpp_policy.back().emplace_back(std::make_pair(t_op, prob));
  945. }
  946. }
  947. return std::make_shared<RandomSelectSubpolicyOp>(cpp_policy);
  948. }));
  949. }
  950. // This is where we externalize the C logic as python modules
  951. PYBIND11_MODULE(_c_dataengine, m) {
  952. m.doc() = "pybind11 for _c_dataengine";
  953. (void)py::class_<DatasetOp, std::shared_ptr<DatasetOp>>(m, "DatasetOp");
  954. (void)py::enum_<OpName>(m, "OpName", py::arithmetic())
  955. .value("SHUFFLE", OpName::kShuffle)
  956. .value("BATCH", OpName::kBatch)
  957. .value("BUCKETBATCH", OpName::kBucketBatch)
  958. .value("BARRIER", OpName::kBarrier)
  959. .value("MINDRECORD", OpName::kMindrecord)
  960. .value("CACHE", OpName::kCache)
  961. .value("REPEAT", OpName::kRepeat)
  962. .value("SKIP", OpName::kSkip)
  963. .value("TAKE", OpName::kTake)
  964. .value("ZIP", OpName::kZip)
  965. .value("CONCAT", OpName::kConcat)
  966. .value("MAP", OpName::kMap)
  967. .value("FILTER", OpName::kFilter)
  968. .value("DEVICEQUEUE", OpName::kDeviceQueue)
  969. .value("GENERATOR", OpName::kGenerator)
  970. .export_values()
  971. .value("RENAME", OpName::kRename)
  972. .value("TFREADER", OpName::kTfReader)
  973. .value("PROJECT", OpName::kProject)
  974. .value("IMAGEFOLDER", OpName::kImageFolder)
  975. .value("MNIST", OpName::kMnist)
  976. .value("MANIFEST", OpName::kManifest)
  977. .value("VOC", OpName::kVoc)
  978. .value("COCO", OpName::kCoco)
  979. .value("CIFAR10", OpName::kCifar10)
  980. .value("CIFAR100", OpName::kCifar100)
  981. .value("RANDOMDATA", OpName::kRandomData)
  982. .value("BUILDVOCAB", OpName::kBuildVocab)
  983. .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab)
  984. .value("CELEBA", OpName::kCelebA)
  985. .value("TEXTFILE", OpName::kTextFile)
  986. .value("EPOCHCTRL", OpName::kEpochCtrl)
  987. .value("CSV", OpName::kCsv)
  988. .value("CLUE", OpName::kClue);
  989. (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
  990. .value("DE_JIEBA_MIX", JiebaMode::kMix)
  991. .value("DE_JIEBA_MP", JiebaMode::kMp)
  992. .value("DE_JIEBA_HMM", JiebaMode::kHmm)
  993. .export_values();
  994. #ifdef ENABLE_ICU4C
  995. (void)py::enum_<NormalizeForm>(m, "NormalizeForm", py::arithmetic())
  996. .value("DE_NORMALIZE_NONE", NormalizeForm::kNone)
  997. .value("DE_NORMALIZE_NFC", NormalizeForm::kNfc)
  998. .value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc)
  999. .value("DE_NORMALIZE_NFD", NormalizeForm::kNfd)
  1000. .value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd)
  1001. .export_values();
  1002. #endif
  1003. (void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
  1004. .value("DE_INTER_LINEAR", InterpolationMode::kLinear)
  1005. .value("DE_INTER_CUBIC", InterpolationMode::kCubic)
  1006. .value("DE_INTER_AREA", InterpolationMode::kArea)
  1007. .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour)
  1008. .export_values();
  1009. (void)py::enum_<BorderType>(m, "BorderType", py::arithmetic())
  1010. .value("DE_BORDER_CONSTANT", BorderType::kConstant)
  1011. .value("DE_BORDER_EDGE", BorderType::kEdge)
  1012. .value("DE_BORDER_REFLECT", BorderType::kReflect)
  1013. .value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric)
  1014. .export_values();
  1015. (void)py::enum_<SentencePieceModel>(m, "SentencePieceModel", py::arithmetic())
  1016. .value("DE_SENTENCE_PIECE_UNIGRAM", SentencePieceModel::kUnigram)
  1017. .value("DE_SENTENCE_PIECE_BPE", SentencePieceModel::kBpe)
  1018. .value("DE_SENTENCE_PIECE_CHAR", SentencePieceModel::kChar)
  1019. .value("DE_SENTENCE_PIECE_WORD", SentencePieceModel::kWord)
  1020. .export_values();
  1021. (void)py::enum_<SPieceTokenizerOutType>(m, "SPieceTokenizerOutType", py::arithmetic())
  1022. .value("DE_SPIECE_TOKENIZER_OUTTYPE_KString", SPieceTokenizerOutType::kString)
  1023. .value("DE_SPIECE_TOKENIZER_OUTTYPE_KINT", SPieceTokenizerOutType::kInt)
  1024. .export_values();
  1025. (void)py::enum_<SPieceTokenizerLoadType>(m, "SPieceTokenizerLoadType", py::arithmetic())
  1026. .value("DE_SPIECE_TOKENIZER_LOAD_KFILE", SPieceTokenizerLoadType::kFile)
  1027. .value("DE_SPIECE_TOKENIZER_LOAD_KMODEL", SPieceTokenizerLoadType::kModel)
  1028. .export_values();
  1029. bindDEPipeline(&m);
  1030. bindTensor(&m);
  1031. bindTensorOps1(&m);
  1032. bindTensorOps2(&m);
  1033. bindTensorOps3(&m);
  1034. bindTensorOps4(&m);
  1035. bindTokenizerOps(&m);
  1036. bindSamplerOps(&m);
  1037. bindDatasetOps(&m);
  1038. bindInfoObjects(&m);
  1039. bindCacheClient(&m);
  1040. bindVocabObjects(&m);
  1041. bindGraphData(&m);
  1042. bindDependIcuTokenizerOps(&m);
  1043. bindRandomTransformTensorOps(&m);
  1044. }
  1045. } // namespace dataset
  1046. } // namespace mindspore