You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.cc 42 kB

5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/include/datasets.h"
  17. #include <algorithm>
  18. #include <fstream>
  19. #include <unordered_set>
  20. #include <utility>
  21. #include "minddata/dataset/include/samplers.h"
  22. #include "minddata/dataset/include/transforms.h"
  23. #ifndef ENABLE_ANDROID
  24. #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
  25. #endif
  26. // Sampler headers (in alphabetical order)
  27. #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
  28. // IR non-leaf nodes
  29. #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
  30. #ifndef ENABLE_ANDROID
  31. #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
  32. #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
  33. #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
  34. #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
  35. #include "minddata/dataset/engine/ir/datasetops/filter_node.h"
  36. #endif
  37. #include "minddata/dataset/engine/ir/datasetops/map_node.h"
  38. #include "minddata/dataset/engine/ir/datasetops/project_node.h"
  39. #ifndef ENABLE_ANDROID
  40. #include "minddata/dataset/engine/ir/datasetops/rename_node.h"
  41. #endif
  42. #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
  43. #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
  44. #ifndef ENABLE_ANDROID
  45. #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
  46. #include "minddata/dataset/engine/ir/datasetops/take_node.h"
  47. #include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
  48. #include "minddata/dataset/engine/ir/datasetops/zip_node.h"
  49. #endif
  50. #include "minddata/dataset/core/config_manager.h"
  51. #include "minddata/dataset/util/path.h"
  52. #include "minddata/dataset/util/random.h"
  53. #include "minddata/dataset/util/services.h"
  54. // IR leaf nodes
  55. #include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
  56. // IR leaf nodes disabled for android
  57. #ifndef ENABLE_ANDROID
  58. #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
  59. #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
  60. #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
  61. #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
  62. #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
  63. #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
  64. #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
  65. #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
  66. #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
  67. #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
  68. #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
  69. #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
  70. #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
  71. #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
  72. #endif
  73. namespace mindspore {
  74. namespace dataset {
  75. // Function to create the iterator, which will build and launch the execution tree.
  76. std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) {
  77. std::shared_ptr<Iterator> iter;
  78. try {
  79. auto ds = shared_from_this();
  80. // The specified columns will be selected from the dataset and passed down the pipeline
  81. // in the order specified, other columns will be discarded.
  82. if (!columns.empty()) {
  83. ds = ds->Project(columns);
  84. }
  85. iter = std::make_shared<Iterator>();
  86. Status rc = iter->BuildAndLaunchTree(ds);
  87. if (rc.IsError()) {
  88. MS_LOG(ERROR) << "CreateIterator failed." << rc;
  89. return nullptr;
  90. }
  91. return iter;
  92. } catch (const std::exception &err) {
  93. MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << err.what();
  94. return nullptr;
  95. }
  96. return iter;
  97. }
  98. #ifndef ENABLE_ANDROID
  99. // Function to return a transferred Node that transfers data through a device.
  100. bool Dataset::DeviceQueue(bool send_epoch_end) {
  101. Status rc;
  102. // Build and launch tree
  103. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  104. rc = runtime_context->Init();
  105. if (rc.IsError()) {
  106. MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
  107. return false;
  108. }
  109. // Add TransferNode IR on top of dataset d
  110. auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), send_epoch_end);
  111. // Get ToDevice consumer
  112. auto consumer = std::make_unique<ToDevice>(send_epoch_end, -1);
  113. ToDevice *consumer_ = consumer.get();
  114. rc = consumer->Init(ds);
  115. if (rc.IsError()) {
  116. MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
  117. return false;
  118. }
  119. runtime_context->AssignConsumer(std::move(consumer));
  120. // Send data to device
  121. rc = consumer_->Send();
  122. if (rc.IsError()) {
  123. MS_LOG(ERROR) << "ToDevice: Failed to send data to device. Error status: " << rc;
  124. return false;
  125. }
  126. return true;
  127. }
  128. // Function to create the saver, which will build and launch the execution tree and save data
  129. bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string dataset_type) {
  130. Status rc;
  131. // Build and launch tree
  132. auto ds = shared_from_this();
  133. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  134. rc = runtime_context->Init();
  135. if (rc.IsError()) {
  136. MS_LOG(ERROR) << "CreateSaver failed." << rc;
  137. return false;
  138. }
  139. // Get SaveToDisk consumer
  140. auto consumer = std::make_unique<SaveToDisk>(dataset_path, num_files, dataset_type);
  141. rc = consumer->ValidateParams();
  142. if (rc.IsError()) {
  143. MS_LOG(ERROR) << "CreateSaver failed." << rc;
  144. return false;
  145. }
  146. SaveToDisk *consumer_ = consumer.get();
  147. rc = consumer->Init(ds->IRNode());
  148. if (rc.IsError()) {
  149. MS_LOG(ERROR) << "CreateSaver failed." << rc;
  150. return false;
  151. }
  152. runtime_context->AssignConsumer(std::move(consumer));
  153. // Save data into file
  154. rc = consumer_->Save();
  155. if (rc.IsError()) {
  156. MS_LOG(ERROR) << "Saver: Failed to save data into file. Error status: " << rc;
  157. return false;
  158. }
  159. // Shut down the data pipeline
  160. rc = runtime_context->Terminate();
  161. if (rc.IsError()) {
  162. MS_LOG(ERROR) << "Saver: Failed to shut down pipeline. Error status: " << rc;
  163. return false;
  164. }
  165. return true;
  166. }
  167. #endif
  168. // Constructor
  169. Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); }
  170. int64_t Dataset::GetDatasetSize() {
  171. int64_t dataset_size;
  172. Status rc;
  173. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  174. rc = runtime_context->Init();
  175. if (rc.IsError()) {
  176. MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
  177. return -1;
  178. }
  179. rc = tree_getters_->Init(this->IRNode());
  180. if (rc.IsError()) {
  181. MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
  182. return -1;
  183. }
  184. rc = tree_getters_->GetDatasetSize(&dataset_size);
  185. return rc.IsError() ? -1 : dataset_size;
  186. }
  187. std::vector<DataType> Dataset::GetOutputTypes() {
  188. std::vector<DataType> types;
  189. Status rc;
  190. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  191. rc = runtime_context->Init();
  192. if (rc.IsError()) {
  193. MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
  194. return types;
  195. }
  196. rc = tree_getters_->Init(this->IRNode());
  197. if (rc.IsError()) {
  198. MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
  199. return types;
  200. }
  201. rc = tree_getters_->GetOutputTypes(&types);
  202. if (rc.IsError()) {
  203. MS_LOG(ERROR) << "GetOutputTypes: Get Output Types failed.";
  204. types.clear();
  205. return types;
  206. }
  207. return types;
  208. }
  209. std::vector<TensorShape> Dataset::GetOutputShapes() {
  210. std::vector<TensorShape> shapes;
  211. Status rc;
  212. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  213. rc = runtime_context->Init();
  214. if (rc.IsError()) {
  215. MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
  216. return shapes;
  217. }
  218. rc = tree_getters_->Init(this->IRNode());
  219. if (rc.IsError()) {
  220. MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
  221. return shapes;
  222. }
  223. rc = tree_getters_->GetOutputShapes(&shapes);
  224. if (rc.IsError()) {
  225. MS_LOG(ERROR) << "GetOutputShapes: Get Output Shapes failed.";
  226. shapes.clear();
  227. return shapes;
  228. }
  229. return shapes;
  230. }
  231. int64_t Dataset::GetNumClasses() {
  232. int64_t num_classes;
  233. auto ds = shared_from_this();
  234. Status rc;
  235. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  236. rc = runtime_context->Init();
  237. if (rc.IsError()) {
  238. MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
  239. return -1;
  240. }
  241. rc = tree_getters_->Init(ds->IRNode());
  242. if (rc.IsError()) {
  243. MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
  244. return -1;
  245. }
  246. rc = tree_getters_->GetNumClasses(&num_classes);
  247. return rc.IsError() ? -1 : num_classes;
  248. }
  249. std::vector<std::string> Dataset::GetColumnNames() {
  250. std::vector<std::string> col_names;
  251. auto ds = shared_from_this();
  252. Status rc;
  253. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  254. rc = runtime_context->Init();
  255. if (rc.IsError()) {
  256. MS_LOG(ERROR) << "GetColumnNames: Initializing RuntimeContext failed.";
  257. return std::vector<std::string>();
  258. }
  259. rc = tree_getters_->Init(ds->IRNode());
  260. if (rc.IsError()) {
  261. MS_LOG(ERROR) << "GetColumnNames: Initializing TreeGetters failed.";
  262. return std::vector<std::string>();
  263. }
  264. rc = tree_getters_->GetColumnNames(&col_names);
  265. return rc.IsError() ? std::vector<std::string>() : col_names;
  266. }
  267. std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() {
  268. std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
  269. auto ds = shared_from_this();
  270. Status rc;
  271. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  272. rc = runtime_context->Init();
  273. if (rc.IsError()) {
  274. MS_LOG(ERROR) << "GetClassIndexing: Initializing RuntimeContext failed.";
  275. return output_class_indexing;
  276. }
  277. rc = tree_getters_->Init(ds->IRNode());
  278. if (rc.IsError()) {
  279. MS_LOG(ERROR) << "GetClassIndexing: Initializing TreeGetters failed.";
  280. return output_class_indexing;
  281. }
  282. rc = tree_getters_->GetClassIndexing(&output_class_indexing);
  283. if (rc.IsError()) {
  284. MS_LOG(ERROR) << "GetClassIndexing: Get Class Index failed.";
  285. output_class_indexing.clear();
  286. return output_class_indexing;
  287. }
  288. return output_class_indexing;
  289. }
  290. /// \brief Function to create a SchemaObj
  291. /// \param[in] schema_file Path of schema file
  292. /// \return Shared pointer to the current schema
  293. std::shared_ptr<SchemaObj> Schema(const std::string &schema_file) {
  294. auto schema = std::make_shared<SchemaObj>(schema_file);
  295. return schema->init() ? schema : nullptr;
  296. }
  297. // FUNCTIONS TO CREATE DATASETS FOR LEAF CLASSES
  298. // (In alphabetical order)
  299. // Function to create a AlbumDataset.
  300. std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::string &data_schema,
  301. const std::vector<std::string> &column_names, bool decode,
  302. const std::shared_ptr<SamplerObj> &sampler,
  303. const std::shared_ptr<DatasetCache> &cache) {
  304. auto ds = std::make_shared<AlbumDataset>(dataset_dir, data_schema, column_names, decode, sampler, cache);
  305. return ds;
  306. }
  307. #ifndef ENABLE_ANDROID
  308. // Function to create a CelebADataset.
  309. std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage,
  310. const std::shared_ptr<SamplerObj> &sampler, bool decode,
  311. const std::set<std::string> &extensions,
  312. const std::shared_ptr<DatasetCache> &cache) {
  313. auto ds = std::make_shared<CelebADataset>(dataset_dir, usage, sampler, decode, extensions, cache);
  314. return ds;
  315. }
  316. // Function to create a Cifar10Dataset.
  317. std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage,
  318. const std::shared_ptr<SamplerObj> &sampler,
  319. const std::shared_ptr<DatasetCache> &cache) {
  320. auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, usage, sampler, cache);
  321. return ds;
  322. }
  323. // Function to create a Cifar100Dataset.
  324. std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage,
  325. const std::shared_ptr<SamplerObj> &sampler,
  326. const std::shared_ptr<DatasetCache> &cache) {
  327. auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, usage, sampler, cache);
  328. return ds;
  329. }
  330. // Function to create a CLUEDataset.
  331. std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &clue_files, const std::string &task,
  332. const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
  333. int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
  334. auto ds = std::make_shared<CLUEDataset>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache);
  335. return ds;
  336. }
  337. // Function to create a CocoDataset.
  338. std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
  339. const std::string &task, const bool &decode,
  340. const std::shared_ptr<SamplerObj> &sampler,
  341. const std::shared_ptr<DatasetCache> &cache) {
  342. auto ds = std::make_shared<CocoDataset>(dataset_dir, annotation_file, task, decode, sampler, cache);
  343. return ds;
  344. }
  345. // Function to create a CSVDataset.
  346. std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, char field_delim,
  347. const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
  348. const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
  349. int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
  350. auto ds = std::make_shared<CSVDataset>(dataset_files, field_delim, column_defaults, column_names, num_samples,
  351. shuffle, num_shards, shard_id, cache);
  352. return ds;
  353. }
  354. // Function to create a ImageFolderDataset.
  355. std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode,
  356. const std::shared_ptr<SamplerObj> &sampler,
  357. const std::set<std::string> &extensions,
  358. const std::map<std::string, int32_t> &class_indexing,
  359. const std::shared_ptr<DatasetCache> &cache) {
  360. auto ds = std::make_shared<ImageFolderDataset>(dataset_dir, decode, sampler, extensions, class_indexing, cache);
  361. return ds;
  362. }
  363. // Function to create a ManifestDataset.
  364. std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const std::string &usage,
  365. const std::shared_ptr<SamplerObj> &sampler,
  366. const std::map<std::string, int32_t> &class_indexing, bool decode,
  367. const std::shared_ptr<DatasetCache> &cache) {
  368. auto ds = std::make_shared<ManifestDataset>(dataset_file, usage, sampler, class_indexing, decode, cache);
  369. return ds;
  370. }
  371. // Function to create a MindDataDataset.
  372. std::shared_ptr<MindDataDataset> MindData(const std::string &dataset_file, const std::vector<std::string> &columns_list,
  373. const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
  374. int64_t num_padded) {
  375. auto ds = std::make_shared<MindDataDataset>(dataset_file, columns_list, sampler, padded_sample, num_padded);
  376. return ds;
  377. }
  378. // Function to create a MindDataDataset.
  379. std::shared_ptr<MindDataDataset> MindData(const std::vector<std::string> &dataset_files,
  380. const std::vector<std::string> &columns_list,
  381. const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
  382. int64_t num_padded) {
  383. auto ds = std::make_shared<MindDataDataset>(dataset_files, columns_list, sampler, padded_sample, num_padded);
  384. return ds;
  385. }
  386. // Function to create a MnistDataset.
  387. std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage,
  388. const std::shared_ptr<SamplerObj> &sampler,
  389. const std::shared_ptr<DatasetCache> &cache) {
  390. auto ds = std::make_shared<MnistDataset>(dataset_dir, usage, sampler, cache);
  391. return ds;
  392. }
  393. // Function to overload "+" operator to concat two datasets
  394. std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
  395. const std::shared_ptr<Dataset> &datasets2) {
  396. return std::make_shared<ConcatDataset>(std::vector({datasets2, datasets1}));
  397. }
  398. // Function to create a TextFileDataset.
  399. std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int64_t num_samples,
  400. ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
  401. const std::shared_ptr<DatasetCache> &cache) {
  402. auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache);
  403. return ds;
  404. }
  405. // Function to create a VOCDataset.
  406. std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage,
  407. const std::map<std::string, int32_t> &class_indexing, bool decode,
  408. const std::shared_ptr<SamplerObj> &sampler,
  409. const std::shared_ptr<DatasetCache> &cache) {
  410. auto ds = std::make_shared<VOCDataset>(dataset_dir, task, usage, class_indexing, decode, sampler, cache);
  411. return ds;
  412. }
  413. // Function to create a ZipDatset.
  414. std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
  415. auto ds = std::make_shared<ZipDataset>(datasets);
  416. return ds;
  417. }
  418. #endif
  419. // FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
  420. // (In alphabetical order)
  421. // Function to create a Batch dataset
  422. BatchDataset::BatchDataset(std::shared_ptr<Dataset> input, int32_t batch_size, bool drop_remainder) {
  423. // Default values
  424. auto ds = std::make_shared<BatchNode>(input->IRNode(), batch_size, drop_remainder);
  425. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  426. }
  427. #ifndef ENABLE_ANDROID
  428. // Function to create a BucketBatchByLength dataset
  429. BucketBatchByLengthDataset::BucketBatchByLengthDataset(
  430. std::shared_ptr<Dataset> input, const std::vector<std::string> &column_names,
  431. const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
  432. std::function<TensorRow(TensorRow)> element_length_function,
  433. const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,
  434. bool drop_remainder) {
  435. auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries,
  436. bucket_batch_sizes, element_length_function, pad_info,
  437. pad_to_bucket_boundary, drop_remainder);
  438. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  439. }
  440. ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
  441. std::vector<std::shared_ptr<DatasetNode>> all_datasets;
  442. (void)std::transform(datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
  443. [](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> {
  444. return (dataset != nullptr) ? dataset->IRNode() : nullptr;
  445. });
  446. auto ds = std::make_shared<ConcatNode>(all_datasets);
  447. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  448. }
  449. FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<TensorRow(TensorRow)> predicate,
  450. std::vector<std::string> input_columns) {
  451. auto ds = std::make_shared<FilterNode>(input->IRNode(), predicate, input_columns);
  452. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  453. }
  454. #endif
  455. MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
  456. std::vector<std::string> input_columns, std::vector<std::string> output_columns,
  457. const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache,
  458. std::vector<std::shared_ptr<DSCallback>> callbacks) {
  459. auto ds = std::make_shared<MapNode>(input->IRNode(), operations, input_columns, output_columns, project_columns,
  460. cache, callbacks);
  461. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  462. }
  463. ProjectDataset::ProjectDataset(std::shared_ptr<Dataset> input, const std::vector<std::string> &columns) {
  464. auto ds = std::make_shared<ProjectNode>(input->IRNode(), columns);
  465. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  466. }
  467. #ifndef ENABLE_ANDROID
  468. RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<std::string> &input_columns,
  469. const std::vector<std::string> &output_columns) {
  470. auto ds = std::make_shared<RenameNode>(input->IRNode(), input_columns, output_columns);
  471. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  472. }
  473. #endif
  474. RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) {
  475. // Workaround for repeat == 1, do not inject repeat.
  476. if (count == 1) {
  477. ir_node_ = input->IRNode();
  478. return;
  479. }
  480. auto ds = std::make_shared<RepeatNode>(input->IRNode(), count);
  481. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  482. }
  483. ShuffleDataset::ShuffleDataset(std::shared_ptr<Dataset> input, int32_t buffer_size) {
  484. // Pass in reshuffle_each_epoch with true
  485. auto ds = std::make_shared<ShuffleNode>(input->IRNode(), buffer_size, true);
  486. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  487. }
  488. #ifndef ENABLE_ANDROID
  489. SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) {
  490. auto ds = std::make_shared<SkipNode>(input->IRNode(), count);
  491. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  492. }
  493. TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
  494. // If count is greater than the number of element in dataset or equal to -1,
  495. // all the element in dataset will be taken
  496. if (count == -1) {
  497. ir_node_ = input->IRNode();
  498. return;
  499. }
  500. auto ds = std::make_shared<TakeNode>(input->IRNode(), count);
  501. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  502. }
  503. ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
  504. std::vector<std::shared_ptr<DatasetNode>> all_datasets;
  505. (void)std::transform(
  506. datasets.begin(), datasets.end(), std::back_inserter(all_datasets),
  507. [](std::shared_ptr<Dataset> dataset) -> std::shared_ptr<DatasetNode> { return dataset->IRNode(); });
  508. auto ds = std::make_shared<ZipNode>(all_datasets);
  509. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  510. }
  511. #endif
  512. int64_t Dataset::GetBatchSize() {
  513. int64_t batch_size;
  514. auto ds = shared_from_this();
  515. Status rc;
  516. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  517. rc = runtime_context->Init();
  518. if (rc.IsError()) {
  519. MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
  520. return -1;
  521. }
  522. rc = tree_getters_->Init(ds->IRNode());
  523. if (rc.IsError()) {
  524. MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
  525. return -1;
  526. }
  527. rc = tree_getters_->GetBatchSize(&batch_size);
  528. return rc.IsError() ? -1 : batch_size;
  529. }
  530. int64_t Dataset::GetRepeatCount() {
  531. int64_t repeat_count;
  532. auto ds = shared_from_this();
  533. Status rc;
  534. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  535. rc = runtime_context->Init();
  536. if (rc.IsError()) {
  537. MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
  538. return -1;
  539. }
  540. rc = tree_getters_->Init(ds->IRNode());
  541. if (rc.IsError()) {
  542. MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
  543. return -1;
  544. }
  545. rc = tree_getters_->GetRepeatCount(&repeat_count);
  546. return rc.IsError() ? 0 : repeat_count;
  547. }
  548. std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
  549. if (ir_node_ == nullptr || ir_node_->SetNumWorkers(num_workers) == nullptr) {
  550. return nullptr;
  551. }
  552. return shared_from_this();
  553. }
  554. #ifndef ENABLE_ANDROID
  555. std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
  556. const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
  557. SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params) {
  558. auto vocab = std::make_shared<SentencePieceVocab>();
  559. auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
  560. model_type, params);
  561. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  562. Status rc = runtime_context->Init();
  563. if (rc.IsError()) {
  564. MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init runtime context. Error status: " << rc;
  565. return nullptr;
  566. }
  567. auto consumer = std::make_unique<BuildVocabConsumer>();
  568. BuildVocabConsumer *bv_consumer = consumer.get();
  569. rc = consumer->Init(ds);
  570. if (rc.IsError()) {
  571. MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to init consumer. Error status: " << rc;
  572. return nullptr;
  573. }
  574. runtime_context->AssignConsumer(std::move(consumer));
  575. // Run tree here to starting building SentencePieceVocab
  576. rc = bv_consumer->Start();
  577. if (rc.IsError()) {
  578. MS_LOG(ERROR) << "BuildSentencePieceVocab: Failed to start consumer. Error status: " << rc;
  579. return nullptr;
  580. }
  581. return vocab;
  582. }
  583. std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &columns,
  584. const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
  585. const std::vector<std::string> &special_tokens, bool special_first) {
  586. auto vocab = std::make_shared<Vocab>();
  587. auto ds =
  588. std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
  589. std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
  590. Status rc = runtime_context->Init();
  591. if (rc.IsError()) {
  592. MS_LOG(ERROR) << "BuildVocab: Failed to init runtime context. Error status: " << rc;
  593. return nullptr;
  594. }
  595. auto consumer = std::make_unique<BuildVocabConsumer>();
  596. BuildVocabConsumer *bv_consumer = consumer.get();
  597. rc = consumer->Init(ds);
  598. if (rc.IsError()) {
  599. MS_LOG(ERROR) << "BuildVocab: Failed to init consumer. Error status: " << rc;
  600. return nullptr;
  601. }
  602. runtime_context->AssignConsumer(std::move(consumer));
  603. // Run tree here to starting building vocab
  604. rc = bv_consumer->Start();
  605. if (rc.IsError()) {
  606. MS_LOG(ERROR) << "BuildVocab: Failed to start consumer. Error status: " << rc;
  607. return nullptr;
  608. }
  609. return vocab;
  610. }
  611. #endif
  612. std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
  613. return std::make_shared<BatchDataset>(shared_from_this(), batch_size, drop_remainder);
  614. }
  615. SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}
  616. // SchemaObj init function
  617. bool SchemaObj::init() {
  618. if (schema_file_ != "") {
  619. Path schema_file(schema_file_);
  620. if (!schema_file.Exists()) {
  621. MS_LOG(ERROR) << "The file " << schema_file << " does not exist or permission denied!";
  622. return false;
  623. }
  624. nlohmann::json js;
  625. try {
  626. std::ifstream in(schema_file_);
  627. in >> js;
  628. if (js.find("columns") == js.end()) {
  629. MS_LOG(ERROR) << "\"columns\" node is required in the schema json file.";
  630. return false;
  631. }
  632. } catch (const std::exception &err) {
  633. MS_LOG(ERROR) << "Schema file failed to load";
  634. return false;
  635. }
  636. return from_json(js);
  637. }
  638. return true;
  639. }
  640. // Function to add a column to schema with a mstype de_type
  641. bool SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) {
  642. nlohmann::json new_column;
  643. new_column["name"] = name;
  644. // if de_type is mstype
  645. DataType data_type = dataset::MSTypeToDEType(de_type);
  646. new_column["type"] = data_type.ToString();
  647. if (shape.size() > 0) {
  648. new_column["shape"] = shape;
  649. new_column["rank"] = shape.size();
  650. } else {
  651. new_column["rank"] = 1;
  652. }
  653. columns_.push_back(new_column);
  654. return true;
  655. }
  656. // Function to add a column to schema with a string de_type
  657. bool SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) {
  658. nlohmann::json new_column;
  659. new_column["name"] = name;
  660. DataType data_type(de_type);
  661. new_column["type"] = data_type.ToString();
  662. if (shape.size() > 0) {
  663. new_column["shape"] = shape;
  664. new_column["rank"] = shape.size();
  665. } else {
  666. new_column["rank"] = 1;
  667. }
  668. columns_.push_back(new_column);
  669. return true;
  670. }
  671. std::string SchemaObj::to_json() {
  672. nlohmann::json json_file;
  673. json_file["columns"] = columns_;
  674. if (dataset_type_ != "") {
  675. json_file["datasetType"] = dataset_type_;
  676. }
  677. if (num_rows_ > 0) {
  678. json_file["numRows"] = num_rows_;
  679. }
  680. return json_file.dump(2);
  681. }
  682. bool SchemaObj::parse_column(nlohmann::json columns) {
  683. std::string name, de_type;
  684. std::vector<int32_t> shape;
  685. columns_.clear();
  686. if (columns.type() == nlohmann::json::value_t::array) {
  687. // reference to python list
  688. for (auto column : columns) {
  689. auto key_name = column.find("name");
  690. if (key_name == column.end()) {
  691. MS_LOG(ERROR) << "Column's name is missing";
  692. return false;
  693. }
  694. name = *key_name;
  695. auto key_type = column.find("type");
  696. if (key_type == column.end()) {
  697. MS_LOG(ERROR) << "Column's type is missing";
  698. return false;
  699. }
  700. de_type = *key_type;
  701. shape.clear();
  702. auto key_shape = column.find("shape");
  703. if (key_shape != column.end()) {
  704. shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
  705. }
  706. if (!add_column(name, de_type, shape)) {
  707. return false;
  708. }
  709. }
  710. } else if (columns.type() == nlohmann::json::value_t::object) {
  711. for (const auto &it_child : columns.items()) {
  712. name = it_child.key();
  713. auto key_type = it_child.value().find("type");
  714. if (key_type == it_child.value().end()) {
  715. MS_LOG(ERROR) << "Column's type is missing";
  716. return false;
  717. }
  718. de_type = *key_type;
  719. shape.clear();
  720. auto key_shape = it_child.value().find("shape");
  721. if (key_shape != it_child.value().end()) {
  722. shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end());
  723. }
  724. if (!add_column(name, de_type, shape)) {
  725. return false;
  726. }
  727. }
  728. } else {
  729. MS_LOG(ERROR) << "columns must be dict or list, columns contain name, type, shape(optional).";
  730. return false;
  731. }
  732. return true;
  733. }
  734. bool SchemaObj::from_json(nlohmann::json json_obj) {
  735. for (const auto &it_child : json_obj.items()) {
  736. if (it_child.key() == "datasetType") {
  737. dataset_type_ = it_child.value();
  738. } else if (it_child.key() == "numRows") {
  739. num_rows_ = it_child.value();
  740. } else if (it_child.key() == "columns") {
  741. if (!parse_column(it_child.value())) {
  742. MS_LOG(ERROR) << "parse columns failed";
  743. return false;
  744. }
  745. } else {
  746. MS_LOG(ERROR) << "Unknown field " << it_child.key();
  747. return false;
  748. }
  749. }
  750. if (columns_.empty()) {
  751. MS_LOG(ERROR) << "Columns are missing.";
  752. return false;
  753. }
  754. if (num_rows_ <= 0) {
  755. MS_LOG(ERROR) << "numRows must be greater than 0";
  756. return false;
  757. }
  758. return true;
  759. }
  760. // OTHER FUNCTIONS
  761. #ifndef ENABLE_ANDROID
  762. std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t mem_sz, bool spill,
  763. std::optional<std::string> hostname, std::optional<int32_t> port,
  764. std::optional<int32_t> num_connections,
  765. std::optional<int32_t> prefetch_sz) {
  766. auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
  767. return cache;
  768. }
  769. #endif
  770. AlbumDataset::AlbumDataset(const std::string &dataset_dir, const std::string &data_schema,
  771. const std::vector<std::string> &column_names, bool decode,
  772. const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) {
  773. auto ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, cache);
  774. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  775. }
  776. #ifndef ENABLE_ANDROID
  777. CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &usage,
  778. const std::shared_ptr<SamplerObj> &sampler, bool decode,
  779. const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache) {
  780. auto ds = std::make_shared<CelebANode>(dataset_dir, usage, sampler, decode, extensions, cache);
  781. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  782. }
  783. Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, const std::string &usage,
  784. const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) {
  785. auto ds = std::make_shared<Cifar10Node>(dataset_dir, usage, sampler, cache);
  786. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  787. }
  788. Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, const std::string &usage,
  789. const std::shared_ptr<SamplerObj> &sampler,
  790. const std::shared_ptr<DatasetCache> &cache) {
  791. auto ds = std::make_shared<Cifar100Node>(dataset_dir, usage, sampler, cache);
  792. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  793. }
  794. CLUEDataset::CLUEDataset(const std::vector<std::string> &dataset_files, const std::string &task,
  795. const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards,
  796. int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
  797. auto ds = std::make_shared<CLUENode>(dataset_files, task, usage, num_samples, shuffle, num_shards, shard_id, cache);
  798. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  799. }
  800. CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
  801. const bool &decode, const std::shared_ptr<SamplerObj> &sampler,
  802. const std::shared_ptr<DatasetCache> &cache) {
  803. auto ds = std::make_shared<CocoNode>(dataset_dir, annotation_file, task, decode, sampler, cache);
  804. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  805. }
  806. CSVDataset::CSVDataset(const std::vector<std::string> &dataset_files, char field_delim,
  807. const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
  808. const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
  809. int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) {
  810. auto ds = std::make_shared<CSVNode>(dataset_files, field_delim, column_defaults, column_names, num_samples, shuffle,
  811. num_shards, shard_id, cache);
  812. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  813. }
  814. ImageFolderDataset::ImageFolderDataset(const std::string &dataset_dir, bool decode,
  815. const std::shared_ptr<SamplerObj> &sampler,
  816. const std::set<std::string> &extensions,
  817. const std::map<std::string, int32_t> &class_indexing,
  818. const std::shared_ptr<DatasetCache> &cache) {
  819. // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
  820. bool recursive = false;
  821. // Create logical representation of ImageFolderDataset.
  822. auto ds =
  823. std::make_shared<ImageFolderNode>(dataset_dir, decode, sampler, recursive, extensions, class_indexing, cache);
  824. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  825. }
  826. ManifestDataset::ManifestDataset(const std::string &dataset_file, const std::string &usage,
  827. const std::shared_ptr<SamplerObj> &sampler,
  828. const std::map<std::string, int32_t> &class_indexing, bool decode,
  829. const std::shared_ptr<DatasetCache> &cache) {
  830. auto ds = std::make_shared<ManifestNode>(dataset_file, usage, sampler, class_indexing, decode, cache);
  831. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  832. }
  833. MindDataDataset::MindDataDataset(const std::string &dataset_file, const std::vector<std::string> &columns_list,
  834. const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
  835. int64_t num_padded) {
  836. auto ds = std::make_shared<MindDataNode>(dataset_file, columns_list, sampler, padded_sample, num_padded);
  837. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  838. }
  839. MindDataDataset::MindDataDataset(const std::vector<std::string> &dataset_files,
  840. const std::vector<std::string> &columns_list,
  841. const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
  842. int64_t num_padded) {
  843. auto ds = std::make_shared<MindDataNode>(dataset_files, columns_list, sampler, padded_sample, num_padded);
  844. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  845. }
  846. MnistDataset::MnistDataset(const std::string &dataset_dir, const std::string &usage,
  847. const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) {
  848. auto ds = std::make_shared<MnistNode>(dataset_dir, usage, sampler, cache);
  849. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  850. }
  851. TextFileDataset::TextFileDataset(const std::vector<std::string> &dataset_files, int64_t num_samples,
  852. ShuffleMode shuffle, int32_t num_shards, int32_t shard_id,
  853. const std::shared_ptr<DatasetCache> &cache) {
  854. auto ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache);
  855. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  856. }
  857. VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &usage,
  858. const std::map<std::string, int32_t> &class_indexing, bool decode,
  859. const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) {
  860. auto ds = std::make_shared<VOCNode>(dataset_dir, task, usage, class_indexing, decode, sampler, cache);
  861. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  862. }
  863. RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
  864. const std::vector<std::string> &columns_list,
  865. std::shared_ptr<DatasetCache> cache) {
  866. auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), cache);
  867. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  868. }
  869. RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::string schema_path,
  870. const std::vector<std::string> &columns_list,
  871. std::shared_ptr<DatasetCache> cache) {
  872. auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema_path), std::move(columns_list), cache);
  873. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  874. }
  875. TFRecordDataset::TFRecordDataset(const std::vector<std::string> &dataset_files, std::string schema,
  876. const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
  877. int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
  878. std::shared_ptr<DatasetCache> cache) {
  879. auto ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards,
  880. shard_id, shard_equal_rows, cache);
  881. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  882. }
  883. TFRecordDataset::TFRecordDataset(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
  884. const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
  885. int32_t num_shards, int32_t shard_id, bool shard_equal_rows,
  886. std::shared_ptr<DatasetCache> cache) {
  887. auto ds = std::make_shared<TFRecordNode>(dataset_files, schema, columns_list, num_samples, shuffle, num_shards,
  888. shard_id, shard_equal_rows, cache);
  889. ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
  890. }
  891. #endif
  892. } // namespace dataset
  893. } // namespace mindspore