You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transforms.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/include/transforms.h"
  17. #include "minddata/dataset/kernels/image/image_utils.h"
  18. #include "minddata/dataset/kernels/image/normalize_op.h"
  19. #include "minddata/dataset/kernels/image/decode_op.h"
  20. #include "minddata/dataset/kernels/image/resize_op.h"
  21. #include "minddata/dataset/kernels/image/random_crop_op.h"
  22. #include "minddata/dataset/kernels/image/center_crop_op.h"
  23. #include "minddata/dataset/kernels/image/uniform_aug_op.h"
  24. #include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
  25. #include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
  26. #include "minddata/dataset/kernels/image/random_rotation_op.h"
  27. #include "minddata/dataset/kernels/image/cut_out_op.h"
  28. #include "minddata/dataset/kernels/image/random_color_adjust_op.h"
  29. #include "minddata/dataset/kernels/image/pad_op.h"
  30. namespace mindspore {
  31. namespace dataset {
  32. namespace api {
  33. TensorOperation::TensorOperation() {}
  34. // Transform operations for computer vision.
  35. namespace vision {
  36. // Function to create NormalizeOperation.
  37. std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std) {
  38. auto op = std::make_shared<NormalizeOperation>(mean, std);
  39. // Input validation
  40. if (!op->ValidateParams()) {
  41. return nullptr;
  42. }
  43. return op;
  44. }
  45. // Function to create DecodeOperation.
  46. std::shared_ptr<DecodeOperation> Decode(bool rgb) {
  47. auto op = std::make_shared<DecodeOperation>(rgb);
  48. // Input validation
  49. if (!op->ValidateParams()) {
  50. return nullptr;
  51. }
  52. return op;
  53. }
  54. // Function to create ResizeOperation.
  55. std::shared_ptr<ResizeOperation> Resize(std::vector<int32_t> size, InterpolationMode interpolation) {
  56. auto op = std::make_shared<ResizeOperation>(size, interpolation);
  57. // Input validation
  58. if (!op->ValidateParams()) {
  59. return nullptr;
  60. }
  61. return op;
  62. }
  63. // Function to create RandomCropOperation.
  64. std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding,
  65. bool pad_if_needed, std::vector<uint8_t> fill_value) {
  66. auto op = std::make_shared<RandomCropOperation>(size, padding, pad_if_needed, fill_value);
  67. // Input validation
  68. if (!op->ValidateParams()) {
  69. return nullptr;
  70. }
  71. return op;
  72. }
  73. // Function to create CenterCropOperation.
  74. std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size) {
  75. auto op = std::make_shared<CenterCropOperation>(size);
  76. // Input validation
  77. if (!op->ValidateParams()) {
  78. return nullptr;
  79. }
  80. return op;
  81. }
  82. // Function to create UniformAugOperation.
  83. std::shared_ptr<UniformAugOperation> UniformAugment(std::vector<std::shared_ptr<TensorOperation>> operations,
  84. int32_t num_ops) {
  85. auto op = std::make_shared<UniformAugOperation>(operations, num_ops);
  86. // Input validation
  87. if (!op->ValidateParams()) {
  88. return nullptr;
  89. }
  90. return op;
  91. }
  92. // Function to create RandomHorizontalFlipOperation.
  93. std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob) {
  94. auto op = std::make_shared<RandomHorizontalFlipOperation>(prob);
  95. // Input validation
  96. if (!op->ValidateParams()) {
  97. return nullptr;
  98. }
  99. return op;
  100. }
  101. // Function to create RandomVerticalFlipOperation.
  102. std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob) {
  103. auto op = std::make_shared<RandomVerticalFlipOperation>(prob);
  104. // Input validation
  105. if (!op->ValidateParams()) {
  106. return nullptr;
  107. }
  108. return op;
  109. }
  110. // Function to create RandomRotationOperation.
  111. std::shared_ptr<RandomRotationOperation> RandomRotation(std::vector<float> degrees, InterpolationMode resample,
  112. bool expand, std::vector<float> center,
  113. std::vector<uint8_t> fill_value) {
  114. auto op = std::make_shared<RandomRotationOperation>(degrees, resample, expand, center, fill_value);
  115. // Input validation
  116. if (!op->ValidateParams()) {
  117. return nullptr;
  118. }
  119. return op;
  120. }
  121. // Function to create PadOperation.
  122. std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value,
  123. BorderType padding_mode) {
  124. auto op = std::make_shared<PadOperation>(padding, fill_value, padding_mode);
  125. // Input validation
  126. if (!op->ValidateParams()) {
  127. return nullptr;
  128. }
  129. return op;
  130. }
  131. // Function to create CutOutOp.
  132. std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches) {
  133. auto op = std::make_shared<CutOutOperation>(length, num_patches);
  134. // Input validation
  135. if (!op->ValidateParams()) {
  136. return nullptr;
  137. }
  138. return op;
  139. }
  140. // Function to create RandomColorAdjustOperation.
  141. std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness,
  142. std::vector<float> contrast,
  143. std::vector<float> saturation, std::vector<float> hue) {
  144. auto op = std::make_shared<RandomColorAdjustOperation>(brightness, contrast, saturation, hue);
  145. // Input validation
  146. if (!op->ValidateParams()) {
  147. return nullptr;
  148. }
  149. return op;
  150. }
  151. /* ####################################### Derived TensorOperation classes ################################# */
  152. // NormalizeOperation
  153. NormalizeOperation::NormalizeOperation(std::vector<float> mean, std::vector<float> std) : mean_(mean), std_(std) {}
  154. bool NormalizeOperation::ValidateParams() {
  155. if (mean_.size() != 3) {
  156. MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size();
  157. return false;
  158. }
  159. if (std_.size() != 3) {
  160. MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size();
  161. return false;
  162. }
  163. return true;
  164. }
  165. std::shared_ptr<TensorOp> NormalizeOperation::Build() {
  166. return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]);
  167. }
  168. // DecodeOperation
  169. DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {}
  170. bool DecodeOperation::ValidateParams() { return true; }
  171. std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); }
  172. // ResizeOperation
  173. ResizeOperation::ResizeOperation(std::vector<int32_t> size, InterpolationMode interpolation)
  174. : size_(size), interpolation_(interpolation) {}
  175. bool ResizeOperation::ValidateParams() {
  176. if (size_.empty() || size_.size() > 2) {
  177. MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size();
  178. return false;
  179. }
  180. return true;
  181. }
  182. std::shared_ptr<TensorOp> ResizeOperation::Build() {
  183. int32_t height = size_[0];
  184. int32_t width = 0;
  185. // User specified the width value.
  186. if (size_.size() == 2) {
  187. width = size_[1];
  188. }
  189. return std::make_shared<ResizeOp>(height, width, interpolation_);
  190. }
  191. // RandomCropOperation
  192. RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
  193. std::vector<uint8_t> fill_value)
  194. : size_(size), padding_(padding), pad_if_needed_(pad_if_needed), fill_value_(fill_value) {}
  195. bool RandomCropOperation::ValidateParams() {
  196. if (size_.empty() || size_.size() > 2) {
  197. MS_LOG(ERROR) << "RandomCrop: size vector has incorrect size: " << size_.size();
  198. return false;
  199. }
  200. if (padding_.empty() || padding_.size() != 4) {
  201. MS_LOG(ERROR) << "RandomCrop: padding vector has incorrect size: padding.size()";
  202. return false;
  203. }
  204. if (fill_value_.empty() || fill_value_.size() != 3) {
  205. MS_LOG(ERROR) << "RandomCrop: fill_value vector has incorrect size: fill_value.size()";
  206. return false;
  207. }
  208. return true;
  209. }
  210. std::shared_ptr<TensorOp> RandomCropOperation::Build() {
  211. int32_t crop_height = size_[0];
  212. int32_t crop_width = 0;
  213. int32_t pad_top = padding_[0];
  214. int32_t pad_bottom = padding_[1];
  215. int32_t pad_left = padding_[2];
  216. int32_t pad_right = padding_[3];
  217. uint8_t fill_r = fill_value_[0];
  218. uint8_t fill_g = fill_value_[1];
  219. uint8_t fill_b = fill_value_[2];
  220. // User has specified the crop_width value.
  221. if (size_.size() == 2) {
  222. crop_width = size_[1];
  223. }
  224. auto tensor_op = std::make_shared<RandomCropOp>(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right,
  225. BorderType::kConstant, pad_if_needed_, fill_r, fill_g, fill_b);
  226. return tensor_op;
  227. }
  228. // CenterCropOperation
  229. CenterCropOperation::CenterCropOperation(std::vector<int32_t> size) : size_(size) {}
  230. bool CenterCropOperation::ValidateParams() {
  231. if (size_.empty() || size_.size() > 2) {
  232. MS_LOG(ERROR) << "CenterCrop: size vector has incorrect size.";
  233. return false;
  234. }
  235. return true;
  236. }
  237. std::shared_ptr<TensorOp> CenterCropOperation::Build() {
  238. int32_t crop_height = size_[0];
  239. int32_t crop_width = 0;
  240. // User has specified crop_width.
  241. if (size_.size() == 2) {
  242. crop_width = size_[1];
  243. }
  244. std::shared_ptr<CenterCropOp> tensor_op = std::make_shared<CenterCropOp>(crop_height, crop_width);
  245. return tensor_op;
  246. }
  247. // UniformAugOperation
  248. UniformAugOperation::UniformAugOperation(std::vector<std::shared_ptr<TensorOperation>> operations, int32_t num_ops)
  249. : operations_(operations), num_ops_(num_ops) {}
  250. bool UniformAugOperation::ValidateParams() { return true; }
  251. std::shared_ptr<TensorOp> UniformAugOperation::Build() {
  252. std::vector<std::shared_ptr<TensorOp>> tensor_ops;
  253. (void)std::transform(operations_.begin(), operations_.end(), std::back_inserter(tensor_ops),
  254. [](std::shared_ptr<TensorOperation> op) -> std::shared_ptr<TensorOp> { return op->Build(); });
  255. std::shared_ptr<UniformAugOp> tensor_op = std::make_shared<UniformAugOp>(tensor_ops, num_ops_);
  256. return tensor_op;
  257. }
  258. // RandomHorizontalFlipOperation
  259. RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {}
  260. bool RandomHorizontalFlipOperation::ValidateParams() { return true; }
  261. std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
  262. std::shared_ptr<RandomHorizontalFlipOp> tensor_op = std::make_shared<RandomHorizontalFlipOp>(probability_);
  263. return tensor_op;
  264. }
  265. // RandomVerticalFlipOperation
  266. RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {}
  267. bool RandomVerticalFlipOperation::ValidateParams() { return true; }
  268. std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() {
  269. std::shared_ptr<RandomVerticalFlipOp> tensor_op = std::make_shared<RandomVerticalFlipOp>(probability_);
  270. return tensor_op;
  271. }
  272. // Function to create RandomRotationOperation.
  273. RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode,
  274. bool expand, std::vector<float> center,
  275. std::vector<uint8_t> fill_value)
  276. : degrees_(degrees),
  277. interpolation_mode_(interpolation_mode),
  278. expand_(expand),
  279. center_(center),
  280. fill_value_(fill_value) {}
  281. bool RandomRotationOperation::ValidateParams() {
  282. if (degrees_.empty() || degrees_.size() != 2) {
  283. MS_LOG(ERROR) << "RandomRotation: degrees vector has incorrect size: degrees.size()";
  284. return false;
  285. }
  286. if (center_.empty() || center_.size() != 2) {
  287. MS_LOG(ERROR) << "RandomRotation: center vector has incorrect size: center.size()";
  288. return false;
  289. }
  290. if (fill_value_.empty() || fill_value_.size() != 3) {
  291. MS_LOG(ERROR) << "RandomRotation: fill_value vector has incorrect size: fill_value.size()";
  292. return false;
  293. }
  294. return true;
  295. }
  296. std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
  297. std::shared_ptr<RandomRotationOp> tensor_op =
  298. std::make_shared<RandomRotationOp>(degrees_[0], degrees_[1], center_[0], center_[1], interpolation_mode_, expand_,
  299. fill_value_[0], fill_value_[1], fill_value_[2]);
  300. return tensor_op;
  301. }
  302. // PadOperation
  303. PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode)
  304. : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {}
  305. bool PadOperation::ValidateParams() {
  306. if (padding_.empty() || padding_.size() == 3 || padding_.size() > 4) {
  307. MS_LOG(ERROR) << "Pad: padding vector has incorrect size: padding.size()";
  308. return false;
  309. }
  310. if (fill_value_.empty() || (fill_value_.size() != 1 && fill_value_.size() != 3)) {
  311. MS_LOG(ERROR) << "Pad: fill_value vector has incorrect size: fill_value.size()";
  312. return false;
  313. }
  314. return true;
  315. }
  316. std::shared_ptr<TensorOp> PadOperation::Build() {
  317. int32_t pad_top, pad_bottom, pad_left, pad_right;
  318. switch (padding_.size()) {
  319. case 1:
  320. pad_left = padding_[0];
  321. pad_top = padding_[0];
  322. pad_right = padding_[0];
  323. pad_bottom = padding_[0];
  324. break;
  325. case 2:
  326. pad_left = padding_[0];
  327. pad_top = padding_[1];
  328. pad_right = padding_[0];
  329. pad_bottom = padding_[1];
  330. break;
  331. default:
  332. pad_left = padding_[0];
  333. pad_top = padding_[1];
  334. pad_right = padding_[2];
  335. pad_bottom = padding_[3];
  336. }
  337. uint8_t fill_r, fill_g, fill_b;
  338. fill_r = fill_value_[0];
  339. fill_g = fill_value_[0];
  340. fill_b = fill_value_[0];
  341. if (fill_value_.size() == 3) {
  342. fill_r = fill_value_[0];
  343. fill_g = fill_value_[1];
  344. fill_b = fill_value_[2];
  345. }
  346. std::shared_ptr<PadOp> tensor_op =
  347. std::make_shared<PadOp>(pad_top, pad_bottom, pad_left, pad_right, padding_mode_, fill_r, fill_g, fill_b);
  348. return tensor_op;
  349. }
  350. // CutOutOperation
  351. CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {}
  352. bool CutOutOperation::ValidateParams() {
  353. if (length_ < 0) {
  354. MS_LOG(ERROR) << "CutOut: length cannot be negative";
  355. return false;
  356. }
  357. if (num_patches_ < 0) {
  358. MS_LOG(ERROR) << "CutOut: number of patches cannot be negative";
  359. return false;
  360. }
  361. return true;
  362. }
  363. std::shared_ptr<TensorOp> CutOutOperation::Build() {
  364. std::shared_ptr<CutOutOp> tensor_op = std::make_shared<CutOutOp>(length_, length_, num_patches_, false, 0, 0, 0);
  365. return tensor_op;
  366. }
  367. // RandomColorAdjustOperation.
  368. RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast,
  369. std::vector<float> saturation, std::vector<float> hue)
  370. : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {}
  371. bool RandomColorAdjustOperation::ValidateParams() {
  372. // Do some input validation.
  373. if (brightness_.empty() || brightness_.size() > 2) {
  374. MS_LOG(ERROR) << "RandomColorAdjust: brightness must be a vector of one or two values";
  375. return false;
  376. }
  377. if (contrast_.empty() || contrast_.size() > 2) {
  378. MS_LOG(ERROR) << "RandomColorAdjust: contrast must be a vector of one or two values";
  379. return false;
  380. }
  381. if (saturation_.empty() || saturation_.size() > 2) {
  382. MS_LOG(ERROR) << "RandomColorAdjust: saturation must be a vector of one or two values";
  383. return false;
  384. }
  385. if (hue_.empty() || hue_.size() > 2) {
  386. MS_LOG(ERROR) << "RandomColorAdjust: hue must be a vector of one or two values";
  387. return false;
  388. }
  389. return true;
  390. }
  391. std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
  392. float brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub;
  393. brightness_lb = brightness_[0];
  394. brightness_ub = brightness_[0];
  395. if (brightness_.size() == 2) brightness_ub = brightness_[1];
  396. contrast_lb = contrast_[0];
  397. contrast_ub = contrast_[0];
  398. if (contrast_.size() == 2) contrast_ub = contrast_[1];
  399. saturation_lb = saturation_[0];
  400. saturation_ub = saturation_[0];
  401. if (saturation_.size() == 2) saturation_ub = saturation_[1];
  402. hue_lb = hue_[0];
  403. hue_ub = hue_[0];
  404. if (hue_.size() == 2) hue_ub = hue_[1];
  405. std::shared_ptr<RandomColorAdjustOp> tensor_op = std::make_shared<RandomColorAdjustOp>(
  406. brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub);
  407. return tensor_op;
  408. }
  409. } // namespace vision
  410. } // namespace api
  411. } // namespace dataset
  412. } // namespace mindspore