You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_bindings.cc 31 kB

5 years ago
5 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <exception>
  17. #include "dataset/api/de_pipeline.h"
  18. #include "dataset/kernels/no_op.h"
  19. #include "dataset/kernels/data/one_hot_op.h"
  20. #include "dataset/kernels/image/center_crop_op.h"
  21. #include "dataset/kernels/image/cut_out_op.h"
  22. #include "dataset/kernels/image/decode_op.h"
  23. #include "dataset/kernels/image/hwc_to_chw_op.h"
  24. #include "dataset/kernels/image/image_utils.h"
  25. #include "dataset/kernels/image/normalize_op.h"
  26. #include "dataset/kernels/image/pad_op.h"
  27. #include "dataset/kernels/image/random_color_adjust_op.h"
  28. #include "dataset/kernels/image/random_crop_decode_resize_op.h"
  29. #include "dataset/kernels/image/random_crop_and_resize_op.h"
  30. #include "dataset/kernels/image/random_crop_op.h"
  31. #include "dataset/kernels/image/random_horizontal_flip_op.h"
  32. #include "dataset/kernels/image/random_resize_op.h"
  33. #include "dataset/kernels/image/random_rotation_op.h"
  34. #include "dataset/kernels/image/random_vertical_flip_op.h"
  35. #include "dataset/kernels/image/rescale_op.h"
  36. #include "dataset/kernels/image/resize_bilinear_op.h"
  37. #include "dataset/kernels/image/resize_op.h"
  38. #include "dataset/kernels/image/uniform_aug_op.h"
  39. #include "dataset/kernels/data/type_cast_op.h"
  40. #include "dataset/engine/datasetops/source/cifar_op.h"
  41. #include "dataset/engine/datasetops/source/image_folder_op.h"
  42. #include "dataset/engine/datasetops/source/io_block.h"
  43. #include "dataset/engine/datasetops/source/mnist_op.h"
  44. #include "dataset/engine/datasetops/source/manifest_op.h"
  45. #include "dataset/engine/datasetops/source/mindrecord_op.h"
  46. #include "dataset/engine/datasetops/source/random_data_op.h"
  47. #include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
  48. #include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
  49. #include "dataset/engine/datasetops/source/sampler/random_sampler.h"
  50. #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
  51. #include "dataset/engine/datasetops/source/sampler/subset_sampler.h"
  52. #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
  53. #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
  54. #include "dataset/engine/datasetops/source/sampler/python_sampler.h"
  55. #include "dataset/engine/datasetops/source/tf_reader_op.h"
  56. #include "dataset/engine/jagged_connector.h"
  57. #include "dataset/engine/datasetops/source/text_file_op.h"
  58. #include "dataset/engine/datasetops/source/voc_op.h"
  59. #include "dataset/engine/gnn/graph.h"
  60. #include "dataset/kernels/data/to_float16_op.h"
  61. #include "dataset/text/kernels/jieba_tokenizer_op.h"
  62. #include "dataset/text/kernels/unicode_char_tokenizer_op.h"
  63. #include "dataset/text/vocab.h"
  64. #include "dataset/text/kernels/lookup_op.h"
  65. #include "dataset/util/random.h"
  66. #include "mindrecord/include/shard_operator.h"
  67. #include "mindrecord/include/shard_pk_sample.h"
  68. #include "mindrecord/include/shard_sample.h"
  69. #include "pybind11/pybind11.h"
  70. #include "pybind11/stl.h"
  71. #include "pybind11/stl_bind.h"
  72. namespace py = pybind11;
  73. namespace mindspore {
  74. namespace dataset {
  75. #define THROW_IF_ERROR(s) \
  76. do { \
  77. Status rc = std::move(s); \
  78. if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
  79. } while (false)
  80. void bindDEPipeline(py::module *m) {
  81. (void)py::class_<DEPipeline>(*m, "DEPipeline")
  82. .def(py::init<>())
  83. .def(
  84. "AddNodeToTree",
  85. [](DEPipeline &de, const OpName &op_name, const py::dict &args) {
  86. DsOpPtr op;
  87. THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &op));
  88. return op;
  89. },
  90. py::return_value_policy::reference)
  91. .def_static("AddChildToParentNode",
  92. [](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
  93. THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
  94. })
  95. .def("AssignRootNode",
  96. [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
  97. .def("SetBatchParameters",
  98. [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
  99. .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
  100. .def("GetNextAsMap",
  101. [](DEPipeline &de) {
  102. py::dict out;
  103. THROW_IF_ERROR(de.GetNextAsMap(&out));
  104. return out;
  105. })
  106. .def("GetNextAsList",
  107. [](DEPipeline &de) {
  108. py::list out;
  109. THROW_IF_ERROR(de.GetNextAsList(&out));
  110. return out;
  111. })
  112. .def("GetOutputShapes",
  113. [](DEPipeline &de) {
  114. py::list out;
  115. THROW_IF_ERROR(de.GetOutputShapes(&out));
  116. return out;
  117. })
  118. .def("GetOutputTypes",
  119. [](DEPipeline &de) {
  120. py::list out;
  121. THROW_IF_ERROR(de.GetOutputTypes(&out));
  122. return out;
  123. })
  124. .def("GetDatasetSize", &DEPipeline::GetDatasetSize)
  125. .def("GetBatchSize", &DEPipeline::GetBatchSize)
  126. .def("GetNumClasses", &DEPipeline::GetNumClasses)
  127. .def("GetRepeatCount", &DEPipeline::GetRepeatCount);
  128. }
  129. void bindDatasetOps(py::module *m) {
  130. (void)py::class_<TFReaderOp, DatasetOp, std::shared_ptr<TFReaderOp>>(*m, "TFReaderOp")
  131. .def_static("get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) {
  132. int64_t count = 0;
  133. std::vector<std::string> filenames;
  134. for (auto l : files) {
  135. !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back("");
  136. }
  137. THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate));
  138. return count;
  139. });
  140. (void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
  141. .def_static("get_num_rows", [](const std::string &dir, int64_t numSamples, bool isCifar10) {
  142. int64_t count = 0;
  143. THROW_IF_ERROR(CifarOp::CountTotalRows(dir, numSamples, isCifar10, &count));
  144. return count;
  145. });
  146. (void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
  147. .def_static("get_num_rows_and_classes", [](const std::string &path, int64_t numSamples) {
  148. int64_t count = 0, num_classes = 0;
  149. THROW_IF_ERROR(
  150. ImageFolderOp::CountRowsAndClasses(path, numSamples, std::set<std::string>{}, &count, &num_classes));
  151. return py::make_tuple(count, num_classes);
  152. });
  153. (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
  154. .def_static("get_num_rows",
  155. [](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler) {
  156. int64_t count = 0;
  157. std::shared_ptr<mindrecord::ShardOperator> op;
  158. if (py::hasattr(sampler, "_create_for_minddataset")) {
  159. auto create = sampler.attr("_create_for_minddataset");
  160. op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
  161. }
  162. THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count));
  163. return count;
  164. });
  165. (void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
  166. .def_static("get_num_rows_and_classes",
  167. [](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
  168. int64_t count = 0, num_classes = 0;
  169. THROW_IF_ERROR(ManifestOp::CountTotalRows(file, numSamples, dict, usage, &count, &num_classes));
  170. return py::make_tuple(count, num_classes);
  171. })
  172. .def_static("get_class_indexing",
  173. [](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
  174. std::map<std::string, int32_t> output_class_indexing;
  175. THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, numSamples, dict, usage, &output_class_indexing));
  176. return output_class_indexing;
  177. });
  178. (void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
  179. .def_static("get_num_rows", [](const std::string &dir, int64_t numSamples) {
  180. int64_t count = 0;
  181. THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
  182. return count;
  183. });
  184. (void)py::class_<TextFileOp, DatasetOp, std::shared_ptr<TextFileOp>>(*m, "TextFileOp")
  185. .def_static("get_num_rows", [](const py::list &files) {
  186. int64_t count = 0;
  187. std::vector<std::string> filenames;
  188. for (auto file : files) {
  189. !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back("");
  190. }
  191. THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
  192. return count;
  193. });
  194. (void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
  195. .def_static("get_num_rows",
  196. [](const std::string &dir, const std::string &task_type, const std::string &task_mode,
  197. const py::dict &dict, int64_t numSamples) {
  198. int64_t count = 0;
  199. THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count));
  200. return count;
  201. })
  202. .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
  203. const std::string &task_mode, const py::dict &dict, int64_t numSamples) {
  204. std::map<std::string, int32_t> output_class_indexing;
  205. THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, numSamples, &output_class_indexing));
  206. return output_class_indexing;
  207. });
  208. }
  209. void bindTensor(py::module *m) {
  210. (void)py::class_<GlobalContext>(*m, "GlobalContext")
  211. .def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference);
  212. (void)py::class_<ConfigManager, std::shared_ptr<ConfigManager>>(*m, "ConfigManager")
  213. .def("__str__", &ConfigManager::ToString)
  214. .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer)
  215. .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers)
  216. .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
  217. .def("set_op_connector_size", &ConfigManager::set_op_connector_size)
  218. .def("set_seed", &ConfigManager::set_seed)
  219. .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
  220. .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers)
  221. .def("get_worker_connector_size", &ConfigManager::worker_connector_size)
  222. .def("get_op_connector_size", &ConfigManager::op_connector_size)
  223. .def("get_seed", &ConfigManager::seed)
  224. .def("load", [](ConfigManager &c, std::string s) { (void)c.LoadFile(s); });
  225. (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
  226. .def(py::init([](py::array arr) {
  227. std::shared_ptr<Tensor> out;
  228. THROW_IF_ERROR(Tensor::CreateTensor(&out, arr));
  229. return out;
  230. }))
  231. .def_buffer([](Tensor &tensor) {
  232. py::buffer_info info;
  233. THROW_IF_ERROR(Tensor::GetBufferInfo(tensor, &info));
  234. return info;
  235. })
  236. .def("__str__", &Tensor::ToString)
  237. .def("shape", &Tensor::shape)
  238. .def("type", &Tensor::type)
  239. .def("as_array", [](py::object &t) {
  240. auto &tensor = py::cast<Tensor &>(t);
  241. if (tensor.type() == DataType::DE_STRING) {
  242. py::array res;
  243. tensor.GetDataAsNumpyStrings(&res);
  244. return res;
  245. }
  246. py::buffer_info info;
  247. THROW_IF_ERROR(Tensor::GetBufferInfo(tensor, &info));
  248. return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t);
  249. });
  250. (void)py::class_<TensorShape>(*m, "TensorShape")
  251. .def(py::init<py::list>())
  252. .def("__str__", &TensorShape::ToString)
  253. .def("as_list", &TensorShape::AsPyList)
  254. .def("is_known", &TensorShape::known);
  255. (void)py::class_<DataType>(*m, "DataType")
  256. .def(py::init<std::string>())
  257. .def(py::self == py::self)
  258. .def("__str__", &DataType::ToString)
  259. .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
  260. }
  261. void bindTensorOps1(py::module *m) {
  262. (void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
  263. .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
  264. (void)py::class_<NormalizeOp, TensorOp, std::shared_ptr<NormalizeOp>>(
  265. *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.")
  266. .def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"),
  267. py::arg("stdR"), py::arg("stdG"), py::arg("stdB"));
  268. (void)py::class_<RescaleOp, TensorOp, std::shared_ptr<RescaleOp>>(
  269. *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.")
  270. .def(py::init<float, float>(), py::arg("rescale"), py::arg("shift"));
  271. (void)py::class_<CenterCropOp, TensorOp, std::shared_ptr<CenterCropOp>>(
  272. *m, "CenterCropOp", "Tensor operation to crop and image in the middle. Takes height and width (optional)")
  273. .def(py::init<int32_t, int32_t>(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth);
  274. (void)py::class_<ResizeOp, TensorOp, std::shared_ptr<ResizeOp>>(
  275. *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode")
  276. .def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"),
  277. py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation);
  278. (void)py::class_<UniformAugOp, TensorOp, std::shared_ptr<UniformAugOp>>(
  279. *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).")
  280. .def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("operations"),
  281. py::arg("NumOps") = UniformAugOp::kDefNumOps);
  282. (void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
  283. *m, "ResizeBilinearOp",
  284. "Tensor operation to resize an image using "
  285. "Bilinear mode. Takes height and width.")
  286. .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeBilinearOp::kDefWidth);
  287. (void)py::class_<DecodeOp, TensorOp, std::shared_ptr<DecodeOp>>(*m, "DecodeOp",
  288. "Tensor operation to decode a jpg image")
  289. .def(py::init<>())
  290. .def(py::init<bool>(), py::arg("rgb_format") = DecodeOp::kDefRgbFormat);
  291. (void)py::class_<RandomHorizontalFlipOp, TensorOp, std::shared_ptr<RandomHorizontalFlipOp>>(
  292. *m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.")
  293. .def(py::init<float>(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability);
  294. }
  295. void bindTensorOps2(py::module *m) {
  296. (void)py::class_<RandomVerticalFlipOp, TensorOp, std::shared_ptr<RandomVerticalFlipOp>>(
  297. *m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.")
  298. .def(py::init<float>(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability);
  299. (void)py::class_<RandomCropOp, TensorOp, std::shared_ptr<RandomCropOp>>(*m, "RandomCropOp",
  300. "Gives random crop of specified size "
  301. "Takes crop size")
  302. .def(py::init<int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, BorderType, bool, uint8_t, uint8_t, uint8_t>(),
  303. py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropOp::kDefPadTop,
  304. py::arg("padBottom") = RandomCropOp::kDefPadBottom, py::arg("padLeft") = RandomCropOp::kDefPadLeft,
  305. py::arg("padRight") = RandomCropOp::kDefPadRight, py::arg("borderType") = RandomCropOp::kDefBorderType,
  306. py::arg("padIfNeeded") = RandomCropOp::kDefPadIfNeeded, py::arg("fillR") = RandomCropOp::kDefFillR,
  307. py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB);
  308. (void)py::class_<HwcToChwOp, TensorOp, std::shared_ptr<HwcToChwOp>>(*m, "ChannelSwapOp").def(py::init<>());
  309. (void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(
  310. *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.")
  311. .def(py::init<int32_t>());
  312. (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
  313. *m, "RandomRotationOp",
  314. "Tensor operation to apply RandomRotation."
  315. "Takes a range for degrees and "
  316. "optional parameters for rotation center and image expand")
  317. .def(py::init<float, float, float, float, InterpolationMode, bool, uint8_t, uint8_t, uint8_t>(),
  318. py::arg("startDegree"), py::arg("endDegree"), py::arg("centerX") = RandomRotationOp::kDefCenterX,
  319. py::arg("centerY") = RandomRotationOp::kDefCenterY,
  320. py::arg("interpolation") = RandomRotationOp::kDefInterpolation,
  321. py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR,
  322. py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
  323. }
  324. void bindTensorOps3(py::module *m) {
  325. (void)py::class_<RandomCropAndResizeOp, TensorOp, std::shared_ptr<RandomCropAndResizeOp>>(
  326. *m, "RandomCropAndResizeOp",
  327. "Tensor operation to randomly crop an image and resize to a given size."
  328. "Takes output height and width and"
  329. "optional parameters for lower and upper bound for aspect ratio (h/w) and scale,"
  330. "interpolation mode, and max attempts to crop")
  331. .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
  332. py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeOp::kDefScaleLb,
  333. py::arg("scaleUb") = RandomCropAndResizeOp::kDefScaleUb,
  334. py::arg("aspectLb") = RandomCropAndResizeOp::kDefAspectLb,
  335. py::arg("aspectUb") = RandomCropAndResizeOp::kDefAspectUb,
  336. py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation,
  337. py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter);
  338. (void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>(
  339. *m, "RandomColorAdjustOp",
  340. "Tensor operation to adjust an image's color randomly."
  341. "Takes range for brightness, contrast, saturation, hue and")
  342. .def(py::init<float, float, float, float, float, float, float, float>(), py::arg("bright_factor_start"),
  343. py::arg("bright_factor_end"), py::arg("contrast_factor_start"), py::arg("contrast_factor_end"),
  344. py::arg("saturation_factor_start"), py::arg("saturation_factor_end"), py::arg("hue_factor_start"),
  345. py::arg("hue_factor_end"));
  346. (void)py::class_<RandomResizeOp, TensorOp, std::shared_ptr<RandomResizeOp>>(
  347. *m, "RandomResizeOp",
  348. "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.")
  349. .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"),
  350. py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth);
  351. (void)py::class_<CutOutOp, TensorOp, std::shared_ptr<CutOutOp>>(
  352. *m, "CutOutOp", "Tensor operation to randomly erase a portion of the image. Takes height and width.")
  353. .def(py::init<int32_t, int32_t, int32_t, bool, uint8_t, uint8_t, uint8_t>(), py::arg("boxHeight"),
  354. py::arg("boxWidth"), py::arg("numPatches"), py::arg("randomColor") = CutOutOp::kDefRandomColor,
  355. py::arg("fillR") = CutOutOp::kDefFillR, py::arg("fillG") = CutOutOp::kDefFillG,
  356. py::arg("fillB") = CutOutOp::kDefFillB);
  357. }
  358. void bindTensorOps4(py::module *m) {
  359. (void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>(
  360. *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.")
  361. .def(py::init<DataType>(), py::arg("data_type"))
  362. .def(py::init<std::string>(), py::arg("data_type"));
  363. (void)py::class_<NoOp, TensorOp, std::shared_ptr<NoOp>>(*m, "NoOp",
  364. "TensorOp that does nothing, for testing purposes only.")
  365. .def(py::init<>());
  366. (void)py::class_<ToFloat16Op, TensorOp, std::shared_ptr<ToFloat16Op>>(
  367. *m, "ToFloat16Op", py::dynamic_attr(), "Tensor operator to type cast float32 data to a float16 type.")
  368. .def(py::init<>());
  369. (void)py::class_<RandomCropDecodeResizeOp, TensorOp, std::shared_ptr<RandomCropDecodeResizeOp>>(
  370. *m, "RandomCropDecodeResizeOp", "equivalent to RandomCropAndResize but crops before decoding")
  371. .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
  372. py::arg("targetWidth"), py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb,
  373. py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb,
  374. py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb,
  375. py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb,
  376. py::arg("interpolation") = RandomCropDecodeResizeOp::kDefInterpolation,
  377. py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter);
  378. (void)py::class_<PadOp, TensorOp, std::shared_ptr<PadOp>>(
  379. *m, "PadOp",
  380. "Pads image with specified color, default black, "
  381. "Takes amount to pad for top, bottom, left, right of image, boarder type and color")
  382. .def(py::init<int32_t, int32_t, int32_t, int32_t, BorderType, uint8_t, uint8_t, uint8_t>(), py::arg("padTop"),
  383. py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType,
  384. py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
  385. }
  386. void bindTensorOps5(py::module *m) {
  387. (void)py::class_<JiebaTokenizerOp, TensorOp, std::shared_ptr<JiebaTokenizerOp>>(*m, "JiebaTokenizerOp", "")
  388. .def(py::init<const std::string, std::string, JiebaMode>(), py::arg("hmm_path"), py::arg("mp_path"),
  389. py::arg("mode") = JiebaMode::kMix)
  390. .def("add_word",
  391. [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); });
  392. (void)py::class_<UnicodeCharTokenizerOp, TensorOp, std::shared_ptr<UnicodeCharTokenizerOp>>(
  393. *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.")
  394. .def(py::init<>());
  395. (void)py::class_<LookupOp, TensorOp, std::shared_ptr<LookupOp>>(*m, "LookupOp",
  396. "Tensor operation to LookUp each word")
  397. .def(py::init<std::shared_ptr<Vocab>, WordIdType>(), py::arg("vocab"), py::arg("unknown"))
  398. .def(py::init<std::shared_ptr<Vocab>>(), py::arg("vocab"));
  399. }
  400. void bindSamplerOps(py::module *m) {
  401. (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
  402. .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
  403. .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
  404. .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
  405. .def("get_indices",
  406. [](Sampler &self) {
  407. py::array ret;
  408. THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
  409. return ret;
  410. })
  411. .def("add_child",
  412. [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) { THROW_IF_ERROR(self->AddChild(child)); });
  413. (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
  414. (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
  415. .def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
  416. py::arg("seed"));
  417. (void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
  418. .def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
  419. (void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
  420. .def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"),
  421. py::arg("num_samples"))
  422. .def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"));
  423. (void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
  424. .def(py::init<>());
  425. (void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler")
  426. .def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size"));
  427. (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
  428. .def(py::init<std::vector<int64_t>>(), py::arg("indices"));
  429. (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
  430. *m, "MindrecordSubsetRandomSampler")
  431. .def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
  432. (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
  433. *m, "MindrecordPkSampler")
  434. .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
  435. if (shuffle == true) {
  436. return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
  437. GetSeed());
  438. } else {
  439. return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
  440. }
  441. }));
  442. (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
  443. .def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
  444. py::arg("replacement"));
  445. (void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
  446. .def(py::init<py::object>(), py::arg("pySampler"));
  447. }
  448. void bindInfoObjects(py::module *m) {
  449. (void)py::class_<BatchOp::CBatchInfo>(*m, "CBatchInfo")
  450. .def(py::init<int64_t, int64_t, int64_t>())
  451. .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num)
  452. .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
  453. }
  454. void bindVocabObjects(py::module *m) {
  455. (void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
  456. .def_static("from_list",
  457. [](const py::list &words) {
  458. std::shared_ptr<Vocab> v;
  459. THROW_IF_ERROR(Vocab::BuildFromPyList(words, &v));
  460. return v;
  461. })
  462. .def_static("from_file",
  463. [](const std::string &path, const std::string &dlm, int32_t vocab_size) {
  464. std::shared_ptr<Vocab> v;
  465. THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, &v));
  466. return v;
  467. })
  468. .def_static("from_dict", [](const py::dict &words) {
  469. std::shared_ptr<Vocab> v;
  470. THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v));
  471. return v;
  472. });
  473. }
  474. void bindGraphData(py::module *m) {
  475. (void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph")
  476. .def(py::init([](std::string dataset_file, int32_t num_workers) {
  477. std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers);
  478. THROW_IF_ERROR(g_out->Init());
  479. return g_out;
  480. }))
  481. .def("get_nodes",
  482. [](gnn::Graph &g, gnn::NodeType node_type, gnn::NodeIdType node_num) {
  483. std::shared_ptr<Tensor> out;
  484. THROW_IF_ERROR(g.GetNodes(node_type, node_num, &out));
  485. return out;
  486. })
  487. .def("get_all_neighbors",
  488. [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
  489. std::shared_ptr<Tensor> out;
  490. THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
  491. return out;
  492. })
  493. .def("get_node_feature",
  494. [](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
  495. TensorRow out;
  496. THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
  497. return out;
  498. });
  499. }
  500. // This is where we externalize the C logic as python modules
  501. PYBIND11_MODULE(_c_dataengine, m) {
  502. m.doc() = "pybind11 for _c_dataengine";
  503. (void)py::class_<DatasetOp, std::shared_ptr<DatasetOp>>(m, "DatasetOp");
  504. (void)py::enum_<OpName>(m, "OpName", py::arithmetic())
  505. .value("STORAGE", OpName::kStorage)
  506. .value("SHUFFLE", OpName::kShuffle)
  507. .value("BATCH", OpName::kBatch)
  508. .value("BARRIER", OpName::kBarrier)
  509. .value("MINDRECORD", OpName::kMindrecord)
  510. .value("CACHE", OpName::kCache)
  511. .value("REPEAT", OpName::kRepeat)
  512. .value("SKIP", OpName::kSkip)
  513. .value("TAKE", OpName::kTake)
  514. .value("ZIP", OpName::kZip)
  515. .value("CONCAT", OpName::kConcat)
  516. .value("MAP", OpName::kMap)
  517. .value("FILTER", OpName::kFilter)
  518. .value("DEVICEQUEUE", OpName::kDeviceQueue)
  519. .value("GENERATOR", OpName::kGenerator)
  520. .export_values()
  521. .value("RENAME", OpName::kRename)
  522. .value("TFREADER", OpName::kTfReader)
  523. .value("PROJECT", OpName::kProject)
  524. .value("IMAGEFOLDER", OpName::kImageFolder)
  525. .value("MNIST", OpName::kMnist)
  526. .value("MANIFEST", OpName::kManifest)
  527. .value("VOC", OpName::kVoc)
  528. .value("CIFAR10", OpName::kCifar10)
  529. .value("CIFAR100", OpName::kCifar100)
  530. .value("RANDOMDATA", OpName::kRandomData)
  531. .value("CELEBA", OpName::kCelebA)
  532. .value("TEXTFILE", OpName::kTextFile);
  533. (void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
  534. .value("DE_JIEBA_MIX", JiebaMode::kMix)
  535. .value("DE_JIEBA_MP", JiebaMode::kMp)
  536. .value("DE_JIEBA_HMM", JiebaMode::kHmm)
  537. .export_values();
  538. (void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
  539. .value("DE_INTER_LINEAR", InterpolationMode::kLinear)
  540. .value("DE_INTER_CUBIC", InterpolationMode::kCubic)
  541. .value("DE_INTER_AREA", InterpolationMode::kArea)
  542. .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour)
  543. .export_values();
  544. (void)py::enum_<BorderType>(m, "BorderType", py::arithmetic())
  545. .value("DE_BORDER_CONSTANT", BorderType::kConstant)
  546. .value("DE_BORDER_EDGE", BorderType::kEdge)
  547. .value("DE_BORDER_REFLECT", BorderType::kReflect)
  548. .value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric)
  549. .export_values();
  550. bindDEPipeline(&m);
  551. bindTensor(&m);
  552. bindTensorOps1(&m);
  553. bindTensorOps2(&m);
  554. bindTensorOps3(&m);
  555. bindTensorOps4(&m);
  556. bindTensorOps5(&m);
  557. bindSamplerOps(&m);
  558. bindDatasetOps(&m);
  559. bindInfoObjects(&m);
  560. bindVocabObjects(&m);
  561. bindGraphData(&m);
  562. }
  563. } // namespace dataset
  564. } // namespace mindspore