You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_bindings.cc 52 kB

6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <exception>
  17. #include "minddata/dataset/api/de_pipeline.h"
  18. #include "minddata/dataset/engine/datasetops/source/cifar_op.h"
  19. #include "minddata/dataset/engine/datasetops/source/clue_op.h"
  20. #include "minddata/dataset/engine/datasetops/source/coco_op.h"
  21. #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
  22. #include "minddata/dataset/engine/datasetops/source/io_block.h"
  23. #include "minddata/dataset/engine/datasetops/source/manifest_op.h"
  24. #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
  25. #include "minddata/dataset/engine/datasetops/source/mnist_op.h"
  26. #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
  27. #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
  28. #include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
  29. #include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h"
  30. #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
  31. #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
  32. #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
  33. #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
  34. #include "minddata/dataset/engine/datasetops/source/text_file_op.h"
  35. #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
  36. #include "minddata/dataset/engine/datasetops/source/voc_op.h"
  37. #include "minddata/dataset/engine/cache/cache_client.h"
  38. #include "minddata/dataset/engine/gnn/graph.h"
  39. #include "minddata/dataset/engine/jagged_connector.h"
  40. #include "minddata/dataset/kernels/data/concatenate_op.h"
  41. #include "minddata/dataset/kernels/data/duplicate_op.h"
  42. #include "minddata/dataset/kernels/data/fill_op.h"
  43. #include "minddata/dataset/kernels/data/mask_op.h"
  44. #include "minddata/dataset/kernels/data/one_hot_op.h"
  45. #include "minddata/dataset/kernels/data/pad_end_op.h"
  46. #include "minddata/dataset/kernels/data/slice_op.h"
  47. #include "minddata/dataset/kernels/data/to_float16_op.h"
  48. #include "minddata/dataset/kernels/data/type_cast_op.h"
  49. #include "minddata/dataset/kernels/image/bounding_box_augment_op.h"
  50. #include "minddata/dataset/kernels/image/center_crop_op.h"
  51. #include "minddata/dataset/kernels/image/cut_out_op.h"
  52. #include "minddata/dataset/kernels/image/decode_op.h"
  53. #include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
  54. #include "minddata/dataset/kernels/image/image_utils.h"
  55. #include "minddata/dataset/kernels/image/normalize_op.h"
  56. #include "minddata/dataset/kernels/image/pad_op.h"
  57. #include "minddata/dataset/kernels/image/random_color_adjust_op.h"
  58. #include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
  59. #include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h"
  60. #include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h"
  61. #include "minddata/dataset/kernels/image/random_crop_op.h"
  62. #include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h"
  63. #include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h"
  64. #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
  65. #include "minddata/dataset/kernels/image/random_resize_op.h"
  66. #include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h"
  67. #include "minddata/dataset/kernels/image/random_rotation_op.h"
  68. #include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
  69. #include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h"
  70. #include "minddata/dataset/kernels/image/rescale_op.h"
  71. #include "minddata/dataset/kernels/image/resize_bilinear_op.h"
  72. #include "minddata/dataset/kernels/image/resize_op.h"
  73. #include "minddata/dataset/kernels/image/resize_with_bbox_op.h"
  74. #include "minddata/dataset/kernels/image/uniform_aug_op.h"
  75. #include "minddata/dataset/kernels/no_op.h"
  76. #include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
  77. #include "minddata/dataset/text/kernels/lookup_op.h"
  78. #include "minddata/dataset/text/kernels/ngram_op.h"
  79. #include "minddata/dataset/text/kernels/sliding_window_op.h"
  80. #include "minddata/dataset/text/kernels/to_number_op.h"
  81. #include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h"
  82. #include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h"
  83. #include "minddata/dataset/text/vocab.h"
  84. #include "minddata/dataset/util/random.h"
  85. #include "minddata/mindrecord/include/shard_distributed_sample.h"
  86. #include "minddata/mindrecord/include/shard_operator.h"
  87. #include "minddata/mindrecord/include/shard_pk_sample.h"
  88. #include "minddata/mindrecord/include/shard_sample.h"
  89. #include "minddata/mindrecord/include/shard_sequential_sample.h"
  90. #include "mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
  91. #include "pybind11/pybind11.h"
  92. #include "pybind11/stl.h"
  93. #include "pybind11/stl_bind.h"
  94. #ifdef ENABLE_ICU4C
  95. #include "minddata/dataset/text/kernels/basic_tokenizer_op.h"
  96. #include "minddata/dataset/text/kernels/bert_tokenizer_op.h"
  97. #include "minddata/dataset/text/kernels/case_fold_op.h"
  98. #include "minddata/dataset/text/kernels/normalize_utf8_op.h"
  99. #include "minddata/dataset/text/kernels/regex_replace_op.h"
  100. #include "minddata/dataset/text/kernels/regex_tokenizer_op.h"
  101. #include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h"
  102. #include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h"
  103. #endif
  104. namespace py = pybind11;
  105. namespace mindspore {
  106. namespace dataset {
  107. #define THROW_IF_ERROR(s) \
  108. do { \
  109. Status rc = std::move(s); \
  110. if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
  111. } while (false)
  112. void bindDEPipeline(py::module *m) {
  113. (void)py::class_<DEPipeline>(*m, "DEPipeline")
  114. .def(py::init<>())
  115. .def(
  116. "AddNodeToTree",
  117. [](DEPipeline &de, const OpName &op_name, const py::dict &args) {
  118. py::dict out;
  119. THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out));
  120. return out;
  121. },
  122. py::return_value_policy::reference)
  123. .def_static("AddChildToParentNode",
  124. [](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
  125. THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
  126. })
  127. .def("AssignRootNode",
  128. [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
  129. .def("SetBatchParameters",
  130. [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
  131. .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
  132. .def("GetNextAsMap",
  133. [](DEPipeline &de) {
  134. py::dict out;
  135. THROW_IF_ERROR(de.GetNextAsMap(&out));
  136. return out;
  137. })
  138. .def("GetNextAsList",
  139. [](DEPipeline &de) {
  140. py::list out;
  141. THROW_IF_ERROR(de.GetNextAsList(&out));
  142. return out;
  143. })
  144. .def("GetOutputShapes",
  145. [](DEPipeline &de) {
  146. py::list out;
  147. THROW_IF_ERROR(de.GetOutputShapes(&out));
  148. return out;
  149. })
  150. .def("GetOutputTypes",
  151. [](DEPipeline &de) {
  152. py::list out;
  153. THROW_IF_ERROR(de.GetOutputTypes(&out));
  154. return out;
  155. })
  156. .def("GetDatasetSize", &DEPipeline::GetDatasetSize)
  157. .def("GetBatchSize", &DEPipeline::GetBatchSize)
  158. .def("GetNumClasses", &DEPipeline::GetNumClasses)
  159. .def("GetRepeatCount", &DEPipeline::GetRepeatCount);
  160. }
  161. void bindDatasetOps(py::module *m) {
  162. (void)py::class_<TFReaderOp, DatasetOp, std::shared_ptr<TFReaderOp>>(*m, "TFReaderOp")
  163. .def_static("get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) {
  164. int64_t count = 0;
  165. std::vector<std::string> filenames;
  166. for (auto l : files) {
  167. !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back("");
  168. }
  169. THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate));
  170. return count;
  171. });
  172. (void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
  173. .def_static("get_num_rows", [](const std::string &dir, bool isCifar10) {
  174. int64_t count = 0;
  175. THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count));
  176. return count;
  177. });
  178. (void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
  179. .def_static("get_num_rows_and_classes", [](const std::string &path) {
  180. int64_t count = 0, num_classes = 0;
  181. THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set<std::string>{}, &count, &num_classes));
  182. return py::make_tuple(count, num_classes);
  183. });
  184. (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
  185. .def_static("get_num_rows", [](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler,
  186. const int64_t num_padded) {
  187. int64_t count = 0;
  188. std::shared_ptr<mindrecord::ShardOperator> op;
  189. if (py::hasattr(sampler, "create_for_minddataset")) {
  190. auto create = sampler.attr("create_for_minddataset");
  191. op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
  192. }
  193. THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded));
  194. return count;
  195. });
  196. (void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
  197. .def_static("get_num_rows_and_classes",
  198. [](const std::string &file, const py::dict &dict, const std::string &usage) {
  199. int64_t count = 0, num_classes = 0;
  200. THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes));
  201. return py::make_tuple(count, num_classes);
  202. })
  203. .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) {
  204. std::map<std::string, int32_t> output_class_indexing;
  205. THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing));
  206. return output_class_indexing;
  207. });
  208. (void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
  209. .def_static("get_num_rows", [](const std::string &dir) {
  210. int64_t count = 0;
  211. THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count));
  212. return count;
  213. });
  214. (void)py::class_<TextFileOp, DatasetOp, std::shared_ptr<TextFileOp>>(*m, "TextFileOp")
  215. .def_static("get_num_rows", [](const py::list &files) {
  216. int64_t count = 0;
  217. std::vector<std::string> filenames;
  218. for (auto file : files) {
  219. !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back("");
  220. }
  221. THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
  222. return count;
  223. });
  224. (void)py::class_<ClueOp, DatasetOp, std::shared_ptr<ClueOp>>(*m, "ClueOp")
  225. .def_static("get_num_rows", [](const py::list &files) {
  226. int64_t count = 0;
  227. std::vector<std::string> filenames;
  228. for (auto file : files) {
  229. file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file));
  230. }
  231. THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count));
  232. return count;
  233. });
  234. (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
  235. .def_static("get_num_rows",
  236. [](const std::string &dir, const std::string &task_type, const std::string &task_mode,
  237. const py::dict &dict, int64_t numSamples) {
  238. int64_t count = 0;
  239. THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count));
  240. return count;
  241. })
  242. .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
  243. const std::string &task_mode, const py::dict &dict) {
  244. std::map<std::string, int32_t> output_class_indexing;
  245. THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing));
  246. return output_class_indexing;
  247. });
  248. (void)py::class_<CocoOp, DatasetOp, std::shared_ptr<CocoOp>>(*m, "CocoOp")
  249. .def_static("get_class_indexing",
  250. [](const std::string &dir, const std::string &file, const std::string &task) {
  251. std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
  252. THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing));
  253. return output_class_indexing;
  254. })
  255. .def_static("get_num_rows", [](const std::string &dir, const std::string &file, const std::string &task) {
  256. int64_t count = 0;
  257. THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count));
  258. return count;
  259. });
  260. }
  261. void bindTensor(py::module *m) {
  262. (void)py::class_<GlobalContext>(*m, "GlobalContext")
  263. .def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference);
  264. (void)py::class_<ConfigManager, std::shared_ptr<ConfigManager>>(*m, "ConfigManager")
  265. .def("__str__", &ConfigManager::ToString)
  266. .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer)
  267. .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers)
  268. .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
  269. .def("set_op_connector_size", &ConfigManager::set_op_connector_size)
  270. .def("set_seed", &ConfigManager::set_seed)
  271. .def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval)
  272. .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
  273. .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers)
  274. .def("get_worker_connector_size", &ConfigManager::worker_connector_size)
  275. .def("get_op_connector_size", &ConfigManager::op_connector_size)
  276. .def("get_seed", &ConfigManager::seed)
  277. .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
  278. .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
  279. (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
  280. .def(py::init([](py::array arr) {
  281. std::shared_ptr<Tensor> out;
  282. THROW_IF_ERROR(Tensor::CreateTensor(&out, arr));
  283. return out;
  284. }))
  285. .def_buffer([](Tensor &tensor) {
  286. py::buffer_info info;
  287. THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info));
  288. return info;
  289. })
  290. .def("__str__", &Tensor::ToString)
  291. .def("shape", &Tensor::shape)
  292. .def("type", &Tensor::type)
  293. .def("as_array", [](py::object &t) {
  294. auto &tensor = py::cast<Tensor &>(t);
  295. if (tensor.type() == DataType::DE_STRING) {
  296. py::array res;
  297. tensor.GetDataAsNumpyStrings(&res);
  298. return res;
  299. }
  300. py::buffer_info info;
  301. THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info));
  302. return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t);
  303. });
  304. (void)py::class_<TensorShape>(*m, "TensorShape")
  305. .def(py::init<py::list>())
  306. .def("__str__", &TensorShape::ToString)
  307. .def("as_list", &TensorShape::AsPyList)
  308. .def("is_known", &TensorShape::known);
  309. (void)py::class_<DataType>(*m, "DataType")
  310. .def(py::init<std::string>())
  311. .def(py::self == py::self)
  312. .def("__str__", &DataType::ToString)
  313. .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
  314. }
  315. void bindTensorOps1(py::module *m) {
  316. (void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
  317. .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
  318. (void)py::class_<NormalizeOp, TensorOp, std::shared_ptr<NormalizeOp>>(
  319. *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.")
  320. .def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"),
  321. py::arg("stdR"), py::arg("stdG"), py::arg("stdB"));
  322. (void)py::class_<RescaleOp, TensorOp, std::shared_ptr<RescaleOp>>(
  323. *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.")
  324. .def(py::init<float, float>(), py::arg("rescale"), py::arg("shift"));
  325. (void)py::class_<CenterCropOp, TensorOp, std::shared_ptr<CenterCropOp>>(
  326. *m, "CenterCropOp", "Tensor operation to crop and image in the middle. Takes height and width (optional)")
  327. .def(py::init<int32_t, int32_t>(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth);
  328. (void)py::class_<ResizeOp, TensorOp, std::shared_ptr<ResizeOp>>(
  329. *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode")
  330. .def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"),
  331. py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation);
  332. (void)py::class_<ResizeWithBBoxOp, TensorOp, std::shared_ptr<ResizeWithBBoxOp>>(
  333. *m, "ResizeWithBBoxOp", "Tensor operation to resize an image. Takes height, width and mode.")
  334. .def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"),
  335. py::arg("targetWidth") = ResizeWithBBoxOp::kDefWidth,
  336. py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation);
  337. (void)py::class_<RandomResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomResizeWithBBoxOp>>(
  338. *m, "RandomResizeWithBBoxOp",
  339. "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.")
  340. .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"),
  341. py::arg("targetWidth") = RandomResizeWithBBoxOp::kDefTargetWidth);
  342. (void)py::class_<UniformAugOp, TensorOp, std::shared_ptr<UniformAugOp>>(
  343. *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).")
  344. .def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("operations"),
  345. py::arg("NumOps") = UniformAugOp::kDefNumOps);
  346. (void)py::class_<BoundingBoxAugmentOp, TensorOp, std::shared_ptr<BoundingBoxAugmentOp>>(
  347. *m, "BoundingBoxAugmentOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.")
  348. .def(py::init<std::shared_ptr<TensorOp>, float>(), py::arg("transform"),
  349. py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio);
  350. (void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
  351. *m, "ResizeBilinearOp",
  352. "Tensor operation to resize an image using "
  353. "Bilinear mode. Takes height and width.")
  354. .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeBilinearOp::kDefWidth);
  355. (void)py::class_<DecodeOp, TensorOp, std::shared_ptr<DecodeOp>>(*m, "DecodeOp",
  356. "Tensor operation to decode a jpg image")
  357. .def(py::init<>())
  358. .def(py::init<bool>(), py::arg("rgb_format") = DecodeOp::kDefRgbFormat);
  359. (void)py::class_<RandomHorizontalFlipOp, TensorOp, std::shared_ptr<RandomHorizontalFlipOp>>(
  360. *m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.")
  361. .def(py::init<float>(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability);
  362. (void)py::class_<RandomHorizontalFlipWithBBoxOp, TensorOp, std::shared_ptr<RandomHorizontalFlipWithBBoxOp>>(
  363. *m, "RandomHorizontalFlipWithBBoxOp",
  364. "Tensor operation to randomly flip an image horizontally, while flipping bounding boxes.")
  365. .def(py::init<float>(), py::arg("probability") = RandomHorizontalFlipWithBBoxOp::kDefProbability);
  366. }
  367. void bindTensorOps2(py::module *m) {
  368. (void)py::class_<RandomVerticalFlipOp, TensorOp, std::shared_ptr<RandomVerticalFlipOp>>(
  369. *m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.")
  370. .def(py::init<float>(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability);
  371. (void)py::class_<RandomVerticalFlipWithBBoxOp, TensorOp, std::shared_ptr<RandomVerticalFlipWithBBoxOp>>(
  372. *m, "RandomVerticalFlipWithBBoxOp",
  373. "Tensor operation to randomly flip an image vertically"
  374. " and adjust bounding boxes.")
  375. .def(py::init<float>(), py::arg("probability") = RandomVerticalFlipWithBBoxOp::kDefProbability);
  376. (void)py::class_<RandomCropOp, TensorOp, std::shared_ptr<RandomCropOp>>(*m, "RandomCropOp",
  377. "Gives random crop of specified size "
  378. "Takes crop size")
  379. .def(py::init<int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, BorderType, bool, uint8_t, uint8_t, uint8_t>(),
  380. py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropOp::kDefPadTop,
  381. py::arg("padBottom") = RandomCropOp::kDefPadBottom, py::arg("padLeft") = RandomCropOp::kDefPadLeft,
  382. py::arg("padRight") = RandomCropOp::kDefPadRight, py::arg("borderType") = RandomCropOp::kDefBorderType,
  383. py::arg("padIfNeeded") = RandomCropOp::kDefPadIfNeeded, py::arg("fillR") = RandomCropOp::kDefFillR,
  384. py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB);
  385. (void)py::class_<HwcToChwOp, TensorOp, std::shared_ptr<HwcToChwOp>>(*m, "ChannelSwapOp").def(py::init<>());
  386. (void)py::class_<RandomCropWithBBoxOp, TensorOp, std::shared_ptr<RandomCropWithBBoxOp>>(*m, "RandomCropWithBBoxOp",
  387. "Gives random crop of given "
  388. "size + adjusts bboxes "
  389. "Takes crop size")
  390. .def(py::init<int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, BorderType, bool, uint8_t, uint8_t, uint8_t>(),
  391. py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropWithBBoxOp::kDefPadTop,
  392. py::arg("padBottom") = RandomCropWithBBoxOp::kDefPadBottom,
  393. py::arg("padLeft") = RandomCropWithBBoxOp::kDefPadLeft,
  394. py::arg("padRight") = RandomCropWithBBoxOp::kDefPadRight,
  395. py::arg("borderType") = RandomCropWithBBoxOp::kDefBorderType,
  396. py::arg("padIfNeeded") = RandomCropWithBBoxOp::kDefPadIfNeeded,
  397. py::arg("fillR") = RandomCropWithBBoxOp::kDefFillR, py::arg("fillG") = RandomCropWithBBoxOp::kDefFillG,
  398. py::arg("fillB") = RandomCropWithBBoxOp::kDefFillB);
  399. (void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(
  400. *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.")
  401. .def(py::init<int32_t>());
  402. (void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(
  403. *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.")
  404. .def(py::init<std::shared_ptr<Tensor>>());
  405. (void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "Tensor slice operation.")
  406. .def(py::init<bool>())
  407. .def(py::init([](const py::list &py_list) {
  408. std::vector<dsize_t> c_list;
  409. for (auto l : py_list) {
  410. if (!l.is_none()) {
  411. c_list.push_back(py::reinterpret_borrow<py::int_>(l));
  412. }
  413. }
  414. return std::make_shared<SliceOp>(c_list);
  415. }))
  416. .def(py::init([](const py::tuple &py_slice) {
  417. if (py_slice.size() != 3) {
  418. THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
  419. }
  420. Slice c_slice;
  421. if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) {
  422. c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1]),
  423. py::reinterpret_borrow<py::int_>(py_slice[2]));
  424. } else if (py_slice[0].is_none() && py_slice[2].is_none()) {
  425. c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[1]));
  426. } else if (!py_slice[0].is_none() && !py_slice[1].is_none()) {
  427. c_slice = Slice(py::reinterpret_borrow<py::int_>(py_slice[0]), py::reinterpret_borrow<py::int_>(py_slice[1]));
  428. }
  429. if (!c_slice.valid()) {
  430. THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object"));
  431. }
  432. return std::make_shared<SliceOp>(c_slice);
  433. }));
  434. (void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic())
  435. .value("EQ", RelationalOp::kEqual)
  436. .value("NE", RelationalOp::kNotEqual)
  437. .value("LT", RelationalOp::kLess)
  438. .value("LE", RelationalOp::kLessEqual)
  439. .value("GT", RelationalOp::kGreater)
  440. .value("GE", RelationalOp::kGreaterEqual)
  441. .export_values();
  442. (void)py::class_<MaskOp, TensorOp, std::shared_ptr<MaskOp>>(*m, "MaskOp",
  443. "Tensor mask operation using relational comparator")
  444. .def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>());
  445. (void)py::class_<DuplicateOp, TensorOp, std::shared_ptr<DuplicateOp>>(*m, "DuplicateOp", "Duplicate tensor.")
  446. .def(py::init<>());
  447. (void)py::class_<TruncateSequencePairOp, TensorOp, std::shared_ptr<TruncateSequencePairOp>>(
  448. *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length")
  449. .def(py::init<int64_t>());
  450. (void)py::class_<ConcatenateOp, TensorOp, std::shared_ptr<ConcatenateOp>>(*m, "ConcatenateOp",
  451. "Tensor operation concatenate tensors.")
  452. .def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>(), py::arg("axis"),
  453. py::arg("prepend").none(true), py::arg("append").none(true));
  454. (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
  455. *m, "RandomRotationOp",
  456. "Tensor operation to apply RandomRotation."
  457. "Takes a range for degrees and "
  458. "optional parameters for rotation center and image expand")
  459. .def(py::init<float, float, float, float, InterpolationMode, bool, uint8_t, uint8_t, uint8_t>(),
  460. py::arg("startDegree"), py::arg("endDegree"), py::arg("centerX") = RandomRotationOp::kDefCenterX,
  461. py::arg("centerY") = RandomRotationOp::kDefCenterY,
  462. py::arg("interpolation") = RandomRotationOp::kDefInterpolation,
  463. py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR,
  464. py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
  465. (void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(
  466. *m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.")
  467. .def(py::init<TensorShape, std::shared_ptr<Tensor>>());
  468. }
  469. void bindTensorOps3(py::module *m) {
  470. (void)py::class_<RandomCropAndResizeOp, TensorOp, std::shared_ptr<RandomCropAndResizeOp>>(
  471. *m, "RandomCropAndResizeOp",
  472. "Tensor operation to randomly crop an image and resize to a given size."
  473. "Takes output height and width and"
  474. "optional parameters for lower and upper bound for aspect ratio (h/w) and scale,"
  475. "interpolation mode, and max attempts to crop")
  476. .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
  477. py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeOp::kDefScaleLb,
  478. py::arg("scaleUb") = RandomCropAndResizeOp::kDefScaleUb,
  479. py::arg("aspectLb") = RandomCropAndResizeOp::kDefAspectLb,
  480. py::arg("aspectUb") = RandomCropAndResizeOp::kDefAspectUb,
  481. py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation,
  482. py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter);
  483. (void)py::class_<RandomCropAndResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomCropAndResizeWithBBoxOp>>(
  484. *m, "RandomCropAndResizeWithBBoxOp",
  485. "Tensor operation to randomly crop an image (with BBoxes) and resize to a given size."
  486. "Takes output height and width and"
  487. "optional parameters for lower and upper bound for aspect ratio (h/w) and scale,"
  488. "interpolation mode, and max attempts to crop")
  489. .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
  490. py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeWithBBoxOp::kDefScaleLb,
  491. py::arg("scaleUb") = RandomCropAndResizeWithBBoxOp::kDefScaleUb,
  492. py::arg("aspectLb") = RandomCropAndResizeWithBBoxOp::kDefAspectLb,
  493. py::arg("aspectUb") = RandomCropAndResizeWithBBoxOp::kDefAspectUb,
  494. py::arg("interpolation") = RandomCropAndResizeWithBBoxOp::kDefInterpolation,
  495. py::arg("maxIter") = RandomCropAndResizeWithBBoxOp::kDefMaxIter);
  496. (void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>(
  497. *m, "RandomColorAdjustOp",
  498. "Tensor operation to adjust an image's color randomly."
  499. "Takes range for brightness, contrast, saturation, hue and")
  500. .def(py::init<float, float, float, float, float, float, float, float>(), py::arg("bright_factor_start"),
  501. py::arg("bright_factor_end"), py::arg("contrast_factor_start"), py::arg("contrast_factor_end"),
  502. py::arg("saturation_factor_start"), py::arg("saturation_factor_end"), py::arg("hue_factor_start"),
  503. py::arg("hue_factor_end"));
  504. (void)py::class_<RandomResizeOp, TensorOp, std::shared_ptr<RandomResizeOp>>(
  505. *m, "RandomResizeOp",
  506. "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.")
  507. .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"),
  508. py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth);
  509. (void)py::class_<CutOutOp, TensorOp, std::shared_ptr<CutOutOp>>(
  510. *m, "CutOutOp", "Tensor operation to randomly erase a portion of the image. Takes height and width.")
  511. .def(py::init<int32_t, int32_t, int32_t, bool, uint8_t, uint8_t, uint8_t>(), py::arg("boxHeight"),
  512. py::arg("boxWidth"), py::arg("numPatches"), py::arg("randomColor") = CutOutOp::kDefRandomColor,
  513. py::arg("fillR") = CutOutOp::kDefFillR, py::arg("fillG") = CutOutOp::kDefFillG,
  514. py::arg("fillB") = CutOutOp::kDefFillB);
  515. }
  516. void bindTensorOps4(py::module *m) {
  517. (void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>(
  518. *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.")
  519. .def(py::init<DataType>(), py::arg("data_type"))
  520. .def(py::init<std::string>(), py::arg("data_type"));
  521. (void)py::class_<NoOp, TensorOp, std::shared_ptr<NoOp>>(*m, "NoOp",
  522. "TensorOp that does nothing, for testing purposes only.")
  523. .def(py::init<>());
  524. (void)py::class_<ToFloat16Op, TensorOp, std::shared_ptr<ToFloat16Op>>(
  525. *m, "ToFloat16Op", py::dynamic_attr(), "Tensor operator to type cast float32 data to a float16 type.")
  526. .def(py::init<>());
  527. (void)py::class_<RandomCropDecodeResizeOp, TensorOp, std::shared_ptr<RandomCropDecodeResizeOp>>(
  528. *m, "RandomCropDecodeResizeOp", "equivalent to RandomCropAndResize but crops before decoding")
  529. .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
  530. py::arg("targetWidth"), py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb,
  531. py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb,
  532. py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb,
  533. py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb,
  534. py::arg("interpolation") = RandomCropDecodeResizeOp::kDefInterpolation,
  535. py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter);
  536. (void)py::class_<PadOp, TensorOp, std::shared_ptr<PadOp>>(
  537. *m, "PadOp",
  538. "Pads image with specified color, default black, "
  539. "Takes amount to pad for top, bottom, left, right of image, boarder type and color")
  540. .def(py::init<int32_t, int32_t, int32_t, int32_t, BorderType, uint8_t, uint8_t, uint8_t>(), py::arg("padTop"),
  541. py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType,
  542. py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
  543. (void)py::class_<ToNumberOp, TensorOp, std::shared_ptr<ToNumberOp>>(*m, "ToNumberOp",
  544. "TensorOp to convert strings to numbers.")
  545. .def(py::init<DataType>(), py::arg("data_type"))
  546. .def(py::init<std::string>(), py::arg("data_type"));
  547. }
  548. void bindTokenizerOps(py::module *m) {
  549. (void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(*m, "JiebaTokenizerOp", "")
  550. .def(py::init<const std::string &, const std::string &, const JiebaMode &, const bool &>(), py::arg("hmm_path"),
  551. py::arg("mp_path"), py::arg("mode") = JiebaMode::kMix,
  552. py::arg("with_offsets") = JiebaTokenizerOp::kDefWithOffsets)
  553. .def("add_word",
  554. [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); });
  555. (void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>(
  556. *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.")
  557. .def(py::init<const bool &>(), py::arg("with_offsets") = UnicodeCharTokenizerOp::kDefWithOffsets);
  558. (void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp",
  559. "Tensor operation to LookUp each word.")
  560. .def(py::init([](std::shared_ptr<Vocab> vocab, const py::object &py_word) {
  561. if (vocab == nullptr) {
  562. THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null."));
  563. }
  564. if (py_word.is_none()) {
  565. return std::make_shared<LookupOp>(vocab, Vocab::kNoTokenExists);
  566. }
  567. std::string word = py::reinterpret_borrow<py::str>(py_word);
  568. WordIdType default_id = vocab->Lookup(word);
  569. if (default_id == Vocab::kNoTokenExists) {
  570. THROW_IF_ERROR(
  571. Status(StatusCode::kUnexpectedError, "default unknown token:" + word + " doesn't exist in vocab."));
  572. }
  573. return std::make_shared<LookupOp>(vocab, default_id);
  574. }));
  575. (void)py::class_<NgramOp, TensorOp, std::shared_ptr<NgramOp>>(*m, "NgramOp", "TensorOp performs ngram mapping.")
  576. .def(py::init<const std::vector<int32_t> &, int32_t, int32_t, const std::string &, const std::string &,
  577. const std::string &>(),
  578. py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"),
  579. py::arg("separator"));
  580. (void)py::class_<WordpieceTokenizerOp, TensorOp, std::shared_ptr<WordpieceTokenizerOp>>(
  581. *m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.")
  582. .def(
  583. py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &, const bool &>(),
  584. py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
  585. py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
  586. py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
  587. py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
  588. (void)py::class_<SlidingWindowOp, TensorOp, std::shared_ptr<SlidingWindowOp>>(
  589. *m, "SlidingWindowOp", "TensorOp to apply sliding window to a 1-D Tensor.")
  590. .def(py::init<uint32_t, int32_t>(), py::arg("width"), py::arg("axis"));
  591. }
  592. void bindDependIcuTokenizerOps(py::module *m) {
  593. #ifdef ENABLE_ICU4C
  594. (void)py::class_<WhitespaceTokenizerOp, TensorOp, std::shared_ptr<WhitespaceTokenizerOp>>(
  595. *m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.")
  596. .def(py::init<const bool &>(), py::arg("with_offsets") = WhitespaceTokenizerOp::kDefWithOffsets);
  597. (void)py::class_<UnicodeScriptTokenizerOp, TensorOp, std::shared_ptr<UnicodeScriptTokenizerOp>>(
  598. *m, "UnicodeScriptTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.")
  599. .def(py::init<>())
  600. .def(py::init<const bool &, const bool &>(),
  601. py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace,
  602. py::arg("with_offsets") = UnicodeScriptTokenizerOp::kDefWithOffsets);
  603. (void)py::class_<CaseFoldOp, TensorOp, std::shared_ptr<CaseFoldOp>>(
  604. *m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor")
  605. .def(py::init<>());
  606. (void)py::class_<NormalizeUTF8Op, TensorOp, std::shared_ptr<NormalizeUTF8Op>>(
  607. *m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.")
  608. .def(py::init<>())
  609. .def(py::init<NormalizeForm>(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm);
  610. (void)py::class_<RegexReplaceOp, TensorOp, std::shared_ptr<RegexReplaceOp>>(
  611. *m, "RegexReplaceOp", "Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.")
  612. .def(py::init<const std::string &, const std::string &, bool>(), py::arg("pattern"), py::arg("replace"),
  613. py::arg("replace_all"));
  614. (void)py::class_<RegexTokenizerOp, TensorOp, std::shared_ptr<RegexTokenizerOp>>(
  615. *m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.")
  616. .def(py::init<const std::string &, const std::string &, const bool &>(), py::arg("delim_pattern"),
  617. py::arg("keep_delim_pattern"), py::arg("with_offsets") = RegexTokenizerOp::kDefWithOffsets);
  618. (void)py::class_<BasicTokenizerOp, TensorOp, std::shared_ptr<BasicTokenizerOp>>(
  619. *m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.")
  620. .def(py::init<const bool &, const bool &, const NormalizeForm &, const bool &, const bool &>(),
  621. py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
  622. py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
  623. py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
  624. py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken,
  625. py::arg("with_offsets") = BasicTokenizerOp::kDefWithOffsets);
  626. (void)py::class_<BertTokenizerOp, TensorOp, std::shared_ptr<BertTokenizerOp>>(*m, "BertTokenizerOp",
  627. "Tokenizer used for Bert text process.")
  628. .def(py::init<const std::shared_ptr<Vocab> &, const std::string &, const int &, const std::string &, const bool &,
  629. const bool &, const NormalizeForm &, const bool &, const bool &>(),
  630. py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator),
  631. py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken,
  632. py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken),
  633. py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase,
  634. py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace,
  635. py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm,
  636. py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken,
  637. py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets);
  638. #endif
  639. }
  640. void bindSamplerOps(py::module *m) {
  641. (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
  642. .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
  643. .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
  644. .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
  645. .def("get_indices",
  646. [](Sampler &self) {
  647. py::array ret;
  648. THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
  649. return ret;
  650. })
  651. .def("add_child",
  652. [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
  653. (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator")
  654. .def("add_child", [](std::shared_ptr<mindrecord::ShardOperator> self,
  655. std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
  656. (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
  657. .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>());
  658. (void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
  659. .def(py::init<int64_t, int64_t, bool>());
  660. (void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
  661. .def(py::init<int64_t, bool, bool>());
  662. (void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
  663. .def(py::init<int64_t, int64_t>());
  664. (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
  665. .def(py::init<int64_t, std::vector<int64_t>>());
  666. (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
  667. *m, "MindrecordSubsetRandomSampler")
  668. .def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
  669. (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
  670. *m, "MindrecordPkSampler")
  671. .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
  672. if (shuffle == true) {
  673. return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
  674. GetSeed());
  675. } else {
  676. return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
  677. }
  678. }));
  679. (void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
  680. std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler")
  681. .def(py::init<int64_t, int64_t, bool, uint32_t>());
  682. (void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
  683. *m, "MindrecordRandomSampler")
  684. .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) {
  685. return std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples, replacement, reshuffle_each_epoch);
  686. }));
  687. (void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
  688. std::shared_ptr<mindrecord::ShardSequentialSample>>(*m, "MindrecordSequentialSampler")
  689. .def(py::init([](int num_samples, int start_index) {
  690. return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
  691. }));
  692. (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
  693. .def(py::init<int64_t, std::vector<double>, bool>());
  694. (void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
  695. .def(py::init<int64_t, py::object>());
  696. }
  697. void bindInfoObjects(py::module *m) {
  698. (void)py::class_<BatchOp::CBatchInfo>(*m, "CBatchInfo")
  699. .def(py::init<int64_t, int64_t, int64_t>())
  700. .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num)
  701. .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
  702. }
  703. void bindCacheClient(py::module *m) {
  704. (void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
  705. .def(py::init<uint32_t, uint64_t, bool>());
  706. }
  707. void bindVocabObjects(py::module *m) {
  708. (void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
  709. .def(py::init<>())
  710. .def_static("from_list",
  711. [](const py::list &words, const py::list &special_tokens, bool special_first) {
  712. std::shared_ptr<Vocab> v;
  713. THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v));
  714. return v;
  715. })
  716. .def_static("from_file",
  717. [](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens,
  718. bool special_first) {
  719. std::shared_ptr<Vocab> v;
  720. THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v));
  721. return v;
  722. })
  723. .def_static("from_dict", [](const py::dict &words) {
  724. std::shared_ptr<Vocab> v;
  725. THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v));
  726. return v;
  727. });
  728. }
  729. void bindGraphData(py::module *m) {
  730. (void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph")
  731. .def(py::init([](std::string dataset_file, int32_t num_workers) {
  732. std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers);
  733. THROW_IF_ERROR(g_out->Init());
  734. return g_out;
  735. }))
  736. .def("get_all_nodes",
  737. [](gnn::Graph &g, gnn::NodeType node_type) {
  738. std::shared_ptr<Tensor> out;
  739. THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
  740. return out;
  741. })
  742. .def("get_all_edges",
  743. [](gnn::Graph &g, gnn::EdgeType edge_type) {
  744. std::shared_ptr<Tensor> out;
  745. THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
  746. return out;
  747. })
  748. .def("get_nodes_from_edges",
  749. [](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) {
  750. std::shared_ptr<Tensor> out;
  751. THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
  752. return out;
  753. })
  754. .def("get_all_neighbors",
  755. [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
  756. std::shared_ptr<Tensor> out;
  757. THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
  758. return out;
  759. })
  760. .def("get_sampled_neighbors",
  761. [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
  762. std::vector<gnn::NodeType> neighbor_types) {
  763. std::shared_ptr<Tensor> out;
  764. THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
  765. return out;
  766. })
  767. .def("get_neg_sampled_neighbors",
  768. [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
  769. gnn::NodeType neg_neighbor_type) {
  770. std::shared_ptr<Tensor> out;
  771. THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
  772. return out;
  773. })
  774. .def("get_node_feature",
  775. [](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
  776. TensorRow out;
  777. THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
  778. return out.getRow();
  779. })
  780. .def("get_edge_feature",
  781. [](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
  782. TensorRow out;
  783. THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
  784. return out.getRow();
  785. })
  786. .def("graph_info",
  787. [](gnn::Graph &g) {
  788. py::dict out;
  789. THROW_IF_ERROR(g.GraphInfo(&out));
  790. return out;
  791. })
  792. .def("random_walk", [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
  793. float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
  794. std::shared_ptr<Tensor> out;
  795. THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
  796. return out;
  797. });
  798. }
  799. // This is where we externalize the C logic as python modules
  800. PYBIND11_MODULE(_c_dataengine, m) {
  801. m.doc() = "pybind11 for _c_dataengine";
  802. (void)py::class_<DatasetOp, std::shared_ptr<DatasetOp>>(m, "DatasetOp");
  803. (void)py::enum_<OpName>(m, "OpName", py::arithmetic())
  804. .value("SHUFFLE", OpName::kShuffle)
  805. .value("BATCH", OpName::kBatch)
  806. .value("BUCKETBATCH", OpName::kBucketBatch)
  807. .value("BARRIER", OpName::kBarrier)
  808. .value("MINDRECORD", OpName::kMindrecord)
  809. .value("CACHE", OpName::kCache)
  810. .value("REPEAT", OpName::kRepeat)
  811. .value("SKIP", OpName::kSkip)
  812. .value("TAKE", OpName::kTake)
  813. .value("ZIP", OpName::kZip)
  814. .value("CONCAT", OpName::kConcat)
  815. .value("MAP", OpName::kMap)
  816. .value("FILTER", OpName::kFilter)
  817. .value("DEVICEQUEUE", OpName::kDeviceQueue)
  818. .value("GENERATOR", OpName::kGenerator)
  819. .export_values()
  820. .value("RENAME", OpName::kRename)
  821. .value("TFREADER", OpName::kTfReader)
  822. .value("PROJECT", OpName::kProject)
  823. .value("IMAGEFOLDER", OpName::kImageFolder)
  824. .value("MNIST", OpName::kMnist)
  825. .value("MANIFEST", OpName::kManifest)
  826. .value("VOC", OpName::kVoc)
  827. .value("COCO", OpName::kCoco)
  828. .value("CIFAR10", OpName::kCifar10)
  829. .value("CIFAR100", OpName::kCifar100)
  830. .value("RANDOMDATA", OpName::kRandomData)
  831. .value("BUILDVOCAB", OpName::kBuildVocab)
  832. .value("CELEBA", OpName::kCelebA)
  833. .value("TEXTFILE", OpName::kTextFile)
  834. .value("CLUE", OpName::kClue);
  835. (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
  836. .value("DE_JIEBA_MIX", JiebaMode::kMix)
  837. .value("DE_JIEBA_MP", JiebaMode::kMp)
  838. .value("DE_JIEBA_HMM", JiebaMode::kHmm)
  839. .export_values();
  840. #ifdef ENABLE_ICU4C
  841. (void)py::enum_<NormalizeForm>(m, "NormalizeForm", py::arithmetic())
  842. .value("DE_NORMALIZE_NONE", NormalizeForm::kNone)
  843. .value("DE_NORMALIZE_NFC", NormalizeForm::kNfc)
  844. .value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc)
  845. .value("DE_NORMALIZE_NFD", NormalizeForm::kNfd)
  846. .value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd)
  847. .export_values();
  848. #endif
  849. (void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
  850. .value("DE_INTER_LINEAR", InterpolationMode::kLinear)
  851. .value("DE_INTER_CUBIC", InterpolationMode::kCubic)
  852. .value("DE_INTER_AREA", InterpolationMode::kArea)
  853. .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour)
  854. .export_values();
  855. (void)py::enum_<BorderType>(m, "BorderType", py::arithmetic())
  856. .value("DE_BORDER_CONSTANT", BorderType::kConstant)
  857. .value("DE_BORDER_EDGE", BorderType::kEdge)
  858. .value("DE_BORDER_REFLECT", BorderType::kReflect)
  859. .value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric)
  860. .export_values();
  861. bindDEPipeline(&m);
  862. bindTensor(&m);
  863. bindTensorOps1(&m);
  864. bindTensorOps2(&m);
  865. bindTensorOps3(&m);
  866. bindTensorOps4(&m);
  867. bindTokenizerOps(&m);
  868. bindSamplerOps(&m);
  869. bindDatasetOps(&m);
  870. bindInfoObjects(&m);
  871. bindCacheClient(&m);
  872. bindVocabObjects(&m);
  873. bindGraphData(&m);
  874. bindDependIcuTokenizerOps(&m);
  875. }
  876. } // namespace dataset
  877. } // namespace mindspore