You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

python_bindings.cc 27 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <exception>
  17. #include "dataset/api/de_pipeline.h"
  18. #include "dataset/kernels/no_op.h"
  19. #include "dataset/kernels/data/one_hot_op.h"
  20. #include "dataset/kernels/image/center_crop_op.h"
  21. #if !defined(_WIN32) && !defined(_WIN64)
  22. #include "dataset/kernels/image/change_mode_op.h"
  23. #endif
  24. #include "dataset/kernels/image/cut_out_op.h"
  25. #include "dataset/kernels/image/decode_op.h"
  26. #include "dataset/kernels/image/distort_bounding_box_crop_op.h"
  27. #include "dataset/kernels/image/hwc_to_chw_op.h"
  28. #include "dataset/kernels/image/image_utils.h"
  29. #include "dataset/kernels/image/normalize_op.h"
  30. #include "dataset/kernels/image/pad_op.h"
  31. #include "dataset/kernels/image/random_color_adjust_op.h"
  32. #include "dataset/kernels/image/random_crop_decode_resize_op.h"
  33. #include "dataset/kernels/image/random_crop_and_resize_op.h"
  34. #include "dataset/kernels/image/random_crop_op.h"
  35. #include "dataset/kernels/image/random_horizontal_flip_op.h"
  36. #include "dataset/kernels/image/random_resize_op.h"
  37. #include "dataset/kernels/image/random_rotation_op.h"
  38. #include "dataset/kernels/image/random_vertical_flip_op.h"
  39. #include "dataset/kernels/image/rescale_op.h"
  40. #include "dataset/kernels/image/resize_bilinear_op.h"
  41. #include "dataset/kernels/image/resize_op.h"
  42. #include "dataset/kernels/image/uniform_aug_op.h"
  43. #include "dataset/kernels/data/type_cast_op.h"
  44. #include "dataset/engine/datasetops/source/cifar_op.h"
  45. #include "dataset/engine/datasetops/source/image_folder_op.h"
  46. #include "dataset/engine/datasetops/source/io_block.h"
  47. #include "dataset/engine/datasetops/source/mnist_op.h"
  48. #include "dataset/engine/datasetops/source/manifest_op.h"
  49. #include "dataset/engine/datasetops/source/mindrecord_op.h"
  50. #include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
  51. #include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
  52. #include "dataset/engine/datasetops/source/sampler/random_sampler.h"
  53. #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
  54. #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
  55. #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
  56. #include "dataset/engine/datasetops/source/sampler/python_sampler.h"
  57. #include "dataset/engine/datasetops/source/tf_reader_op.h"
  58. #include "dataset/engine/jagged_connector.h"
  59. #include "dataset/engine/datasetops/source/text_file_op.h"
  60. #include "dataset/kernels/data/to_float16_op.h"
  61. #include "dataset/util/random.h"
  62. #include "mindrecord/include/shard_operator.h"
  63. #include "mindrecord/include/shard_pk_sample.h"
  64. #include "mindrecord/include/shard_sample.h"
  65. #include "pybind11/pybind11.h"
  66. #include "pybind11/stl.h"
  67. #include "pybind11/stl_bind.h"
  68. namespace py = pybind11;
  69. namespace mindspore {
  70. namespace dataset {
  71. #define THROW_IF_ERROR(s) \
  72. do { \
  73. Status rc = std::move(s); \
  74. if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
  75. } while (false)
  76. void bindDEPipeline(py::module *m) {
  77. (void)py::class_<DEPipeline>(*m, "DEPipeline")
  78. .def(py::init<>())
  79. .def(
  80. "AddNodeToTree",
  81. [](DEPipeline &de, const OpName &op_name, const py::dict &args) {
  82. DsOpPtr op;
  83. THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &op));
  84. return op;
  85. },
  86. py::return_value_policy::reference)
  87. .def_static("AddChildToParentNode",
  88. [](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
  89. THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
  90. })
  91. .def("AssignRootNode",
  92. [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
  93. .def("SetBatchParameters",
  94. [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
  95. .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
  96. .def("GetNextAsMap",
  97. [](DEPipeline &de) {
  98. py::dict out;
  99. THROW_IF_ERROR(de.GetNextAsMap(&out));
  100. return out;
  101. })
  102. .def("GetNextAsList",
  103. [](DEPipeline &de) {
  104. py::list out;
  105. THROW_IF_ERROR(de.GetNextAsList(&out));
  106. return out;
  107. })
  108. .def("GetOutputShapes",
  109. [](DEPipeline &de) {
  110. py::list out;
  111. THROW_IF_ERROR(de.GetOutputShapes(&out));
  112. return out;
  113. })
  114. .def("GetOutputTypes",
  115. [](DEPipeline &de) {
  116. py::list out;
  117. THROW_IF_ERROR(de.GetOutputTypes(&out));
  118. return out;
  119. })
  120. .def("GetDatasetSize", &DEPipeline::GetDatasetSize)
  121. .def("GetBatchSize", &DEPipeline::GetBatchSize)
  122. .def("GetNumClasses", &DEPipeline::GetNumClasses)
  123. .def("GetRepeatCount", &DEPipeline::GetRepeatCount);
  124. }
  125. void bindDatasetOps(py::module *m) {
  126. (void)py::class_<TFReaderOp, DatasetOp, std::shared_ptr<TFReaderOp>>(*m, "TFReaderOp")
  127. .def_static("get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) {
  128. int64_t count = 0;
  129. std::vector<std::string> filenames;
  130. for (auto l : files) {
  131. !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back("");
  132. }
  133. THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate));
  134. return count;
  135. });
  136. (void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
  137. .def_static("get_num_rows", [](const std::string &dir, int64_t numSamples, bool isCifar10) {
  138. int64_t count = 0;
  139. THROW_IF_ERROR(CifarOp::CountTotalRows(dir, numSamples, isCifar10, &count));
  140. return count;
  141. });
  142. (void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
  143. .def_static("get_num_rows_and_classes", [](const std::string &path, int64_t numSamples) {
  144. int64_t count = 0, num_classes = 0;
  145. THROW_IF_ERROR(
  146. ImageFolderOp::CountRowsAndClasses(path, numSamples, std::set<std::string>{}, &count, &num_classes));
  147. return py::make_tuple(count, num_classes);
  148. });
  149. (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
  150. .def_static("get_num_rows", [](const std::string &path, const py::object &sampler) {
  151. int64_t count = 0;
  152. std::shared_ptr<mindrecord::ShardOperator> op;
  153. if (py::hasattr(sampler, "_create_for_minddataset")) {
  154. auto create = sampler.attr("_create_for_minddataset");
  155. op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
  156. }
  157. THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count));
  158. return count;
  159. });
  160. (void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
  161. .def_static("get_num_rows_and_classes",
  162. [](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
  163. int64_t count = 0, num_classes = 0;
  164. THROW_IF_ERROR(ManifestOp::CountTotalRows(file, numSamples, dict, usage, &count, &num_classes));
  165. return py::make_tuple(count, num_classes);
  166. })
  167. .def_static("get_class_indexing",
  168. [](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
  169. std::map<std::string, int32_t> output_class_indexing;
  170. THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, numSamples, dict, usage, &output_class_indexing));
  171. return output_class_indexing;
  172. });
  173. (void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
  174. .def_static("get_num_rows", [](const std::string &dir, int64_t numSamples) {
  175. int64_t count = 0;
  176. THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
  177. return count;
  178. });
  179. (void)py::class_<TextFileOp, DatasetOp, std::shared_ptr<TextFileOp>>(*m, "TextFileOp")
  180. .def_static("get_num_rows", [](const py::list &files) {
  181. int64_t count = 0;
  182. std::vector<std::string> filenames;
  183. for (auto file : files) {
  184. !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back("");
  185. }
  186. THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count));
  187. return count;
  188. });
  189. }
  190. void bindTensor(py::module *m) {
  191. (void)py::class_<GlobalContext>(*m, "GlobalContext")
  192. .def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference);
  193. (void)py::class_<ConfigManager, std::shared_ptr<ConfigManager>>(*m, "ConfigManager")
  194. .def("__str__", &ConfigManager::ToString)
  195. .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer)
  196. .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers)
  197. .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
  198. .def("set_op_connector_size", &ConfigManager::set_op_connector_size)
  199. .def("set_seed", &ConfigManager::set_seed)
  200. .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
  201. .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers)
  202. .def("get_worker_connector_size", &ConfigManager::worker_connector_size)
  203. .def("get_op_connector_size", &ConfigManager::op_connector_size)
  204. .def("get_seed", &ConfigManager::seed)
  205. .def("load", [](ConfigManager &c, std::string s) { (void)c.LoadFile(s); });
  206. (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
  207. .def(py::init([](py::array arr) {
  208. std::shared_ptr<Tensor> out;
  209. THROW_IF_ERROR(Tensor::CreateTensor(&out, arr));
  210. return out;
  211. }))
  212. .def_buffer([](Tensor &tensor) {
  213. py::buffer_info info;
  214. THROW_IF_ERROR(Tensor::GetBufferInfo(tensor, &info));
  215. return info;
  216. })
  217. .def("__str__", &Tensor::ToString)
  218. .def("shape", &Tensor::shape)
  219. .def("type", &Tensor::type)
  220. .def("as_array", [](py::object &t) {
  221. auto &tensor = py::cast<Tensor &>(t);
  222. py::buffer_info info;
  223. THROW_IF_ERROR(Tensor::GetBufferInfo(tensor, &info));
  224. return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t);
  225. });
  226. (void)py::class_<TensorShape>(*m, "TensorShape")
  227. .def(py::init<py::list>())
  228. .def("__str__", &TensorShape::ToString)
  229. .def("as_list", &TensorShape::AsPyList)
  230. .def("is_known", &TensorShape::known);
  231. (void)py::class_<DataType>(*m, "DataType")
  232. .def(py::init<std::string>())
  233. .def(py::self == py::self)
  234. .def("__str__", &DataType::ToString)
  235. .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
  236. }
  237. void bindTensorOps1(py::module *m) {
  238. (void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp")
  239. .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; });
  240. (void)py::class_<NormalizeOp, TensorOp, std::shared_ptr<NormalizeOp>>(
  241. *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.")
  242. .def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"),
  243. py::arg("stdR"), py::arg("stdG"), py::arg("stdB"));
  244. (void)py::class_<RescaleOp, TensorOp, std::shared_ptr<RescaleOp>>(
  245. *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.")
  246. .def(py::init<float, float>(), py::arg("rescale"), py::arg("shift"));
  247. (void)py::class_<CenterCropOp, TensorOp, std::shared_ptr<CenterCropOp>>(
  248. *m, "CenterCropOp", "Tensor operation to crop and image in the middle. Takes height and width (optional)")
  249. .def(py::init<int32_t, int32_t>(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth);
  250. (void)py::class_<ResizeOp, TensorOp, std::shared_ptr<ResizeOp>>(
  251. *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode")
  252. .def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"),
  253. py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation);
  254. (void)py::class_<UniformAugOp, TensorOp, std::shared_ptr<UniformAugOp>>(
  255. *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).")
  256. .def(py::init<py::list, int32_t>(), py::arg("operations"), py::arg("NumOps") = UniformAugOp::kDefNumOps);
  257. (void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
  258. *m, "ResizeBilinearOp",
  259. "Tensor operation to resize an image using "
  260. "Bilinear mode. Takes height and width.")
  261. .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeBilinearOp::kDefWidth);
  262. (void)py::class_<DecodeOp, TensorOp, std::shared_ptr<DecodeOp>>(*m, "DecodeOp",
  263. "Tensor operation to decode a jpg image")
  264. .def(py::init<>())
  265. .def(py::init<bool>(), py::arg("rgb_format") = DecodeOp::kDefRgbFormat);
  266. (void)py::class_<RandomHorizontalFlipOp, TensorOp, std::shared_ptr<RandomHorizontalFlipOp>>(
  267. *m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.")
  268. .def(py::init<float>(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability);
  269. }
  270. void bindTensorOps2(py::module *m) {
  271. (void)py::class_<RandomVerticalFlipOp, TensorOp, std::shared_ptr<RandomVerticalFlipOp>>(
  272. *m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.")
  273. .def(py::init<float>(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability);
  274. (void)py::class_<RandomCropOp, TensorOp, std::shared_ptr<RandomCropOp>>(*m, "RandomCropOp",
  275. "Gives random crop of specified size "
  276. "Takes crop size")
  277. .def(py::init<int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, BorderType, bool, uint8_t, uint8_t, uint8_t>(),
  278. py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropOp::kDefPadTop,
  279. py::arg("padBottom") = RandomCropOp::kDefPadBottom, py::arg("padLeft") = RandomCropOp::kDefPadLeft,
  280. py::arg("padRight") = RandomCropOp::kDefPadRight, py::arg("borderType") = RandomCropOp::kDefBorderType,
  281. py::arg("padIfNeeded") = RandomCropOp::kDefPadIfNeeded, py::arg("fillR") = RandomCropOp::kDefFillR,
  282. py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB);
  283. (void)py::class_<HwcToChwOp, TensorOp, std::shared_ptr<HwcToChwOp>>(*m, "ChannelSwapOp").def(py::init<>());
  284. #if !defined(_WIN32) && !defined(_WIN64)
  285. (void)py::class_<ChangeModeOp, TensorOp, std::shared_ptr<ChangeModeOp>>(
  286. *m, "ChangeModeOp", "Tensor operation to change colors from BGR to RGB")
  287. .def(py::init<>());
  288. #endif
  289. (void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(
  290. *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.")
  291. .def(py::init<int32_t>());
  292. (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
  293. *m, "RandomRotationOp",
  294. "Tensor operation to apply RandomRotation."
  295. "Takes a range for degrees and "
  296. "optional parameters for rotation center and image expand")
  297. .def(py::init<float, float, float, float, InterpolationMode, bool, uint8_t, uint8_t, uint8_t>(),
  298. py::arg("startDegree"), py::arg("endDegree"), py::arg("centerX") = RandomRotationOp::kDefCenterX,
  299. py::arg("centerY") = RandomRotationOp::kDefCenterY,
  300. py::arg("interpolation") = RandomRotationOp::kDefInterpolation,
  301. py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR,
  302. py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
  303. }
  304. void bindTensorOps3(py::module *m) {
  305. (void)py::class_<RandomCropAndResizeOp, TensorOp, std::shared_ptr<RandomCropAndResizeOp>>(
  306. *m, "RandomCropAndResizeOp",
  307. "Tensor operation to randomly crop an image and resize to a given size."
  308. "Takes output height and width and"
  309. "optional parameters for lower and upper bound for aspect ratio (h/w) and scale,"
  310. "interpolation mode, and max attempts to crop")
  311. .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
  312. py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeOp::kDefScaleLb,
  313. py::arg("scaleUb") = RandomCropAndResizeOp::kDefScaleUb,
  314. py::arg("aspectLb") = RandomCropAndResizeOp::kDefAspectLb,
  315. py::arg("aspectUb") = RandomCropAndResizeOp::kDefAspectUb,
  316. py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation,
  317. py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter);
  318. (void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>(
  319. *m, "RandomColorAdjustOp",
  320. "Tensor operation to adjust an image's color randomly."
  321. "Takes range for brightness, contrast, saturation, hue and")
  322. .def(py::init<float, float, float, float, float, float, float, float>(), py::arg("bright_factor_start"),
  323. py::arg("bright_factor_end"), py::arg("contrast_factor_start"), py::arg("contrast_factor_end"),
  324. py::arg("saturation_factor_start"), py::arg("saturation_factor_end"), py::arg("hue_factor_start"),
  325. py::arg("hue_factor_end"));
  326. (void)py::class_<RandomResizeOp, TensorOp, std::shared_ptr<RandomResizeOp>>(
  327. *m, "RandomResizeOp",
  328. "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.")
  329. .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"),
  330. py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth);
  331. (void)py::class_<CutOutOp, TensorOp, std::shared_ptr<CutOutOp>>(
  332. *m, "CutOutOp", "Tensor operation to randomly erase a portion of the image. Takes height and width.")
  333. .def(py::init<int32_t, int32_t, int32_t, bool, uint8_t, uint8_t, uint8_t>(), py::arg("boxHeight"),
  334. py::arg("boxWidth"), py::arg("numPatches"), py::arg("randomColor") = CutOutOp::kDefRandomColor,
  335. py::arg("fillR") = CutOutOp::kDefFillR, py::arg("fillG") = CutOutOp::kDefFillG,
  336. py::arg("fillB") = CutOutOp::kDefFillB);
  337. }
  338. void bindTensorOps4(py::module *m) {
  339. (void)py::class_<DistortBoundingBoxCropOp, TensorOp, std::shared_ptr<DistortBoundingBoxCropOp>>(
  340. *m, "DistortBoundingBoxCropOp",
  341. "Tensor operator to crop an image randomly as long as the cropped image has sufficient "
  342. "overlap with any one bounding box associated with original image"
  343. "Takes aspect ratio of the generated crop box, the intersection ratio of crop box and bounding box,"
  344. "crop ratio lower and upper bounds"
  345. "Optional parameters: number of attempts for crop, number of attempts of crop box generation")
  346. .def(py::init<float, float, float, float, int32_t, int32_t>(), py::arg("aspect_ratio"), py::arg("intersect_ratio"),
  347. py::arg("crop_ratio_lower_bound"), py::arg("crop_ratio_upper_bound"),
  348. py::arg("max_attempts") = DistortBoundingBoxCropOp::kDefMaxAttempts,
  349. py::arg("box_gen_attempts") = DistortBoundingBoxCropOp::kDefBoxGenAttempts);
  350. (void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>(
  351. *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.")
  352. .def(py::init<DataType>(), py::arg("data_type"))
  353. .def(py::init<std::string>(), py::arg("data_type"));
  354. (void)py::class_<NoOp, TensorOp, std::shared_ptr<NoOp>>(*m, "NoOp",
  355. "TensorOp that does nothing, for testing purposes only.")
  356. .def(py::init<>());
  357. (void)py::class_<ToFloat16Op, TensorOp, std::shared_ptr<ToFloat16Op>>(
  358. *m, "ToFloat16Op", py::dynamic_attr(), "Tensor operator to type cast float32 data to a float16 type.")
  359. .def(py::init<>());
  360. (void)py::class_<RandomCropDecodeResizeOp, TensorOp, std::shared_ptr<RandomCropDecodeResizeOp>>(
  361. *m, "RandomCropDecodeResizeOp", "equivalent to RandomCropAndResize but crops before decoding")
  362. .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
  363. py::arg("targetWidth"), py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb,
  364. py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb,
  365. py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb,
  366. py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb,
  367. py::arg("interpolation") = RandomCropDecodeResizeOp::kDefInterpolation,
  368. py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter);
  369. (void)py::class_<PadOp, TensorOp, std::shared_ptr<PadOp>>(
  370. *m, "PadOp",
  371. "Pads image with specified color, default black, "
  372. "Takes amount to pad for top, bottom, left, right of image, boarder type and color")
  373. .def(py::init<int32_t, int32_t, int32_t, int32_t, BorderType, uint8_t, uint8_t, uint8_t>(), py::arg("padTop"),
  374. py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType,
  375. py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
  376. }
  377. void bindSamplerOps(py::module *m) {
  378. (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
  379. .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
  380. .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
  381. .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
  382. .def("get_indices", [](Sampler &self) {
  383. py::array ret;
  384. THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
  385. return ret;
  386. });
  387. (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator");
  388. (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
  389. .def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
  390. py::arg("seed"));
  391. (void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
  392. .def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));
  393. (void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
  394. .def(py::init<bool, int64_t>(), py::arg("replacement"), py::arg("numSamples"))
  395. .def(py::init<bool>(), py::arg("replacement"));
  396. (void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
  397. .def(py::init<>());
  398. (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
  399. .def(py::init<std::vector<int64_t>>(), py::arg("indices"));
  400. (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
  401. *m, "MindrecordSubsetRandomSampler")
  402. .def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
  403. (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
  404. *m, "MindrecordPkSampler")
  405. .def(py::init([](int64_t kVal, bool shuffle) {
  406. if (shuffle == true) {
  407. return std::make_shared<mindrecord::ShardPkSample>("label", kVal, std::numeric_limits<int64_t>::max(),
  408. GetSeed());
  409. } else {
  410. return std::make_shared<mindrecord::ShardPkSample>("label", kVal);
  411. }
  412. }));
  413. (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
  414. .def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
  415. py::arg("replacement"));
  416. (void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
  417. .def(py::init<py::object>(), py::arg("pySampler"));
  418. }
  419. void bindInfoObjects(py::module *m) {
  420. (void)py::class_<BatchOp::CBatchInfo>(*m, "CBatchInfo")
  421. .def(py::init<int64_t, int64_t, int64_t>())
  422. .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num)
  423. .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
  424. }
  425. // This is where we externalize the C logic as python modules
  426. PYBIND11_MODULE(_c_dataengine, m) {
  427. m.doc() = "pybind11 for _c_dataengine";
  428. (void)py::class_<DatasetOp, std::shared_ptr<DatasetOp>>(m, "DatasetOp");
  429. (void)py::enum_<OpName>(m, "OpName", py::arithmetic())
  430. .value("STORAGE", OpName::kStorage)
  431. .value("SHUFFLE", OpName::kShuffle)
  432. .value("BATCH", OpName::kBatch)
  433. .value("BARRIER", OpName::kBarrier)
  434. .value("MINDRECORD", OpName::kMindrecord)
  435. .value("CACHE", OpName::kCache)
  436. .value("REPEAT", OpName::kRepeat)
  437. .value("SKIP", OpName::kSkip)
  438. .value("TAKE", OpName::kTake)
  439. .value("ZIP", OpName::kZip)
  440. .value("MAP", OpName::kMap)
  441. .value("FILTER", OpName::kFilter)
  442. .value("DEVICEQUEUE", OpName::kDeviceQueue)
  443. .value("GENERATOR", OpName::kGenerator)
  444. .export_values()
  445. .value("RENAME", OpName::kRename)
  446. .value("TFREADER", OpName::kTfReader)
  447. .value("PROJECT", OpName::kProject)
  448. .value("IMAGEFOLDER", OpName::kImageFolder)
  449. .value("MNIST", OpName::kMnist)
  450. .value("MANIFEST", OpName::kManifest)
  451. .value("VOC", OpName::kVoc)
  452. .value("CIFAR10", OpName::kCifar10)
  453. .value("CIFAR100", OpName::kCifar100)
  454. .value("CELEBA", OpName::kCelebA)
  455. .value("TEXTFILE", OpName::kTextFile);
  456. (void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
  457. .value("DE_INTER_LINEAR", InterpolationMode::kLinear)
  458. .value("DE_INTER_CUBIC", InterpolationMode::kCubic)
  459. .value("DE_INTER_AREA", InterpolationMode::kArea)
  460. .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour)
  461. .export_values();
  462. (void)py::enum_<BorderType>(m, "BorderType", py::arithmetic())
  463. .value("DE_BORDER_CONSTANT", BorderType::kConstant)
  464. .value("DE_BORDER_EDGE", BorderType::kEdge)
  465. .value("DE_BORDER_REFLECT", BorderType::kReflect)
  466. .value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric)
  467. .export_values();
  468. bindDEPipeline(&m);
  469. bindTensor(&m);
  470. bindTensorOps1(&m);
  471. bindTensorOps2(&m);
  472. bindTensorOps3(&m);
  473. bindTensorOps4(&m);
  474. bindSamplerOps(&m);
  475. bindDatasetOps(&m);
  476. bindInfoObjects(&m);
  477. }
  478. } // namespace dataset
  479. } // namespace mindspore