You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transforms.h 24 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_TRANSFORMS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_TRANSFORMS_H_
  18. #include <vector>
  19. #include <memory>
  20. #include "minddata/dataset/core/constants.h"
  21. namespace mindspore {
  22. namespace dataset {
  23. class TensorOp;
  24. namespace api {
  25. // Abstract class to represent a dataset in the data pipeline.
  26. class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
  27. public:
  28. /// \brief Constructor
  29. TensorOperation();
  30. /// \brief Destructor
  31. ~TensorOperation() = default;
  32. /// \brief Pure virtual function to convert a TensorOperation class into a runtime TensorOp object.
  33. /// \return shared pointer to the newly created TensorOp.
  34. virtual std::shared_ptr<TensorOp> Build() = 0;
  35. virtual bool ValidateParams() = 0;
  36. };
  37. // Transform operations for performing computer vision.
  38. namespace vision {
  39. // Transform Op classes (in alphabetical order)
  40. class CenterCropOperation;
  41. class CropOperation;
  42. class CutMixBatchOperation;
  43. class CutOutOperation;
  44. class DecodeOperation;
  45. class HwcToChwOperation;
  46. class MixUpBatchOperation;
  47. class NormalizeOperation;
  48. class OneHotOperation;
  49. class PadOperation;
  50. class RandomAffineOperation;
  51. class RandomColorAdjustOperation;
  52. class RandomCropOperation;
  53. class RandomHorizontalFlipOperation;
  54. class RandomRotationOperation;
  55. class RandomSharpnessOperation;
  56. class RandomSolarizeOperation;
  57. class RandomVerticalFlipOperation;
  58. class ResizeOperation;
  59. class RgbaToBgrOperation;
  60. class RgbaToRgbOperation;
  61. class SwapRedBlueOperation;
  62. class UniformAugOperation;
  63. /// \brief Function to create a CenterCrop TensorOperation.
  64. /// \notes Crops the input image at the center to the given size.
  65. /// \param[in] size - a vector representing the output size of the cropped image.
  66. /// If size is a single value, a square crop of size (size, size) is returned.
  67. /// If size has 2 values, it should be (height, width).
  68. /// \return Shared pointer to the current TensorOperation.
  69. std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size);
  70. /// \brief Function to create a Crop TensorOp
  71. /// \notes Crop an image based on location and crop size
  72. /// \param[in] coordinates Starting location of crop. Must be a vector of two values, in the form of {x_coor, y_coor}
  73. /// \param[in] size Size of the cropped area. Must be a vector of two values, in the form of {height, width}
  74. /// \return Shared pointer to the current TensorOp
  75. std::shared_ptr<CropOperation> Crop(std::vector<int32_t> coordinates, std::vector<int32_t> size);
  76. /// \brief Function to apply CutMix on a batch of images
  77. /// \notes Masks a random section of each image with the corresponding part of another randomly selected image in
  78. /// that batch
  79. /// \param[in] image_batch_format The format of the batch
  80. /// \param[in] alpha The hyperparameter of beta distribution (default = 1.0)
  81. /// \param[in] prob The probability by which CutMix is applied to each image (default = 1.0)
  82. /// \return Shared pointer to the current TensorOp
  83. std::shared_ptr<CutMixBatchOperation> CutMixBatch(ImageBatchFormat image_batch_format, float alpha = 1.0,
  84. float prob = 1.0);
  85. /// \brief Function to create a CutOut TensorOp
  86. /// \notes Randomly cut (mask) out a given number of square patches from the input image
  87. /// \param[in] length Integer representing the side length of each square patch
  88. /// \param[in] num_patches Integer representing the number of patches to be cut out of an image
  89. /// \return Shared pointer to the current TensorOp
  90. std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches = 1);
  91. /// \brief Function to create a Decode TensorOperation.
  92. /// \notes Decode the input image in RGB mode.
  93. /// \param[in] rgb - a boolean of whether to decode in RGB mode or not.
  94. /// \return Shared pointer to the current TensorOperation.
  95. std::shared_ptr<DecodeOperation> Decode(bool rgb = true);
  96. /// \brief Function to create a HwcToChw TensorOperation.
  97. /// \notes Transpose the input image; shape (H, W, C) to shape (C, H, W).
  98. /// \return Shared pointer to the current TensorOperation.
  99. std::shared_ptr<HwcToChwOperation> HWC2CHW();
  100. /// \brief Function to create a MixUpBatch TensorOperation.
  101. /// \notes Apply MixUp transformation on an input batch of images and labels. The labels must be in one-hot format and
  102. /// Batch must be called before calling this function.
  103. /// \param[in] alpha hyperparameter of beta distribution (default = 1.0)
  104. /// \return Shared pointer to the current TensorOperation.
  105. std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha = 1);
  106. /// \brief Function to create a Normalize TensorOperation.
  107. /// \notes Normalize the input image with respect to mean and standard deviation.
  108. /// \param[in] mean A vector of mean values for each channel, w.r.t channel order.
  109. /// \param[in] std A vector of standard deviations for each channel, w.r.t. channel order.
  110. /// \return Shared pointer to the current TensorOperation.
  111. std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std);
  112. /// \brief Function to create a OneHot TensorOperation.
  113. /// \notes Convert the labels into OneHot format.
  114. /// \param[in] num_classes number of classes.
  115. /// \return Shared pointer to the current TensorOperation.
  116. std::shared_ptr<OneHotOperation> OneHot(int32_t num_classes);
  117. /// \brief Function to create a Pad TensorOp
  118. /// \notes Pads the image according to padding parameters
  119. /// \param[in] padding A vector representing the number of pixels to pad the image
  120. /// If vector has one value, it pads all sides of the image with that value
  121. /// If vector has two values, it pads left and right with the first and
  122. /// top and bottom with the second value
  123. /// If vector has four values, it pads left, top, right, and bottom with
  124. /// those values respectively
  125. /// \param[in] fill_value A vector representing the pixel intensity of the borders if the padding_mode is
  126. /// BorderType.kConstant. If 3 values are provided,
  127. /// it is used to fill R, G, B channels respectively
  128. /// \param[in] padding_mode The method of padding (default=BorderType.kConstant)
  129. /// Can be any of
  130. /// [BorderType.kConstant, BorderType.kEdge, BorderType.kReflect, BorderType.kSymmetric]
  131. /// - BorderType.kConstant, means it fills the border with constant values
  132. /// - BorderType.kEdge, means it pads with the last value on the edge
  133. /// - BorderType.kReflect, means it reflects the values on the edge omitting the last value of edge
  134. /// - BorderType.kSymmetric, means it reflects the values on the edge repeating the last value of edge
  135. /// \return Shared pointer to the current TensorOp
  136. std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0},
  137. BorderType padding_mode = BorderType::kConstant);
  138. /// \brief Function to create a RandomAffine TensorOperation.
  139. /// \notes Applies a Random Affine transformation on input image in RGB or Greyscale mode.
  140. /// \param[in] degrees A float vector size 2, representing the starting and ending degree
  141. /// \param[in] translate_range A float vector size 2, representing percentages of translation on x and y axes.
  142. /// \param[in] scale_range A float vector size 2, representing the starting and ending scales in the range.
  143. /// \param[in] shear_ranges A float vector size 4, representing the starting and ending shear degrees vertically and
  144. /// horizontally.
  145. /// \param[in] interpolation An enum for the mode of interpolation
  146. /// \param[in] fill_value A uint8_t vector size 3, representing the pixel intensity of the borders, it is used to
  147. /// fill R, G, B channels respectively.
  148. /// \return Shared pointer to the current TensorOperation.
  149. std::shared_ptr<RandomAffineOperation> RandomAffine(
  150. const std::vector<float_t> &degrees, const std::vector<float_t> &translate_range = {0.0, 0.0},
  151. const std::vector<float_t> &scale_range = {1.0, 1.0}, const std::vector<float_t> &shear_ranges = {0.0, 0.0, 0.0, 0.0},
  152. InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
  153. const std::vector<uint8_t> &fill_value = {0, 0, 0});
  154. /// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image
  155. /// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values
  156. /// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1}
  157. /// \param[in] contrast Contrast adjustment factor. Must be a vector of one or two values
  158. /// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1}
  159. /// \param[in] saturation Saturation adjustment factor. Must be a vector of one or two values
  160. /// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1}
  161. /// \param[in] hue Brightness adjustment factor. Must be a vector of one or two values
  162. /// if it's a vector of two values it must be in the form of [min, max] where -0.5 <= min <= max <= 0.5
  163. /// Default value is {0, 0}
  164. /// \return Shared pointer to the current TensorOp
  165. std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness = {1.0, 1.0},
  166. std::vector<float> contrast = {1.0, 1.0},
  167. std::vector<float> saturation = {1.0, 1.0},
  168. std::vector<float> hue = {0.0, 0.0});
  169. /// \brief Function to create a RandomCrop TensorOperation.
  170. /// \notes Crop the input image at a random location.
  171. /// \param[in] size - a vector representing the output size of the cropped image.
  172. /// If size is a single value, a square crop of size (size, size) is returned.
  173. /// If size has 2 values, it should be (height, width).
  174. /// \param[in] padding - a vector with the value of pixels to pad the image. If 4 values are provided,
  175. /// it pads the left, top, right and bottom respectively.
  176. /// \param[in] pad_if_needed - a boolean whether to pad the image if either side is smaller than
  177. /// the given output size.
  178. /// \param[in] fill_value - a vector representing the pixel intensity of the borders, it is used to
  179. /// fill R, G, B channels respectively.
  180. /// \return Shared pointer to the current TensorOperation.
  181. std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0},
  182. bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0},
  183. BorderType padding_mode = BorderType::kConstant);
  184. /// \brief Function to create a RandomHorizontalFlip TensorOperation.
  185. /// \notes Tensor operation to perform random horizontal flip.
  186. /// \param[in] prob - float representing the probability of flip.
  187. /// \return Shared pointer to the current TensorOperation.
  188. std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob = 0.5);
  189. /// \brief Function to create a RandomRotation TensorOp
  190. /// \notes Rotates the image according to parameters
  191. /// \param[in] degrees A float vector size 2, representing the starting and ending degree
  192. /// \param[in] resample An enum for the mode of interpolation
  193. /// \param[in] expand A boolean representing whether the image is expanded after rotation
  194. /// \param[in] center A float vector size 2, representing the x and y center of rotation.
  195. /// \param[in] fill_value A uint8_t vector size 3, representing the rgb value of the fill color
  196. /// \return Shared pointer to the current TensorOp
  197. std::shared_ptr<RandomRotationOperation> RandomRotation(
  198. std::vector<float> degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false,
  199. std::vector<float> center = {-1, -1}, std::vector<uint8_t> fill_value = {0, 0, 0});
  200. /// \brief Function to create a RandomSharpness TensorOperation.
  201. /// \notes Tensor operation to perform random sharpness.
  202. /// \param[in] start_degree - float representing the start of the range to uniformly sample the factor from it.
  203. /// \param[in] end_degree - float representing the end of the range.
  204. /// \return Shared pointer to the current TensorOperation.
  205. std::shared_ptr<RandomSharpnessOperation> RandomSharpness(std::vector<float> degrees = {0.1, 1.9});
  206. /// \brief Function to create a RandomSolarize TensorOperation.
  207. /// \notes Invert pixels within specified range. If min=max, then it inverts all pixel above that threshold
  208. /// \param[in] threshold_min - lower limit
  209. /// \param[in] threshold_max - upper limit
  210. /// \return Shared pointer to the current TensorOperation.
  211. std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min = 0, uint8_t threshold_max = 255);
  212. /// \brief Function to create a RandomVerticalFlip TensorOperation.
  213. /// \notes Tensor operation to perform random vertical flip.
  214. /// \param[in] prob - float representing the probability of flip.
  215. /// \return Shared pointer to the current TensorOperation.
  216. std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob = 0.5);
  217. /// \brief Function to create a RgbaToBgr TensorOperation.
  218. /// \notes Changes the input 4 channel RGBA tensor to 3 channel BGR.
  219. /// \return Shared pointer to the current TensorOperation.
  220. std::shared_ptr<RgbaToBgrOperation> RGBA2BGR();
  221. /// \brief Function to create a RgbaToRgb TensorOperation.
  222. /// \notes Changes the input 4 channel RGBA tensor to 3 channel RGB.
  223. /// \return Shared pointer to the current TensorOperation.
  224. std::shared_ptr<RgbaToRgbOperation> RGBA2RGB();
  225. /// \brief Function to create a Resize TensorOperation.
  226. /// \notes Resize the input image to the given size.
  227. /// \param[in] size - a vector representing the output size of the resized image.
  228. /// If size is a single value, the image will be resized to this value with
  229. /// the same image aspect ratio. If size has 2 values, it should be (height, width).
  230. /// \param[in] interpolation An enum for the mode of interpolation
  231. /// \return Shared pointer to the current TensorOperation.
  232. std::shared_ptr<ResizeOperation> Resize(std::vector<int32_t> size,
  233. InterpolationMode interpolation = InterpolationMode::kLinear);
  234. /// \brief Function to create a SwapRedBlue TensorOp
  235. /// \notes Swaps the red and blue channels in image
  236. /// \return Shared pointer to the current TensorOp
  237. std::shared_ptr<SwapRedBlueOperation> SwapRedBlue();
  238. /// \brief Function to create a UniformAugment TensorOperation.
  239. /// \notes Tensor operation to perform randomly selected augmentation.
  240. /// \param[in] transforms - a vector of TensorOperation transforms.
  241. /// \param[in] num_ops - integer representing the number of OPs to be selected and applied.
  242. /// \return Shared pointer to the current TensorOperation.
  243. std::shared_ptr<UniformAugOperation> UniformAugment(std::vector<std::shared_ptr<TensorOperation>> transforms,
  244. int32_t num_ops = 2);
  245. /* ####################################### Derived TensorOperation classes ################################# */
  246. class CenterCropOperation : public TensorOperation {
  247. public:
  248. explicit CenterCropOperation(std::vector<int32_t> size);
  249. ~CenterCropOperation() = default;
  250. std::shared_ptr<TensorOp> Build() override;
  251. bool ValidateParams() override;
  252. private:
  253. std::vector<int32_t> size_;
  254. };
  255. class CropOperation : public TensorOperation {
  256. public:
  257. CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size);
  258. ~CropOperation() = default;
  259. std::shared_ptr<TensorOp> Build() override;
  260. bool ValidateParams() override;
  261. private:
  262. std::vector<int32_t> coordinates_;
  263. std::vector<int32_t> size_;
  264. };
  265. class CutMixBatchOperation : public TensorOperation {
  266. public:
  267. explicit CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha = 1.0, float prob = 1.0);
  268. ~CutMixBatchOperation() = default;
  269. std::shared_ptr<TensorOp> Build() override;
  270. bool ValidateParams() override;
  271. private:
  272. float alpha_;
  273. float prob_;
  274. ImageBatchFormat image_batch_format_;
  275. };
  276. class CutOutOperation : public TensorOperation {
  277. public:
  278. explicit CutOutOperation(int32_t length, int32_t num_patches = 1);
  279. ~CutOutOperation() = default;
  280. std::shared_ptr<TensorOp> Build() override;
  281. bool ValidateParams() override;
  282. private:
  283. int32_t length_;
  284. int32_t num_patches_;
  285. ImageBatchFormat image_batch_format_;
  286. };
  287. class DecodeOperation : public TensorOperation {
  288. public:
  289. explicit DecodeOperation(bool rgb = true);
  290. ~DecodeOperation() = default;
  291. std::shared_ptr<TensorOp> Build() override;
  292. bool ValidateParams() override;
  293. private:
  294. bool rgb_;
  295. };
  296. class HwcToChwOperation : public TensorOperation {
  297. public:
  298. ~HwcToChwOperation() = default;
  299. std::shared_ptr<TensorOp> Build() override;
  300. bool ValidateParams() override;
  301. };
  302. class MixUpBatchOperation : public TensorOperation {
  303. public:
  304. explicit MixUpBatchOperation(float alpha = 1);
  305. ~MixUpBatchOperation() = default;
  306. std::shared_ptr<TensorOp> Build() override;
  307. bool ValidateParams() override;
  308. private:
  309. float alpha_;
  310. };
  311. class NormalizeOperation : public TensorOperation {
  312. public:
  313. NormalizeOperation(std::vector<float> mean, std::vector<float> std);
  314. ~NormalizeOperation() = default;
  315. std::shared_ptr<TensorOp> Build() override;
  316. bool ValidateParams() override;
  317. private:
  318. std::vector<float> mean_;
  319. std::vector<float> std_;
  320. };
  321. class OneHotOperation : public TensorOperation {
  322. public:
  323. explicit OneHotOperation(int32_t num_classes_);
  324. ~OneHotOperation() = default;
  325. std::shared_ptr<TensorOp> Build() override;
  326. bool ValidateParams() override;
  327. private:
  328. float num_classes_;
  329. };
  330. class PadOperation : public TensorOperation {
  331. public:
  332. PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0},
  333. BorderType padding_mode = BorderType::kConstant);
  334. ~PadOperation() = default;
  335. std::shared_ptr<TensorOp> Build() override;
  336. bool ValidateParams() override;
  337. private:
  338. std::vector<int32_t> padding_;
  339. std::vector<uint8_t> fill_value_;
  340. BorderType padding_mode_;
  341. };
  342. class RandomAffineOperation : public TensorOperation {
  343. public:
  344. RandomAffineOperation(const std::vector<float_t> &degrees, const std::vector<float_t> &translate_range = {0.0, 0.0},
  345. const std::vector<float_t> &scale_range = {1.0, 1.0},
  346. const std::vector<float_t> &shear_ranges = {0.0, 0.0, 0.0, 0.0},
  347. InterpolationMode interpolation = InterpolationMode::kNearestNeighbour,
  348. const std::vector<uint8_t> &fill_value = {0, 0, 0});
  349. ~RandomAffineOperation() = default;
  350. std::shared_ptr<TensorOp> Build() override;
  351. bool ValidateParams() override;
  352. private:
  353. std::vector<float_t> degrees_; // min_degree, max_degree
  354. std::vector<float_t> translate_range_; // maximum x translation percentage, maximum y translation percentage
  355. std::vector<float_t> scale_range_; // min_scale, max_scale
  356. std::vector<float_t> shear_ranges_; // min_x_shear, max_x_shear, min_y_shear, max_y_shear
  357. InterpolationMode interpolation_;
  358. std::vector<uint8_t> fill_value_;
  359. };
  360. class RandomColorAdjustOperation : public TensorOperation {
  361. public:
  362. RandomColorAdjustOperation(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0},
  363. std::vector<float> saturation = {1.0, 1.0}, std::vector<float> hue = {0.0, 0.0});
  364. ~RandomColorAdjustOperation() = default;
  365. std::shared_ptr<TensorOp> Build() override;
  366. bool ValidateParams() override;
  367. private:
  368. std::vector<float> brightness_;
  369. std::vector<float> contrast_;
  370. std::vector<float> saturation_;
  371. std::vector<float> hue_;
  372. };
  373. class RandomCropOperation : public TensorOperation {
  374. public:
  375. RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0},
  376. bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0},
  377. BorderType padding_mode = BorderType::kConstant);
  378. ~RandomCropOperation() = default;
  379. std::shared_ptr<TensorOp> Build() override;
  380. bool ValidateParams() override;
  381. private:
  382. std::vector<int32_t> size_;
  383. std::vector<int32_t> padding_;
  384. bool pad_if_needed_;
  385. std::vector<uint8_t> fill_value_;
  386. BorderType padding_mode_;
  387. };
  388. class RandomHorizontalFlipOperation : public TensorOperation {
  389. public:
  390. explicit RandomHorizontalFlipOperation(float probability = 0.5);
  391. ~RandomHorizontalFlipOperation() = default;
  392. std::shared_ptr<TensorOp> Build() override;
  393. bool ValidateParams() override;
  394. private:
  395. float probability_;
  396. };
  397. class RandomRotationOperation : public TensorOperation {
  398. public:
  399. RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode, bool expand,
  400. std::vector<float> center, std::vector<uint8_t> fill_value);
  401. ~RandomRotationOperation() = default;
  402. std::shared_ptr<TensorOp> Build() override;
  403. bool ValidateParams() override;
  404. private:
  405. std::vector<float> degrees_;
  406. InterpolationMode interpolation_mode_;
  407. std::vector<float> center_;
  408. bool expand_;
  409. std::vector<uint8_t> fill_value_;
  410. };
  411. class RandomSharpnessOperation : public TensorOperation {
  412. public:
  413. explicit RandomSharpnessOperation(std::vector<float> degrees = {0.1, 1.9});
  414. ~RandomSharpnessOperation() = default;
  415. std::shared_ptr<TensorOp> Build() override;
  416. bool ValidateParams() override;
  417. private:
  418. std::vector<float> degrees_;
  419. };
  420. class RandomVerticalFlipOperation : public TensorOperation {
  421. public:
  422. explicit RandomVerticalFlipOperation(float probability = 0.5);
  423. ~RandomVerticalFlipOperation() = default;
  424. std::shared_ptr<TensorOp> Build() override;
  425. bool ValidateParams() override;
  426. private:
  427. float probability_;
  428. };
  429. class ResizeOperation : public TensorOperation {
  430. public:
  431. explicit ResizeOperation(std::vector<int32_t> size,
  432. InterpolationMode interpolation_mode = InterpolationMode::kLinear);
  433. ~ResizeOperation() = default;
  434. std::shared_ptr<TensorOp> Build() override;
  435. bool ValidateParams() override;
  436. private:
  437. std::vector<int32_t> size_;
  438. InterpolationMode interpolation_;
  439. };
  440. class RgbaToBgrOperation : public TensorOperation {
  441. public:
  442. RgbaToBgrOperation();
  443. ~RgbaToBgrOperation() = default;
  444. std::shared_ptr<TensorOp> Build() override;
  445. bool ValidateParams() override;
  446. };
  447. class RgbaToRgbOperation : public TensorOperation {
  448. public:
  449. RgbaToRgbOperation();
  450. ~RgbaToRgbOperation() = default;
  451. std::shared_ptr<TensorOp> Build() override;
  452. bool ValidateParams() override;
  453. };
  454. class UniformAugOperation : public TensorOperation {
  455. public:
  456. explicit UniformAugOperation(std::vector<std::shared_ptr<TensorOperation>> transforms, int32_t num_ops = 2);
  457. ~UniformAugOperation() = default;
  458. std::shared_ptr<TensorOp> Build() override;
  459. bool ValidateParams() override;
  460. private:
  461. std::vector<std::shared_ptr<TensorOperation>> transforms_;
  462. int32_t num_ops_;
  463. };
  464. class SwapRedBlueOperation : public TensorOperation {
  465. public:
  466. SwapRedBlueOperation();
  467. ~SwapRedBlueOperation() = default;
  468. std::shared_ptr<TensorOp> Build() override;
  469. bool ValidateParams() override;
  470. };
  471. class RandomSolarizeOperation : public TensorOperation {
  472. public:
  473. explicit RandomSolarizeOperation(uint8_t threshold_min, uint8_t threshold_max);
  474. ~RandomSolarizeOperation() = default;
  475. std::shared_ptr<TensorOp> Build() override;
  476. bool ValidateParams() override;
  477. private:
  478. uint8_t threshold_min_;
  479. uint8_t threshold_max_;
  480. };
  481. } // namespace vision
  482. } // namespace api
  483. } // namespace dataset
  484. } // namespace mindspore
  485. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_TRANSFORMS_H_