You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.cc 50 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <fstream>
  17. #include "minddata/dataset/include/datasets.h"
  18. #include "minddata/dataset/include/samplers.h"
  19. #include "minddata/dataset/include/transforms.h"
  20. #include "minddata/dataset/engine/dataset_iterator.h"
  21. // Source dataset headers (in alphabetical order)
  22. #include "minddata/dataset/engine/datasetops/source/celeba_op.h"
  23. #include "minddata/dataset/engine/datasetops/source/cifar_op.h"
  24. #include "minddata/dataset/engine/datasetops/source/clue_op.h"
  25. #include "minddata/dataset/engine/datasetops/source/coco_op.h"
  26. #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
  27. #include "minddata/dataset/engine/datasetops/source/mnist_op.h"
  28. #include "minddata/dataset/engine/datasetops/source/text_file_op.h"
  29. #include "minddata/dataset/engine/datasetops/source/voc_op.h"
  30. // Dataset operator headers (in alphabetical order)
  31. #include "minddata/dataset/engine/datasetops/batch_op.h"
  32. #include "minddata/dataset/engine/datasetops/concat_op.h"
  33. #include "minddata/dataset/engine/datasetops/map_op/map_op.h"
  34. #include "minddata/dataset/engine/datasetops/project_op.h"
  35. #include "minddata/dataset/engine/datasetops/rename_op.h"
  36. #include "minddata/dataset/engine/datasetops/repeat_op.h"
  37. #include "minddata/dataset/engine/datasetops/shuffle_op.h"
  38. #include "minddata/dataset/engine/datasetops/skip_op.h"
  39. #include "minddata/dataset/engine/datasetops/take_op.h"
  40. #include "minddata/dataset/engine/datasetops/zip_op.h"
  41. // Sampler headers (in alphabetical order)
  42. #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
  43. #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
  44. #include "minddata/dataset/core/config_manager.h"
  45. #include "minddata/dataset/util/random.h"
  46. #include "minddata/dataset/util/path.h"
  47. namespace mindspore {
  48. namespace dataset {
  49. namespace api {
  50. #define RETURN_EMPTY_IF_ERROR(_s) \
  51. do { \
  52. Status __rc = (_s); \
  53. if (__rc.IsError()) { \
  54. MS_LOG(ERROR) << __rc; \
  55. return {}; \
  56. } \
  57. } while (false)
  58. // Function to create the iterator, which will build and launch the execution tree.
  59. std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) {
  60. std::shared_ptr<Iterator> iter;
  61. try {
  62. auto ds = shared_from_this();
  63. // The specified columns will be selected from the dataset and passed down the pipeline
  64. // in the order specified, other columns will be discarded.
  65. if (!columns.empty()) {
  66. ds = ds->Project(columns);
  67. }
  68. iter = std::make_shared<Iterator>();
  69. Status rc = iter->BuildAndLaunchTree(ds);
  70. if (rc.IsError()) {
  71. MS_LOG(ERROR) << "CreateIterator failed." << rc;
  72. return nullptr;
  73. }
  74. return iter;
  75. } catch (const std::exception &err) {
  76. MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << err.what();
  77. return nullptr;
  78. }
  79. return iter;
  80. }
  81. // Constructor
  82. Dataset::Dataset() {
  83. // Fetch some default value from config manager
  84. std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  85. num_workers_ = cfg->num_parallel_workers();
  86. rows_per_buffer_ = cfg->rows_per_buffer();
  87. connector_que_size_ = cfg->op_connector_size();
  88. worker_connector_size_ = cfg->worker_connector_size();
  89. }
  90. // FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
  91. // (In alphabetical order)
  92. // Function to create a CelebADataset.
  93. std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type,
  94. const std::shared_ptr<SamplerObj> &sampler, bool decode,
  95. const std::set<std::string> &extensions) {
  96. auto ds = std::make_shared<CelebADataset>(dataset_dir, dataset_type, sampler, decode, extensions);
  97. // Call derived class validation method.
  98. return ds->ValidateParams() ? ds : nullptr;
  99. }
  100. // Function to create a Cifar10Dataset.
  101. std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
  102. auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler);
  103. // Call derived class validation method.
  104. return ds->ValidateParams() ? ds : nullptr;
  105. }
  106. // Function to create a Cifar100Dataset.
  107. std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
  108. auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, sampler);
  109. // Call derived class validation method.
  110. return ds->ValidateParams() ? ds : nullptr;
  111. }
  112. // Function to create a CLUEDataset.
  113. std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &clue_files, const std::string &task,
  114. const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
  115. int32_t num_shards, int32_t shard_id) {
  116. auto ds = std::make_shared<CLUEDataset>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id);
  117. // Call derived class validation method.
  118. return ds->ValidateParams() ? ds : nullptr;
  119. }
  120. // Function to create a CocoDataset.
  121. std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
  122. const std::string &task, const bool &decode,
  123. const std::shared_ptr<SamplerObj> &sampler) {
  124. auto ds = std::make_shared<CocoDataset>(dataset_dir, annotation_file, task, decode, sampler);
  125. // Call derived class validation method.
  126. return ds->ValidateParams() ? ds : nullptr;
  127. }
  128. // Function to create a ImageFolderDataset.
  129. std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode,
  130. const std::shared_ptr<SamplerObj> &sampler,
  131. const std::set<std::string> &extensions,
  132. const std::map<std::string, int32_t> &class_indexing) {
  133. // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
  134. bool recursive = false;
  135. // Create logical representation of ImageFolderDataset.
  136. auto ds = std::make_shared<ImageFolderDataset>(dataset_dir, decode, sampler, recursive, extensions, class_indexing);
  137. // Call derived class validation method.
  138. return ds->ValidateParams() ? ds : nullptr;
  139. }
  140. // Function to create a MnistDataset.
  141. std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
  142. auto ds = std::make_shared<MnistDataset>(dataset_dir, sampler);
  143. // Call derived class validation method.
  144. return ds->ValidateParams() ? ds : nullptr;
  145. }
  146. // Function to overload "+" operator to concat two datasets
  147. std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
  148. const std::shared_ptr<Dataset> &datasets2) {
  149. std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2}));
  150. // Call derived class validation method.
  151. return ds->ValidateParams() ? ds : nullptr;
  152. }
  153. // Function to create a TextFileDataset.
  154. std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples,
  155. ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) {
  156. auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id);
  157. // Call derived class validation method.
  158. return ds->ValidateParams() ? ds : nullptr;
  159. }
  160. // Function to create a VOCDataset.
  161. std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode,
  162. const std::map<std::string, int32_t> &class_indexing, bool decode,
  163. const std::shared_ptr<SamplerObj> &sampler) {
  164. auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_indexing, decode, sampler);
  165. // Call derived class validation method.
  166. return ds->ValidateParams() ? ds : nullptr;
  167. }
  168. // Function to create a ZipDataset.
  169. std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
  170. auto ds = std::make_shared<ZipDataset>(datasets);
  171. // Call derived class validation method.
  172. return ds->ValidateParams() ? ds : nullptr;
  173. }
  174. // FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
  175. // (In alphabetical order)
  176. // Function to create a Batch dataset
  177. std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
  178. // Default values
  179. std::vector<std::string> cols_to_map = {};
  180. std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map;
  181. bool pad = false;
  182. auto ds = std::make_shared<BatchDataset>(batch_size, drop_remainder, pad, cols_to_map, pad_map);
  183. if (!ds->ValidateParams()) {
  184. return nullptr;
  185. }
  186. ds->children.push_back(shared_from_this());
  187. return ds;
  188. }
  189. // Function to create a Concat dataset
  190. std::shared_ptr<ConcatDataset> Dataset::Concat(const std::vector<std::shared_ptr<Dataset>> &datasets) {
  191. auto ds = std::make_shared<ConcatDataset>(datasets);
  192. ds->children.push_back(shared_from_this());
  193. return ds->ValidateParams() ? ds : nullptr;
  194. }
  195. // Function to create a Map dataset.
  196. std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations,
  197. std::vector<std::string> input_columns,
  198. std::vector<std::string> output_columns,
  199. const std::vector<std::string> &project_columns) {
  200. auto ds = std::make_shared<MapDataset>(operations, input_columns, output_columns, project_columns);
  201. if (!ds->ValidateParams()) {
  202. return nullptr;
  203. }
  204. ds->children.push_back(shared_from_this());
  205. return ds;
  206. }
  207. // Function to create a ProjectDataset.
  208. std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) {
  209. auto ds = std::make_shared<ProjectDataset>(columns);
  210. // Call derived class validation method.
  211. if (!ds->ValidateParams()) {
  212. return nullptr;
  213. }
  214. ds->children.push_back(shared_from_this());
  215. return ds;
  216. }
  217. // Function to create a RenameDataset.
  218. std::shared_ptr<RenameDataset> Dataset::Rename(const std::vector<std::string> &input_columns,
  219. const std::vector<std::string> &output_columns) {
  220. auto ds = std::make_shared<RenameDataset>(input_columns, output_columns);
  221. // Call derived class validation method.
  222. if (!ds->ValidateParams()) {
  223. return nullptr;
  224. }
  225. ds->children.push_back(shared_from_this());
  226. return ds;
  227. }
  228. // Function to create Repeat dataset.
  229. std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) {
  230. // Workaround for repeat == 1, do not inject repeat.
  231. if (count == 1) {
  232. return shared_from_this();
  233. }
  234. auto ds = std::make_shared<RepeatDataset>(count);
  235. if (!ds->ValidateParams()) {
  236. return nullptr;
  237. }
  238. ds->children.push_back(shared_from_this());
  239. return ds;
  240. }
  241. // Function to create a ShuffleOp
  242. std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t buffer_size) {
  243. // Pass in reshuffle_each_epoch with true
  244. auto ds = std::make_shared<ShuffleDataset>(buffer_size, true);
  245. if (!ds->ValidateParams()) {
  246. return nullptr;
  247. }
  248. ds->children.push_back(shared_from_this());
  249. return ds;
  250. }
  251. // Function to create a SkipDataset.
  252. std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
  253. auto ds = std::make_shared<SkipDataset>(count);
  254. // Call derived class validation method.
  255. if (!ds->ValidateParams()) {
  256. return nullptr;
  257. }
  258. ds->children.push_back(shared_from_this());
  259. return ds;
  260. }
  261. // Function to create a TakeDataset.
  262. std::shared_ptr<Dataset> Dataset::Take(int32_t count) {
  263. // If count is greater than the number of element in dataset or equal to -1,
  264. // all the element in dataset will be taken
  265. if (count == -1) {
  266. return shared_from_this();
  267. }
  268. auto ds = std::make_shared<TakeDataset>(count);
  269. // Call derived class validation method.
  270. if (!ds->ValidateParams()) {
  271. return nullptr;
  272. }
  273. ds->children.push_back(shared_from_this());
  274. return ds;
  275. }
  276. // Function to create a Zip dataset
  277. std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
  278. // Default values
  279. auto ds = std::make_shared<ZipDataset>(datasets);
  280. ds->children.push_back(shared_from_this());
  281. return ds->ValidateParams() ? ds : nullptr;
  282. }
  283. // OTHER FUNCTIONS
  284. // Helper function to create default RandomSampler.
  285. std::shared_ptr<SamplerObj> CreateDefaultSampler() {
  286. const int32_t num_samples = 0; // 0 means to sample all ids.
  287. bool replacement = false;
  288. return std::make_shared<RandomSamplerObj>(replacement, num_samples);
  289. }
  290. // Helper function to compute a default shuffle size
  291. Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
  292. int64_t *shuffle_size) {
  293. const int64_t average_files_multiplier = 4;
  294. const int64_t shuffle_max = 10000;
  295. int64_t avg_rows_per_file = 0;
  296. // Adjust the num rows per shard if sharding was given
  297. if (num_devices > 0) {
  298. if (num_rows % num_devices == 0) {
  299. num_rows = num_rows / num_devices;
  300. } else {
  301. num_rows = (num_rows / num_devices) + 1;
  302. }
  303. }
  304. // Cap based on total rows directive. Some ops do not have this and give value of 0.
  305. if (total_rows > 0) {
  306. num_rows = std::min(num_rows, total_rows);
  307. }
  308. // get the average per file
  309. avg_rows_per_file = num_rows / num_files;
  310. *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max);
  311. return Status::OK();
  312. }
  313. // Helper function to inject a shuffle operator over top of current operator being built
  314. Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
  315. int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op) {
  316. std::shared_ptr<ShuffleOp> new_shuffle_op = nullptr;
  317. int64_t shuffle_size = 0;
  318. RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size));
  319. MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
  320. // Add the shuffle op
  321. *shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer);
  322. return Status::OK();
  323. }
  324. // Helper function to validate dataset directory parameter
  325. bool ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) {
  326. if (dataset_dir.empty()) {
  327. MS_LOG(ERROR) << dataset_name << ": dataset_dir is not specified.";
  328. return false;
  329. }
  330. Path dir(dataset_dir);
  331. if (!dir.IsDirectory()) {
  332. MS_LOG(ERROR) << dataset_name << ": dataset_dir: [" << dataset_dir << "] is an invalid directory path.";
  333. return false;
  334. }
  335. if (access(dataset_dir.c_str(), R_OK) == -1) {
  336. MS_LOG(ERROR) << dataset_name << ": No access to specified dataset path: " << dataset_dir;
  337. return false;
  338. }
  339. return true;
  340. }
  341. // Helper function to validate dataset dataset files parameter
  342. bool ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) {
  343. if (dataset_files.empty()) {
  344. MS_LOG(ERROR) << dataset_name << ": dataset_files is not specified.";
  345. return false;
  346. }
  347. for (auto f : dataset_files) {
  348. Path dataset_file(f);
  349. if (!dataset_file.Exists()) {
  350. MS_LOG(ERROR) << dataset_name << ": dataset file: [" << f << "] is invalid or does not exist.";
  351. return false;
  352. }
  353. }
  354. return true;
  355. }
  356. // Helper function to validate dataset num_shards and shard_id parameters
  357. bool ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) {
  358. if (num_shards <= 0) {
  359. MS_LOG(ERROR) << dataset_name << ": Invalid num_shards: " << num_shards;
  360. return false;
  361. }
  362. if (shard_id < 0 || shard_id >= num_shards) {
  363. MS_LOG(ERROR) << dataset_name << ": Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
  364. return false;
  365. }
  366. return true;
  367. }
  368. /* ####################################### Derived Dataset classes ################################# */
  369. // DERIVED DATASET CLASSES LEAF-NODE DATASETS
  370. // (In alphabetical order)
  371. // Constructor for CelebADataset
  372. CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &dataset_type,
  373. const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
  374. const std::set<std::string> &extensions)
  375. : dataset_dir_(dataset_dir),
  376. dataset_type_(dataset_type),
  377. sampler_(sampler),
  378. decode_(decode),
  379. extensions_(extensions) {}
  380. bool CelebADataset::ValidateParams() {
  381. if (!ValidateDatasetDirParam("CelebADataset", dataset_dir_)) {
  382. return false;
  383. }
  384. std::set<std::string> dataset_type_list = {"all", "train", "valid", "test"};
  385. auto iter = dataset_type_list.find(dataset_type_);
  386. if (iter == dataset_type_list.end()) {
  387. MS_LOG(ERROR) << "dataset_type should be one of 'all', 'train', 'valid' or 'test'.";
  388. return false;
  389. }
  390. return true;
  391. }
  392. // Function to build CelebADataset
  393. std::vector<std::shared_ptr<DatasetOp>> CelebADataset::Build() {
  394. // A vector containing shared pointer to the Dataset Ops that this object will create
  395. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  396. // If user does not specify Sampler, create a default sampler based on the shuffle variable.
  397. if (sampler_ == nullptr) {
  398. sampler_ = CreateDefaultSampler();
  399. }
  400. std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
  401. RETURN_EMPTY_IF_ERROR(
  402. schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  403. // label is like this:0 1 0 0 1......
  404. RETURN_EMPTY_IF_ERROR(
  405. schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  406. node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
  407. decode_, dataset_type_, extensions_, std::move(schema),
  408. std::move(sampler_->Build())));
  409. return node_ops;
  410. }
  411. // Constructor for Cifar10Dataset
  412. Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
  413. : dataset_dir_(dataset_dir), sampler_(sampler) {}
  414. bool Cifar10Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_); }
  415. // Function to build CifarOp for Cifar10
  416. std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
  417. // A vector containing shared pointer to the Dataset Ops that this object will create
  418. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  419. // If user does not specify Sampler, create a default sampler based on the shuffle variable.
  420. if (sampler_ == nullptr) {
  421. sampler_ = CreateDefaultSampler();
  422. }
  423. // Do internal Schema generation.
  424. auto schema = std::make_unique<DataSchema>();
  425. RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
  426. TensorShape scalar = TensorShape::CreateScalar();
  427. RETURN_EMPTY_IF_ERROR(
  428. schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
  429. node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_,
  430. dataset_dir_, connector_que_size_, std::move(schema),
  431. std::move(sampler_->Build())));
  432. return node_ops;
  433. }
  434. // Constructor for Cifar100Dataset
  435. Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
  436. : dataset_dir_(dataset_dir), sampler_(sampler) {}
  437. bool Cifar100Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_); }
  438. // Function to build CifarOp for Cifar100
  439. std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
  440. // A vector containing shared pointer to the Dataset Ops that this object will create
  441. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  442. // If user does not specify Sampler, create a default sampler based on the shuffle variable.
  443. if (sampler_ == nullptr) {
  444. sampler_ = CreateDefaultSampler();
  445. }
  446. // Do internal Schema generation.
  447. auto schema = std::make_unique<DataSchema>();
  448. RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
  449. TensorShape scalar = TensorShape::CreateScalar();
  450. RETURN_EMPTY_IF_ERROR(
  451. schema->AddColumn(ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
  452. RETURN_EMPTY_IF_ERROR(
  453. schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
  454. node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, num_workers_, rows_per_buffer_,
  455. dataset_dir_, connector_que_size_, std::move(schema),
  456. std::move(sampler_->Build())));
  457. return node_ops;
  458. }
  459. // Constructor for CLUEDataset
  460. CLUEDataset::CLUEDataset(const std::vector<std::string> clue_files, std::string task, std::string usage,
  461. int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id)
  462. : dataset_files_(clue_files),
  463. task_(task),
  464. usage_(usage),
  465. num_samples_(num_samples),
  466. shuffle_(shuffle),
  467. num_shards_(num_shards),
  468. shard_id_(shard_id) {}
  469. bool CLUEDataset::ValidateParams() {
  470. if (!ValidateDatasetFilesParam("CLUEDataset", dataset_files_)) {
  471. return false;
  472. }
  473. std::vector<std::string> task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"};
  474. std::vector<std::string> usage_list = {"train", "test", "eval"};
  475. if (find(task_list.begin(), task_list.end(), task_) == task_list.end()) {
  476. MS_LOG(ERROR) << "task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL.";
  477. return false;
  478. }
  479. if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) {
  480. MS_LOG(ERROR) << "usage should be train, test or eval.";
  481. return false;
  482. }
  483. if (num_samples_ < 0) {
  484. MS_LOG(ERROR) << "CLUEDataset: Invalid number of samples: " << num_samples_;
  485. return false;
  486. }
  487. if (!ValidateDatasetShardParams("CLUEDataset", num_shards_, shard_id_)) {
  488. return false;
  489. }
  490. return true;
  491. }
  492. // Function to split string based on a character delimiter
  493. std::vector<std::string> CLUEDataset::split(const std::string &s, char delim) {
  494. std::vector<std::string> res;
  495. std::stringstream ss(s);
  496. std::string item;
  497. while (getline(ss, item, delim)) {
  498. res.push_back(item);
  499. }
  500. return res;
  501. }
  502. // Function to build CLUEDataset
  503. std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
  504. // A vector containing shared pointer to the Dataset Ops that this object will create
  505. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  506. std::map<std::string, std::string> key_map;
  507. if (task_ == "AFQMC") {
  508. if (usage_ == "train") {
  509. key_map["sentence1"] = "sentence1";
  510. key_map["sentence2"] = "sentence2";
  511. key_map["label"] = "label";
  512. } else if (usage_ == "test") {
  513. key_map["id"] = "id";
  514. key_map["sentence1"] = "sentence1";
  515. key_map["sentence2"] = "sentence2";
  516. } else if (usage_ == "eval") {
  517. key_map["sentence1"] = "sentence1";
  518. key_map["sentence2"] = "sentence2";
  519. key_map["label"] = "label";
  520. }
  521. } else if (task_ == "CMNLI") {
  522. if (usage_ == "train") {
  523. key_map["sentence1"] = "sentence1";
  524. key_map["sentence2"] = "sentence2";
  525. key_map["label"] = "label";
  526. } else if (usage_ == "test") {
  527. key_map["id"] = "id";
  528. key_map["sentence1"] = "sentence1";
  529. key_map["sentence2"] = "sentence2";
  530. } else if (usage_ == "eval") {
  531. key_map["sentence1"] = "sentence1";
  532. key_map["sentence2"] = "sentence2";
  533. key_map["label"] = "label";
  534. }
  535. } else if (task_ == "CSL") {
  536. if (usage_ == "train") {
  537. key_map["id"] = "id";
  538. key_map["abst"] = "abst";
  539. key_map["keyword"] = "keyword";
  540. key_map["label"] = "label";
  541. } else if (usage_ == "test") {
  542. key_map["id"] = "id";
  543. key_map["abst"] = "abst";
  544. key_map["keyword"] = "keyword";
  545. } else if (usage_ == "eval") {
  546. key_map["id"] = "id";
  547. key_map["abst"] = "abst";
  548. key_map["keyword"] = "keyword";
  549. key_map["label"] = "label";
  550. }
  551. } else if (task_ == "IFLYTEK") {
  552. if (usage_ == "train") {
  553. key_map["label"] = "label";
  554. key_map["label_des"] = "label_des";
  555. key_map["sentence"] = "sentence";
  556. } else if (usage_ == "test") {
  557. key_map["id"] = "id";
  558. key_map["sentence"] = "sentence";
  559. } else if (usage_ == "eval") {
  560. key_map["label"] = "label";
  561. key_map["label_des"] = "label_des";
  562. key_map["sentence"] = "sentence";
  563. }
  564. } else if (task_ == "TNEWS") {
  565. if (usage_ == "train") {
  566. key_map["label"] = "label";
  567. key_map["label_desc"] = "label_desc";
  568. key_map["sentence"] = "sentence";
  569. key_map["keywords"] = "keywords";
  570. } else if (usage_ == "test") {
  571. key_map["id"] = "id";
  572. key_map["sentence"] = "sentence";
  573. key_map["keywords"] = "keywords";
  574. } else if (usage_ == "eval") {
  575. key_map["label"] = "label";
  576. key_map["label_desc"] = "label_desc";
  577. key_map["sentence"] = "sentence";
  578. key_map["keywords"] = "keywords";
  579. }
  580. } else if (task_ == "WSC") {
  581. if (usage_ == "train") {
  582. key_map["span1_index"] = "target/span1_index";
  583. key_map["span2_index"] = "target/span2_index";
  584. key_map["span1_text"] = "target/span1_text";
  585. key_map["span2_text"] = "target/span2_text";
  586. key_map["idx"] = "idx";
  587. key_map["label"] = "label";
  588. key_map["text"] = "text";
  589. } else if (usage_ == "test") {
  590. key_map["span1_index"] = "target/span1_index";
  591. key_map["span2_index"] = "target/span2_index";
  592. key_map["span1_text"] = "target/span1_text";
  593. key_map["span2_text"] = "target/span2_text";
  594. key_map["idx"] = "idx";
  595. key_map["text"] = "text";
  596. } else if (usage_ == "eval") {
  597. key_map["span1_index"] = "target/span1_index";
  598. key_map["span2_index"] = "target/span2_index";
  599. key_map["span1_text"] = "target/span1_text";
  600. key_map["span2_text"] = "target/span2_text";
  601. key_map["idx"] = "idx";
  602. key_map["label"] = "label";
  603. key_map["text"] = "text";
  604. }
  605. }
  606. ColKeyMap ck_map;
  607. for (auto &p : key_map) {
  608. ck_map.insert({p.first, split(p.second, '/')});
  609. }
  610. bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
  611. std::shared_ptr<ClueOp> clue_op =
  612. std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
  613. dataset_files_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
  614. RETURN_EMPTY_IF_ERROR(clue_op->Init());
  615. if (shuffle_ == ShuffleMode::kGlobal) {
  616. // Inject ShuffleOp
  617. std::shared_ptr<DatasetOp> shuffle_op = nullptr;
  618. int64_t num_rows = 0;
  619. // First, get the number of rows in the dataset
  620. RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(dataset_files_, &num_rows));
  621. // Add the shuffle op after this op
  622. RETURN_EMPTY_IF_ERROR(AddShuffleOp(dataset_files_.size(), num_shards_, num_rows, 0, connector_que_size_,
  623. rows_per_buffer_, &shuffle_op));
  624. node_ops.push_back(shuffle_op);
  625. }
  626. node_ops.push_back(clue_op);
  627. return node_ops;
  628. }
  629. // Constructor for CocoDataset
  630. CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
  631. const bool &decode, const std::shared_ptr<SamplerObj> &sampler)
  632. : dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {}
  633. bool CocoDataset::ValidateParams() {
  634. if (!ValidateDatasetDirParam("CocoDataset", dataset_dir_)) {
  635. return false;
  636. }
  637. Path annotation_file(annotation_file_);
  638. if (!annotation_file.Exists()) {
  639. MS_LOG(ERROR) << "annotation_file is invalid or not exist";
  640. return false;
  641. }
  642. std::set<std::string> task_list = {"Detection", "Stuff", "Panoptic", "Keypoint"};
  643. auto task_iter = task_list.find(task_);
  644. if (task_iter == task_list.end()) {
  645. MS_LOG(ERROR) << "Invalid task type";
  646. return false;
  647. }
  648. return true;
  649. }
  650. // Function to build CocoDataset
  651. std::vector<std::shared_ptr<DatasetOp>> CocoDataset::Build() {
  652. // A vector containing shared pointer to the Dataset Ops that this object will create
  653. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  654. // If user does not specify Sampler, create a default sampler based on the shuffle variable.
  655. if (sampler_ == nullptr) {
  656. sampler_ = CreateDefaultSampler();
  657. }
  658. CocoOp::TaskType task_type;
  659. if (task_ == "Detection") {
  660. task_type = CocoOp::TaskType::Detection;
  661. } else if (task_ == "Stuff") {
  662. task_type = CocoOp::TaskType::Stuff;
  663. } else if (task_ == "Keypoint") {
  664. task_type = CocoOp::TaskType::Keypoint;
  665. } else if (task_ == "Panoptic") {
  666. task_type = CocoOp::TaskType::Panoptic;
  667. }
  668. std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
  669. RETURN_EMPTY_IF_ERROR(
  670. schema->AddColumn(ColDescriptor(std::string("image"), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  671. switch (task_type) {
  672. case CocoOp::TaskType::Detection:
  673. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  674. ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
  675. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  676. ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  677. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  678. ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  679. break;
  680. case CocoOp::TaskType::Stuff:
  681. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  682. ColDescriptor(std::string("segmentation"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
  683. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  684. ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  685. break;
  686. case CocoOp::TaskType::Keypoint:
  687. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  688. ColDescriptor(std::string("keypoints"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
  689. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  690. ColDescriptor(std::string("num_keypoints"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  691. break;
  692. case CocoOp::TaskType::Panoptic:
  693. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  694. ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
  695. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  696. ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  697. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  698. ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  699. RETURN_EMPTY_IF_ERROR(
  700. schema->AddColumn(ColDescriptor(std::string("area"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  701. break;
  702. default:
  703. MS_LOG(ERROR) << "CocoDataset::Build : Invalid task type";
  704. return {};
  705. }
  706. std::shared_ptr<CocoOp> op =
  707. std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
  708. connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
  709. node_ops.push_back(op);
  710. return node_ops;
  711. }
  712. ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
  713. bool recursive, std::set<std::string> extensions,
  714. std::map<std::string, int32_t> class_indexing)
  715. : dataset_dir_(dataset_dir),
  716. decode_(decode),
  717. sampler_(sampler),
  718. recursive_(recursive),
  719. class_indexing_(class_indexing),
  720. exts_(extensions) {}
  721. bool ImageFolderDataset::ValidateParams() { return ValidateDatasetDirParam("ImageFolderDataset", dataset_dir_); }
  722. std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
  723. // A vector containing shared pointer to the Dataset Ops that this object will create
  724. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  725. // If user does not specify Sampler, create a default sampler, i.e., RandomSampler.
  726. if (sampler_ == nullptr) {
  727. sampler_ = CreateDefaultSampler();
  728. }
  729. // Do internal Schema generation.
  730. // This arg is exist in ImageFolderOp, but not externalized (in Python API).
  731. std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
  732. TensorShape scalar = TensorShape::CreateScalar();
  733. RETURN_EMPTY_IF_ERROR(
  734. schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  735. RETURN_EMPTY_IF_ERROR(
  736. schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
  737. node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
  738. recursive_, decode_, exts_, class_indexing_, std::move(schema),
  739. std::move(sampler_->Build())));
  740. return node_ops;
  741. }
  742. MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler)
  743. : dataset_dir_(dataset_dir), sampler_(sampler) {}
  744. bool MnistDataset::ValidateParams() { return ValidateDatasetDirParam("MnistDataset", dataset_dir_); }
  745. std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
  746. // A vector containing shared pointer to the Dataset Ops that this object will create
  747. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  748. // If user does not specify Sampler, create a default sampler, i.e., RandomSampler.
  749. if (sampler_ == nullptr) {
  750. sampler_ = CreateDefaultSampler();
  751. }
  752. // Do internal Schema generation.
  753. auto schema = std::make_unique<DataSchema>();
  754. RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
  755. TensorShape scalar = TensorShape::CreateScalar();
  756. RETURN_EMPTY_IF_ERROR(
  757. schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
  758. node_ops.push_back(std::make_shared<MnistOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
  759. std::move(schema), std::move(sampler_->Build())));
  760. return node_ops;
  761. }
  762. // Constructor for TextFileDataset
  763. TextFileDataset::TextFileDataset(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle,
  764. int32_t num_shards, int32_t shard_id)
  765. : dataset_files_(dataset_files),
  766. num_samples_(num_samples),
  767. shuffle_(shuffle),
  768. num_shards_(num_shards),
  769. shard_id_(shard_id) {}
  770. bool TextFileDataset::ValidateParams() {
  771. if (!ValidateDatasetFilesParam("TextFileDataset", dataset_files_)) {
  772. return false;
  773. }
  774. if (num_samples_ < 0) {
  775. MS_LOG(ERROR) << "TextFileDataset: Invalid number of samples: " << num_samples_;
  776. return false;
  777. }
  778. if (!ValidateDatasetShardParams("TextfileDataset", num_shards_, shard_id_)) {
  779. return false;
  780. }
  781. return true;
  782. }
  783. // Function to build TextFileDataset
  784. std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
  785. // A vector containing shared pointer to the Dataset Ops that this object will create
  786. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  787. bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
  788. // Do internal Schema generation.
  789. auto schema = std::make_unique<DataSchema>();
  790. RETURN_EMPTY_IF_ERROR(
  791. schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  792. // Create and initalize TextFileOp
  793. std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
  794. num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), dataset_files_,
  795. connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr));
  796. RETURN_EMPTY_IF_ERROR(text_file_op->Init());
  797. if (shuffle_ == ShuffleMode::kGlobal) {
  798. // Inject ShuffleOp
  799. std::shared_ptr<DatasetOp> shuffle_op = nullptr;
  800. int64_t num_rows = 0;
  801. // First, get the number of rows in the dataset
  802. RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(dataset_files_, &num_rows));
  803. // Add the shuffle op after this op
  804. RETURN_EMPTY_IF_ERROR(AddShuffleOp(dataset_files_.size(), num_shards_, num_rows, 0, connector_que_size_,
  805. rows_per_buffer_, &shuffle_op));
  806. node_ops.push_back(shuffle_op);
  807. }
  808. // Add TextFileOp
  809. node_ops.push_back(text_file_op);
  810. return node_ops;
  811. }
  812. // Constructor for VOCDataset
  813. VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
  814. const std::map<std::string, int32_t> &class_indexing, bool decode,
  815. std::shared_ptr<SamplerObj> sampler)
  816. : dataset_dir_(dataset_dir),
  817. task_(task),
  818. mode_(mode),
  819. class_index_(class_indexing),
  820. decode_(decode),
  821. sampler_(sampler) {}
  822. bool VOCDataset::ValidateParams() {
  823. Path dir(dataset_dir_);
  824. if (!dir.IsDirectory()) {
  825. MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified.";
  826. return false;
  827. }
  828. if (task_ == "Segmentation") {
  829. if (!class_index_.empty()) {
  830. MS_LOG(ERROR) << "class_indexing is invalid in Segmentation task.";
  831. return false;
  832. }
  833. Path imagesets_file = dir / "ImageSets" / "Segmentation" / mode_ + ".txt";
  834. if (!imagesets_file.Exists()) {
  835. MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!";
  836. return false;
  837. }
  838. } else if (task_ == "Detection") {
  839. Path imagesets_file = dir / "ImageSets" / "Main" / mode_ + ".txt";
  840. if (!imagesets_file.Exists()) {
  841. MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!";
  842. return false;
  843. }
  844. } else {
  845. MS_LOG(ERROR) << "Invalid task: " << task_;
  846. return false;
  847. }
  848. return true;
  849. }
  850. // Function to build VOCDataset
  851. std::vector<std::shared_ptr<DatasetOp>> VOCDataset::Build() {
  852. // A vector containing shared pointer to the Dataset Ops that this object will create
  853. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  854. // If user does not specify Sampler, create a default sampler based on the shuffle variable.
  855. if (sampler_ == nullptr) {
  856. sampler_ = CreateDefaultSampler();
  857. }
  858. auto schema = std::make_unique<DataSchema>();
  859. VOCOp::TaskType task_type_;
  860. if (task_ == "Segmentation") {
  861. task_type_ = VOCOp::TaskType::Segmentation;
  862. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  863. ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  864. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  865. ColDescriptor(std::string(kColumnTarget), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  866. } else if (task_ == "Detection") {
  867. task_type_ = VOCOp::TaskType::Detection;
  868. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  869. ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  870. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  871. ColDescriptor(std::string(kColumnBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
  872. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  873. ColDescriptor(std::string(kColumnLabel), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  874. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  875. ColDescriptor(std::string(kColumnDifficult), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  876. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  877. ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  878. }
  879. std::shared_ptr<VOCOp> voc_op;
  880. voc_op = std::make_shared<VOCOp>(task_type_, mode_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
  881. connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
  882. node_ops.push_back(voc_op);
  883. return node_ops;
  884. }
  885. // DERIVED DATASET CLASSES LEAF-NODE DATASETS
  886. // (In alphabetical order)
  887. BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector<std::string> cols_to_map,
  888. std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
  889. : batch_size_(batch_size),
  890. drop_remainder_(drop_remainder),
  891. pad_(pad),
  892. cols_to_map_(cols_to_map),
  893. pad_map_(pad_map) {}
  894. std::vector<std::shared_ptr<DatasetOp>> BatchDataset::Build() {
  895. // A vector containing shared pointer to the Dataset Ops that this object will create
  896. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  897. #ifdef ENABLE_PYTHON
  898. py::function noop;
  899. node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
  900. cols_to_map_, noop, noop, pad_map_));
  901. #else
  902. node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
  903. cols_to_map_, pad_map_));
  904. #endif
  905. return node_ops;
  906. }
  907. bool BatchDataset::ValidateParams() {
  908. if (batch_size_ <= 0) {
  909. MS_LOG(ERROR) << "Batch: Batch size cannot be negative";
  910. return false;
  911. }
  912. return true;
  913. }
  914. // Function to build ConcatOp
  915. ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
  916. this->children = datasets_;
  917. }
  918. bool ConcatDataset::ValidateParams() {
  919. if (datasets_.empty()) {
  920. MS_LOG(ERROR) << "Concat: concatenated datasets are not specified.";
  921. return false;
  922. }
  923. return true;
  924. }
  925. std::vector<std::shared_ptr<DatasetOp>> ConcatDataset::Build() {
  926. // A vector containing shared pointer to the Dataset Ops that this object will create
  927. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  928. node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
  929. return node_ops;
  930. }
  931. MapDataset::MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns,
  932. std::vector<std::string> output_columns, const std::vector<std::string> &project_columns)
  933. : operations_(operations),
  934. input_columns_(input_columns),
  935. output_columns_(output_columns),
  936. project_columns_(project_columns) {}
  937. std::vector<std::shared_ptr<DatasetOp>> MapDataset::Build() {
  938. // A vector containing shared pointer to the Dataset Ops that this object will create
  939. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  940. std::vector<std::shared_ptr<TensorOp>> tensor_ops;
  941. // Build tensorOp from tensorOperation vector
  942. // This is to ensure each iterator hold its own copy of the tensorOp objects.
  943. (void)std::transform(
  944. operations_.begin(), operations_.end(), std::back_inserter(tensor_ops),
  945. [](std::shared_ptr<TensorOperation> operation) -> std::shared_ptr<TensorOp> { return operation->Build(); });
  946. // This parameter will be removed with next rebase
  947. std::vector<std::string> col_orders;
  948. auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_);
  949. if (!project_columns_.empty()) {
  950. auto project_op = std::make_shared<ProjectOp>(project_columns_);
  951. node_ops.push_back(project_op);
  952. }
  953. node_ops.push_back(map_op);
  954. return node_ops;
  955. }
  956. bool MapDataset::ValidateParams() {
  957. if (operations_.empty()) {
  958. MS_LOG(ERROR) << "Map: No operation is specified.";
  959. return false;
  960. }
  961. return true;
  962. }
  963. // Function to build ProjectOp
  964. ProjectDataset::ProjectDataset(const std::vector<std::string> &columns) : columns_(columns) {}
  965. bool ProjectDataset::ValidateParams() {
  966. if (columns_.empty()) {
  967. MS_LOG(ERROR) << "No columns are specified.";
  968. return false;
  969. }
  970. return true;
  971. }
  972. std::vector<std::shared_ptr<DatasetOp>> ProjectDataset::Build() {
  973. // A vector containing shared pointer to the Dataset Ops that this object will create
  974. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  975. node_ops.push_back(std::make_shared<ProjectOp>(columns_));
  976. return node_ops;
  977. }
  978. // Function to build RenameOp
  979. RenameDataset::RenameDataset(const std::vector<std::string> &input_columns,
  980. const std::vector<std::string> &output_columns)
  981. : input_columns_(input_columns), output_columns_(output_columns) {}
  982. bool RenameDataset::ValidateParams() {
  983. if (input_columns_.empty() || output_columns_.empty()) {
  984. MS_LOG(ERROR) << "input and output columns must be specified";
  985. return false;
  986. }
  987. if (input_columns_.size() != output_columns_.size()) {
  988. MS_LOG(ERROR) << "input and output columns must be the same size";
  989. return false;
  990. }
  991. return true;
  992. }
  993. std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() {
  994. // A vector containing shared pointer to the Dataset Ops that this object will create
  995. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  996. node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
  997. return node_ops;
  998. }
  999. RepeatDataset::RepeatDataset(int32_t count) : repeat_count_(count) {}
  1000. std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
  1001. // A vector containing shared pointer to the Dataset Ops that this object will create
  1002. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  1003. node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
  1004. return node_ops;
  1005. }
  1006. bool RepeatDataset::ValidateParams() {
  1007. if (repeat_count_ <= 0 && repeat_count_ != -1) {
  1008. MS_LOG(ERROR) << "Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " << repeat_count_;
  1009. return false;
  1010. }
  1011. return true;
  1012. }
  1013. // Constructor for ShuffleDataset
  1014. ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch)
  1015. : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {}
  1016. // Function to build the ShuffleOp
  1017. std::vector<std::shared_ptr<DatasetOp>> ShuffleDataset::Build() {
  1018. // A vector containing shared pointer to the Dataset Ops that this object will create
  1019. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  1020. node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
  1021. rows_per_buffer_));
  1022. return node_ops;
  1023. }
  1024. // Function to validate the parameters for ShuffleDataset
  1025. bool ShuffleDataset::ValidateParams() {
  1026. if (shuffle_size_ <= 1) {
  1027. MS_LOG(ERROR) << "ShuffleDataset: Invalid input, shuffle_size: " << shuffle_size_;
  1028. return false;
  1029. }
  1030. return true;
  1031. }
  1032. // Constructor for SkipDataset
  1033. SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {}
  1034. // Function to build the SkipOp
  1035. std::vector<std::shared_ptr<DatasetOp>> SkipDataset::Build() {
  1036. // A vector containing shared pointer to the Dataset Ops that this object will create
  1037. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  1038. node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
  1039. return node_ops;
  1040. }
  1041. // Function to validate the parameters for SkipDataset
  1042. bool SkipDataset::ValidateParams() {
  1043. if (skip_count_ <= -1) {
  1044. MS_LOG(ERROR) << "Skip: skip_count should not be negative, skip_count: " << skip_count_;
  1045. return false;
  1046. }
  1047. return true;
  1048. }
  1049. // Constructor for TakeDataset
  1050. TakeDataset::TakeDataset(int32_t count) : take_count_(count) {}
  1051. // Function to build the TakeOp
  1052. std::vector<std::shared_ptr<DatasetOp>> TakeDataset::Build() {
  1053. // A vector containing shared pointer to the Dataset Ops that this object will create
  1054. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  1055. node_ops.push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_));
  1056. return node_ops;
  1057. }
  1058. // Function to validate the parameters for TakeDataset
  1059. bool TakeDataset::ValidateParams() {
  1060. if (take_count_ < 0 && take_count_ != -1) {
  1061. MS_LOG(ERROR) << "Take: take_count should be either -1 or positive integer, take_count: " << take_count_;
  1062. return false;
  1063. }
  1064. return true;
  1065. }
  1066. // Function to build ZipOp
  1067. ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
  1068. for (auto dataset : datasets_) {
  1069. this->children.push_back(dataset);
  1070. }
  1071. }
  1072. bool ZipDataset::ValidateParams() {
  1073. if (datasets_.empty()) {
  1074. MS_LOG(ERROR) << "Zip: dataset to zip are not specified.";
  1075. return false;
  1076. }
  1077. return true;
  1078. }
  1079. std::vector<std::shared_ptr<DatasetOp>> ZipDataset::Build() {
  1080. // A vector containing shared pointer to the Dataset Ops that this object will create
  1081. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  1082. node_ops.push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_));
  1083. return node_ops;
  1084. }
  1085. } // namespace api
  1086. } // namespace dataset
  1087. } // namespace mindspore