You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rgb_hsv.py 6.6 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing RgbToHsv and HsvToRgb op in DE
  17. """
  18. import colorsys
  19. import numpy as np
  20. from numpy.testing import assert_allclose
  21. import mindspore.dataset as ds
  22. import mindspore.dataset.transforms.py_transforms
  23. import mindspore.dataset.vision.py_transforms as vision
  24. import mindspore.dataset.vision.py_transforms_util as util
  25. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  26. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  27. def generate_numpy_random_rgb(shape):
  28. # Only generate floating points that are fractions like n / 256, since they
  29. # are RGB pixels. Some low-precision floating point types in this test can't
  30. # handle arbitrary precision floating points well.
  31. return np.random.randint(0, 256, shape) / 255.
  32. def test_rgb_hsv_hwc():
  33. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  34. rgb_np = rgb_flat.reshape((8, 8, 3))
  35. hsv_base = np.array([
  36. colorsys.rgb_to_hsv(
  37. r.astype(np.float64), g.astype(np.float64), b.astype(np.float64))
  38. for r, g, b in rgb_flat
  39. ])
  40. hsv_base = hsv_base.reshape((8, 8, 3))
  41. hsv_de = util.rgb_to_hsvs(rgb_np, True)
  42. assert hsv_base.shape == hsv_de.shape
  43. assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
  44. hsv_flat = hsv_base.reshape(64, 3)
  45. rgb_base = np.array([
  46. colorsys.hsv_to_rgb(
  47. h.astype(np.float64), s.astype(np.float64), v.astype(np.float64))
  48. for h, s, v in hsv_flat
  49. ])
  50. rgb_base = rgb_base.reshape((8, 8, 3))
  51. rgb_de = util.hsv_to_rgbs(hsv_base, True)
  52. assert rgb_base.shape == rgb_de.shape
  53. assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
  54. def test_rgb_hsv_batch_hwc():
  55. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  56. rgb_np = rgb_flat.reshape((4, 2, 8, 3))
  57. hsv_base = np.array([
  58. colorsys.rgb_to_hsv(
  59. r.astype(np.float64), g.astype(np.float64), b.astype(np.float64))
  60. for r, g, b in rgb_flat
  61. ])
  62. hsv_base = hsv_base.reshape((4, 2, 8, 3))
  63. hsv_de = util.rgb_to_hsvs(rgb_np, True)
  64. assert hsv_base.shape == hsv_de.shape
  65. assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
  66. hsv_flat = hsv_base.reshape((64, 3))
  67. rgb_base = np.array([
  68. colorsys.hsv_to_rgb(
  69. h.astype(np.float64), s.astype(np.float64), v.astype(np.float64))
  70. for h, s, v in hsv_flat
  71. ])
  72. rgb_base = rgb_base.reshape((4, 2, 8, 3))
  73. rgb_de = util.hsv_to_rgbs(hsv_base, True)
  74. assert rgb_de.shape == rgb_base.shape
  75. assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
  76. def test_rgb_hsv_chw():
  77. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  78. rgb_np = rgb_flat.reshape((3, 8, 8))
  79. hsv_base = np.array([
  80. np.vectorize(colorsys.rgb_to_hsv)(
  81. rgb_np[0, :, :].astype(np.float64), rgb_np[1, :, :].astype(np.float64), rgb_np[2, :, :].astype(np.float64))
  82. ])
  83. hsv_base = hsv_base.reshape((3, 8, 8))
  84. hsv_de = util.rgb_to_hsvs(rgb_np, False)
  85. assert hsv_base.shape == hsv_de.shape
  86. assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
  87. rgb_base = np.array([
  88. np.vectorize(colorsys.hsv_to_rgb)(
  89. hsv_base[0, :, :].astype(np.float64), hsv_base[1, :, :].astype(np.float64),
  90. hsv_base[2, :, :].astype(np.float64))
  91. ])
  92. rgb_base = rgb_base.reshape((3, 8, 8))
  93. rgb_de = util.hsv_to_rgbs(hsv_base, False)
  94. assert rgb_de.shape == rgb_base.shape
  95. assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
  96. def test_rgb_hsv_batch_chw():
  97. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  98. rgb_imgs = rgb_flat.reshape((4, 3, 2, 8))
  99. hsv_base_imgs = np.array([
  100. np.vectorize(colorsys.rgb_to_hsv)(
  101. img[0, :, :].astype(np.float64), img[1, :, :].astype(np.float64), img[2, :, :].astype(np.float64))
  102. for img in rgb_imgs
  103. ])
  104. hsv_de = util.rgb_to_hsvs(rgb_imgs, False)
  105. assert hsv_base_imgs.shape == hsv_de.shape
  106. assert_allclose(hsv_base_imgs.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
  107. rgb_base = np.array([
  108. np.vectorize(colorsys.hsv_to_rgb)(
  109. img[0, :, :].astype(np.float64), img[1, :, :].astype(np.float64), img[2, :, :].astype(np.float64))
  110. for img in hsv_base_imgs
  111. ])
  112. rgb_de = util.hsv_to_rgbs(hsv_base_imgs, False)
  113. assert rgb_base.shape == rgb_de.shape
  114. assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
  115. def test_rgb_hsv_pipeline():
  116. # First dataset
  117. transforms1 = [
  118. vision.Decode(),
  119. vision.Resize([64, 64]),
  120. vision.ToTensor()
  121. ]
  122. transforms1 = mindspore.dataset.transforms.py_transforms.Compose(transforms1)
  123. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  124. ds1 = ds1.map(operations=transforms1, input_columns=["image"])
  125. # Second dataset
  126. transforms2 = [
  127. vision.Decode(),
  128. vision.Resize([64, 64]),
  129. vision.ToTensor(),
  130. vision.RgbToHsv(),
  131. vision.HsvToRgb()
  132. ]
  133. transform2 = mindspore.dataset.transforms.py_transforms.Compose(transforms2)
  134. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  135. ds2 = ds2.map(operations=transform2, input_columns=["image"])
  136. num_iter = 0
  137. for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1), ds2.create_dict_iterator(num_epochs=1)):
  138. num_iter += 1
  139. ori_img = data1["image"]
  140. cvt_img = data2["image"]
  141. assert_allclose(ori_img.flatten(), cvt_img.flatten(), rtol=1e-5, atol=0)
  142. assert ori_img.shape == cvt_img.shape
  143. if __name__ == "__main__":
  144. test_rgb_hsv_hwc()
  145. test_rgb_hsv_batch_hwc()
  146. test_rgb_hsv_chw()
  147. test_rgb_hsv_batch_chw()
  148. test_rgb_hsv_pipeline()