You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transforms.cc 29 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/include/transforms.h"
  17. #include "minddata/dataset/kernels/image/image_utils.h"
  18. #include "minddata/dataset/kernels/image/center_crop_op.h"
  19. #include "minddata/dataset/kernels/image/crop_op.h"
  20. #include "minddata/dataset/kernels/image/cutmix_batch_op.h"
  21. #include "minddata/dataset/kernels/image/cut_out_op.h"
  22. #include "minddata/dataset/kernels/image/decode_op.h"
  23. #include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
  24. #include "minddata/dataset/kernels/image/mixup_batch_op.h"
  25. #include "minddata/dataset/kernels/image/normalize_op.h"
  26. #include "minddata/dataset/kernels/data/one_hot_op.h"
  27. #include "minddata/dataset/kernels/image/pad_op.h"
  28. #include "minddata/dataset/kernels/image/random_affine_op.h"
  29. #include "minddata/dataset/kernels/image/random_color_adjust_op.h"
  30. #include "minddata/dataset/kernels/image/random_crop_op.h"
  31. #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
  32. #include "minddata/dataset/kernels/image/random_rotation_op.h"
  33. #include "minddata/dataset/kernels/image/random_sharpness_op.h"
  34. #include "minddata/dataset/kernels/image/random_solarize_op.h"
  35. #include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
  36. #include "minddata/dataset/kernels/image/resize_op.h"
  37. #include "minddata/dataset/kernels/image/rgba_to_bgr_op.h"
  38. #include "minddata/dataset/kernels/image/rgba_to_rgb_op.h"
  39. #include "minddata/dataset/kernels/image/swap_red_blue_op.h"
  40. #include "minddata/dataset/kernels/image/uniform_aug_op.h"
  41. namespace mindspore {
  42. namespace dataset {
  43. namespace api {
  44. TensorOperation::TensorOperation() {}
  45. // Transform operations for computer vision.
  46. namespace vision {
  47. // Function to create CenterCropOperation.
  48. std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size) {
  49. auto op = std::make_shared<CenterCropOperation>(size);
  50. // Input validation
  51. if (!op->ValidateParams()) {
  52. return nullptr;
  53. }
  54. return op;
  55. }
  56. // Function to create CropOperation.
  57. std::shared_ptr<CropOperation> Crop(std::vector<int32_t> coordinates, std::vector<int32_t> size) {
  58. auto op = std::make_shared<CropOperation>(coordinates, size);
  59. // Input validation
  60. if (!op->ValidateParams()) {
  61. return nullptr;
  62. }
  63. return op;
  64. }
  65. // Function to create CutMixBatchOperation.
  66. std::shared_ptr<CutMixBatchOperation> CutMixBatch(ImageBatchFormat image_batch_format, float alpha, float prob) {
  67. auto op = std::make_shared<CutMixBatchOperation>(image_batch_format, alpha, prob);
  68. // Input validation
  69. if (!op->ValidateParams()) {
  70. return nullptr;
  71. }
  72. return op;
  73. }
  74. // Function to create CutOutOp.
  75. std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches) {
  76. auto op = std::make_shared<CutOutOperation>(length, num_patches);
  77. // Input validation
  78. if (!op->ValidateParams()) {
  79. return nullptr;
  80. }
  81. return op;
  82. }
  83. // Function to create DecodeOperation.
  84. std::shared_ptr<DecodeOperation> Decode(bool rgb) {
  85. auto op = std::make_shared<DecodeOperation>(rgb);
  86. // Input validation
  87. if (!op->ValidateParams()) {
  88. return nullptr;
  89. }
  90. return op;
  91. }
  92. // Function to create HwcToChwOperation.
  93. std::shared_ptr<HwcToChwOperation> HWC2CHW() {
  94. auto op = std::make_shared<HwcToChwOperation>();
  95. // Input validation
  96. if (!op->ValidateParams()) {
  97. return nullptr;
  98. }
  99. return op;
  100. }
  101. // Function to create MixUpBatchOperation.
  102. std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha) {
  103. auto op = std::make_shared<MixUpBatchOperation>(alpha);
  104. // Input validation
  105. if (!op->ValidateParams()) {
  106. return nullptr;
  107. }
  108. return op;
  109. }
  110. // Function to create NormalizeOperation.
  111. std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std) {
  112. auto op = std::make_shared<NormalizeOperation>(mean, std);
  113. // Input validation
  114. if (!op->ValidateParams()) {
  115. return nullptr;
  116. }
  117. return op;
  118. }
  119. // Function to create OneHotOperation.
  120. std::shared_ptr<OneHotOperation> OneHot(int32_t num_classes) {
  121. auto op = std::make_shared<OneHotOperation>(num_classes);
  122. // Input validation
  123. if (!op->ValidateParams()) {
  124. return nullptr;
  125. }
  126. return op;
  127. }
  128. // Function to create PadOperation.
  129. std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value,
  130. BorderType padding_mode) {
  131. auto op = std::make_shared<PadOperation>(padding, fill_value, padding_mode);
  132. // Input validation
  133. if (!op->ValidateParams()) {
  134. return nullptr;
  135. }
  136. return op;
  137. }
  138. // Function to create RandomColorAdjustOperation.
  139. std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness,
  140. std::vector<float> contrast,
  141. std::vector<float> saturation, std::vector<float> hue) {
  142. auto op = std::make_shared<RandomColorAdjustOperation>(brightness, contrast, saturation, hue);
  143. // Input validation
  144. if (!op->ValidateParams()) {
  145. return nullptr;
  146. }
  147. return op;
  148. }
  149. // Function to create RandomAffineOperation.
  150. std::shared_ptr<RandomAffineOperation> RandomAffine(const std::vector<float_t> &degrees,
  151. const std::vector<float_t> &translate_range,
  152. const std::vector<float_t> &scale_range,
  153. const std::vector<float_t> &shear_ranges,
  154. InterpolationMode interpolation,
  155. const std::vector<uint8_t> &fill_value) {
  156. auto op = std::make_shared<RandomAffineOperation>(degrees, translate_range, scale_range, shear_ranges, interpolation,
  157. fill_value);
  158. // Input validation
  159. if (!op->ValidateParams()) {
  160. return nullptr;
  161. }
  162. return op;
  163. }
  164. // Function to create RandomCropOperation.
  165. std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding,
  166. bool pad_if_needed, std::vector<uint8_t> fill_value,
  167. BorderType padding_mode) {
  168. auto op = std::make_shared<RandomCropOperation>(size, padding, pad_if_needed, fill_value, padding_mode);
  169. // Input validation
  170. if (!op->ValidateParams()) {
  171. return nullptr;
  172. }
  173. return op;
  174. }
  175. // Function to create RandomHorizontalFlipOperation.
  176. std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob) {
  177. auto op = std::make_shared<RandomHorizontalFlipOperation>(prob);
  178. // Input validation
  179. if (!op->ValidateParams()) {
  180. return nullptr;
  181. }
  182. return op;
  183. }
  184. // Function to create RandomRotationOperation.
  185. std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degrees, InterpolationMode resample,
  186. bool expand, std::vector<float> center,
  187. std::vector<uint8_t> fill_value) {
  188. auto op = std::make_shared<RandomRotationOperation>(degrees, resample, expand, center, fill_value);
  189. // Input validation
  190. if (!op->ValidateParams()) {
  191. return nullptr;
  192. }
  193. return op;
  194. }
  195. // Function to create RandomSolarizeOperation.
  196. std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min, uint8_t threshold_max) {
  197. auto op = std::make_shared<RandomSolarizeOperation>(threshold_min, threshold_max);
  198. // Input validation
  199. if (!op->ValidateParams()) {
  200. return nullptr;
  201. }
  202. return op;
  203. }
  204. // Function to create RandomSharpnessOperation.
  205. std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> degrees) {
  206. auto op = std::make_shared<RandomSharpnessOperation>(degrees);
  207. // Input validation
  208. if (!op->ValidateParams()) {
  209. return nullptr;
  210. }
  211. return op;
  212. }
  213. // Function to create RandomVerticalFlipOperation.
  214. std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob) {
  215. auto op = std::make_shared<RandomVerticalFlipOperation>(prob);
  216. // Input validation
  217. if (!op->ValidateParams()) {
  218. return nullptr;
  219. }
  220. return op;
  221. }
  222. // Function to create ResizeOperation.
  223. std::shared_ptr<ResizeOperation> Resize(std::vector<int32_t> size, InterpolationMode interpolation) {
  224. auto op = std::make_shared<ResizeOperation>(size, interpolation);
  225. // Input validation
  226. if (!op->ValidateParams()) {
  227. return nullptr;
  228. }
  229. return op;
  230. }
  231. // Function to create RgbaToBgrOperation.
  232. std::shared_ptr<RgbaToBgrOperation> RGBA2BGR() {
  233. auto op = std::make_shared<RgbaToBgrOperation>();
  234. // Input validation
  235. if (!op->ValidateParams()) {
  236. return nullptr;
  237. }
  238. return op;
  239. }
  240. // Function to create RgbaToRgbOperation.
  241. std::shared_ptr<RgbaToRgbOperation> RGBA2RGB() {
  242. auto op = std::make_shared<RgbaToRgbOperation>();
  243. // Input validation
  244. if (!op->ValidateParams()) {
  245. return nullptr;
  246. }
  247. return op;
  248. }
  249. // Function to create SwapRedBlueOperation.
  250. std::shared_ptr<SwapRedBlueOperation> SwapRedBlue() {
  251. auto op = std::make_shared<SwapRedBlueOperation>();
  252. // Input validation
  253. if (!op->ValidateParams()) {
  254. return nullptr;
  255. }
  256. return op;
  257. }
  258. // Function to create UniformAugOperation.
  259. std::shared_ptr<UniformAugOperation> UniformAugment(std::vector<std::shared_ptr<TensorOperation>> transforms,
  260. int32_t num_ops) {
  261. auto op = std::make_shared<UniformAugOperation>(transforms, num_ops);
  262. // Input validation
  263. if (!op->ValidateParams()) {
  264. return nullptr;
  265. }
  266. return op;
  267. }
  268. /* ####################################### Derived TensorOperation classes ################################# */
  269. // CenterCropOperation
  270. CenterCropOperation::CenterCropOperation(std::vector<int32_t> size) : size_(size) {}
  271. bool CenterCropOperation::ValidateParams() {
  272. if (size_.empty() || size_.size() > 2) {
  273. MS_LOG(ERROR) << "CenterCrop: size vector has incorrect size.";
  274. return false;
  275. }
  276. return true;
  277. }
  278. std::shared_ptr<TensorOp> CenterCropOperation::Build() {
  279. int32_t crop_height = size_[0];
  280. int32_t crop_width = 0;
  281. // User has specified crop_width.
  282. if (size_.size() == 2) {
  283. crop_width = size_[1];
  284. }
  285. std::shared_ptr<CenterCropOp> tensor_op = std::make_shared<CenterCropOp>(crop_height, crop_width);
  286. return tensor_op;
  287. }
  288. // CropOperation.
  289. CropOperation::CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size)
  290. : coordinates_(coordinates), size_(size) {}
  291. bool CropOperation::ValidateParams() {
  292. // Do some input validation.
  293. if (coordinates_.empty() || coordinates_.size() > 2) {
  294. MS_LOG(ERROR) << "Crop: coordinates must be a vector of one or two values";
  295. return false;
  296. }
  297. if (size_.empty() || size_.size() > 2) {
  298. MS_LOG(ERROR) << "Crop: size must be a vector of one or two values";
  299. return false;
  300. }
  301. return true;
  302. }
  303. std::shared_ptr<TensorOp> CropOperation::Build() {
  304. int32_t x, y, height, width;
  305. x = coordinates_[0];
  306. y = coordinates_[1];
  307. height = size_[0];
  308. width = size_[1];
  309. std::shared_ptr<CropOp> tensor_op = std::make_shared<CropOp>(x, y, height, width);
  310. return tensor_op;
  311. }
  312. // CutMixBatchOperation
  313. CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha, float prob)
  314. : image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) {}
  315. bool CutMixBatchOperation::ValidateParams() {
  316. if (alpha_ < 0) {
  317. MS_LOG(ERROR) << "CutMixBatch: alpha cannot be negative.";
  318. return false;
  319. }
  320. if (prob_ < 0 || prob_ > 1) {
  321. MS_LOG(ERROR) << "CutMixBatch: Probability has to be between 0 and 1.";
  322. return false;
  323. }
  324. return true;
  325. }
  326. std::shared_ptr<TensorOp> CutMixBatchOperation::Build() {
  327. std::shared_ptr<CutMixBatchOp> tensor_op = std::make_shared<CutMixBatchOp>(image_batch_format_, alpha_, prob_);
  328. return tensor_op;
  329. }
  330. // CutOutOperation
  331. CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {}
  332. bool CutOutOperation::ValidateParams() {
  333. if (length_ < 0) {
  334. MS_LOG(ERROR) << "CutOut: length cannot be negative";
  335. return false;
  336. }
  337. if (num_patches_ < 0) {
  338. MS_LOG(ERROR) << "CutOut: number of patches cannot be negative";
  339. return false;
  340. }
  341. return true;
  342. }
  343. std::shared_ptr<TensorOp> CutOutOperation::Build() {
  344. std::shared_ptr<CutOutOp> tensor_op = std::make_shared<CutOutOp>(length_, length_, num_patches_, false, 0, 0, 0);
  345. return tensor_op;
  346. }
  347. // DecodeOperation
  348. DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {}
  349. bool DecodeOperation::ValidateParams() { return true; }
  350. std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); }
  351. // HwcToChwOperation
  352. bool HwcToChwOperation::ValidateParams() { return true; }
  353. std::shared_ptr<TensorOp> HwcToChwOperation::Build() { return std::make_shared<HwcToChwOp>(); }
  354. // MixUpOperation
  355. MixUpBatchOperation::MixUpBatchOperation(float alpha) : alpha_(alpha) {}
  356. bool MixUpBatchOperation::ValidateParams() {
  357. if (alpha_ < 0) {
  358. MS_LOG(ERROR) << "MixUpBatch: alpha must be a positive floating value however it is: " << alpha_;
  359. return false;
  360. }
  361. return true;
  362. }
  363. std::shared_ptr<TensorOp> MixUpBatchOperation::Build() { return std::make_shared<MixUpBatchOp>(alpha_); }
  364. // NormalizeOperation
  365. NormalizeOperation::NormalizeOperation(std::vector<float> mean, std::vector<float> std) : mean_(mean), std_(std) {}
  366. bool NormalizeOperation::ValidateParams() {
  367. if (mean_.size() != 3) {
  368. MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size();
  369. return false;
  370. }
  371. if (std_.size() != 3) {
  372. MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size();
  373. return false;
  374. }
  375. return true;
  376. }
  377. std::shared_ptr<TensorOp> NormalizeOperation::Build() {
  378. return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]);
  379. }
  380. // OneHotOperation
  381. OneHotOperation::OneHotOperation(int32_t num_classes) : num_classes_(num_classes) {}
  382. bool OneHotOperation::ValidateParams() {
  383. if (num_classes_ < 0) {
  384. MS_LOG(ERROR) << "OneHot: Number of classes cannot be negative. Number of classes: " << num_classes_;
  385. return false;
  386. }
  387. return true;
  388. }
  389. std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
  390. // PadOperation
  391. PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode)
  392. : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {}
  393. bool PadOperation::ValidateParams() {
  394. if (padding_.empty() || padding_.size() == 3 || padding_.size() > 4) {
  395. MS_LOG(ERROR) << "Pad: padding vector has incorrect size: padding.size()";
  396. return false;
  397. }
  398. if (fill_value_.empty() || (fill_value_.size() != 1 && fill_value_.size() != 3)) {
  399. MS_LOG(ERROR) << "Pad: fill_value vector has incorrect size: fill_value.size()";
  400. return false;
  401. }
  402. return true;
  403. }
  404. std::shared_ptr<TensorOp> PadOperation::Build() {
  405. int32_t pad_top, pad_bottom, pad_left, pad_right;
  406. switch (padding_.size()) {
  407. case 1:
  408. pad_left = padding_[0];
  409. pad_top = padding_[0];
  410. pad_right = padding_[0];
  411. pad_bottom = padding_[0];
  412. break;
  413. case 2:
  414. pad_left = padding_[0];
  415. pad_top = padding_[1];
  416. pad_right = padding_[0];
  417. pad_bottom = padding_[1];
  418. break;
  419. default:
  420. pad_left = padding_[0];
  421. pad_top = padding_[1];
  422. pad_right = padding_[2];
  423. pad_bottom = padding_[3];
  424. }
  425. uint8_t fill_r, fill_g, fill_b;
  426. fill_r = fill_value_[0];
  427. fill_g = fill_value_[0];
  428. fill_b = fill_value_[0];
  429. if (fill_value_.size() == 3) {
  430. fill_r = fill_value_[0];
  431. fill_g = fill_value_[1];
  432. fill_b = fill_value_[2];
  433. }
  434. std::shared_ptr<PadOp> tensor_op =
  435. std::make_shared<PadOp>(pad_top, pad_bottom, pad_left, pad_right, padding_mode_, fill_r, fill_g, fill_b);
  436. return tensor_op;
  437. }
  438. // RandomColorAdjustOperation.
  439. RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast,
  440. std::vector<float> saturation, std::vector<float> hue)
  441. : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {}
  442. bool RandomColorAdjustOperation::ValidateParams() {
  443. // Do some input validation.
  444. if (brightness_.empty() || brightness_.size() > 2) {
  445. MS_LOG(ERROR) << "RandomColorAdjust: brightness must be a vector of one or two values";
  446. return false;
  447. }
  448. if (contrast_.empty() || contrast_.size() > 2) {
  449. MS_LOG(ERROR) << "RandomColorAdjust: contrast must be a vector of one or two values";
  450. return false;
  451. }
  452. if (saturation_.empty() || saturation_.size() > 2) {
  453. MS_LOG(ERROR) << "RandomColorAdjust: saturation must be a vector of one or two values";
  454. return false;
  455. }
  456. if (hue_.empty() || hue_.size() > 2) {
  457. MS_LOG(ERROR) << "RandomColorAdjust: hue must be a vector of one or two values";
  458. return false;
  459. }
  460. return true;
  461. }
  462. std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
  463. float brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub;
  464. brightness_lb = brightness_[0];
  465. brightness_ub = brightness_[0];
  466. if (brightness_.size() == 2) brightness_ub = brightness_[1];
  467. contrast_lb = contrast_[0];
  468. contrast_ub = contrast_[0];
  469. if (contrast_.size() == 2) contrast_ub = contrast_[1];
  470. saturation_lb = saturation_[0];
  471. saturation_ub = saturation_[0];
  472. if (saturation_.size() == 2) saturation_ub = saturation_[1];
  473. hue_lb = hue_[0];
  474. hue_ub = hue_[0];
  475. if (hue_.size() == 2) hue_ub = hue_[1];
  476. std::shared_ptr<RandomColorAdjustOp> tensor_op = std::make_shared<RandomColorAdjustOp>(
  477. brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub);
  478. return tensor_op;
  479. }
  480. // RandomAffineOperation
  481. RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> &degrees,
  482. const std::vector<float_t> &translate_range,
  483. const std::vector<float_t> &scale_range,
  484. const std::vector<float_t> &shear_ranges, InterpolationMode interpolation,
  485. const std::vector<uint8_t> &fill_value)
  486. : degrees_(degrees),
  487. translate_range_(translate_range),
  488. scale_range_(scale_range),
  489. shear_ranges_(shear_ranges),
  490. interpolation_(interpolation),
  491. fill_value_(fill_value) {}
  492. bool RandomAffineOperation::ValidateParams() {
  493. // Degrees
  494. if (degrees_.size() != 2) {
  495. MS_LOG(ERROR) << "RandomAffine: degrees vector has incorrect size: degrees.size() = " << degrees_.size();
  496. return false;
  497. }
  498. if (degrees_[0] > degrees_[1]) {
  499. MS_LOG(ERROR) << "RandomAffine: minimum of degrees range is greater than maximum: min = " << degrees_[0]
  500. << ", max = " << degrees_[1];
  501. return false;
  502. }
  503. // Translate
  504. if (translate_range_.size() != 2) {
  505. MS_LOG(ERROR) << "RandomAffine: translate_range vector has incorrect size: translate_range.size() = "
  506. << translate_range_.size();
  507. return false;
  508. }
  509. if (translate_range_[0] > translate_range_[1]) {
  510. MS_LOG(ERROR) << "RandomAffine: minimum of translate range is greater than maximum: min = " << translate_range_[0]
  511. << ", max = " << translate_range_[1];
  512. return false;
  513. }
  514. // Scale
  515. if (scale_range_.size() != 2) {
  516. MS_LOG(ERROR) << "RandomAffine: scale_range vector has incorrect size: scale_range.size() = "
  517. << scale_range_.size();
  518. return false;
  519. }
  520. if (scale_range_[0] > scale_range_[1]) {
  521. MS_LOG(ERROR) << "RandomAffine: minimum of scale range is greater than maximum: min = " << scale_range_[0]
  522. << ", max = " << scale_range_[1];
  523. return false;
  524. }
  525. // Shear
  526. if (shear_ranges_.size() != 4) {
  527. MS_LOG(ERROR) << "RandomAffine: shear_ranges vector has incorrect size: shear_ranges.size() = "
  528. << shear_ranges_.size();
  529. return false;
  530. }
  531. if (shear_ranges_[0] > shear_ranges_[1]) {
  532. MS_LOG(ERROR) << "RandomAffine: minimum of horizontal shear range is greater than maximum: min = "
  533. << shear_ranges_[0] << ", max = " << shear_ranges_[1];
  534. return false;
  535. }
  536. if (shear_ranges_[2] > shear_ranges_[3]) {
  537. MS_LOG(ERROR) << "RandomAffine: minimum of vertical shear range is greater than maximum: min = " << shear_ranges_[2]
  538. << ", max = " << scale_range_[3];
  539. return false;
  540. }
  541. // Fill Value
  542. if (fill_value_.size() != 3) {
  543. MS_LOG(ERROR) << "RandomAffine: fill_value vector has incorrect size: fill_value.size() = " << fill_value_.size();
  544. return false;
  545. }
  546. return true;
  547. }
  548. std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
  549. auto tensor_op = std::make_shared<RandomAffineOp>(degrees_, translate_range_, scale_range_, shear_ranges_,
  550. interpolation_, fill_value_);
  551. return tensor_op;
  552. }
  553. // RandomCropOperation
  554. RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
  555. std::vector<uint8_t> fill_value, BorderType padding_mode)
  556. : size_(size),
  557. padding_(padding),
  558. pad_if_needed_(pad_if_needed),
  559. fill_value_(fill_value),
  560. padding_mode_(padding_mode) {}
  561. bool RandomCropOperation::ValidateParams() {
  562. if (size_.empty() || size_.size() > 2) {
  563. MS_LOG(ERROR) << "RandomCrop: size vector has incorrect size: " << size_.size();
  564. return false;
  565. }
  566. if (padding_.empty() || padding_.size() != 4) {
  567. MS_LOG(ERROR) << "RandomCrop: padding vector has incorrect size: padding.size()";
  568. return false;
  569. }
  570. if (fill_value_.empty() || fill_value_.size() != 3) {
  571. MS_LOG(ERROR) << "RandomCrop: fill_value vector has incorrect size: fill_value.size()";
  572. return false;
  573. }
  574. return true;
  575. }
  576. std::shared_ptr<TensorOp> RandomCropOperation::Build() {
  577. int32_t crop_height = size_[0];
  578. int32_t crop_width = 0;
  579. int32_t pad_top = padding_[0];
  580. int32_t pad_bottom = padding_[1];
  581. int32_t pad_left = padding_[2];
  582. int32_t pad_right = padding_[3];
  583. uint8_t fill_r = fill_value_[0];
  584. uint8_t fill_g = fill_value_[1];
  585. uint8_t fill_b = fill_value_[2];
  586. // User has specified the crop_width value.
  587. if (size_.size() == 2) {
  588. crop_width = size_[1];
  589. }
  590. auto tensor_op = std::make_shared<RandomCropOp>(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right,
  591. padding_mode_, pad_if_needed_, fill_r, fill_g, fill_b);
  592. return tensor_op;
  593. }
  594. // RandomHorizontalFlipOperation
  595. RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {}
  596. bool RandomHorizontalFlipOperation::ValidateParams() { return true; }
  597. std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
  598. std::shared_ptr<RandomHorizontalFlipOp> tensor_op = std::make_shared<RandomHorizontalFlipOp>(probability_);
  599. return tensor_op;
  600. }
  601. // Function to create RandomRotationOperation.
  602. RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode,
  603. bool expand, std::vector<float> center,
  604. std::vector<uint8_t> fill_value)
  605. : degrees_(degrees),
  606. interpolation_mode_(interpolation_mode),
  607. expand_(expand),
  608. center_(center),
  609. fill_value_(fill_value) {}
  610. bool RandomRotationOperation::ValidateParams() {
  611. if (degrees_.empty() || degrees_.size() != 2) {
  612. MS_LOG(ERROR) << "RandomRotation: degrees vector has incorrect size: degrees.size()";
  613. return false;
  614. }
  615. if (center_.empty() || center_.size() != 2) {
  616. MS_LOG(ERROR) << "RandomRotation: center vector has incorrect size: center.size()";
  617. return false;
  618. }
  619. if (fill_value_.empty() || fill_value_.size() != 3) {
  620. MS_LOG(ERROR) << "RandomRotation: fill_value vector has incorrect size: fill_value.size()";
  621. return false;
  622. }
  623. return true;
  624. }
  625. std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
  626. std::shared_ptr<RandomRotationOp> tensor_op =
  627. std::make_shared<RandomRotationOp>(degrees_[0], degrees_[1], center_[0], center_[1], interpolation_mode_, expand_,
  628. fill_value_[0], fill_value_[1], fill_value_[2]);
  629. return tensor_op;
  630. }
  631. // Function to create RandomSharpness.
  632. RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {}
  633. bool RandomSharpnessOperation::ValidateParams() {
  634. if (degrees_.empty() || degrees_.size() != 2) {
  635. MS_LOG(ERROR) << "RandomSharpness: degrees vector has incorrect size: degrees.size()";
  636. return false;
  637. }
  638. return true;
  639. }
  640. std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
  641. std::shared_ptr<RandomSharpnessOp> tensor_op = std::make_shared<RandomSharpnessOp>(degrees_[0], degrees_[1]);
  642. return tensor_op;
  643. }
  644. // RandomSolarizeOperation.
  645. RandomSolarizeOperation::RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max)
  646. : threshold_min_(threshold_min), threshold_max_(threshold_max) {}
  647. bool RandomSolarizeOperation::ValidateParams() {
  648. if (threshold_max_ < threshold_min_) {
  649. MS_LOG(ERROR) << "RandomSolarize: threshold_max must be greater or equal to threshold_min";
  650. return false;
  651. }
  652. return true;
  653. }
  654. std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() {
  655. std::shared_ptr<RandomSolarizeOp> tensor_op = std::make_shared<RandomSolarizeOp>(threshold_min_, threshold_max_);
  656. return tensor_op;
  657. }
  658. // RandomVerticalFlipOperation
  659. RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {}
  660. bool RandomVerticalFlipOperation::ValidateParams() { return true; }
  661. std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() {
  662. std::shared_ptr<RandomVerticalFlipOp> tensor_op = std::make_shared<RandomVerticalFlipOp>(probability_);
  663. return tensor_op;
  664. }
  665. // ResizeOperation
  666. ResizeOperation::ResizeOperation(std::vector<int32_t> size, InterpolationMode interpolation)
  667. : size_(size), interpolation_(interpolation) {}
  668. bool ResizeOperation::ValidateParams() {
  669. if (size_.empty() || size_.size() > 2) {
  670. MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size();
  671. return false;
  672. }
  673. return true;
  674. }
  675. std::shared_ptr<TensorOp> ResizeOperation::Build() {
  676. int32_t height = size_[0];
  677. int32_t width = 0;
  678. // User specified the width value.
  679. if (size_.size() == 2) {
  680. width = size_[1];
  681. }
  682. return std::make_shared<ResizeOp>(height, width, interpolation_);
  683. }
  684. // RgbaToBgrOperation.
  685. RgbaToBgrOperation::RgbaToBgrOperation() {}
  686. bool RgbaToBgrOperation::ValidateParams() { return true; }
  687. std::shared_ptr<TensorOp> RgbaToBgrOperation::Build() {
  688. std::shared_ptr<RgbaToBgrOp> tensor_op = std::make_shared<RgbaToBgrOp>();
  689. return tensor_op;
  690. }
  691. // RgbaToRgbOperation.
  692. RgbaToRgbOperation::RgbaToRgbOperation() {}
  693. bool RgbaToRgbOperation::ValidateParams() { return true; }
  694. std::shared_ptr<TensorOp> RgbaToRgbOperation::Build() {
  695. std::shared_ptr<RgbaToRgbOp> tensor_op = std::make_shared<RgbaToRgbOp>();
  696. return tensor_op;
  697. }
  698. // SwapRedBlueOperation.
  699. SwapRedBlueOperation::SwapRedBlueOperation() {}
  700. bool SwapRedBlueOperation::ValidateParams() { return true; }
  701. std::shared_ptr<TensorOp> SwapRedBlueOperation::Build() {
  702. std::shared_ptr<SwapRedBlueOp> tensor_op = std::make_shared<SwapRedBlueOp>();
  703. return tensor_op;
  704. }
  705. // UniformAugOperation
  706. UniformAugOperation::UniformAugOperation(std::vector<std::shared_ptr<TensorOperation>> transforms, int32_t num_ops)
  707. : transforms_(transforms), num_ops_(num_ops) {}
  708. bool UniformAugOperation::ValidateParams() { return true; }
  709. std::shared_ptr<TensorOp> UniformAugOperation::Build() {
  710. std::vector<std::shared_ptr<TensorOp>> tensor_ops;
  711. (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
  712. [](std::shared_ptr<TensorOperation> op) -> std::shared_ptr<TensorOp> { return op->Build(); });
  713. std::shared_ptr<UniformAugOp> tensor_op = std::make_shared<UniformAugOp>(tensor_ops, num_ops_);
  714. return tensor_op;
  715. }
  716. } // namespace vision
  717. } // namespace api
  718. } // namespace dataset
  719. } // namespace mindspore