You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transforms.h 8.5 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include "include/api/status.h"
  23. #include "minddata/dataset/include/constants.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. class TensorOperation;
  27. // We need the following two groups of forward declaration to friend the class in class TensorTransform.
  28. namespace transforms {
  29. class Compose;
  30. class RandomApply;
  31. class RandomChoice;
  32. } // namespace transforms
  33. namespace vision {
  34. class BoundingBoxAugment;
  35. class RandomSelectSubpolicy;
  36. class UniformAugment;
  37. } // namespace vision
  38. // Abstract class to represent a tensor transform operation in the data pipeline.
  39. /// \class TensorTransform transforms.h
  40. /// \brief A base class to represent a tensor transform operation in the data pipeline.
  41. class TensorTransform : public std::enable_shared_from_this<TensorTransform> {
  42. friend class Dataset;
  43. friend class Execute;
  44. friend class transforms::Compose;
  45. friend class transforms::RandomApply;
  46. friend class transforms::RandomChoice;
  47. friend class vision::BoundingBoxAugment;
  48. friend class vision::RandomSelectSubpolicy;
  49. friend class vision::UniformAugment;
  50. public:
  51. /// \brief Constructor
  52. TensorTransform() {}
  53. /// \brief Destructor
  54. ~TensorTransform() = default;
  55. protected:
  56. /// \brief Pure virtual function to convert a TensorTransform class into a IR TensorOperation object.
  57. /// \return shared pointer to the newly created TensorOperation.
  58. virtual std::shared_ptr<TensorOperation> Parse() = 0;
  59. /// \brief Virtual function to convert a TensorTransform class into a IR TensorOperation object.
  60. /// \param[in] env A string to determine the running environment
  61. /// \return shared pointer to the newly created TensorOperation.
  62. virtual std::shared_ptr<TensorOperation> Parse(const MapTargetDevice &env) { return nullptr; }
  63. };
  64. // Transform operations for performing data transformation.
  65. namespace transforms {
  66. /// \brief Compose Op.
  67. /// \notes Compose a list of transforms into a single transform.
  68. class Compose : public TensorTransform {
  69. public:
  70. /// \brief Constructor.
  71. /// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
  72. explicit Compose(const std::vector<TensorTransform *> &transforms);
  73. /// \brief Constructor.
  74. /// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
  75. explicit Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
  76. /// \brief Constructor.
  77. /// \param[in] transforms A vector of TensorTransform objects to be applied.
  78. explicit Compose(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);
  79. /// \brief Destructor
  80. ~Compose() = default;
  81. protected:
  82. /// \brief Function to convert TensorTransform object into a TensorOperation object.
  83. /// \return Shared pointer to TensorOperation object.
  84. std::shared_ptr<TensorOperation> Parse() override;
  85. private:
  86. std::vector<std::shared_ptr<TensorOperation>> transforms_;
  87. };
  88. /// \brief Duplicate Op.
  89. /// \notes Duplicate the input tensor to a new output tensor.
  90. /// The input tensor is carried over to the output list.
  91. class Duplicate : public TensorTransform {
  92. public:
  93. /// \brief Constructor.
  94. Duplicate();
  95. /// \brief Destructor
  96. ~Duplicate() = default;
  97. protected:
  98. /// \brief Function to convert TensorTransform object into a TensorOperation object.
  99. /// \return Shared pointer to TensorOperation object.
  100. std::shared_ptr<TensorOperation> Parse() override;
  101. };
  102. /// \brief OneHot Op.
  103. /// \notes Convert the labels into OneHot format.
  104. class OneHot : public TensorTransform {
  105. public:
  106. /// \brief Constructor.
  107. /// \param[in] num_classes number of classes.
  108. explicit OneHot(int32_t num_classes);
  109. /// \brief Destructor
  110. ~OneHot() = default;
  111. protected:
  112. /// \brief Function to convert TensorTransform object into a TensorOperation object.
  113. /// \return Shared pointer to TensorOperation object.
  114. std::shared_ptr<TensorOperation> Parse() override;
  115. private:
  116. float num_classes_;
  117. };
  118. /// \brief RandomApply Op.
  119. /// \notes Randomly perform a series of transforms with a given probability.
  120. class RandomApply : public TensorTransform {
  121. public:
  122. /// \brief Constructor.
  123. /// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
  124. /// \param[in] prob The probability to apply the transformation list (default=0.5)
  125. explicit RandomApply(const std::vector<TensorTransform *> &transforms, double prob = 0.5);
  126. /// \brief Constructor.
  127. /// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
  128. /// \param[in] prob The probability to apply the transformation list (default=0.5)
  129. explicit RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob = 0.5);
  130. /// \brief Constructor.
  131. /// \param[in] transforms A vector of TensorTransform objects to be applied.
  132. /// \param[in] prob The probability to apply the transformation list (default=0.5)
  133. explicit RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> &transforms, double prob = 0.5);
  134. /// \brief Destructor
  135. ~RandomApply() = default;
  136. protected:
  137. /// \brief Function to convert TensorTransform object into a TensorOperation object.
  138. /// \return Shared pointer to TensorOperation object.
  139. std::shared_ptr<TensorOperation> Parse() override;
  140. private:
  141. std::vector<std::shared_ptr<TensorOperation>> transforms_;
  142. double prob_;
  143. };
  144. /// \brief RandomChoice Op.
  145. /// \notes Randomly selects one transform from a list of transforms to perform operation.
  146. class RandomChoice : public TensorTransform {
  147. public:
  148. /// \brief Constructor.
  149. /// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
  150. explicit RandomChoice(const std::vector<TensorTransform *> &transforms);
  151. /// \brief Constructor.
  152. /// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
  153. explicit RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
  154. /// \brief Constructor.
  155. /// \param[in] transforms A vector of TensorTransform objects to be applied.
  156. explicit RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);
  157. /// \brief Destructor
  158. ~RandomChoice() = default;
  159. protected:
  160. /// \brief Function to convert TensorTransform object into a TensorOperation object.
  161. /// \return Shared pointer to TensorOperation object.
  162. std::shared_ptr<TensorOperation> Parse() override;
  163. private:
  164. std::vector<std::shared_ptr<TensorOperation>> transforms_;
  165. };
  166. /// \brief TypeCast Op.
  167. /// \notes Tensor operation to cast to a given MindSpore data type.
  168. class TypeCast : public TensorTransform {
  169. public:
  170. /// \brief Constructor.
  171. /// \param[in] data_type mindspore.dtype to be cast to.
  172. explicit TypeCast(std::string data_type);
  173. /// \brief Destructor
  174. ~TypeCast() = default;
  175. protected:
  176. /// \brief Function to convert TensorTransform object into a TensorOperation object.
  177. /// \return Shared pointer to TensorOperation object.
  178. std::shared_ptr<TensorOperation> Parse() override;
  179. private:
  180. std::string data_type_;
  181. };
  182. /// \brief Unique Op.
  183. /// \notes Return an output tensor containing all the unique elements of the input tensor in
  184. /// the same order that they occur in the input tensor.
  185. class Unique : public TensorTransform {
  186. public:
  187. /// \brief Constructor.
  188. Unique();
  189. /// \brief Destructor
  190. ~Unique() = default;
  191. protected:
  192. /// \brief Function to convert TensorTransform object into a TensorOperation object.
  193. /// \return Shared pointer to TensorOperation object.
  194. std::shared_ptr<TensorOperation> Parse() override;
  195. };
  196. } // namespace transforms
  197. } // namespace dataset
  198. } // namespace mindspore
  199. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_