You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transforms.h 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_TRANSFORMS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_TRANSFORMS_H_
  18. #include <vector>
  19. #include <memory>
  20. #include "minddata/dataset/core/constants.h"
  21. namespace mindspore {
  22. namespace dataset {
  23. class TensorOp;
  24. namespace api {
  25. // Abstract class to represent a dataset in the data pipeline.
  26. class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
  27. public:
  28. /// \brief Constructor
  29. TensorOperation();
  30. /// \brief Destructor
  31. ~TensorOperation() = default;
  32. /// \brief Pure virtual function to convert a TensorOperation class into a runtime TensorOp object.
  33. /// \return shared pointer to the newly created TensorOp.
  34. virtual std::shared_ptr<TensorOp> Build() = 0;
  35. virtual bool ValidateParams() = 0;
  36. };
  37. // Transform operations for performing computer vision.
  38. namespace vision {
  39. class NormalizeOperation;
  40. class DecodeOperation;
  41. class ResizeOperation;
  42. class RandomCropOperation;
  43. class CenterCropOperation;
  44. class UniformAugOperation;
  45. class RandomHorizontalFlipOperation;
  46. class RandomVerticalFlipOperation;
  47. class RandomRotationOperation;
  48. class PadOperation;
  49. class CutOutOperation;
  50. class RandomColorAdjustOperation;
  51. /// \brief Function to create a Normalize TensorOperation.
  52. /// \notes Normalize the input image with respect to mean and standard deviation.
  53. /// \param[in] mean - a vector of mean values for each channel, w.r.t channel order.
  54. /// \param[in] std - a vector of standard deviations for each channel, w.r.t. channel order.
  55. /// \return Shared pointer to the current TensorOperation.
  56. std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std);
  57. /// \brief Function to create a Decode TensorOperation.
  58. /// \notes Decode the input image in RGB mode.
  59. /// \param[in] rgb - a boolean of whether to decode in RGB mode or not.
  60. /// \return Shared pointer to the current TensorOperation.
  61. std::shared_ptr<DecodeOperation> Decode(bool rgb = true);
  62. /// \brief Function to create a Resize TensorOperation.
  63. /// \notes Resize the input image to the given size..
  64. /// \param[in] size - a vector representing the output size of the resized image.
  65. /// If size is a single value, the image will be resized to this value with
  66. /// the same image aspect ratio. If size has 2 values, it should be (height, width).
  67. /// \param[in] interpolation An enum for the mode of interpolation
  68. /// \return Shared pointer to the current TensorOperation.
  69. std::shared_ptr<ResizeOperation> Resize(std::vector<int32_t> size,
  70. InterpolationMode interpolation = InterpolationMode::kLinear);
  71. /// \brief Function to create a RandomCrop TensorOperation.
  72. /// \notes Crop the input image at a random location.
  73. /// \param[in] size - a vector representing the output size of the cropped image.
  74. /// If size is a single value, a square crop of size (size, size) is returned.
  75. /// If size has 2 values, it should be (height, width).
  76. /// \param[in] padding - a vector with the value of pixels to pad the image. If 4 values are provided,
  77. /// it pads the left, top, right and bottom respectively.
  78. /// \param[in] pad_if_needed - a boolean whether to pad the image if either side is smaller than
  79. /// the given output size.
  80. /// \param[in] fill_value - a vector representing the pixel intensity of the borders, it is used to
  81. /// fill R, G, B channels respectively.
  82. /// \return Shared pointer to the current TensorOperation.
  83. std::shared_ptr<RandomCropOperation> RandomCrop(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0},
  84. bool pad_if_needed = false,
  85. std::vector<uint8_t> fill_value = {0, 0, 0});
  86. /// \brief Function to create a CenterCrop TensorOperation.
  87. /// \notes Crops the input image at the center to the given size.
  88. /// \param[in] size - a vector representing the output size of the cropped image.
  89. /// If size is a single value, a square crop of size (size, size) is returned.
  90. /// If size has 2 values, it should be (height, width).
  91. /// \return Shared pointer to the current TensorOperation.
  92. std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size);
  93. /// \brief Function to create a UniformAugment TensorOperation.
  94. /// \notes Tensor operation to perform randomly selected augmentation.
  95. /// \param[in] transforms - a vector of TensorOperation transforms.
  96. /// \param[in] num_ops - integer representing the number of OPs to be selected and applied.
  97. /// \return Shared pointer to the current TensorOperation.
  98. std::shared_ptr<UniformAugOperation> UniformAugment(std::vector<std::shared_ptr<TensorOperation>> transforms,
  99. int32_t num_ops = 2);
  100. /// \brief Function to create a RandomHorizontalFlip TensorOperation.
  101. /// \notes Tensor operation to perform random horizontal flip.
  102. /// \param[in] prob - float representing the probability of flip.
  103. /// \return Shared pointer to the current TensorOperation.
  104. std::shared_ptr<RandomHorizontalFlipOperation> RandomHorizontalFlip(float prob = 0.5);
  105. /// \brief Function to create a RandomVerticalFlip TensorOperation.
  106. /// \notes Tensor operation to perform random vertical flip.
  107. /// \param[in] prob - float representing the probability of flip.
  108. /// \return Shared pointer to the current TensorOperation.
  109. std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob = 0.5);
  110. /// \brief Function to create a RandomRotation TensorOp
  111. /// \notes Rotates the image according to parameters
  112. /// \param[in] degrees A float vector size 2, representing the starting and ending degree
  113. /// \param[in] resample An enum for the mode of interpolation
  114. /// \param[in] expand A boolean representing whether the image is expanded after rotation
  115. /// \param[in] center A float vector size 2, representing the x and y center of rotation.
  116. /// \param[in] fill_value A uint8_t vector size 3, representing the rgb value of the fill color
  117. /// \return Shared pointer to the current TensorOp
  118. std::shared_ptr<RandomRotationOperation> RandomRotation(
  119. std::vector<float> degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false,
  120. std::vector<float> center = {-1, -1}, std::vector<uint8_t> fill_value = {0, 0, 0});
  121. /// \brief Function to create a Pad TensorOp
  122. /// \notes Pads the image according to padding parameters
  123. /// \param[in] padding A vector representing the number of pixels to pad the image
  124. /// If vector has one value, it pads all sides of the image with that value
  125. /// If vector has two values, it pads left and right with the first and
  126. /// top and bottom with the second value
  127. /// If vector has four values, it pads left, top, right, and bottom with
  128. /// those values respectively
  129. /// \param[in] fill_value A vector representing the pixel intensity of the borders if the padding_mode is
  130. /// BorderType.kConstant. If 3 values are provided,
  131. /// it is used to fill R, G, B channels respectively
  132. /// \param[in] padding_mode The method of padding (default=BorderType.kConstant)
  133. /// Can be any of
  134. /// [BorderType.kConstant, BorderType.kEdge, BorderType.kReflect, BorderType.kSymmetric]
  135. /// - BorderType.kConstant, means it fills the border with constant values
  136. /// - BorderType.kEdge, means it pads with the last value on the edge
  137. /// - BorderType.kReflect, means it reflects the values on the edge omitting the last value of edge
  138. /// - BorderType.kSymmetric, means it reflects the values on the edge repeating the last value of edge
  139. /// \return Shared pointer to the current TensorOp
  140. std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0},
  141. BorderType padding_mode = BorderType::kConstant);
  142. /// \brief Function to create a CutOut TensorOp
  143. /// \notes Randomly cut (mask) out a given number of square patches from the input image
  144. /// \param[in] length Integer representing the side length of each square patch
  145. /// \param[in] num_patches Integer representing the number of patches to be cut out of an image
  146. /// \return Shared pointer to the current TensorOp
  147. std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches = 1);
  148. /// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image
  149. /// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values
  150. /// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1}
  151. /// \param[in] contrast Contrast adjustment factor. Must be a vector of one or two values
  152. /// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1}
  153. /// \param[in] saturation Saturation adjustment factor. Must be a vector of one or two values
  154. /// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1}
  155. /// \param[in] hue Brightness adjustment factor. Must be a vector of one or two values
  156. /// if it's a vector of two values it must be in the form of [min, max] where -0.5 <= min <= max <= 0.5
  157. /// Default value is {0, 0}
  158. /// \return Shared pointer to the current TensorOp
  159. std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float> brightness = {1.0, 1.0},
  160. std::vector<float> contrast = {1.0, 1.0},
  161. std::vector<float> saturation = {1.0, 1.0},
  162. std::vector<float> hue = {0.0, 0.0});
  163. /* ####################################### Derived TensorOperation classes ################################# */
  164. class NormalizeOperation : public TensorOperation {
  165. public:
  166. NormalizeOperation(std::vector<float> mean, std::vector<float> std);
  167. ~NormalizeOperation() = default;
  168. std::shared_ptr<TensorOp> Build() override;
  169. bool ValidateParams() override;
  170. private:
  171. std::vector<float> mean_;
  172. std::vector<float> std_;
  173. };
  174. class DecodeOperation : public TensorOperation {
  175. public:
  176. explicit DecodeOperation(bool rgb = true);
  177. ~DecodeOperation() = default;
  178. std::shared_ptr<TensorOp> Build() override;
  179. bool ValidateParams() override;
  180. private:
  181. bool rgb_;
  182. };
  183. class ResizeOperation : public TensorOperation {
  184. public:
  185. explicit ResizeOperation(std::vector<int32_t> size,
  186. InterpolationMode interpolation_mode = InterpolationMode::kLinear);
  187. ~ResizeOperation() = default;
  188. std::shared_ptr<TensorOp> Build() override;
  189. bool ValidateParams() override;
  190. private:
  191. std::vector<int32_t> size_;
  192. InterpolationMode interpolation_;
  193. };
  194. class RandomCropOperation : public TensorOperation {
  195. public:
  196. RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding = {0, 0, 0, 0},
  197. bool pad_if_needed = false, std::vector<uint8_t> fill_value = {0, 0, 0});
  198. ~RandomCropOperation() = default;
  199. std::shared_ptr<TensorOp> Build() override;
  200. bool ValidateParams() override;
  201. private:
  202. std::vector<int32_t> size_;
  203. std::vector<int32_t> padding_;
  204. bool pad_if_needed_;
  205. std::vector<uint8_t> fill_value_;
  206. };
  207. class CenterCropOperation : public TensorOperation {
  208. public:
  209. explicit CenterCropOperation(std::vector<int32_t> size);
  210. ~CenterCropOperation() = default;
  211. std::shared_ptr<TensorOp> Build() override;
  212. bool ValidateParams() override;
  213. private:
  214. std::vector<int32_t> size_;
  215. };
  216. class UniformAugOperation : public TensorOperation {
  217. public:
  218. explicit UniformAugOperation(std::vector<std::shared_ptr<TensorOperation>> transforms, int32_t num_ops = 2);
  219. ~UniformAugOperation() = default;
  220. std::shared_ptr<TensorOp> Build() override;
  221. bool ValidateParams() override;
  222. private:
  223. std::vector<std::shared_ptr<TensorOperation>> transforms_;
  224. int32_t num_ops_;
  225. };
  226. class RandomHorizontalFlipOperation : public TensorOperation {
  227. public:
  228. explicit RandomHorizontalFlipOperation(float probability = 0.5);
  229. ~RandomHorizontalFlipOperation() = default;
  230. std::shared_ptr<TensorOp> Build() override;
  231. bool ValidateParams() override;
  232. private:
  233. float probability_;
  234. };
  235. class RandomVerticalFlipOperation : public TensorOperation {
  236. public:
  237. explicit RandomVerticalFlipOperation(float probability = 0.5);
  238. ~RandomVerticalFlipOperation() = default;
  239. std::shared_ptr<TensorOp> Build() override;
  240. bool ValidateParams() override;
  241. private:
  242. float probability_;
  243. };
  244. class RandomRotationOperation : public TensorOperation {
  245. public:
  246. RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode, bool expand,
  247. std::vector<float> center, std::vector<uint8_t> fill_value);
  248. ~RandomRotationOperation() = default;
  249. std::shared_ptr<TensorOp> Build() override;
  250. bool ValidateParams() override;
  251. private:
  252. std::vector<float> degrees_;
  253. InterpolationMode interpolation_mode_;
  254. std::vector<float> center_;
  255. bool expand_;
  256. std::vector<uint8_t> fill_value_;
  257. };
  258. class PadOperation : public TensorOperation {
  259. public:
  260. PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0},
  261. BorderType padding_mode = BorderType::kConstant);
  262. ~PadOperation() = default;
  263. std::shared_ptr<TensorOp> Build() override;
  264. bool ValidateParams() override;
  265. private:
  266. std::vector<int32_t> padding_;
  267. std::vector<uint8_t> fill_value_;
  268. BorderType padding_mode_;
  269. };
  270. class CutOutOperation : public TensorOperation {
  271. public:
  272. explicit CutOutOperation(int32_t length, int32_t num_patches = 1);
  273. ~CutOutOperation() = default;
  274. std::shared_ptr<TensorOp> Build() override;
  275. bool ValidateParams() override;
  276. private:
  277. int32_t length_;
  278. int32_t num_patches_;
  279. };
  280. class RandomColorAdjustOperation : public TensorOperation {
  281. public:
  282. RandomColorAdjustOperation(std::vector<float> brightness = {1.0, 1.0}, std::vector<float> contrast = {1.0, 1.0},
  283. std::vector<float> saturation = {1.0, 1.0}, std::vector<float> hue = {0.0, 0.0});
  284. ~RandomColorAdjustOperation() = default;
  285. std::shared_ptr<TensorOp> Build() override;
  286. bool ValidateParams() override;
  287. private:
  288. std::vector<float> brightness_;
  289. std::vector<float> contrast_;
  290. std::vector<float> saturation_;
  291. std::vector<float> hue_;
  292. };
  293. } // namespace vision
  294. } // namespace api
  295. } // namespace dataset
  296. } // namespace mindspore
  297. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_TRANSFORMS_H_