You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

c_transforms.py 9.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This module c_transforms provides common operations, including OneHotOp and TypeCast.
  17. """
  18. from enum import IntEnum
  19. import numpy as np
  20. import mindspore.common.dtype as mstype
  21. import mindspore._c_dataengine as cde
  22. from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, \
  23. check_pad_end, check_concat_type, check_random_transform_ops
  24. from ..core.datatypes import mstype_to_detype
  25. class OneHot(cde.OneHotOp):
  26. """
  27. Tensor operation to apply one hot encoding.
  28. Args:
  29. num_classes (int): Number of classes of the label
  30. it should be bigger than or equal to label class number.
  31. Raises:
  32. RuntimeError: feature size is bigger than num_classes.
  33. """
  34. @check_num_classes
  35. def __init__(self, num_classes):
  36. self.num_classes = num_classes
  37. super().__init__(num_classes)
  38. class Fill(cde.FillOp):
  39. """
  40. Tensor operation to create a tensor filled with passed scalar value.
  41. The output tensor will have the same shape and type as the input tensor.
  42. Args:
  43. fill_value (Union[str, bytes, int, float, bool])) : scalar value
  44. to fill created tensor with.
  45. """
  46. @check_fill_value
  47. def __init__(self, fill_value):
  48. super().__init__(cde.Tensor(np.array(fill_value)))
  49. class TypeCast(cde.TypeCastOp):
  50. """
  51. Tensor operation to cast to a given MindSpore data type.
  52. Args:
  53. data_type (mindspore.dtype): mindspore.dtype to be casted to.
  54. """
  55. @check_de_type
  56. def __init__(self, data_type):
  57. data_type = mstype_to_detype(data_type)
  58. self.data_type = str(data_type)
  59. super().__init__(data_type)
  60. class Slice(cde.SliceOp):
  61. """
  62. Slice operation to extract a tensor out using the given n slices.
  63. The functionality of Slice is similar to NumPy indexing feature.
  64. (Currently only rank-1 tensors are supported).
  65. Args:
  66. slices(Union[int, list(int), slice, None, Ellipses]):
  67. Maximum `n` number of arguments to slice a tensor of rank `n`.
  68. One object in slices can be one of:
  69. 1. :py:obj:`int`: Slice this index only. Negative index is supported.
  70. 2. :py:obj:`list(int)`: Slice these indices ion the list only. Negative indices are supported.
  71. 3. :py:obj:`slice`: Slice the generated indices from the slice object. Similar to `start:stop:step`.
  72. 4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in python indexing.
  73. 5. :py:obj:`Ellipses`: Slice all dimensions between the two slices. Similar to `...` in python indexing.
  74. Examples:
  75. >>> # Data before
  76. >>> # | col |
  77. >>> # +---------+
  78. >>> # | [1,2,3] |
  79. >>> # +---------|
  80. >>> data = data.map(operations=Slice(slice(1,3))) # slice indices 1 and 2 only
  81. >>> # Data after
  82. >>> # | col |
  83. >>> # +---------+
  84. >>> # | [2,3] |
  85. >>> # +---------|
  86. """
  87. @check_slice_op
  88. def __init__(self, *slices):
  89. dim0 = slices[0]
  90. if isinstance(dim0, int):
  91. dim0 = [dim0]
  92. elif dim0 is None:
  93. dim0 = True
  94. elif isinstance(dim0, slice):
  95. dim0 = (dim0.start, dim0.stop, dim0.step)
  96. elif dim0 is Ellipsis:
  97. dim0 = True
  98. super().__init__(dim0)
  99. class Relational(IntEnum):
  100. EQ = 0
  101. NE = 1
  102. GT = 2
  103. GE = 3
  104. LT = 4
  105. LE = 5
  106. DE_C_RELATIONAL = {Relational.EQ: cde.RelationalOp.EQ,
  107. Relational.NE: cde.RelationalOp.NE,
  108. Relational.GT: cde.RelationalOp.GT,
  109. Relational.GE: cde.RelationalOp.GE,
  110. Relational.LT: cde.RelationalOp.LT,
  111. Relational.LE: cde.RelationalOp.LE}
  112. class Mask(cde.MaskOp):
  113. """
  114. Mask content of the input tensor with the given predicate.
  115. Any element of the tensor that matches the predicate will be evaluated to True, otherwise False.
  116. Args:
  117. operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE
  118. constant (Union[str, int, float, bool]): constant to be compared to.
  119. Constant will be casted to the type of the input tensor
  120. dtype (mindspore.dtype, optional): type of the generated mask. Default to bool
  121. Examples:
  122. >>> # Data before
  123. >>> # | col1 |
  124. >>> # +---------+
  125. >>> # | [1,2,3] |
  126. >>> # +---------+
  127. >>> data = data.map(operations=Mask(Relational.EQ, 2))
  128. >>> # Data after
  129. >>> # | col1 |
  130. >>> # +--------------------+
  131. >>> # | [False,True,False] |
  132. >>> # +--------------------+
  133. """
  134. @check_mask_op
  135. def __init__(self, operator, constant, dtype=mstype.bool_):
  136. dtype = mstype_to_detype(dtype)
  137. constant = cde.Tensor(np.array(constant))
  138. super().__init__(DE_C_RELATIONAL[operator], constant, dtype)
  139. class PadEnd(cde.PadEndOp):
  140. """
  141. Pad input tensor according to `pad_shape`, need to have same rank.
  142. Args:
  143. pad_shape (list(int)): list on integers representing the shape needed. Dimensions that set to `None` will
  144. not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values.
  145. pad_value (Union[str, bytes, int, float, bool]), optional): value used to pad. Default to 0 or empty
  146. string in case of Tensors of strings.
  147. Examples:
  148. >>> # Data before
  149. >>> # | col |
  150. >>> # +---------+
  151. >>> # | [1,2,3] |
  152. >>> # +---------|
  153. >>> data = data.map(operations=PadEnd(pad_shape=[4], pad_value=10))
  154. >>> # Data after
  155. >>> # | col |
  156. >>> # +------------+
  157. >>> # | [1,2,3,10] |
  158. >>> # +------------|
  159. """
  160. @check_pad_end
  161. def __init__(self, pad_shape, pad_value=None):
  162. if pad_value is not None:
  163. pad_value = cde.Tensor(np.array(pad_value))
  164. super().__init__(cde.TensorShape(pad_shape), pad_value)
  165. class Concatenate(cde.ConcatenateOp):
  166. """
  167. Tensor operation that concatenates all columns into a single tensor.
  168. Args:
  169. axis (int, optional): concatenate the tensors along given axis (Default=0).
  170. prepend (numpy.array, optional): numpy array to be prepended to the already concatenated tensors (Default=None).
  171. append (numpy.array, optional): numpy array to be appended to the already concatenated tensors (Default=None).
  172. """
  173. @check_concat_type
  174. def __init__(self, axis=0, prepend=None, append=None):
  175. if prepend is not None:
  176. prepend = cde.Tensor(np.array(prepend))
  177. if append is not None:
  178. append = cde.Tensor(np.array(append))
  179. super().__init__(axis, prepend, append)
  180. class Duplicate(cde.DuplicateOp):
  181. """
  182. Duplicate the input tensor to a new output tensor. The input tensor is carried over to the output list.
  183. Examples:
  184. >>> # Data before
  185. >>> # | x |
  186. >>> # +---------+
  187. >>> # | [1,2,3] |
  188. >>> # +---------+
  189. >>> data = data.map(input_columns=["x"], operations=Duplicate(),
  190. >>> output_columns=["x", "y"], columns_order=["x", "y"])
  191. >>> # Data after
  192. >>> # | x | y |
  193. >>> # +---------+---------+
  194. >>> # | [1,2,3] | [1,2,3] |
  195. >>> # +---------+---------+
  196. """
  197. class Compose(cde.ComposeOp):
  198. """
  199. Compose a list of transforms into a single transform.
  200. Args:
  201. transforms (list): List of transformations to be applied.
  202. Examples:
  203. >>> compose = Compose([vision.Decode(), vision.RandomCrop()])
  204. >>> dataset = ds.map(operations=compose)
  205. """
  206. @check_random_transform_ops
  207. def __init__(self, transforms):
  208. super().__init__(transforms)
  209. class RandomApply(cde.RandomApplyOp):
  210. """
  211. Randomly performs a series of transforms with a given probability.
  212. Args:
  213. transforms (list): List of transformations to be applied.
  214. prob (float, optional): The probability to apply the transformation list (default=0.5)
  215. Examples:
  216. >>> rand_apply = RandomApply([vision.RandomCrop()])
  217. >>> dataset = ds.map(operations=rand_apply)
  218. """
  219. @check_random_transform_ops
  220. def __init__(self, transforms, prob=0.5):
  221. super().__init__(prob, transforms)
  222. class RandomChoice(cde.RandomChoiceOp):
  223. """
  224. Randomly selects one transform from a list of transforms to perform operation.
  225. Args:
  226. transforms (list): List of transformations to be chosen from to apply.
  227. Examples:
  228. >>> rand_choice = RandomChoice([vision.CenterCrop(), vision.RandomCrop()])
  229. >>> dataset = ds.map(operations=rand_choice)
  230. """
  231. @check_random_transform_ops
  232. def __init__(self, transforms):
  233. super().__init__(transforms)