You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transforms.cc 31 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/include/transforms.h"
  17. #include "minddata/dataset/kernels/image/image_utils.h"
  18. #include "minddata/dataset/kernels/image/center_crop_op.h"
  19. #include "minddata/dataset/kernels/image/crop_op.h"
  20. #include "minddata/dataset/kernels/image/cutmix_batch_op.h"
  21. #include "minddata/dataset/kernels/image/cut_out_op.h"
  22. #include "minddata/dataset/kernels/image/decode_op.h"
  23. #include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
  24. #include "minddata/dataset/kernels/image/mixup_batch_op.h"
  25. #include "minddata/dataset/kernels/image/normalize_op.h"
  26. #include "minddata/dataset/kernels/data/one_hot_op.h"
  27. #include "minddata/dataset/kernels/image/pad_op.h"
  28. #include "minddata/dataset/kernels/image/random_affine_op.h"
  29. #include "minddata/dataset/kernels/image/random_color_op.h"
  30. #include "minddata/dataset/kernels/image/random_color_adjust_op.h"
  31. #include "minddata/dataset/kernels/image/random_crop_op.h"
  32. #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
  33. #include "minddata/dataset/kernels/image/random_posterize_op.h"
  34. #include "minddata/dataset/kernels/image/random_rotation_op.h"
  35. #include "minddata/dataset/kernels/image/random_sharpness_op.h"
  36. #include "minddata/dataset/kernels/image/random_solarize_op.h"
  37. #include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
  38. #include "minddata/dataset/kernels/image/resize_op.h"
  39. #include "minddata/dataset/kernels/image/rgba_to_bgr_op.h"
  40. #include "minddata/dataset/kernels/image/rgba_to_rgb_op.h"
  41. #include "minddata/dataset/kernels/image/swap_red_blue_op.h"
  42. #include "minddata/dataset/kernels/image/uniform_aug_op.h"
  43. namespace mindspore {
  44. namespace dataset {
  45. namespace api {
  46. TensorOperation::TensorOperation() {}
  47. // Transform operations for computer vision.
  48. namespace vision {
  49. // Function to create CenterCropOperation.
  50. std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size) {
  51. auto op = std::make_shared<CenterCropOperation>(size);
  52. // Input validation
  53. if (!op->ValidateParams()) {
  54. return nullptr;
  55. }
  56. return op;
  57. }
  58. // Function to create CropOperation.
  59. std::shared_ptr<CropOperation> Crop(std::vector<int32_t> coordinates, std::vector<int32_t> size) {
  60. auto op = std::make_shared<CropOperation>(coordinates, size);
  61. // Input validation
  62. if (!op->ValidateParams()) {
  63. return nullptr;
  64. }
  65. return op;
  66. }
  67. // Function to create CutMixBatchOperation.
  68. std::shared_ptr<CutMixBatchOperation> CutMixBatch(ImageBatchFormat image_batch_format, float alpha, float prob) {
  69. auto op = std::make_shared<CutMixBatchOperation>(image_batch_format, alpha, prob);
  70. // Input validation
  71. if (!op->ValidateParams()) {
  72. return nullptr;
  73. }
  74. return op;
  75. }
  76. // Function to create CutOutOp.
  77. std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches) {
  78. auto op = std::make_shared<CutOutOperation>(length, num_patches);
  79. // Input validation
  80. if (!op->ValidateParams()) {
  81. return nullptr;
  82. }
  83. return op;
  84. }
  85. // Function to create DecodeOperation.
  86. std::shared_ptr<DecodeOperation> Decode(bool rgb) {
  87. auto op = std::make_shared<DecodeOperation>(rgb);
  88. // Input validation
  89. if (!op->ValidateParams()) {
  90. return nullptr;
  91. }
  92. return op;
  93. }
  94. // Function to create HwcToChwOperation.
  95. std::shared_ptr<HwcToChwOperation> HWC2CHW() {
  96. auto op = std::make_shared<HwcToChwOperation>();
  97. // Input validation
  98. if (!op->ValidateParams()) {
  99. return nullptr;
  100. }
  101. return op;
  102. }
  103. // Function to create MixUpBatchOperation.
  104. std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha) {
  105. auto op = std::make_shared<MixUpBatchOperation>(alpha);
  106. // Input validation
  107. if (!op->ValidateParams()) {
  108. return nullptr;
  109. }
  110. return op;
  111. }
  112. // Function to create NormalizeOperation.
  113. std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std) {
  114. auto op = std::make_shared<NormalizeOperation>(mean, std);
  115. // Input validation
  116. if (!op->ValidateParams()) {
  117. return nullptr;
  118. }
  119. return op;
  120. }
  121. // Function to create OneHotOperation.
  122. std::shared_ptr<OneHotOperation> OneHot(int32_t num_classes) {
  123. auto op = std::make_shared<OneHotOperation>(num_classes);
  124. // Input validation
  125. if (!op->ValidateParams()) {
  126. return nullptr;
  127. }
  128. return op;
  129. }
  130. // Function to create PadOperation.
  131. std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value,
  132. BorderType padding_mode) {
  133. auto op = std::make_shared<PadOperation>(padding, fill_value, padding_mode);
  134. // Input validation
  135. if (!op->ValidateParams()) {
  136. return nullptr;
  137. }
  138. return op;
  139. }
  140. // Function to create RandomColorOperation.
  141. std::shared_ptr<RandomColorOperation> RandomColor(float t_lb, float t_ub) {
  142. auto op = std::make_shared<RandomColorOperation>(t_lb, t_ub);
  143. // Input validation
  144. if (!op->ValidateParams()) {
  145. return nullptr;
  146. }
  147. return op;
  148. }
  149. std::shared_ptr<TensorOp> RandomColorOperation::Build() {
  150. std::shared_ptr<RandomColorOp> tensor_op = std::make_shared<RandomColorOp>(t_lb_, t_ub_);
  151. return tensor_op;
  152. }
  153. // Function to create RandomColorAdjustOperation.
  154. std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness,
  155. std::vector<float> contrast,
  156. std::vector<float> saturation, std::vector<float> hue) {
  157. auto op = std::make_shared<RandomColorAdjustOperation>(brightness, contrast, saturation, hue);
  158. // Input validation
  159. if (!op->ValidateParams()) {
  160. return nullptr;
  161. }
  162. return op;
  163. }
  164. // Function to create RandomAffineOperation.
  165. std::shared_ptr<RandomAffineOperation> RandomAffine(const std::vector<float_t> &degrees,
  166. const std::vector<float_t> &translate_range,
  167. const std::vector<float_t> &scale_range,
  168. const std::vector<float_t> &shear_ranges,
  169. InterpolationMode interpolation,
  170. const std::vector<uint8_t> &fill_value) {
  171. auto op = std::make_shared<RandomAffineOperation>(degrees, translate_range, scale_range, shear_ranges, interpolation,
  172. fill_value);
  173. // Input validation
  174. if (!op->ValidateParams()) {
  175. return nullptr;
  176. }
  177. return op;
  178. }
  179. // Function to create RandomCropOperation.
  180. std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding,
  181. bool pad_if_needed, std::vector<uint8_t> fill_value,
  182. BorderType padding_mode) {
  183. auto op = std::make_shared<RandomCropOperation>(size, padding, pad_if_needed, fill_value, padding_mode);
  184. // Input validation
  185. if (!op->ValidateParams()) {
  186. return nullptr;
  187. }
  188. return op;
  189. }
  190. // Function to create RandomHorizontalFlipOperation.
  191. std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob) {
  192. auto op = std::make_shared<RandomHorizontalFlipOperation>(prob);
  193. // Input validation
  194. if (!op->ValidateParams()) {
  195. return nullptr;
  196. }
  197. return op;
  198. }
  199. // Function to create RandomPosterizeOperation.
  200. std::shared_ptr<RandomPosterizeOperation> RandomPosterize(uint8_t min_bit, uint8_t max_bit) {
  201. auto op = std::make_shared<RandomPosterizeOperation>(min_bit, max_bit);
  202. // Input validation
  203. if (!op->ValidateParams()) {
  204. return nullptr;
  205. }
  206. return op;
  207. }
  208. // Function to create RandomRotationOperation.
  209. std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degrees, InterpolationMode resample,
  210. bool expand, std::vector<float> center,
  211. std::vector<uint8_t> fill_value) {
  212. auto op = std::make_shared<RandomRotationOperation>(degrees, resample, expand, center, fill_value);
  213. // Input validation
  214. if (!op->ValidateParams()) {
  215. return nullptr;
  216. }
  217. return op;
  218. }
  219. // Function to create RandomSolarizeOperation.
  220. std::shared_ptr<RandomSolarizeOperation> RandomSolarize(std::vector<uint8_t> threshold) {
  221. auto op = std::make_shared<RandomSolarizeOperation>(threshold);
  222. // Input validation
  223. if (!op->ValidateParams()) {
  224. return nullptr;
  225. }
  226. return op;
  227. }
  228. // Function to create RandomSharpnessOperation.
  229. std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> degrees) {
  230. auto op = std::make_shared<RandomSharpnessOperation>(degrees);
  231. // Input validation
  232. if (!op->ValidateParams()) {
  233. return nullptr;
  234. }
  235. return op;
  236. }
  237. // Function to create RandomVerticalFlipOperation.
  238. std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob) {
  239. auto op = std::make_shared<RandomVerticalFlipOperation>(prob);
  240. // Input validation
  241. if (!op->ValidateParams()) {
  242. return nullptr;
  243. }
  244. return op;
  245. }
  246. // Function to create ResizeOperation.
  247. std::shared_ptr<ResizeOperation> Resize(std::vector<int32_t> size, InterpolationMode interpolation) {
  248. auto op = std::make_shared<ResizeOperation>(size, interpolation);
  249. // Input validation
  250. if (!op->ValidateParams()) {
  251. return nullptr;
  252. }
  253. return op;
  254. }
  255. // Function to create RgbaToBgrOperation.
  256. std::shared_ptr<RgbaToBgrOperation> RGBA2BGR() {
  257. auto op = std::make_shared<RgbaToBgrOperation>();
  258. // Input validation
  259. if (!op->ValidateParams()) {
  260. return nullptr;
  261. }
  262. return op;
  263. }
  264. // Function to create RgbaToRgbOperation.
  265. std::shared_ptr<RgbaToRgbOperation> RGBA2RGB() {
  266. auto op = std::make_shared<RgbaToRgbOperation>();
  267. // Input validation
  268. if (!op->ValidateParams()) {
  269. return nullptr;
  270. }
  271. return op;
  272. }
  273. // Function to create SwapRedBlueOperation.
  274. std::shared_ptr<SwapRedBlueOperation> SwapRedBlue() {
  275. auto op = std::make_shared<SwapRedBlueOperation>();
  276. // Input validation
  277. if (!op->ValidateParams()) {
  278. return nullptr;
  279. }
  280. return op;
  281. }
  282. // Function to create UniformAugOperation.
  283. std::shared_ptr<UniformAugOperation> UniformAugment(std::vector<std::shared_ptr<TensorOperation>> transforms,
  284. int32_t num_ops) {
  285. auto op = std::make_shared<UniformAugOperation>(transforms, num_ops);
  286. // Input validation
  287. if (!op->ValidateParams()) {
  288. return nullptr;
  289. }
  290. return op;
  291. }
  292. /* ####################################### Derived TensorOperation classes ################################# */
  293. // CenterCropOperation
  294. CenterCropOperation::CenterCropOperation(std::vector<int32_t> size) : size_(size) {}
  295. bool CenterCropOperation::ValidateParams() {
  296. if (size_.empty() || size_.size() > 2) {
  297. MS_LOG(ERROR) << "CenterCrop: size vector has incorrect size.";
  298. return false;
  299. }
  300. return true;
  301. }
  302. std::shared_ptr<TensorOp> CenterCropOperation::Build() {
  303. int32_t crop_height = size_[0];
  304. int32_t crop_width = 0;
  305. // User has specified crop_width.
  306. if (size_.size() == 2) {
  307. crop_width = size_[1];
  308. }
  309. std::shared_ptr<CenterCropOp> tensor_op = std::make_shared<CenterCropOp>(crop_height, crop_width);
  310. return tensor_op;
  311. }
  312. // CropOperation.
  313. CropOperation::CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size)
  314. : coordinates_(coordinates), size_(size) {}
  315. bool CropOperation::ValidateParams() {
  316. // Do some input validation.
  317. if (coordinates_.empty() || coordinates_.size() > 2) {
  318. MS_LOG(ERROR) << "Crop: coordinates must be a vector of one or two values";
  319. return false;
  320. }
  321. if (size_.empty() || size_.size() > 2) {
  322. MS_LOG(ERROR) << "Crop: size must be a vector of one or two values";
  323. return false;
  324. }
  325. return true;
  326. }
  327. std::shared_ptr<TensorOp> CropOperation::Build() {
  328. int32_t x, y, height, width;
  329. x = coordinates_[0];
  330. y = coordinates_[1];
  331. height = size_[0];
  332. width = size_[1];
  333. std::shared_ptr<CropOp> tensor_op = std::make_shared<CropOp>(x, y, height, width);
  334. return tensor_op;
  335. }
  336. // CutMixBatchOperation
  337. CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha, float prob)
  338. : image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) {}
  339. bool CutMixBatchOperation::ValidateParams() {
  340. if (alpha_ <= 0) {
  341. MS_LOG(ERROR) << "CutMixBatch: alpha cannot be negative.";
  342. return false;
  343. }
  344. if (prob_ < 0 || prob_ > 1) {
  345. MS_LOG(ERROR) << "CutMixBatch: Probability has to be between 0 and 1.";
  346. return false;
  347. }
  348. return true;
  349. }
  350. std::shared_ptr<TensorOp> CutMixBatchOperation::Build() {
  351. std::shared_ptr<CutMixBatchOp> tensor_op = std::make_shared<CutMixBatchOp>(image_batch_format_, alpha_, prob_);
  352. return tensor_op;
  353. }
  354. // CutOutOperation
  355. CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {}
  356. bool CutOutOperation::ValidateParams() {
  357. if (length_ < 0) {
  358. MS_LOG(ERROR) << "CutOut: length cannot be negative";
  359. return false;
  360. }
  361. if (num_patches_ < 0) {
  362. MS_LOG(ERROR) << "CutOut: number of patches cannot be negative";
  363. return false;
  364. }
  365. return true;
  366. }
  367. std::shared_ptr<TensorOp> CutOutOperation::Build() {
  368. std::shared_ptr<CutOutOp> tensor_op = std::make_shared<CutOutOp>(length_, length_, num_patches_, false, 0, 0, 0);
  369. return tensor_op;
  370. }
  371. // DecodeOperation
  372. DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {}
  373. bool DecodeOperation::ValidateParams() { return true; }
  374. std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); }
  375. // HwcToChwOperation
  376. bool HwcToChwOperation::ValidateParams() { return true; }
  377. std::shared_ptr<TensorOp> HwcToChwOperation::Build() { return std::make_shared<HwcToChwOp>(); }
  378. // MixUpOperation
  379. MixUpBatchOperation::MixUpBatchOperation(float alpha) : alpha_(alpha) {}
  380. bool MixUpBatchOperation::ValidateParams() {
  381. if (alpha_ <= 0) {
  382. MS_LOG(ERROR) << "MixUpBatch: alpha must be a positive floating value however it is: " << alpha_;
  383. return false;
  384. }
  385. return true;
  386. }
  387. std::shared_ptr<TensorOp> MixUpBatchOperation::Build() { return std::make_shared<MixUpBatchOp>(alpha_); }
  388. // NormalizeOperation
  389. NormalizeOperation::NormalizeOperation(std::vector<float> mean, std::vector<float> std) : mean_(mean), std_(std) {}
  390. bool NormalizeOperation::ValidateParams() {
  391. if (mean_.size() != 3) {
  392. MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size();
  393. return false;
  394. }
  395. if (std_.size() != 3) {
  396. MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size();
  397. return false;
  398. }
  399. return true;
  400. }
  401. std::shared_ptr<TensorOp> NormalizeOperation::Build() {
  402. return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]);
  403. }
  404. // OneHotOperation
  405. OneHotOperation::OneHotOperation(int32_t num_classes) : num_classes_(num_classes) {}
  406. bool OneHotOperation::ValidateParams() {
  407. if (num_classes_ < 0) {
  408. MS_LOG(ERROR) << "OneHot: Number of classes cannot be negative. Number of classes: " << num_classes_;
  409. return false;
  410. }
  411. return true;
  412. }
  413. std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
  414. // PadOperation
  415. PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode)
  416. : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {}
  417. bool PadOperation::ValidateParams() {
  418. if (padding_.empty() || padding_.size() == 3 || padding_.size() > 4) {
  419. MS_LOG(ERROR) << "Pad: padding vector has incorrect size: padding.size()";
  420. return false;
  421. }
  422. if (fill_value_.empty() || (fill_value_.size() != 1 && fill_value_.size() != 3)) {
  423. MS_LOG(ERROR) << "Pad: fill_value vector has incorrect size: fill_value.size()";
  424. return false;
  425. }
  426. return true;
  427. }
  428. std::shared_ptr<TensorOp> PadOperation::Build() {
  429. int32_t pad_top, pad_bottom, pad_left, pad_right;
  430. switch (padding_.size()) {
  431. case 1:
  432. pad_left = padding_[0];
  433. pad_top = padding_[0];
  434. pad_right = padding_[0];
  435. pad_bottom = padding_[0];
  436. break;
  437. case 2:
  438. pad_left = padding_[0];
  439. pad_top = padding_[1];
  440. pad_right = padding_[0];
  441. pad_bottom = padding_[1];
  442. break;
  443. default:
  444. pad_left = padding_[0];
  445. pad_top = padding_[1];
  446. pad_right = padding_[2];
  447. pad_bottom = padding_[3];
  448. }
  449. uint8_t fill_r, fill_g, fill_b;
  450. fill_r = fill_value_[0];
  451. fill_g = fill_value_[0];
  452. fill_b = fill_value_[0];
  453. if (fill_value_.size() == 3) {
  454. fill_r = fill_value_[0];
  455. fill_g = fill_value_[1];
  456. fill_b = fill_value_[2];
  457. }
  458. std::shared_ptr<PadOp> tensor_op =
  459. std::make_shared<PadOp>(pad_top, pad_bottom, pad_left, pad_right, padding_mode_, fill_r, fill_g, fill_b);
  460. return tensor_op;
  461. }
  462. // RandomColorOperation.
  463. RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {}
  464. bool RandomColorOperation::ValidateParams() {
  465. // Do some input validation.
  466. if (t_lb_ > t_ub_) {
  467. MS_LOG(ERROR) << "RandomColor: lower bound must be less or equal to upper bound";
  468. return false;
  469. }
  470. return true;
  471. }
  472. // RandomColorAdjustOperation.
  473. RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast,
  474. std::vector<float> saturation, std::vector<float> hue)
  475. : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {}
  476. bool RandomColorAdjustOperation::ValidateParams() {
  477. // Do some input validation.
  478. if (brightness_.empty() || brightness_.size() > 2) {
  479. MS_LOG(ERROR) << "RandomColorAdjust: brightness must be a vector of one or two values";
  480. return false;
  481. }
  482. if (contrast_.empty() || contrast_.size() > 2) {
  483. MS_LOG(ERROR) << "RandomColorAdjust: contrast must be a vector of one or two values";
  484. return false;
  485. }
  486. if (saturation_.empty() || saturation_.size() > 2) {
  487. MS_LOG(ERROR) << "RandomColorAdjust: saturation must be a vector of one or two values";
  488. return false;
  489. }
  490. if (hue_.empty() || hue_.size() > 2) {
  491. MS_LOG(ERROR) << "RandomColorAdjust: hue must be a vector of one or two values";
  492. return false;
  493. }
  494. return true;
  495. }
  496. std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
  497. float brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub;
  498. brightness_lb = brightness_[0];
  499. brightness_ub = brightness_[0];
  500. if (brightness_.size() == 2) brightness_ub = brightness_[1];
  501. contrast_lb = contrast_[0];
  502. contrast_ub = contrast_[0];
  503. if (contrast_.size() == 2) contrast_ub = contrast_[1];
  504. saturation_lb = saturation_[0];
  505. saturation_ub = saturation_[0];
  506. if (saturation_.size() == 2) saturation_ub = saturation_[1];
  507. hue_lb = hue_[0];
  508. hue_ub = hue_[0];
  509. if (hue_.size() == 2) hue_ub = hue_[1];
  510. std::shared_ptr<RandomColorAdjustOp> tensor_op = std::make_shared<RandomColorAdjustOp>(
  511. brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub);
  512. return tensor_op;
  513. }
  514. // RandomAffineOperation
  515. RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> &degrees,
  516. const std::vector<float_t> &translate_range,
  517. const std::vector<float_t> &scale_range,
  518. const std::vector<float_t> &shear_ranges, InterpolationMode interpolation,
  519. const std::vector<uint8_t> &fill_value)
  520. : degrees_(degrees),
  521. translate_range_(translate_range),
  522. scale_range_(scale_range),
  523. shear_ranges_(shear_ranges),
  524. interpolation_(interpolation),
  525. fill_value_(fill_value) {}
  526. bool RandomAffineOperation::ValidateParams() {
  527. // Degrees
  528. if (degrees_.size() != 2) {
  529. MS_LOG(ERROR) << "RandomAffine: degrees vector has incorrect size: degrees.size() = " << degrees_.size();
  530. return false;
  531. }
  532. if (degrees_[0] > degrees_[1]) {
  533. MS_LOG(ERROR) << "RandomAffine: minimum of degrees range is greater than maximum: min = " << degrees_[0]
  534. << ", max = " << degrees_[1];
  535. return false;
  536. }
  537. // Translate
  538. if (translate_range_.size() != 2) {
  539. MS_LOG(ERROR) << "RandomAffine: translate_range vector has incorrect size: translate_range.size() = "
  540. << translate_range_.size();
  541. return false;
  542. }
  543. if (translate_range_[0] > translate_range_[1]) {
  544. MS_LOG(ERROR) << "RandomAffine: minimum of translate range is greater than maximum: min = " << translate_range_[0]
  545. << ", max = " << translate_range_[1];
  546. return false;
  547. }
  548. // Scale
  549. if (scale_range_.size() != 2) {
  550. MS_LOG(ERROR) << "RandomAffine: scale_range vector has incorrect size: scale_range.size() = "
  551. << scale_range_.size();
  552. return false;
  553. }
  554. if (scale_range_[0] > scale_range_[1]) {
  555. MS_LOG(ERROR) << "RandomAffine: minimum of scale range is greater than maximum: min = " << scale_range_[0]
  556. << ", max = " << scale_range_[1];
  557. return false;
  558. }
  559. // Shear
  560. if (shear_ranges_.size() != 4) {
  561. MS_LOG(ERROR) << "RandomAffine: shear_ranges vector has incorrect size: shear_ranges.size() = "
  562. << shear_ranges_.size();
  563. return false;
  564. }
  565. if (shear_ranges_[0] > shear_ranges_[1]) {
  566. MS_LOG(ERROR) << "RandomAffine: minimum of horizontal shear range is greater than maximum: min = "
  567. << shear_ranges_[0] << ", max = " << shear_ranges_[1];
  568. return false;
  569. }
  570. if (shear_ranges_[2] > shear_ranges_[3]) {
  571. MS_LOG(ERROR) << "RandomAffine: minimum of vertical shear range is greater than maximum: min = " << shear_ranges_[2]
  572. << ", max = " << scale_range_[3];
  573. return false;
  574. }
  575. // Fill Value
  576. if (fill_value_.size() != 3) {
  577. MS_LOG(ERROR) << "RandomAffine: fill_value vector has incorrect size: fill_value.size() = " << fill_value_.size();
  578. return false;
  579. }
  580. return true;
  581. }
  582. std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
  583. auto tensor_op = std::make_shared<RandomAffineOp>(degrees_, translate_range_, scale_range_, shear_ranges_,
  584. interpolation_, fill_value_);
  585. return tensor_op;
  586. }
  587. // RandomCropOperation
  588. RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
  589. std::vector<uint8_t> fill_value, BorderType padding_mode)
  590. : size_(size),
  591. padding_(padding),
  592. pad_if_needed_(pad_if_needed),
  593. fill_value_(fill_value),
  594. padding_mode_(padding_mode) {}
  595. bool RandomCropOperation::ValidateParams() {
  596. if (size_.empty() || size_.size() > 2) {
  597. MS_LOG(ERROR) << "RandomCrop: size vector has incorrect size: " << size_.size();
  598. return false;
  599. }
  600. if (padding_.empty() || padding_.size() != 4) {
  601. MS_LOG(ERROR) << "RandomCrop: padding vector has incorrect size: padding.size()";
  602. return false;
  603. }
  604. if (fill_value_.empty() || fill_value_.size() != 3) {
  605. MS_LOG(ERROR) << "RandomCrop: fill_value vector has incorrect size: fill_value.size()";
  606. return false;
  607. }
  608. return true;
  609. }
  610. std::shared_ptr<TensorOp> RandomCropOperation::Build() {
  611. int32_t crop_height = size_[0];
  612. int32_t crop_width = 0;
  613. int32_t pad_top = padding_[0];
  614. int32_t pad_bottom = padding_[1];
  615. int32_t pad_left = padding_[2];
  616. int32_t pad_right = padding_[3];
  617. uint8_t fill_r = fill_value_[0];
  618. uint8_t fill_g = fill_value_[1];
  619. uint8_t fill_b = fill_value_[2];
  620. // User has specified the crop_width value.
  621. if (size_.size() == 2) {
  622. crop_width = size_[1];
  623. }
  624. auto tensor_op = std::make_shared<RandomCropOp>(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right,
  625. padding_mode_, pad_if_needed_, fill_r, fill_g, fill_b);
  626. return tensor_op;
  627. }
  628. // RandomHorizontalFlipOperation
  629. RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {}
  630. bool RandomHorizontalFlipOperation::ValidateParams() { return true; }
  631. std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
  632. std::shared_ptr<RandomHorizontalFlipOp> tensor_op = std::make_shared<RandomHorizontalFlipOp>(probability_);
  633. return tensor_op;
  634. }
  635. // RandomPosterizeOperation
  636. RandomPosterizeOperation::RandomPosterizeOperation(uint8_t min_bit, uint8_t max_bit)
  637. : min_bit_(min_bit), max_bit_(max_bit) {}
  638. bool RandomPosterizeOperation::ValidateParams() {
  639. if (min_bit_ < 1 || min_bit_ > 8) {
  640. MS_LOG(ERROR) << "RandomPosterize: min_bit value is out of range [1-8]: " << min_bit_;
  641. return false;
  642. }
  643. if (max_bit_ < 1 || max_bit_ > 8) {
  644. MS_LOG(ERROR) << "RandomPosterize: max_bit value is out of range [1-8]: " << max_bit_;
  645. return false;
  646. }
  647. if (max_bit_ < min_bit_) {
  648. MS_LOG(ERROR) << "RandomPosterize: max_bit value is less than min_bit: max =" << max_bit_ << ", min = " << min_bit_;
  649. return false;
  650. }
  651. return true;
  652. }
  653. std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() {
  654. std::shared_ptr<RandomPosterizeOp> tensor_op = std::make_shared<RandomPosterizeOp>(min_bit_, max_bit_);
  655. return tensor_op;
  656. }
  657. // Function to create RandomRotationOperation.
  658. RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode,
  659. bool expand, std::vector<float> center,
  660. std::vector<uint8_t> fill_value)
  661. : degrees_(degrees),
  662. interpolation_mode_(interpolation_mode),
  663. expand_(expand),
  664. center_(center),
  665. fill_value_(fill_value) {}
  666. bool RandomRotationOperation::ValidateParams() {
  667. if (degrees_.empty() || degrees_.size() != 2) {
  668. MS_LOG(ERROR) << "RandomRotation: degrees vector has incorrect size: degrees.size()";
  669. return false;
  670. }
  671. if (center_.empty() || center_.size() != 2) {
  672. MS_LOG(ERROR) << "RandomRotation: center vector has incorrect size: center.size()";
  673. return false;
  674. }
  675. if (fill_value_.empty() || fill_value_.size() != 3) {
  676. MS_LOG(ERROR) << "RandomRotation: fill_value vector has incorrect size: fill_value.size()";
  677. return false;
  678. }
  679. return true;
  680. }
  681. std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
  682. std::shared_ptr<RandomRotationOp> tensor_op =
  683. std::make_shared<RandomRotationOp>(degrees_[0], degrees_[1], center_[0], center_[1], interpolation_mode_, expand_,
  684. fill_value_[0], fill_value_[1], fill_value_[2]);
  685. return tensor_op;
  686. }
  687. // Function to create RandomSharpness.
  688. RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {}
  689. bool RandomSharpnessOperation::ValidateParams() {
  690. if (degrees_.empty() || degrees_.size() != 2) {
  691. MS_LOG(ERROR) << "RandomSharpness: degrees vector has incorrect size: degrees.size()";
  692. return false;
  693. }
  694. return true;
  695. }
  696. std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
  697. std::shared_ptr<RandomSharpnessOp> tensor_op = std::make_shared<RandomSharpnessOp>(degrees_[0], degrees_[1]);
  698. return tensor_op;
  699. }
  700. // RandomSolarizeOperation.
  701. RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) : threshold_(threshold) {}
  702. bool RandomSolarizeOperation::ValidateParams() {
  703. if (threshold_.size() != 2) {
  704. MS_LOG(ERROR) << "RandomSolarize: threshold vector has incorrect size: " << threshold_.size();
  705. return false;
  706. }
  707. if (threshold_.at(0) > threshold_.at(1)) {
  708. MS_LOG(ERROR) << "RandomSolarize: threshold must be passed in a min, max format";
  709. return false;
  710. }
  711. return true;
  712. }
  713. std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() {
  714. std::shared_ptr<RandomSolarizeOp> tensor_op = std::make_shared<RandomSolarizeOp>(threshold_);
  715. return tensor_op;
  716. }
  717. // RandomVerticalFlipOperation
  718. RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {}
  719. bool RandomVerticalFlipOperation::ValidateParams() { return true; }
  720. std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() {
  721. std::shared_ptr<RandomVerticalFlipOp> tensor_op = std::make_shared<RandomVerticalFlipOp>(probability_);
  722. return tensor_op;
  723. }
  724. // ResizeOperation
  725. ResizeOperation::ResizeOperation(std::vector<int32_t> size, InterpolationMode interpolation)
  726. : size_(size), interpolation_(interpolation) {}
  727. bool ResizeOperation::ValidateParams() {
  728. if (size_.empty() || size_.size() > 2) {
  729. MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size();
  730. return false;
  731. }
  732. return true;
  733. }
  734. std::shared_ptr<TensorOp> ResizeOperation::Build() {
  735. int32_t height = size_[0];
  736. int32_t width = 0;
  737. // User specified the width value.
  738. if (size_.size() == 2) {
  739. width = size_[1];
  740. }
  741. return std::make_shared<ResizeOp>(height, width, interpolation_);
  742. }
  743. // RgbaToBgrOperation.
  744. RgbaToBgrOperation::RgbaToBgrOperation() {}
  745. bool RgbaToBgrOperation::ValidateParams() { return true; }
  746. std::shared_ptr<TensorOp> RgbaToBgrOperation::Build() {
  747. std::shared_ptr<RgbaToBgrOp> tensor_op = std::make_shared<RgbaToBgrOp>();
  748. return tensor_op;
  749. }
  750. // RgbaToRgbOperation.
  751. RgbaToRgbOperation::RgbaToRgbOperation() {}
  752. bool RgbaToRgbOperation::ValidateParams() { return true; }
  753. std::shared_ptr<TensorOp> RgbaToRgbOperation::Build() {
  754. std::shared_ptr<RgbaToRgbOp> tensor_op = std::make_shared<RgbaToRgbOp>();
  755. return tensor_op;
  756. }
  757. // SwapRedBlueOperation.
  758. SwapRedBlueOperation::SwapRedBlueOperation() {}
  759. bool SwapRedBlueOperation::ValidateParams() { return true; }
  760. std::shared_ptr<TensorOp> SwapRedBlueOperation::Build() {
  761. std::shared_ptr<SwapRedBlueOp> tensor_op = std::make_shared<SwapRedBlueOp>();
  762. return tensor_op;
  763. }
  764. // UniformAugOperation
  765. UniformAugOperation::UniformAugOperation(std::vector<std::shared_ptr<TensorOperation>> transforms, int32_t num_ops)
  766. : transforms_(transforms), num_ops_(num_ops) {}
  767. bool UniformAugOperation::ValidateParams() { return true; }
  768. std::shared_ptr<TensorOp> UniformAugOperation::Build() {
  769. std::vector<std::shared_ptr<TensorOp>> tensor_ops;
  770. (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
  771. [](std::shared_ptr<TensorOperation> op) -> std::shared_ptr<TensorOp> { return op->Build(); });
  772. std::shared_ptr<UniformAugOp> tensor_op = std::make_shared<UniformAugOp>(tensor_ops, num_ops_);
  773. return tensor_op;
  774. }
  775. } // namespace vision
  776. } // namespace api
  777. } // namespace dataset
  778. } // namespace mindspore