You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

de_pipeline.cc 40 kB

6 years ago
6 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/api/de_pipeline.h"
  17. #include <set>
  18. #include <map>
  19. #include "common/utils.h"
  20. #include "dataset/kernels/py_func_op.h"
  21. #include "dataset/engine/datasetops/source/image_folder_op.h"
  22. #include "dataset/engine/datasetops/source/mnist_op.h"
  23. #include "dataset/engine/datasetops/source/voc_op.h"
  24. #include "dataset/core/tensor.h"
  25. #include "dataset/engine/dataset_iterator.h"
  26. #include "dataset/engine/datasetops/source/manifest_op.h"
  27. #include "dataset/engine/datasetops/source/cifar_op.h"
  28. #include "dataset/engine/datasetops/source/celeba_op.h"
  29. #include "dataset/engine/datasetops/source/text_file_op.h"
  30. #include "dataset/engine/datasetops/filter_op.h"
  31. #include "mindrecord/include/shard_category.h"
  32. #include "mindrecord/include/shard_sample.h"
  33. #include "mindrecord/include/shard_shuffle.h"
  34. #include "dataset/util/random.h"
  35. #include "dataset/util/status.h"
  36. #include "utils/log_adapter.h"
  37. #include "pybind11/stl.h"
  38. namespace mindspore {
  39. namespace dataset {
  40. using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr<DatasetOp> *);
  41. static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &DEPipeline::ParseStorageOp},
  42. {kShuffle, &DEPipeline::ParseShuffleOp},
  43. {kMindrecord, &DEPipeline::ParseMindRecordOp},
  44. {kMap, &DEPipeline::ParseMapOp},
  45. {kFilter, &DEPipeline::ParseFilterOp},
  46. {kBatch, &DEPipeline::ParseBatchOp},
  47. {kBarrier, &DEPipeline::ParseBarrierOp},
  48. {kRepeat, &DEPipeline::ParseRepeatOp},
  49. {kSkip, &DEPipeline::ParseSkipOp},
  50. {kZip, &DEPipeline::ParseZipOp},
  51. {kRename, &DEPipeline::ParseRenameOp},
  52. {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp},
  53. {kGenerator, &DEPipeline::ParseGeneratorOp},
  54. {kTfReader, &DEPipeline::ParseTFReaderOp},
  55. {kProject, &DEPipeline::ParseProjectOp},
  56. {kTake, &DEPipeline::ParseTakeOp},
  57. {kImageFolder, &DEPipeline::ParseImageFolderOp},
  58. {kMnist, &DEPipeline::ParseMnistOp},
  59. {kManifest, &DEPipeline::ParseManifestOp},
  60. {kVoc, &DEPipeline::ParseVOCOp},
  61. {kCifar10, &DEPipeline::ParseCifar10Op},
  62. {kCifar100, &DEPipeline::ParseCifar100Op},
  63. {kCelebA, &DEPipeline::ParseCelebAOp},
  64. {kTextFile, &DEPipeline::ParseTextFileOp}};
  65. DEPipeline::DEPipeline() : iterator_(nullptr) {
  66. try {
  67. // One time init
  68. (void)GlobalInit();
  69. // Instantiate the execution tree
  70. tree_ = std::make_shared<ExecutionTree>();
  71. repeat_num_ = 1;
  72. batch_size_ = 1;
  73. num_rows_ = 0;
  74. num_classes_ = 0;
  75. temp_batch_size_ = 1;
  76. temp_drop_remainder_ = false;
  77. } catch (const std::exception &err) {
  78. MS_LOG(ERROR) << "Dataset pipeline exception caught on init: " << err.what() << ".";
  79. return;
  80. }
  81. }
  82. DEPipeline::~DEPipeline() {
  83. {
  84. // Release GIL before joining all threads
  85. py::gil_scoped_release gil_release;
  86. // Release tree
  87. tree_.reset();
  88. }
  89. }
  90. // Function to add a Node to the Execution Tree.
  91. Status DEPipeline::AddNodeToTree(const OpName &op_name, const py::dict &args, DsOpPtr *out) {
  92. // For each operator, Parse through the list of arguments,
  93. // then call the respective builder/constructor.
  94. auto iter = g_parse_op_func_.find(op_name);
  95. if (iter != g_parse_op_func_.end()) {
  96. pFunction func = iter->second;
  97. RETURN_IF_NOT_OK((this->*func)(args, out));
  98. } else {
  99. RETURN_STATUS_UNEXPECTED("No such Op");
  100. }
  101. // Associate current dataset op node with the tree.
  102. RETURN_IF_NOT_OK(tree_->AssociateNode(*out));
  103. return Status::OK();
  104. }
  105. // Function to add a child and parent relationship.
  106. Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op) {
  107. // Link this relationship.
  108. // Note parent node takes ownership of the child
  109. return (parent_op->AddChild(child_op));
  110. }
  111. // Function to assign the node as root.
  112. Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); }
  113. // Function to launch the tree execution.
  114. Status DEPipeline::LaunchTreeExec() {
  115. RETURN_IF_NOT_OK(tree_->Prepare());
  116. RETURN_IF_NOT_OK(tree_->Launch());
  117. iterator_ = std::make_unique<DatasetIterator>(tree_);
  118. if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator.");
  119. return Status::OK();
  120. }
  121. void DEPipeline::PrintTree() {
  122. for (auto itr = tree_->begin(); itr != tree_->end(); ++itr) {
  123. std::stringstream ss;
  124. ss << *itr;
  125. MS_LOG(INFO) << "Operator ID is " << itr->id() << ". Details: " << ss.str().c_str() << ".";
  126. }
  127. }
  128. Status DEPipeline::GetNextAsMap(py::dict *output) {
  129. TensorMap row;
  130. Status s;
  131. {
  132. py::gil_scoped_release gil_release;
  133. s = iterator_->GetNextAsMap(&row);
  134. }
  135. RETURN_IF_NOT_OK(s);
  136. // Generate Python dict as return
  137. for (auto el : row) {
  138. (*output)[common::SafeCStr(el.first)] = el.second;
  139. }
  140. return Status::OK();
  141. }
  142. Status DEPipeline::GetNextAsList(py::list *output) {
  143. TensorRow row;
  144. Status s;
  145. {
  146. py::gil_scoped_release gil_release;
  147. s = iterator_->FetchNextTensorRow(&row);
  148. }
  149. RETURN_IF_NOT_OK(s);
  150. // Generate Python list as return
  151. for (auto el : row) {
  152. output->append(el);
  153. }
  154. return Status::OK();
  155. }
  156. Status DEPipeline::GetOutputShapes(py::list *output) {
  157. std::vector<TensorShape> shapes;
  158. Status s;
  159. {
  160. py::gil_scoped_release gil_release;
  161. s = iterator_->GetOutputShapes(&shapes);
  162. }
  163. RETURN_IF_NOT_OK(s);
  164. for (auto el : shapes) {
  165. py::list shape;
  166. for (auto dim : el.AsVector()) {
  167. shape.append(dim);
  168. }
  169. output->append(shape);
  170. }
  171. return Status::OK();
  172. }
  173. Status DEPipeline::GetOutputTypes(py::list *output) {
  174. std::vector<DataType> types;
  175. Status s;
  176. {
  177. py::gil_scoped_release gil_release;
  178. s = iterator_->GetOutputTypes(&types);
  179. }
  180. RETURN_IF_NOT_OK(s);
  181. for (auto el : types) {
  182. output->append(el.AsNumpyType());
  183. }
  184. return Status::OK();
  185. }
  186. int DEPipeline::GetDatasetSize() const { return num_rows_ / batch_size_; }
  187. int DEPipeline::GetBatchSize() const { return batch_size_; }
  188. int DEPipeline::GetRepeatCount() const { return repeat_num_; }
  189. int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
  190. bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
  191. std::string ToString(const py::handle &handle) { return py::reinterpret_borrow<py::str>(handle); }
  192. std::vector<std::string> ToStringVector(const py::handle handle) {
  193. py::list list = py::reinterpret_borrow<py::list>(handle);
  194. std::vector<std::string> vector;
  195. for (auto l : list) {
  196. if (!l.is_none())
  197. vector.push_back(py::str(l));
  198. else
  199. vector.emplace_back("");
  200. }
  201. return vector;
  202. }
  203. std::set<std::string> ToStringSet(const py::handle handle) {
  204. py::list list = py::reinterpret_borrow<py::list>(handle);
  205. std::set<std::string> set;
  206. for (auto l : list) {
  207. if (!l.is_none()) {
  208. (void)set.insert(py::str(l));
  209. }
  210. }
  211. return set;
  212. }
  213. std::map<std::string, int32_t> ToStringMap(const py::handle handle) {
  214. py::dict dict = py::reinterpret_borrow<py::dict>(handle);
  215. std::map<std::string, int32_t> map;
  216. for (auto p : dict) {
  217. (void)map.insert(std::make_pair(ToString(p.first), ToInt(p.second)));
  218. }
  219. return map;
  220. }
  221. std::vector<int> ToIntVector(const py::handle handle) {
  222. py::list list = py::reinterpret_borrow<py::list>(handle);
  223. std::vector<int> vector;
  224. for (auto l : list) {
  225. if (!l.is_none()) {
  226. vector.push_back(ToInt(l));
  227. }
  228. }
  229. return vector;
  230. }
  231. std::vector<DataType> ToTypeVector(const py::handle handle) {
  232. py::list list = py::reinterpret_borrow<py::list>(handle);
  233. std::vector<DataType> vector;
  234. for (auto l : list) {
  235. if (l.is_none()) {
  236. vector.emplace_back(DataType());
  237. } else {
  238. vector.push_back(l.cast<DataType>());
  239. }
  240. }
  241. return vector;
  242. }
  243. Status DEPipeline::SetBatchParameters(const py::dict &args) {
  244. if (args["batch_size"].is_none()) {
  245. std::string err_msg = "Error: batchSize is invalid or not set.";
  246. RETURN_STATUS_UNEXPECTED(err_msg);
  247. }
  248. temp_batch_size_ = ToInt(args["batch_size"]);
  249. CHECK_FAIL_RETURN_UNEXPECTED(temp_batch_size_ > 0, "Error: batchSize is invalid.");
  250. for (auto arg : args) {
  251. std::string key = py::str(arg.first);
  252. py::handle value = arg.second;
  253. if (!value.is_none()) {
  254. if (key == "drop_remainder") {
  255. temp_drop_remainder_ = ToBool(value);
  256. }
  257. }
  258. }
  259. return Status::OK();
  260. }
  261. Status DEPipeline::ValidateArgStorageOp(const py::dict &args) {
  262. // Required arguments
  263. if (((args.contains("dataset_files") && args["dataset_files"].is_none()) || args["schema"].is_none()) &&
  264. ((args.contains("dataset_dir") && args["dataset_dir"].is_none()) ||
  265. (args["schema"].is_none() && args["schema_json_string"].is_none()))) {
  266. std::string err_msg = "Error: at least one of dataset_files or schema_file is missing";
  267. RETURN_STATUS_UNEXPECTED(err_msg);
  268. }
  269. return Status::OK();
  270. }
  271. Status DEPipeline::ParseStorageOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  272. RETURN_IF_NOT_OK(ValidateArgStorageOp(args));
  273. std::shared_ptr<StorageOp::Builder> builder;
  274. if (args.contains("dataset_files") && !args["dataset_files"].is_none()) {
  275. builder = std::make_shared<StorageOp::Builder>();
  276. (void)builder->SetDatasetFileList(ToStringVector(args["dataset_files"]));
  277. (void)builder->SetSchemaFile(ToString(args["schema"]));
  278. } else if (args.contains("dataset_dir") && !args["dataset_dir"].is_none()) {
  279. builder = std::make_shared<StorageOp::Builder>();
  280. (void)builder->SetDatasetFilesDir(ToString(args["dataset_dir"]));
  281. if (!args["schema"].is_none()) {
  282. (void)builder->SetSchemaFile(ToString(args["schema"]));
  283. } else if (!args["schema_json_string"].is_none()) {
  284. std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
  285. std::string s = ToString(args["schema_json_string"]);
  286. RETURN_IF_NOT_OK(schema->LoadSchemaString(s, std::vector<std::string>()));
  287. (void)builder->SetNumRows(schema->num_rows());
  288. (void)builder->SetSchema(std::move(schema));
  289. }
  290. }
  291. // Optional arguments
  292. for (auto arg : args) {
  293. std::string key = py::str(arg.first);
  294. py::handle value = arg.second;
  295. if (!value.is_none()) {
  296. if (key == "num_parallel_workers") {
  297. (void)builder->SetNumWorkers(ToInt(value));
  298. } else if (key == "prefetch_size") {
  299. (void)builder->SetOpConnectorSize(ToInt(value));
  300. } else if (key == "columns_list") {
  301. (void)builder->SetColumnsToLoad(ToStringVector(value));
  302. } else if (key == "distribution") {
  303. (void)builder->SetDataDistributionFile(ToString(value));
  304. } else if (key == "labels_filename") {
  305. (void)builder->setLabelsFileName(ToString(value));
  306. } else if (key == "dataset_usage") {
  307. (void)builder->SetDatasetUsage(ToString(value));
  308. }
  309. }
  310. }
  311. (void)builder->SetBatchSize(temp_batch_size_);
  312. (void)builder->SetDropRemainder(temp_drop_remainder_);
  313. std::shared_ptr<StorageOp> op;
  314. RETURN_IF_NOT_OK(builder->Build(&op));
  315. num_rows_ = op->num_rows();
  316. num_classes_ = op->num_classes();
  317. *ptr = op;
  318. return Status::OK();
  319. }
  320. Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  321. std::shared_ptr<ShuffleOp::Builder> builder = std::make_shared<ShuffleOp::Builder>();
  322. if (!args["buffer_size"].is_none()) {
  323. (void)builder->SetShuffleSize(ToInt(args["buffer_size"]));
  324. } else {
  325. std::string err_msg = "Error: Shuffle buffer size is missing";
  326. RETURN_STATUS_UNEXPECTED(err_msg);
  327. }
  328. std::shared_ptr<ShuffleOp> op;
  329. RETURN_IF_NOT_OK(builder->Build(&op));
  330. *ptr = op;
  331. return Status::OK();
  332. }
  333. Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vector<int> *in_partitions) {
  334. if (args["partitions"].is_none()) {
  335. std::string err_msg = "Error: partitions is not set (None)";
  336. RETURN_STATUS_UNEXPECTED(err_msg);
  337. }
  338. py::list list = py::reinterpret_borrow<py::list>(args["partitions"]);
  339. for (auto l : list) {
  340. if (!l.is_none()) {
  341. in_partitions->push_back(ToInt(l));
  342. }
  343. }
  344. if (in_partitions->size() != 2) {
  345. std::string err_msg = "Error: partitions is invalid or not set.";
  346. RETURN_STATUS_UNEXPECTED(err_msg);
  347. }
  348. constexpr int kMaxPartitions = 64;
  349. if (in_partitions->at(0) <= 0 || in_partitions->at(0) > kMaxPartitions) {
  350. std::string err_msg = "Error: partitions is invalid or not set.";
  351. RETURN_STATUS_UNEXPECTED(err_msg);
  352. }
  353. if (in_partitions->at(1) < 0 || in_partitions->at(1) >= in_partitions->at(0)) {
  354. std::string err_msg = "Error: partitions is invalid or not set.";
  355. RETURN_STATUS_UNEXPECTED(err_msg);
  356. }
  357. return Status::OK();
  358. }
  359. Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  360. if (args["dataset_file"].is_none()) {
  361. std::string err_msg = "Error: at least one of dataset_files is missing";
  362. RETURN_STATUS_UNEXPECTED(err_msg);
  363. }
  364. std::shared_ptr<MindRecordOp::Builder> builder = std::make_shared<MindRecordOp::Builder>();
  365. (void)builder->SetDatasetFile(ToString(args["dataset_file"]));
  366. std::vector<std::string> in_col_names;
  367. if (!args["columns_list"].is_none()) {
  368. in_col_names = ToStringVector(args["columns_list"]);
  369. if (in_col_names.empty() || in_col_names[0].empty()) {
  370. std::string err_msg = "Error: columns_list is invalid or not set.";
  371. RETURN_STATUS_UNEXPECTED(err_msg);
  372. }
  373. (void)builder->SetColumnsToLoad(in_col_names);
  374. }
  375. std::vector<std::shared_ptr<mindrecord::ShardOperator>> operators;
  376. for (auto arg : args) {
  377. std::string key = py::str(arg.first);
  378. py::handle value = arg.second;
  379. if (!value.is_none()) {
  380. if (key == "num_parallel_workers") {
  381. (void)builder->SetNumMindRecordWorkers(ToInt(value));
  382. } else if (key == "block_reader" && ToBool(value) == true) {
  383. (void)builder->SetBlockReader();
  384. } else if (key == "global_shuffle" && ToBool(value) == true) {
  385. uint32_t seed = args["partitions"].is_none() ? GetSeed() : 0;
  386. operators.push_back(std::make_shared<mindrecord::ShardShuffle>(seed));
  387. } else if (key == "sampler") {
  388. auto create = py::reinterpret_borrow<py::object>(value).attr("_create_for_minddataset");
  389. std::shared_ptr<mindrecord::ShardOperator> sample_op =
  390. create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
  391. operators.push_back(sample_op);
  392. }
  393. }
  394. }
  395. std::vector<int> in_partitions;
  396. if (!args["partitions"].is_none()) {
  397. auto ret = CheckMindRecordPartitionInfo(args, &in_partitions);
  398. if (Status::OK() != ret) {
  399. return ret;
  400. }
  401. operators.push_back(std::make_shared<mindrecord::ShardSample>(1, in_partitions[0], in_partitions[1]));
  402. }
  403. if (!operators.empty()) {
  404. (void)builder->SetOperators(operators);
  405. }
  406. std::shared_ptr<MindRecordOp> op;
  407. RETURN_IF_NOT_OK(builder->Build(&op));
  408. num_rows_ = op->num_rows();
  409. *ptr = op;
  410. return Status::OK();
  411. }
  412. Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  413. std::shared_ptr<MapOp::Builder> builder = std::make_shared<MapOp::Builder>();
  414. std::vector<std::shared_ptr<TensorOp>> tensor_op_list;
  415. if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n");
  416. for (auto arg : args) {
  417. std::string key = py::str(arg.first);
  418. py::handle value = arg.second;
  419. if (!value.is_none()) {
  420. if (key == "input_columns") {
  421. std::vector<std::string> in_col_names = ToStringVector(args["input_columns"]);
  422. (void)builder->SetInColNames(in_col_names);
  423. } else if (key == "output_columns") {
  424. (void)builder->SetOutColNames(ToStringVector(value));
  425. } else if (key == "num_parallel_workers") {
  426. (void)builder->SetNumWorkers(ToInt(value));
  427. } else if (key == "prefetch_size") {
  428. (void)builder->SetOpConnectorSize(ToInt(value));
  429. } else if (key == "operations") {
  430. py::handle tensor_ops = args["operations"];
  431. // operation can be a list of TensorOps or a single TensorOp.
  432. if (py::isinstance<py::list>(tensor_ops)) {
  433. for (auto op : tensor_ops) {
  434. std::shared_ptr<TensorOp> tensor_op;
  435. if (py::isinstance<TensorOp>(op)) {
  436. tensor_op = op.cast<std::shared_ptr<TensorOp>>();
  437. } else if (py::isinstance<py::function>(op)) {
  438. tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
  439. } else {
  440. RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc).");
  441. }
  442. tensor_op_list.push_back(tensor_op);
  443. }
  444. }
  445. if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set.");
  446. (void)builder->SetTensorFuncs(std::move(tensor_op_list));
  447. } else {
  448. RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
  449. }
  450. }
  451. }
  452. std::shared_ptr<MapOp> op;
  453. RETURN_IF_NOT_OK(builder->Build(&op));
  454. *ptr = op;
  455. return Status::OK();
  456. }
  457. Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  458. std::shared_ptr<FilterOp::Builder> builder = std::make_shared<FilterOp::Builder>();
  459. if (args["predicate"].is_none()) {
  460. RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n");
  461. }
  462. for (auto arg : args) {
  463. std::string key = py::str(arg.first);
  464. py::handle value = arg.second;
  465. if (!value.is_none()) {
  466. if (key == "num_parallel_workers") {
  467. (void)builder->SetNumWorkers(ToInt(value));
  468. } else if (key == "predicate") {
  469. py::handle op = args["predicate"];
  470. if (!py::isinstance<py::function>(op)) {
  471. RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc).");
  472. }
  473. py::function predicate_func = op.cast<py::function>();
  474. (void)builder->SetPredicateFunc(std::move(predicate_func));
  475. } else if (key == "input_columns") {
  476. std::vector<std::string> in_col_names = ToStringVector(args["input_columns"]);
  477. (void)builder->SetInColNames(in_col_names);
  478. } else {
  479. RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
  480. }
  481. }
  482. }
  483. std::shared_ptr<FilterOp> op;
  484. RETURN_IF_NOT_OK(builder->Build(&op));
  485. *ptr = op;
  486. return Status::OK();
  487. }
  488. Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  489. if (args["count"].is_none()) {
  490. std::string err_msg = "Error: count is invalid or not set.";
  491. RETURN_STATUS_UNEXPECTED(err_msg);
  492. }
  493. repeat_num_ = ToInt(args["count"]);
  494. std::shared_ptr<RepeatOp> op;
  495. RETURN_IF_NOT_OK(RepeatOp::Builder(ToInt(args["count"])).Build(&op));
  496. *ptr = op;
  497. return Status::OK();
  498. }
  499. Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  500. if (args["count"].is_none()) {
  501. std::string err_msg = "Error: count is invalid or not set.";
  502. RETURN_STATUS_UNEXPECTED(err_msg);
  503. }
  504. std::shared_ptr<SkipOp> op;
  505. RETURN_IF_NOT_OK(SkipOp::Builder(ToInt(args["count"])).Build(&op));
  506. *ptr = op;
  507. return Status::OK();
  508. }
  509. Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  510. std::shared_ptr<GeneratorOp::Builder> builder = std::make_shared<GeneratorOp::Builder>();
  511. for (auto arg : args) {
  512. std::string key = py::str(arg.first);
  513. py::handle value = arg.second;
  514. if (!value.is_none()) {
  515. if (key == "source") {
  516. py::object obj = py::cast(&value);
  517. if (!py::isinstance<py::function>(obj)) {
  518. std::string err_msg = "Error: generator is invalid or not set.";
  519. RETURN_STATUS_UNEXPECTED(err_msg);
  520. }
  521. (void)builder->SetGeneratorFunction(obj.cast<py::function>());
  522. } else if (key == "column_names") {
  523. (void)builder->SetColumnNames(ToStringVector(value));
  524. } else if (key == "column_types") {
  525. (void)builder->SetColumnTypes(ToTypeVector(value));
  526. }
  527. }
  528. }
  529. std::shared_ptr<GeneratorOp> op;
  530. RETURN_IF_NOT_OK(builder->Build(&op));
  531. *ptr = op;
  532. return Status::OK();
  533. }
  534. Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  535. std::shared_ptr<BatchOp::Builder> builder;
  536. if (py::isinstance<py::int_>(args["batch_size"])) {
  537. batch_size_ = ToInt(args["batch_size"]);
  538. CHECK_FAIL_RETURN_UNEXPECTED(batch_size_ > 0, "Error: batch_size is invalid.");
  539. builder = std::make_shared<BatchOp::Builder>(ToInt(args["batch_size"]));
  540. } else if (py::isinstance<py::function>(args["batch_size"])) {
  541. builder = std::make_shared<BatchOp::Builder>(1);
  542. (void)builder->SetBatchSizeFunc(args["batch_size"].cast<py::function>());
  543. } else {
  544. std::string err_msg = "Error: batch_size is neither an Integer nor a python function";
  545. RETURN_STATUS_UNEXPECTED(err_msg);
  546. }
  547. for (auto arg : args) {
  548. std::string key = py::str(arg.first);
  549. py::handle value = arg.second;
  550. if (!value.is_none()) {
  551. if (key == "drop_remainder") {
  552. (void)builder->SetDrop(ToBool(value));
  553. }
  554. if (key == "num_parallel_workers") {
  555. (void)builder->SetNumWorkers(ToInt(value));
  556. }
  557. if (key == "per_batch_map") {
  558. (void)builder->SetBatchMapFunc(value.cast<py::function>());
  559. }
  560. if (key == "input_columns") {
  561. (void)builder->SetColumnsToMap(ToStringVector(value));
  562. }
  563. }
  564. }
  565. std::shared_ptr<BatchOp> op;
  566. RETURN_IF_NOT_OK(builder->Build(&op));
  567. *ptr = op;
  568. return Status::OK();
  569. }
  570. Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  571. std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>();
  572. // Right now barrier should only take num_rows_per_buffer = 1
  573. // The reason for this is because having it otherwise can lead to blocking issues
  574. // See barrier_op.h for more details
  575. (void)builder->SetRowsPerBuffer(1);
  576. for (auto arg : args) {
  577. std::string key = py::str(arg.first);
  578. py::handle value = arg.second;
  579. if (!value.is_none()) {
  580. if (key == "condition_name") {
  581. (void)builder->SetConditionName(ToString(value));
  582. } else if (key == "condition_func") {
  583. (void)builder->SetConditionFunc(value.cast<py::function>());
  584. }
  585. }
  586. }
  587. std::shared_ptr<BarrierOp> op;
  588. RETURN_IF_NOT_OK(builder->Build(&op));
  589. *ptr = op;
  590. return Status::OK();
  591. }
  592. Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  593. int32_t prefetch_size = 0;
  594. if (args.contains("prefetch_size")) {
  595. if (args["prefetch_size"].is_none()) {
  596. prefetch_size = 16;
  597. } else {
  598. prefetch_size = ToInt(args["prefetch_size"]);
  599. }
  600. }
  601. std::shared_ptr<DeviceQueueOp::Builder> builder = std::make_shared<DeviceQueueOp::Builder>(prefetch_size);
  602. for (auto arg : args) {
  603. std::string key = py::str(arg.first);
  604. py::handle value = arg.second;
  605. if (!value.is_none()) {
  606. if (key == "queue_name") {
  607. (void)builder->SetChannelName(ToString(value));
  608. } else if (key == "device_type") {
  609. (void)builder->SetDeviceType(ToString(value));
  610. } else if (key == "device_id") {
  611. (void)builder->SetDeviceId(ToInt(value));
  612. } else if (key == "num_batch") {
  613. (void)builder->SetNumBatch(ToInt(value));
  614. }
  615. }
  616. }
  617. std::shared_ptr<DeviceQueueOp> op;
  618. RETURN_IF_NOT_OK(builder->Build(&op));
  619. *ptr = op;
  620. return Status::OK();
  621. }
  622. Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  623. std::vector<std::string> in_col_names;
  624. std::vector<std::string> out_col_names;
  625. std::shared_ptr<RenameOp::Builder> builder = std::make_shared<RenameOp::Builder>();
  626. for (auto arg : args) {
  627. std::string key = py::str(arg.first);
  628. py::handle value = arg.second;
  629. if (!value.is_none()) {
  630. if (key == "input_columns") {
  631. in_col_names = ToStringVector(value);
  632. } else if (key == "output_columns") {
  633. out_col_names = ToStringVector(value);
  634. }
  635. }
  636. }
  637. if (in_col_names.empty() || in_col_names[0].empty()) {
  638. std::string err_msg = "Error: input_column_names is invalid or not set.";
  639. RETURN_STATUS_UNEXPECTED(err_msg);
  640. }
  641. if (out_col_names.empty() || out_col_names[0].empty()) {
  642. std::string err_msg = "Error: output_column_names is invalid or not set.";
  643. RETURN_STATUS_UNEXPECTED(err_msg);
  644. }
  645. (void)builder->SetInColNames(in_col_names);
  646. (void)builder->SetOutColNames(out_col_names);
  647. std::shared_ptr<RenameOp> op;
  648. RETURN_IF_NOT_OK(builder->Build(&op));
  649. *ptr = op;
  650. return Status::OK();
  651. }
  652. Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  653. if (args["count"].is_none()) {
  654. std::string err_msg = "Error: count is invalid or not set.";
  655. RETURN_STATUS_UNEXPECTED(err_msg);
  656. }
  657. std::shared_ptr<TakeOp> op;
  658. RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op));
  659. *ptr = op;
  660. return Status::OK();
  661. }
  662. Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  663. std::shared_ptr<ZipOp::Builder> builder = std::make_shared<ZipOp::Builder>();
  664. std::shared_ptr<ZipOp> op;
  665. RETURN_IF_NOT_OK(builder->Build(&op));
  666. *ptr = op;
  667. return Status::OK();
  668. }
  669. Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  670. // Required arguments
  671. std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>();
  672. if (!args["dataset_files"].is_none()) {
  673. (void)builder->SetDatasetFilesList(ToStringVector(args["dataset_files"]));
  674. } else {
  675. std::string err_msg = "Error: at least one of dataset_files or schema_file is missing";
  676. RETURN_STATUS_UNEXPECTED(err_msg);
  677. }
  678. std::vector<std::string> columns_to_load;
  679. bool schema_exists = false;
  680. // Optional arguments
  681. for (auto arg : args) {
  682. std::string key = py::str(arg.first);
  683. py::handle value = arg.second;
  684. if (!value.is_none()) {
  685. if (key == "num_parallel_workers") {
  686. (void)builder->SetNumWorkers(ToInt(value));
  687. } else if (key == "columns_list") {
  688. columns_to_load = ToStringVector(value);
  689. (void)builder->SetColumnsToLoad(columns_to_load);
  690. } else if (key == "shuffle_files") {
  691. (void)builder->SetShuffleFiles(ToBool(value));
  692. } else if (key == "schema_file_path" || key == "schema_json_string") {
  693. schema_exists = true;
  694. } else if (key == "num_samples") {
  695. (void)builder->setTotalRows(ToInt(value));
  696. } else if (key == "num_shards") {
  697. (void)builder->SetNumDevices(ToInt(value));
  698. } else if (key == "shard_id") {
  699. (void)builder->SetDeviceId(ToInt(value));
  700. } else if (key == "shard_equal_rows") {
  701. (void)builder->SetShardEqualRows(ToBool(value));
  702. }
  703. }
  704. }
  705. if (schema_exists) {
  706. std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
  707. if (args.contains("schema_file_path")) {
  708. RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load));
  709. } else {
  710. RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load));
  711. }
  712. (void)builder->SetDataSchema(std::move(schema));
  713. }
  714. std::shared_ptr<TFReaderOp> op;
  715. RETURN_IF_NOT_OK(builder->Build(&op));
  716. *ptr = op;
  717. return Status::OK();
  718. }
  719. Status DEPipeline::ParseProjectOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  720. if (args["columns"].is_none()) {
  721. std::string err_msg = "Error: columns is missing";
  722. RETURN_STATUS_UNEXPECTED(err_msg);
  723. }
  724. std::vector<std::string> columns_to_project = ToStringVector(args["columns"]);
  725. std::shared_ptr<ProjectOp::Builder> builder = std::make_shared<ProjectOp::Builder>(columns_to_project);
  726. std::shared_ptr<ProjectOp> op;
  727. RETURN_IF_NOT_OK(builder->Build(&op));
  728. *ptr = op;
  729. return Status::OK();
  730. }
  731. Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  732. // Required arguments
  733. if (args["dataset_dir"].is_none()) {
  734. std::string err_msg = "Error: No dataset path specified";
  735. RETURN_STATUS_UNEXPECTED(err_msg);
  736. }
  737. std::shared_ptr<ImageFolderOp::Builder> builder = std::make_shared<ImageFolderOp::Builder>();
  738. (void)builder->SetImageFolderDir(ToString(args["dataset_dir"]));
  739. // Optional arguments
  740. for (auto arg : args) {
  741. std::string key = py::str(arg.first);
  742. py::handle value = arg.second;
  743. if (!value.is_none()) {
  744. if (key == "num_samples") {
  745. (void)builder->SetNumSamples(ToInt(value));
  746. } else if (key == "num_parallel_workers") {
  747. (void)builder->SetNumWorkers(ToInt(value));
  748. } else if (key == "sampler") {
  749. auto create = py::reinterpret_borrow<py::object>(value).attr("create");
  750. std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
  751. (void)builder->SetSampler(std::move(sampler));
  752. } else if (key == "extensions") {
  753. (void)builder->SetExtensions(ToStringSet(value));
  754. } else if (key == "class_indexing") {
  755. (void)builder->SetClassIndex(ToStringMap(value));
  756. } else if (key == "decode") {
  757. (void)builder->SetDecode(ToBool(value));
  758. }
  759. }
  760. }
  761. std::shared_ptr<ImageFolderOp> op;
  762. RETURN_IF_NOT_OK(builder->Build(&op));
  763. *ptr = op;
  764. return Status::OK();
  765. }
  766. Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  767. // Required arguments
  768. if (args["dataset_file"].is_none()) {
  769. std::string err_msg = "Error: No dataset files specified for manifest";
  770. RETURN_STATUS_UNEXPECTED(err_msg);
  771. }
  772. std::shared_ptr<ManifestOp::Builder> builder = std::make_shared<ManifestOp::Builder>();
  773. (void)builder->SetManifestFile(ToString(args["dataset_file"]));
  774. // Optional arguments
  775. for (auto arg : args) {
  776. std::string key = py::str(arg.first);
  777. py::handle value = arg.second;
  778. if (!value.is_none()) {
  779. if (key == "num_samples") {
  780. (void)builder->SetNumSamples(ToInt(value));
  781. } else if (key == "num_parallel_workers") {
  782. (void)builder->SetNumWorkers(ToInt(value));
  783. } else if (key == "sampler") {
  784. auto create = py::reinterpret_borrow<py::object>(value).attr("create");
  785. std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
  786. (void)builder->SetSampler(std::move(sampler));
  787. } else if (key == "class_indexing") {
  788. (void)builder->SetClassIndex(ToStringMap(value));
  789. } else if (key == "decode") {
  790. (void)builder->SetDecode(ToBool(value));
  791. } else if (key == "usage") {
  792. (void)builder->SetUsage(ToString(value));
  793. }
  794. }
  795. }
  796. std::shared_ptr<ManifestOp> op;
  797. RETURN_IF_NOT_OK(builder->Build(&op));
  798. *ptr = op;
  799. return Status::OK();
  800. }
  801. Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  802. if (args["dataset_dir"].is_none()) {
  803. std::string err_msg = "Error: No dataset path specified";
  804. RETURN_STATUS_UNEXPECTED(err_msg);
  805. }
  806. std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>();
  807. (void)builder->SetDir(ToString(args["dataset_dir"]));
  808. for (auto arg : args) {
  809. std::string key = py::str(arg.first);
  810. py::handle value = arg.second;
  811. if (!value.is_none()) {
  812. if (key == "num_samples") {
  813. (void)builder->SetNumSamples(ToInt(value));
  814. } else if (key == "num_parallel_workers") {
  815. (void)builder->SetNumWorkers(ToInt(value));
  816. } else if (key == "sampler") {
  817. auto create = py::reinterpret_borrow<py::object>(value).attr("create");
  818. std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
  819. (void)builder->SetSampler(std::move(sampler));
  820. } else if (key == "decode") {
  821. (void)builder->SetDecode(ToBool(value));
  822. }
  823. }
  824. }
  825. std::shared_ptr<VOCOp> op;
  826. RETURN_IF_NOT_OK(builder->Build(&op));
  827. *ptr = op;
  828. return Status::OK();
  829. }
  830. Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  831. // Required arguments
  832. if (args["dataset_dir"].is_none()) {
  833. std::string err_msg = "Error: No dataset path specified";
  834. RETURN_STATUS_UNEXPECTED(err_msg);
  835. }
  836. std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>();
  837. (void)builder->SetCifarDir(ToString(args["dataset_dir"]));
  838. // Optional arguments
  839. for (auto arg : args) {
  840. std::string key = py::str(arg.first);
  841. py::handle value = arg.second;
  842. if (!value.is_none()) {
  843. if (key == "num_samples") {
  844. (void)builder->SetNumSamples(ToInt(value));
  845. } else if (key == "num_parallel_workers") {
  846. (void)builder->SetNumWorkers(ToInt(value));
  847. } else if (key == "sampler") {
  848. auto create = py::reinterpret_borrow<py::object>(value).attr("create");
  849. std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
  850. (void)builder->SetSampler(std::move(sampler));
  851. }
  852. }
  853. }
  854. (void)builder->SetCifarType(true);
  855. std::shared_ptr<CifarOp> op;
  856. RETURN_IF_NOT_OK(builder->Build(&op));
  857. *ptr = op;
  858. return Status::OK();
  859. }
  860. Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  861. // Required arguments
  862. if (args["dataset_dir"].is_none()) {
  863. std::string err_msg = "Error: No dataset path specified";
  864. RETURN_STATUS_UNEXPECTED(err_msg);
  865. }
  866. std::shared_ptr<CifarOp::Builder> builder = std::make_shared<CifarOp::Builder>();
  867. (void)builder->SetCifarDir(ToString(args["dataset_dir"]));
  868. // Optional arguments
  869. for (auto arg : args) {
  870. std::string key = py::str(arg.first);
  871. py::handle value = arg.second;
  872. if (!value.is_none()) {
  873. if (key == "num_samples") {
  874. (void)builder->SetNumSamples(ToInt(value));
  875. } else if (key == "num_parallel_workers") {
  876. (void)builder->SetNumWorkers(ToInt(value));
  877. } else if (key == "sampler") {
  878. auto create = py::reinterpret_borrow<py::object>(value).attr("create");
  879. std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
  880. (void)builder->SetSampler(std::move(sampler));
  881. }
  882. }
  883. }
  884. (void)builder->SetCifarType(false);
  885. std::shared_ptr<CifarOp> op;
  886. RETURN_IF_NOT_OK(builder->Build(&op));
  887. *ptr = op;
  888. return Status::OK();
  889. }
  890. int32_t DEPipeline::GetNumClasses() const { return num_classes_; }
  891. Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  892. // Required arguments
  893. if (args["dataset_dir"].is_none()) {
  894. std::string err_msg = "Error: No dataset path specified";
  895. RETURN_STATUS_UNEXPECTED(err_msg);
  896. }
  897. std::shared_ptr<MnistOp::Builder> builder = std::make_shared<MnistOp::Builder>();
  898. (void)builder->SetDir(ToString(args["dataset_dir"]));
  899. // Optional arguments
  900. for (auto arg : args) {
  901. std::string key = py::str(arg.first);
  902. py::handle value = arg.second;
  903. if (!value.is_none()) {
  904. if (key == "num_samples") {
  905. (void)builder->SetNumSamples(ToInt(value));
  906. } else if (key == "num_parallel_workers") {
  907. (void)builder->SetNumWorkers(ToInt(value));
  908. } else if (key == "sampler") {
  909. auto create = py::reinterpret_borrow<py::object>(value).attr("create");
  910. std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
  911. (void)builder->SetSampler(std::move(sampler));
  912. }
  913. }
  914. }
  915. std::shared_ptr<MnistOp> op;
  916. RETURN_IF_NOT_OK(builder->Build(&op));
  917. *ptr = op;
  918. return Status::OK();
  919. }
  920. Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  921. // Required arguments
  922. if (args["dataset_dir"].is_none()) {
  923. std::string err_msg = "Error: No dataset path specified";
  924. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
  925. }
  926. std::shared_ptr<CelebAOp::Builder> builder = std::make_shared<CelebAOp::Builder>();
  927. if (builder == nullptr) {
  928. std::string err_msg = "Create celebaop builder failed";
  929. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
  930. }
  931. (void)builder->SetCelebADir(ToString(args["dataset_dir"]));
  932. for (const auto &arg : args) {
  933. std::string key = py::str(arg.first);
  934. py::handle value = arg.second;
  935. if (!value.is_none()) {
  936. if (key == "num_parallel_workers") {
  937. (void)builder->SetNumWorkers(ToInt(value));
  938. } else if (key == "sampler") {
  939. auto create = py::reinterpret_borrow<py::object>(value).attr("create");
  940. std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
  941. (void)builder->SetSampler(std::move(sampler));
  942. } else if (key == "decode") {
  943. (void)builder->SetDecode(ToBool(value));
  944. } else if (key == "extensions") {
  945. (void)builder->SetExtensions(ToStringSet(value));
  946. } else if (key == "num_samples") {
  947. (void)builder->SetNumSamples(ToInt(value));
  948. } else if (key == "dataset_type") {
  949. (void)builder->SetDatasetType(ToString(value));
  950. }
  951. }
  952. }
  953. std::shared_ptr<CelebAOp> op;
  954. RETURN_IF_NOT_OK(builder->Build(&op));
  955. *ptr = op;
  956. return Status::OK();
  957. }
  958. Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
  959. // Required arguments
  960. std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>();
  961. if (!args["dataset_files"].is_none()) {
  962. (void)builder->SetTextFilesList(ToStringVector(args["dataset_files"]));
  963. } else {
  964. RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing");
  965. }
  966. // Optional arguments
  967. for (auto arg : args) {
  968. std::string key = py::str(arg.first);
  969. py::handle value = arg.second;
  970. if (!value.is_none()) {
  971. if (key == "num_parallel_workers") {
  972. (void)builder->SetNumWorkers(ToInt(value));
  973. } else if (key == "shuffle_files") {
  974. (void)builder->SetShuffleFiles(ToBool(value));
  975. } else if (key == "num_samples") {
  976. (void)builder->SetNumSamples(ToInt(value));
  977. } else if (key == "num_shards") {
  978. (void)builder->SetNumDevices(ToInt(value));
  979. } else if (key == "shard_id") {
  980. (void)builder->SetDeviceId(ToInt(value));
  981. }
  982. }
  983. }
  984. std::shared_ptr<TextFileOp> op;
  985. RETURN_IF_NOT_OK(builder->Build(&op));
  986. *ptr = op;
  987. return Status::OK();
  988. }
  989. } // namespace dataset
  990. } // namespace mindspore