You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.cc 49 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <fstream>
  17. #include "minddata/dataset/include/datasets.h"
  18. #include "minddata/dataset/include/samplers.h"
  19. #include "minddata/dataset/include/transforms.h"
  20. #include "minddata/dataset/engine/dataset_iterator.h"
  21. // Source dataset headers (in alphabetical order)
  22. #include "minddata/dataset/engine/datasetops/source/celeba_op.h"
  23. #include "minddata/dataset/engine/datasetops/source/cifar_op.h"
  24. #include "minddata/dataset/engine/datasetops/source/clue_op.h"
  25. #include "minddata/dataset/engine/datasetops/source/coco_op.h"
  26. #include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
  27. #include "minddata/dataset/engine/datasetops/source/mnist_op.h"
  28. #include "minddata/dataset/engine/datasetops/source/text_file_op.h"
  29. #include "minddata/dataset/engine/datasetops/source/voc_op.h"
  30. // Dataset operator headers (in alphabetical order)
  31. #include "minddata/dataset/engine/datasetops/batch_op.h"
  32. #include "minddata/dataset/engine/datasetops/concat_op.h"
  33. #include "minddata/dataset/engine/datasetops/map_op/map_op.h"
  34. #include "minddata/dataset/engine/datasetops/project_op.h"
  35. #include "minddata/dataset/engine/datasetops/rename_op.h"
  36. #include "minddata/dataset/engine/datasetops/repeat_op.h"
  37. #include "minddata/dataset/engine/datasetops/shuffle_op.h"
  38. #include "minddata/dataset/engine/datasetops/skip_op.h"
  39. #include "minddata/dataset/engine/datasetops/take_op.h"
  40. #include "minddata/dataset/engine/datasetops/zip_op.h"
  41. // Sampler headers (in alphabetical order)
  42. #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
  43. #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
  44. #include "minddata/dataset/core/config_manager.h"
  45. #include "minddata/dataset/util/random.h"
  46. #include "minddata/dataset/util/path.h"
  47. namespace mindspore {
  48. namespace dataset {
  49. namespace api {
  50. #define RETURN_EMPTY_IF_ERROR(_s) \
  51. do { \
  52. Status __rc = (_s); \
  53. if (__rc.IsError()) { \
  54. MS_LOG(ERROR) << __rc; \
  55. return {}; \
  56. } \
  57. } while (false)
  58. // Function to create the iterator, which will build and launch the execution tree.
  59. std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> columns) {
  60. std::shared_ptr<Iterator> iter;
  61. try {
  62. auto ds = shared_from_this();
  63. // The specified columns will be selected from the dataset and passed down the pipeline
  64. // in the order specified, other columns will be discarded.
  65. if (!columns.empty()) {
  66. ds = ds->Project(columns);
  67. }
  68. iter = std::make_shared<Iterator>();
  69. Status rc = iter->BuildAndLaunchTree(ds);
  70. if (rc.IsError()) {
  71. MS_LOG(ERROR) << "CreateIterator failed." << rc;
  72. return nullptr;
  73. }
  74. return iter;
  75. } catch (const std::exception &err) {
  76. MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << err.what();
  77. return nullptr;
  78. }
  79. return iter;
  80. }
  81. // Constructor
  82. Dataset::Dataset() {
  83. // Fetch some default value from config manager
  84. std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  85. num_workers_ = cfg->num_parallel_workers();
  86. rows_per_buffer_ = cfg->rows_per_buffer();
  87. connector_que_size_ = cfg->op_connector_size();
  88. worker_connector_size_ = cfg->worker_connector_size();
  89. }
  90. // FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
  91. // (In alphabetical order)
  92. // Function to create a CelebADataset.
  93. std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type,
  94. const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
  95. const std::set<std::string> &extensions) {
  96. auto ds = std::make_shared<CelebADataset>(dataset_dir, dataset_type, sampler, decode, extensions);
  97. // Call derived class validation method.
  98. return ds->ValidateParams() ? ds : nullptr;
  99. }
  100. // Function to create a Cifar10Dataset.
  101. std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
  102. auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler);
  103. // Call derived class validation method.
  104. return ds->ValidateParams() ? ds : nullptr;
  105. }
  106. // Function to create a Cifar100Dataset.
  107. std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
  108. auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, sampler);
  109. // Call derived class validation method.
  110. return ds->ValidateParams() ? ds : nullptr;
  111. }
  112. // Function to create a CLUEDataset.
  113. std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &clue_files, const std::string &task,
  114. const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int num_shards,
  115. int shard_id) {
  116. auto ds = std::make_shared<CLUEDataset>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id);
  117. // Call derived class validation method.
  118. return ds->ValidateParams() ? ds : nullptr;
  119. }
  120. // Function to create a CocoDataset.
  121. std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
  122. const std::string &task, const bool &decode,
  123. const std::shared_ptr<SamplerObj> &sampler) {
  124. auto ds = std::make_shared<CocoDataset>(dataset_dir, annotation_file, task, decode, sampler);
  125. // Call derived class validation method.
  126. return ds->ValidateParams() ? ds : nullptr;
  127. }
  128. // Function to create a ImageFolderDataset.
  129. std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode,
  130. std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions,
  131. std::map<std::string, int32_t> class_indexing) {
  132. // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
  133. bool recursive = false;
  134. // Create logical representation of ImageFolderDataset.
  135. auto ds = std::make_shared<ImageFolderDataset>(dataset_dir, decode, sampler, recursive, extensions, class_indexing);
  136. // Call derived class validation method.
  137. return ds->ValidateParams() ? ds : nullptr;
  138. }
  139. // Function to create a MnistDataset.
  140. std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler) {
  141. auto ds = std::make_shared<MnistDataset>(dataset_dir, sampler);
  142. // Call derived class validation method.
  143. return ds->ValidateParams() ? ds : nullptr;
  144. }
  145. // Function to overload "+" operator to concat two datasets
  146. std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
  147. const std::shared_ptr<Dataset> &datasets2) {
  148. std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2}));
  149. // Call derived class validation method.
  150. return ds->ValidateParams() ? ds : nullptr;
  151. }
  152. // Function to create a TextFileDataset.
  153. std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples,
  154. ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) {
  155. auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id);
  156. // Call derived class validation method.
  157. return ds->ValidateParams() ? ds : nullptr;
  158. }
  159. // Function to create a VOCDataset.
  160. std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode,
  161. const std::map<std::string, int32_t> &class_index, bool decode,
  162. std::shared_ptr<SamplerObj> sampler) {
  163. auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_index, decode, sampler);
  164. // Call derived class validation method.
  165. return ds->ValidateParams() ? ds : nullptr;
  166. }
  167. // Function to create a ZipDataset.
  168. std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
  169. auto ds = std::make_shared<ZipDataset>(datasets);
  170. // Call derived class validation method.
  171. return ds->ValidateParams() ? ds : nullptr;
  172. }
  173. // FUNCTIONS TO CREATE DATASETS FOR DATASET OPS
  174. // (In alphabetical order)
  175. // Function to create a Batch dataset
  176. std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
  177. // Default values
  178. std::vector<std::string> cols_to_map = {};
  179. std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map;
  180. bool pad = false;
  181. auto ds = std::make_shared<BatchDataset>(batch_size, drop_remainder, pad, cols_to_map, pad_map);
  182. if (!ds->ValidateParams()) {
  183. return nullptr;
  184. }
  185. ds->children.push_back(shared_from_this());
  186. return ds;
  187. }
  188. // Function to create a Concat dataset
  189. std::shared_ptr<ConcatDataset> Dataset::Concat(const std::vector<std::shared_ptr<Dataset>> &datasets) {
  190. auto ds = std::make_shared<ConcatDataset>(datasets);
  191. ds->children.push_back(shared_from_this());
  192. return ds->ValidateParams() ? ds : nullptr;
  193. }
  194. // Function to create a Map dataset.
  195. std::shared_ptr<MapDataset> Dataset::Map(std::vector<std::shared_ptr<TensorOperation>> operations,
  196. std::vector<std::string> input_columns,
  197. std::vector<std::string> output_columns,
  198. const std::vector<std::string> &project_columns) {
  199. auto ds = std::make_shared<MapDataset>(operations, input_columns, output_columns, project_columns);
  200. if (!ds->ValidateParams()) {
  201. return nullptr;
  202. }
  203. ds->children.push_back(shared_from_this());
  204. return ds;
  205. }
  206. // Function to create a ProjectDataset.
  207. std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string> &columns) {
  208. auto ds = std::make_shared<ProjectDataset>(columns);
  209. // Call derived class validation method.
  210. if (!ds->ValidateParams()) {
  211. return nullptr;
  212. }
  213. ds->children.push_back(shared_from_this());
  214. return ds;
  215. }
  216. // Function to create a RenameDataset.
  217. std::shared_ptr<RenameDataset> Dataset::Rename(const std::vector<std::string> &input_columns,
  218. const std::vector<std::string> &output_columns) {
  219. auto ds = std::make_shared<RenameDataset>(input_columns, output_columns);
  220. // Call derived class validation method.
  221. if (!ds->ValidateParams()) {
  222. return nullptr;
  223. }
  224. ds->children.push_back(shared_from_this());
  225. return ds;
  226. }
  227. // Function to create Repeat dataset.
  228. std::shared_ptr<Dataset> Dataset::Repeat(int32_t count) {
  229. // Workaround for repeat == 1, do not inject repeat.
  230. if (count == 1) {
  231. return shared_from_this();
  232. }
  233. auto ds = std::make_shared<RepeatDataset>(count);
  234. if (!ds->ValidateParams()) {
  235. return nullptr;
  236. }
  237. ds->children.push_back(shared_from_this());
  238. return ds;
  239. }
  240. // Function to create a ShuffleOp
  241. std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t buffer_size) {
  242. // Pass in reshuffle_each_epoch with true
  243. auto ds = std::make_shared<ShuffleDataset>(buffer_size, true);
  244. if (!ds->ValidateParams()) {
  245. return nullptr;
  246. }
  247. ds->children.push_back(shared_from_this());
  248. return ds;
  249. }
  250. // Function to create a SkipDataset.
  251. std::shared_ptr<SkipDataset> Dataset::Skip(int32_t count) {
  252. auto ds = std::make_shared<SkipDataset>(count);
  253. // Call derived class validation method.
  254. if (!ds->ValidateParams()) {
  255. return nullptr;
  256. }
  257. ds->children.push_back(shared_from_this());
  258. return ds;
  259. }
  260. // Function to create a TakeDataset.
  261. std::shared_ptr<Dataset> Dataset::Take(int32_t count) {
  262. // If count is greater than the number of element in dataset or equal to -1,
  263. // all the element in dataset will be taken
  264. if (count == -1) {
  265. return shared_from_this();
  266. }
  267. auto ds = std::make_shared<TakeDataset>(count);
  268. // Call derived class validation method.
  269. if (!ds->ValidateParams()) {
  270. return nullptr;
  271. }
  272. ds->children.push_back(shared_from_this());
  273. return ds;
  274. }
  275. // Function to create a Zip dataset
  276. std::shared_ptr<ZipDataset> Dataset::Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
  277. // Default values
  278. auto ds = std::make_shared<ZipDataset>(datasets);
  279. ds->children.push_back(shared_from_this());
  280. return ds->ValidateParams() ? ds : nullptr;
  281. }
  282. // OTHER FUNCTIONS
  283. // Helper function to create default RandomSampler.
  284. std::shared_ptr<SamplerObj> CreateDefaultSampler() {
  285. const int32_t num_samples = 0; // 0 means to sample all ids.
  286. bool replacement = false;
  287. return std::make_shared<RandomSamplerObj>(replacement, num_samples);
  288. }
  289. // Helper function to compute a default shuffle size
  290. Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
  291. int64_t *shuffle_size) {
  292. const int64_t average_files_multiplier = 4;
  293. const int64_t shuffle_max = 10000;
  294. int64_t avg_rows_per_file = 0;
  295. // Adjust the num rows per shard if sharding was given
  296. if (num_devices > 0) {
  297. if (num_rows % num_devices == 0) {
  298. num_rows = num_rows / num_devices;
  299. } else {
  300. num_rows = (num_rows / num_devices) + 1;
  301. }
  302. }
  303. // Cap based on total rows directive. Some ops do not have this and give value of 0.
  304. if (total_rows > 0) {
  305. num_rows = std::min(num_rows, total_rows);
  306. }
  307. // get the average per file
  308. avg_rows_per_file = num_rows / num_files;
  309. *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max);
  310. return Status::OK();
  311. }
  312. // Helper function to inject a shuffle operator over top of current operator being built
  313. Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
  314. int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op) {
  315. std::shared_ptr<ShuffleOp> new_shuffle_op = nullptr;
  316. int64_t shuffle_size = 0;
  317. RETURN_EMPTY_IF_ERROR(ComputeShuffleSize(num_files, num_devices, num_rows, total_rows, &shuffle_size));
  318. MS_LOG(INFO) << "Dataset::AddShuffleOp - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
  319. // Add the shuffle op
  320. *shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size, true, rows_per_buffer);
  321. return Status::OK();
  322. }
  323. // Helper function to validate dataset params
  324. bool ValidateCommonDatasetParams(std::string dataset_dir) {
  325. if (dataset_dir.empty()) {
  326. MS_LOG(ERROR) << "No dataset path is specified";
  327. return false;
  328. }
  329. return true;
  330. }
  331. /* ####################################### Derived Dataset classes ################################# */
  332. // DERIVED DATASET CLASSES LEAF-NODE DATASETS
  333. // (In alphabetical order)
  334. // Constructor for CelebADataset
  335. CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &dataset_type,
  336. const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
  337. const std::set<std::string> &extensions)
  338. : dataset_dir_(dataset_dir),
  339. dataset_type_(dataset_type),
  340. sampler_(sampler),
  341. decode_(decode),
  342. extensions_(extensions) {}
  343. bool CelebADataset::ValidateParams() {
  344. Path dir(dataset_dir_);
  345. if (!dir.IsDirectory()) {
  346. MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified.";
  347. return false;
  348. }
  349. std::set<std::string> dataset_type_list = {"all", "train", "valid", "test"};
  350. auto iter = dataset_type_list.find(dataset_type_);
  351. if (iter == dataset_type_list.end()) {
  352. MS_LOG(ERROR) << "dataset_type should be one of 'all', 'train', 'valid' or 'test'.";
  353. return false;
  354. }
  355. return true;
  356. }
  357. // Function to build CelebADataset
  358. std::vector<std::shared_ptr<DatasetOp>> CelebADataset::Build() {
  359. // A vector containing shared pointer to the Dataset Ops that this object will create
  360. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  361. // If user does not specify Sampler, create a default sampler based on the shuffle variable.
  362. if (sampler_ == nullptr) {
  363. sampler_ = CreateDefaultSampler();
  364. }
  365. std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
  366. RETURN_EMPTY_IF_ERROR(
  367. schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  368. // label is like this:0 1 0 0 1......
  369. RETURN_EMPTY_IF_ERROR(
  370. schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  371. node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
  372. decode_, dataset_type_, extensions_, std::move(schema),
  373. std::move(sampler_->Build())));
  374. return node_ops;
  375. }
  376. // Constructor for Cifar10Dataset
  377. Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
  378. : dataset_dir_(dataset_dir), sampler_(sampler) {}
  379. bool Cifar10Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
  380. // Function to build CifarOp for Cifar10
  381. std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
  382. // A vector containing shared pointer to the Dataset Ops that this object will create
  383. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  384. // If user does not specify Sampler, create a default sampler based on the shuffle variable.
  385. if (sampler_ == nullptr) {
  386. sampler_ = CreateDefaultSampler();
  387. }
  388. // Do internal Schema generation.
  389. auto schema = std::make_unique<DataSchema>();
  390. RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
  391. TensorShape scalar = TensorShape::CreateScalar();
  392. RETURN_EMPTY_IF_ERROR(
  393. schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
  394. node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_,
  395. dataset_dir_, connector_que_size_, std::move(schema),
  396. std::move(sampler_->Build())));
  397. return node_ops;
  398. }
  399. // Constructor for Cifar100Dataset
  400. Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
  401. : dataset_dir_(dataset_dir), sampler_(sampler) {}
  402. bool Cifar100Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
  403. // Function to build CifarOp for Cifar100
  404. std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
  405. // A vector containing shared pointer to the Dataset Ops that this object will create
  406. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  407. // If user does not specify Sampler, create a default sampler based on the shuffle variable.
  408. if (sampler_ == nullptr) {
  409. sampler_ = CreateDefaultSampler();
  410. }
  411. // Do internal Schema generation.
  412. auto schema = std::make_unique<DataSchema>();
  413. RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
  414. TensorShape scalar = TensorShape::CreateScalar();
  415. RETURN_EMPTY_IF_ERROR(
  416. schema->AddColumn(ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
  417. RETURN_EMPTY_IF_ERROR(
  418. schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
  419. node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, num_workers_, rows_per_buffer_,
  420. dataset_dir_, connector_que_size_, std::move(schema),
  421. std::move(sampler_->Build())));
  422. return node_ops;
  423. }
  424. // Constructor for CLUEDataset
  425. CLUEDataset::CLUEDataset(const std::vector<std::string> clue_files, std::string task, std::string usage,
  426. int64_t num_samples, ShuffleMode shuffle, int num_shards, int shard_id)
  427. : dataset_files_(clue_files),
  428. task_(task),
  429. usage_(usage),
  430. num_samples_(num_samples),
  431. shuffle_(shuffle),
  432. num_shards_(num_shards),
  433. shard_id_(shard_id) {}
  434. bool CLUEDataset::ValidateParams() {
  435. if (dataset_files_.empty()) {
  436. MS_LOG(ERROR) << "CLUEDataset: dataset_files is not specified.";
  437. return false;
  438. }
  439. for (auto f : dataset_files_) {
  440. Path clue_file(f);
  441. if (!clue_file.Exists()) {
  442. MS_LOG(ERROR) << "dataset file: [" << f << "] is invalid or not exist";
  443. return false;
  444. }
  445. }
  446. std::vector<std::string> task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"};
  447. std::vector<std::string> usage_list = {"train", "test", "eval"};
  448. if (find(task_list.begin(), task_list.end(), task_) == task_list.end()) {
  449. MS_LOG(ERROR) << "task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL.";
  450. return false;
  451. }
  452. if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) {
  453. MS_LOG(ERROR) << "usage should be train, test or eval.";
  454. return false;
  455. }
  456. if (num_samples_ < 0) {
  457. MS_LOG(ERROR) << "CLUEDataset: Invalid number of samples: " << num_samples_;
  458. return false;
  459. }
  460. if (num_shards_ <= 0) {
  461. MS_LOG(ERROR) << "CLUEDataset: Invalid num_shards: " << num_shards_;
  462. return false;
  463. }
  464. if (shard_id_ < 0 || shard_id_ >= num_shards_) {
  465. MS_LOG(ERROR) << "CLUEDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
  466. return false;
  467. }
  468. return true;
  469. }
  470. // Function to split string based on a character delimiter
  471. std::vector<std::string> CLUEDataset::split(const std::string &s, char delim) {
  472. std::vector<std::string> res;
  473. std::stringstream ss(s);
  474. std::string item;
  475. while (getline(ss, item, delim)) {
  476. res.push_back(item);
  477. }
  478. return res;
  479. }
  480. // Function to build CLUEDataset
  481. std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
  482. // A vector containing shared pointer to the Dataset Ops that this object will create
  483. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  484. std::map<std::string, std::string> key_map;
  485. if (task_ == "AFQMC") {
  486. if (usage_ == "train") {
  487. key_map["sentence1"] = "sentence1";
  488. key_map["sentence2"] = "sentence2";
  489. key_map["label"] = "label";
  490. } else if (usage_ == "test") {
  491. key_map["id"] = "id";
  492. key_map["sentence1"] = "sentence1";
  493. key_map["sentence2"] = "sentence2";
  494. } else if (usage_ == "eval") {
  495. key_map["sentence1"] = "sentence1";
  496. key_map["sentence2"] = "sentence2";
  497. key_map["label"] = "label";
  498. }
  499. } else if (task_ == "CMNLI") {
  500. if (usage_ == "train") {
  501. key_map["sentence1"] = "sentence1";
  502. key_map["sentence2"] = "sentence2";
  503. key_map["label"] = "label";
  504. } else if (usage_ == "test") {
  505. key_map["id"] = "id";
  506. key_map["sentence1"] = "sentence1";
  507. key_map["sentence2"] = "sentence2";
  508. } else if (usage_ == "eval") {
  509. key_map["sentence1"] = "sentence1";
  510. key_map["sentence2"] = "sentence2";
  511. key_map["label"] = "label";
  512. }
  513. } else if (task_ == "CSL") {
  514. if (usage_ == "train") {
  515. key_map["id"] = "id";
  516. key_map["abst"] = "abst";
  517. key_map["keyword"] = "keyword";
  518. key_map["label"] = "label";
  519. } else if (usage_ == "test") {
  520. key_map["id"] = "id";
  521. key_map["abst"] = "abst";
  522. key_map["keyword"] = "keyword";
  523. } else if (usage_ == "eval") {
  524. key_map["id"] = "id";
  525. key_map["abst"] = "abst";
  526. key_map["keyword"] = "keyword";
  527. key_map["label"] = "label";
  528. }
  529. } else if (task_ == "IFLYTEK") {
  530. if (usage_ == "train") {
  531. key_map["label"] = "label";
  532. key_map["label_des"] = "label_des";
  533. key_map["sentence"] = "sentence";
  534. } else if (usage_ == "test") {
  535. key_map["id"] = "id";
  536. key_map["sentence"] = "sentence";
  537. } else if (usage_ == "eval") {
  538. key_map["label"] = "label";
  539. key_map["label_des"] = "label_des";
  540. key_map["sentence"] = "sentence";
  541. }
  542. } else if (task_ == "TNEWS") {
  543. if (usage_ == "train") {
  544. key_map["label"] = "label";
  545. key_map["label_desc"] = "label_desc";
  546. key_map["sentence"] = "sentence";
  547. key_map["keywords"] = "keywords";
  548. } else if (usage_ == "test") {
  549. key_map["id"] = "id";
  550. key_map["sentence"] = "sentence";
  551. key_map["keywords"] = "keywords";
  552. } else if (usage_ == "eval") {
  553. key_map["label"] = "label";
  554. key_map["label_desc"] = "label_desc";
  555. key_map["sentence"] = "sentence";
  556. key_map["keywords"] = "keywords";
  557. }
  558. } else if (task_ == "WSC") {
  559. if (usage_ == "train") {
  560. key_map["span1_index"] = "target/span1_index";
  561. key_map["span2_index"] = "target/span2_index";
  562. key_map["span1_text"] = "target/span1_text";
  563. key_map["span2_text"] = "target/span2_text";
  564. key_map["idx"] = "idx";
  565. key_map["label"] = "label";
  566. key_map["text"] = "text";
  567. } else if (usage_ == "test") {
  568. key_map["span1_index"] = "target/span1_index";
  569. key_map["span2_index"] = "target/span2_index";
  570. key_map["span1_text"] = "target/span1_text";
  571. key_map["span2_text"] = "target/span2_text";
  572. key_map["idx"] = "idx";
  573. key_map["text"] = "text";
  574. } else if (usage_ == "eval") {
  575. key_map["span1_index"] = "target/span1_index";
  576. key_map["span2_index"] = "target/span2_index";
  577. key_map["span1_text"] = "target/span1_text";
  578. key_map["span2_text"] = "target/span2_text";
  579. key_map["idx"] = "idx";
  580. key_map["label"] = "label";
  581. key_map["text"] = "text";
  582. }
  583. }
  584. ColKeyMap ck_map;
  585. for (auto &p : key_map) {
  586. ck_map.insert({p.first, split(p.second, '/')});
  587. }
  588. bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
  589. std::shared_ptr<ClueOp> clue_op =
  590. std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
  591. dataset_files_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
  592. RETURN_EMPTY_IF_ERROR(clue_op->Init());
  593. if (shuffle_ == ShuffleMode::kGlobal) {
  594. // Inject ShuffleOp
  595. std::shared_ptr<DatasetOp> shuffle_op = nullptr;
  596. int64_t num_rows = 0;
  597. // First, get the number of rows in the dataset
  598. RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(dataset_files_, &num_rows));
  599. // Add the shuffle op after this op
  600. RETURN_EMPTY_IF_ERROR(AddShuffleOp(dataset_files_.size(), num_shards_, num_rows, 0, connector_que_size_,
  601. rows_per_buffer_, &shuffle_op));
  602. node_ops.push_back(shuffle_op);
  603. }
  604. node_ops.push_back(clue_op);
  605. return node_ops;
  606. }
  607. // Constructor for CocoDataset
  608. CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
  609. const bool &decode, const std::shared_ptr<SamplerObj> &sampler)
  610. : dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {}
  611. bool CocoDataset::ValidateParams() {
  612. Path dir(dataset_dir_);
  613. if (!dir.IsDirectory()) {
  614. MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified.";
  615. return false;
  616. }
  617. Path annotation_file(annotation_file_);
  618. if (!annotation_file.Exists()) {
  619. MS_LOG(ERROR) << "annotation_file is invalid or not exist";
  620. return false;
  621. }
  622. std::set<std::string> task_list = {"Detection", "Stuff", "Panoptic", "Keypoint"};
  623. auto task_iter = task_list.find(task_);
  624. if (task_iter == task_list.end()) {
  625. MS_LOG(ERROR) << "Invalid task type";
  626. return false;
  627. }
  628. return true;
  629. }
  630. // Function to build CocoDataset
  631. std::vector<std::shared_ptr<DatasetOp>> CocoDataset::Build() {
  632. // A vector containing shared pointer to the Dataset Ops that this object will create
  633. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  634. // If user does not specify Sampler, create a default sampler based on the shuffle variable.
  635. if (sampler_ == nullptr) {
  636. sampler_ = CreateDefaultSampler();
  637. }
  638. CocoOp::TaskType task_type;
  639. if (task_ == "Detection") {
  640. task_type = CocoOp::TaskType::Detection;
  641. } else if (task_ == "Stuff") {
  642. task_type = CocoOp::TaskType::Stuff;
  643. } else if (task_ == "Keypoint") {
  644. task_type = CocoOp::TaskType::Keypoint;
  645. } else if (task_ == "Panoptic") {
  646. task_type = CocoOp::TaskType::Panoptic;
  647. }
  648. std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
  649. RETURN_EMPTY_IF_ERROR(
  650. schema->AddColumn(ColDescriptor(std::string("image"), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  651. switch (task_type) {
  652. case CocoOp::TaskType::Detection:
  653. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  654. ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
  655. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  656. ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  657. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  658. ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  659. break;
  660. case CocoOp::TaskType::Stuff:
  661. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  662. ColDescriptor(std::string("segmentation"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
  663. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  664. ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  665. break;
  666. case CocoOp::TaskType::Keypoint:
  667. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  668. ColDescriptor(std::string("keypoints"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
  669. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  670. ColDescriptor(std::string("num_keypoints"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  671. break;
  672. case CocoOp::TaskType::Panoptic:
  673. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  674. ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
  675. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  676. ColDescriptor(std::string("category_id"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  677. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  678. ColDescriptor(std::string("iscrowd"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  679. RETURN_EMPTY_IF_ERROR(
  680. schema->AddColumn(ColDescriptor(std::string("area"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  681. break;
  682. default:
  683. MS_LOG(ERROR) << "CocoDataset::Build : Invalid task type";
  684. return {};
  685. }
  686. std::shared_ptr<CocoOp> op =
  687. std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
  688. connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
  689. node_ops.push_back(op);
  690. return node_ops;
  691. }
  692. ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
  693. bool recursive, std::set<std::string> extensions,
  694. std::map<std::string, int32_t> class_indexing)
  695. : dataset_dir_(dataset_dir),
  696. decode_(decode),
  697. sampler_(sampler),
  698. recursive_(recursive),
  699. class_indexing_(class_indexing),
  700. exts_(extensions) {}
  701. bool ImageFolderDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
  702. std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
  703. // A vector containing shared pointer to the Dataset Ops that this object will create
  704. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  705. // If user does not specify Sampler, create a default sampler, i.e., RandomSampler.
  706. if (sampler_ == nullptr) {
  707. sampler_ = CreateDefaultSampler();
  708. }
  709. // Do internal Schema generation.
  710. // This arg is exist in ImageFolderOp, but not externalized (in Python API).
  711. std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
  712. TensorShape scalar = TensorShape::CreateScalar();
  713. RETURN_EMPTY_IF_ERROR(
  714. schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  715. RETURN_EMPTY_IF_ERROR(
  716. schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
  717. node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
  718. recursive_, decode_, exts_, class_indexing_, std::move(schema),
  719. std::move(sampler_->Build())));
  720. return node_ops;
  721. }
  722. MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler)
  723. : dataset_dir_(dataset_dir), sampler_(sampler) {}
  724. bool MnistDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
  725. std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
  726. // A vector containing shared pointer to the Dataset Ops that this object will create
  727. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  728. // If user does not specify Sampler, create a default sampler, i.e., RandomSampler.
  729. if (sampler_ == nullptr) {
  730. sampler_ = CreateDefaultSampler();
  731. }
  732. // Do internal Schema generation.
  733. auto schema = std::make_unique<DataSchema>();
  734. RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
  735. TensorShape scalar = TensorShape::CreateScalar();
  736. RETURN_EMPTY_IF_ERROR(
  737. schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
  738. node_ops.push_back(std::make_shared<MnistOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
  739. std::move(schema), std::move(sampler_->Build())));
  740. return node_ops;
  741. }
  742. // Constructor for TextFileDataset
  743. TextFileDataset::TextFileDataset(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle,
  744. int32_t num_shards, int32_t shard_id)
  745. : dataset_files_(dataset_files),
  746. num_samples_(num_samples),
  747. shuffle_(shuffle),
  748. num_shards_(num_shards),
  749. shard_id_(shard_id) {}
  750. bool TextFileDataset::ValidateParams() {
  751. if (dataset_files_.empty()) {
  752. MS_LOG(ERROR) << "TextFileDataset: dataset_files is not specified.";
  753. return false;
  754. }
  755. for (auto file : dataset_files_) {
  756. std::ifstream handle(file);
  757. if (!handle.is_open()) {
  758. MS_LOG(ERROR) << "TextFileDataset: Failed to open file: " << file;
  759. return false;
  760. }
  761. }
  762. if (num_samples_ < 0) {
  763. MS_LOG(ERROR) << "TextFileDataset: Invalid number of samples: " << num_samples_;
  764. return false;
  765. }
  766. if (num_shards_ <= 0) {
  767. MS_LOG(ERROR) << "TextFileDataset: Invalid num_shards: " << num_shards_;
  768. return false;
  769. }
  770. if (shard_id_ < 0 || shard_id_ >= num_shards_) {
  771. MS_LOG(ERROR) << "TextFileDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
  772. return false;
  773. }
  774. return true;
  775. }
  776. // Function to build TextFileDataset
  777. std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
  778. // A vector containing shared pointer to the Dataset Ops that this object will create
  779. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  780. bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
  781. // Do internal Schema generation.
  782. auto schema = std::make_unique<DataSchema>();
  783. RETURN_EMPTY_IF_ERROR(
  784. schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  785. // Create and initalize TextFileOp
  786. std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
  787. num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), dataset_files_,
  788. connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr));
  789. RETURN_EMPTY_IF_ERROR(text_file_op->Init());
  790. if (shuffle_ == ShuffleMode::kGlobal) {
  791. // Inject ShuffleOp
  792. std::shared_ptr<DatasetOp> shuffle_op = nullptr;
  793. int64_t num_rows = 0;
  794. // First, get the number of rows in the dataset
  795. RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(dataset_files_, &num_rows));
  796. // Add the shuffle op after this op
  797. RETURN_EMPTY_IF_ERROR(AddShuffleOp(dataset_files_.size(), num_shards_, num_rows, 0, connector_que_size_,
  798. rows_per_buffer_, &shuffle_op));
  799. node_ops.push_back(shuffle_op);
  800. }
  801. // Add TextFileOp
  802. node_ops.push_back(text_file_op);
  803. return node_ops;
  804. }
  805. // Constructor for VOCDataset
  806. VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
  807. const std::map<std::string, int32_t> &class_index, bool decode,
  808. std::shared_ptr<SamplerObj> sampler)
  809. : dataset_dir_(dataset_dir),
  810. task_(task),
  811. mode_(mode),
  812. class_index_(class_index),
  813. decode_(decode),
  814. sampler_(sampler) {}
  815. bool VOCDataset::ValidateParams() {
  816. Path dir(dataset_dir_);
  817. if (!dir.IsDirectory()) {
  818. MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified.";
  819. return false;
  820. }
  821. if (task_ == "Segmentation") {
  822. if (!class_index_.empty()) {
  823. MS_LOG(ERROR) << "class_indexing is invalid in Segmentation task.";
  824. return false;
  825. }
  826. Path imagesets_file = dir / "ImageSets" / "Segmentation" / mode_ + ".txt";
  827. if (!imagesets_file.Exists()) {
  828. MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!";
  829. return false;
  830. }
  831. } else if (task_ == "Detection") {
  832. Path imagesets_file = dir / "ImageSets" / "Main" / mode_ + ".txt";
  833. if (!imagesets_file.Exists()) {
  834. MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!";
  835. return false;
  836. }
  837. } else {
  838. MS_LOG(ERROR) << "Invalid task: " << task_;
  839. return false;
  840. }
  841. return true;
  842. }
  843. // Function to build VOCDataset
  844. std::vector<std::shared_ptr<DatasetOp>> VOCDataset::Build() {
  845. // A vector containing shared pointer to the Dataset Ops that this object will create
  846. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  847. // If user does not specify Sampler, create a default sampler based on the shuffle variable.
  848. if (sampler_ == nullptr) {
  849. sampler_ = CreateDefaultSampler();
  850. }
  851. auto schema = std::make_unique<DataSchema>();
  852. VOCOp::TaskType task_type_;
  853. if (task_ == "Segmentation") {
  854. task_type_ = VOCOp::TaskType::Segmentation;
  855. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  856. ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  857. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  858. ColDescriptor(std::string(kColumnTarget), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  859. } else if (task_ == "Detection") {
  860. task_type_ = VOCOp::TaskType::Detection;
  861. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  862. ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
  863. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  864. ColDescriptor(std::string(kColumnBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
  865. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  866. ColDescriptor(std::string(kColumnLabel), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  867. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  868. ColDescriptor(std::string(kColumnDifficult), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  869. RETURN_EMPTY_IF_ERROR(schema->AddColumn(
  870. ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
  871. }
  872. std::shared_ptr<VOCOp> voc_op;
  873. voc_op = std::make_shared<VOCOp>(task_type_, mode_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
  874. connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build()));
  875. node_ops.push_back(voc_op);
  876. return node_ops;
  877. }
  878. // DERIVED DATASET CLASSES LEAF-NODE DATASETS
  879. // (In alphabetical order)
  880. BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector<std::string> cols_to_map,
  881. std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
  882. : batch_size_(batch_size),
  883. drop_remainder_(drop_remainder),
  884. pad_(pad),
  885. cols_to_map_(cols_to_map),
  886. pad_map_(pad_map) {}
  887. std::vector<std::shared_ptr<DatasetOp>> BatchDataset::Build() {
  888. // A vector containing shared pointer to the Dataset Ops that this object will create
  889. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  890. #ifdef ENABLE_PYTHON
  891. py::function noop;
  892. node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
  893. cols_to_map_, noop, noop, pad_map_));
  894. #else
  895. node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
  896. cols_to_map_, pad_map_));
  897. #endif
  898. return node_ops;
  899. }
  900. bool BatchDataset::ValidateParams() {
  901. if (batch_size_ <= 0) {
  902. MS_LOG(ERROR) << "Batch: Batch size cannot be negative";
  903. return false;
  904. }
  905. return true;
  906. }
  907. // Function to build ConcatOp
  908. ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
  909. this->children = datasets_;
  910. }
  911. bool ConcatDataset::ValidateParams() {
  912. if (datasets_.empty()) {
  913. MS_LOG(ERROR) << "Concat: concatenated datasets are not specified.";
  914. return false;
  915. }
  916. return true;
  917. }
  918. std::vector<std::shared_ptr<DatasetOp>> ConcatDataset::Build() {
  919. // A vector containing shared pointer to the Dataset Ops that this object will create
  920. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  921. node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_));
  922. return node_ops;
  923. }
  924. MapDataset::MapDataset(std::vector<std::shared_ptr<TensorOperation>> operations, std::vector<std::string> input_columns,
  925. std::vector<std::string> output_columns, const std::vector<std::string> &project_columns)
  926. : operations_(operations),
  927. input_columns_(input_columns),
  928. output_columns_(output_columns),
  929. project_columns_(project_columns) {}
  930. std::vector<std::shared_ptr<DatasetOp>> MapDataset::Build() {
  931. // A vector containing shared pointer to the Dataset Ops that this object will create
  932. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  933. std::vector<std::shared_ptr<TensorOp>> tensor_ops;
  934. // Build tensorOp from tensorOperation vector
  935. // This is to ensure each iterator hold its own copy of the tensorOp objects.
  936. (void)std::transform(
  937. operations_.begin(), operations_.end(), std::back_inserter(tensor_ops),
  938. [](std::shared_ptr<TensorOperation> operation) -> std::shared_ptr<TensorOp> { return operation->Build(); });
  939. // This parameter will be removed with next rebase
  940. std::vector<std::string> col_orders;
  941. auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_);
  942. if (!project_columns_.empty()) {
  943. auto project_op = std::make_shared<ProjectOp>(project_columns_);
  944. node_ops.push_back(project_op);
  945. }
  946. node_ops.push_back(map_op);
  947. return node_ops;
  948. }
  949. bool MapDataset::ValidateParams() {
  950. if (operations_.empty()) {
  951. MS_LOG(ERROR) << "Map: No operation is specified.";
  952. return false;
  953. }
  954. return true;
  955. }
  956. // Function to build ProjectOp
  957. ProjectDataset::ProjectDataset(const std::vector<std::string> &columns) : columns_(columns) {}
  958. bool ProjectDataset::ValidateParams() {
  959. if (columns_.empty()) {
  960. MS_LOG(ERROR) << "No columns are specified.";
  961. return false;
  962. }
  963. return true;
  964. }
  965. std::vector<std::shared_ptr<DatasetOp>> ProjectDataset::Build() {
  966. // A vector containing shared pointer to the Dataset Ops that this object will create
  967. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  968. node_ops.push_back(std::make_shared<ProjectOp>(columns_));
  969. return node_ops;
  970. }
  971. // Function to build RenameOp
  972. RenameDataset::RenameDataset(const std::vector<std::string> &input_columns,
  973. const std::vector<std::string> &output_columns)
  974. : input_columns_(input_columns), output_columns_(output_columns) {}
  975. bool RenameDataset::ValidateParams() {
  976. if (input_columns_.empty() || output_columns_.empty()) {
  977. MS_LOG(ERROR) << "input and output columns must be specified";
  978. return false;
  979. }
  980. if (input_columns_.size() != output_columns_.size()) {
  981. MS_LOG(ERROR) << "input and output columns must be the same size";
  982. return false;
  983. }
  984. return true;
  985. }
  986. std::vector<std::shared_ptr<DatasetOp>> RenameDataset::Build() {
  987. // A vector containing shared pointer to the Dataset Ops that this object will create
  988. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  989. node_ops.push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
  990. return node_ops;
  991. }
  992. RepeatDataset::RepeatDataset(int32_t count) : repeat_count_(count) {}
  993. std::vector<std::shared_ptr<DatasetOp>> RepeatDataset::Build() {
  994. // A vector containing shared pointer to the Dataset Ops that this object will create
  995. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  996. node_ops.push_back(std::make_shared<RepeatOp>(repeat_count_));
  997. return node_ops;
  998. }
  999. bool RepeatDataset::ValidateParams() {
  1000. if (repeat_count_ != -1 && repeat_count_ <= 0) {
  1001. MS_LOG(ERROR) << "Repeat: Repeat count cannot be" << repeat_count_;
  1002. return false;
  1003. }
  1004. return true;
  1005. }
  1006. // Constructor for ShuffleDataset
  1007. ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch)
  1008. : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {}
  1009. // Function to build the ShuffleOp
  1010. std::vector<std::shared_ptr<DatasetOp>> ShuffleDataset::Build() {
  1011. // A vector containing shared pointer to the Dataset Ops that this object will create
  1012. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  1013. node_ops.push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
  1014. rows_per_buffer_));
  1015. return node_ops;
  1016. }
  1017. // Function to validate the parameters for ShuffleDataset
  1018. bool ShuffleDataset::ValidateParams() {
  1019. if (shuffle_size_ <= 1) {
  1020. MS_LOG(ERROR) << "ShuffleDataset: Invalid input, shuffle_size: " << shuffle_size_;
  1021. return false;
  1022. }
  1023. return true;
  1024. }
  1025. // Constructor for SkipDataset
  1026. SkipDataset::SkipDataset(int32_t count) : skip_count_(count) {}
  1027. // Function to build the SkipOp
  1028. std::vector<std::shared_ptr<DatasetOp>> SkipDataset::Build() {
  1029. // A vector containing shared pointer to the Dataset Ops that this object will create
  1030. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  1031. node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
  1032. return node_ops;
  1033. }
  1034. // Function to validate the parameters for SkipDataset
  1035. bool SkipDataset::ValidateParams() {
  1036. if (skip_count_ <= -1) {
  1037. MS_LOG(ERROR) << "Skip: Invalid input, skip_count: " << skip_count_;
  1038. return false;
  1039. }
  1040. return true;
  1041. }
  1042. // Constructor for TakeDataset
  1043. TakeDataset::TakeDataset(int32_t count) : take_count_(count) {}
  1044. // Function to build the TakeOp
  1045. std::vector<std::shared_ptr<DatasetOp>> TakeDataset::Build() {
  1046. // A vector containing shared pointer to the Dataset Ops that this object will create
  1047. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  1048. node_ops.push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_));
  1049. return node_ops;
  1050. }
  1051. // Function to validate the parameters for TakeDataset
  1052. bool TakeDataset::ValidateParams() {
  1053. if (take_count_ < -1) {
  1054. MS_LOG(ERROR) << "Take: Invalid input, take_count: " << take_count_;
  1055. return false;
  1056. }
  1057. return true;
  1058. }
  1059. // Function to build ZipOp
  1060. ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
  1061. for (auto dataset : datasets_) {
  1062. this->children.push_back(dataset);
  1063. }
  1064. }
  1065. bool ZipDataset::ValidateParams() {
  1066. if (datasets_.empty()) {
  1067. MS_LOG(ERROR) << "Zip: dataset to zip are not specified.";
  1068. return false;
  1069. }
  1070. return true;
  1071. }
  1072. std::vector<std::shared_ptr<DatasetOp>> ZipDataset::Build() {
  1073. // A vector containing shared pointer to the Dataset Ops that this object will create
  1074. std::vector<std::shared_ptr<DatasetOp>> node_ops;
  1075. node_ops.push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_));
  1076. return node_ops;
  1077. }
  1078. } // namespace api
  1079. } // namespace dataset
  1080. } // namespace mindspore