You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rgb_bgr.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing RgbToBgr op in DE
  17. """
  18. import numpy as np
  19. from numpy.testing import assert_allclose
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.transforms.py_transforms
  22. import mindspore.dataset.vision.c_transforms as vision
  23. import mindspore.dataset.vision.py_transforms as py_vision
  24. import mindspore.dataset.vision.py_transforms_util as util
  25. GENERATE_GOLDEN = False
  26. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  27. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  28. def generate_numpy_random_rgb(shape):
  29. # Only generate floating points that are fractions like n / 256, since they
  30. # are RGB pixels. Some low-precision floating point types in this test can't
  31. # handle arbitrary precision floating points well.
  32. return np.random.randint(0, 256, shape) / 255.
  33. def test_rgb_bgr_hwc_py():
  34. # Eager
  35. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  36. rgb_np = rgb_flat.reshape((8, 8, 3))
  37. bgr_np_pred = util.rgb_to_bgrs(rgb_np, True)
  38. r, g, b = rgb_np[:, :, 0], rgb_np[:, :, 1], rgb_np[:, :, 2]
  39. bgr_np_gt = np.stack((b, g, r), axis=2)
  40. assert bgr_np_pred.shape == rgb_np.shape
  41. assert_allclose(bgr_np_pred.flatten(),
  42. bgr_np_gt.flatten(),
  43. rtol=1e-5,
  44. atol=0)
  45. def test_rgb_bgr_hwc_c():
  46. # Eager
  47. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  48. rgb_np = rgb_flat.reshape((8, 8, 3))
  49. rgb2bgr_op = vision.RgbToBgr()
  50. bgr_np_pred = rgb2bgr_op(rgb_np)
  51. r, g, b = rgb_np[:, :, 0], rgb_np[:, :, 1], rgb_np[:, :, 2]
  52. bgr_np_gt = np.stack((b, g, r), axis=2)
  53. assert bgr_np_pred.shape == rgb_np.shape
  54. assert_allclose(bgr_np_pred.flatten(),
  55. bgr_np_gt.flatten(),
  56. rtol=1e-5,
  57. atol=0)
  58. def test_rgb_bgr_chw_py():
  59. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  60. rgb_np = rgb_flat.reshape((3, 8, 8))
  61. rgb_np_pred = util.rgb_to_bgrs(rgb_np, False)
  62. rgb_np_gt = rgb_np[::-1, :, :]
  63. assert rgb_np_pred.shape == rgb_np.shape
  64. assert_allclose(rgb_np_pred.flatten(),
  65. rgb_np_gt.flatten(),
  66. rtol=1e-5,
  67. atol=0)
  68. def test_rgb_bgr_pipeline_py():
  69. # First dataset
  70. transforms1 = [py_vision.Decode(), py_vision.Resize([64, 64]), py_vision.ToTensor()]
  71. transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
  72. transforms1)
  73. ds1 = ds.TFRecordDataset(DATA_DIR,
  74. SCHEMA_DIR,
  75. columns_list=["image"],
  76. shuffle=False)
  77. ds1 = ds1.map(operations=transforms1, input_columns=["image"])
  78. # Second dataset
  79. transforms2 = [
  80. py_vision.Decode(),
  81. py_vision.Resize([64, 64]),
  82. py_vision.ToTensor(),
  83. py_vision.RgbToBgr()
  84. ]
  85. transforms2 = mindspore.dataset.transforms.py_transforms.Compose(
  86. transforms2)
  87. ds2 = ds.TFRecordDataset(DATA_DIR,
  88. SCHEMA_DIR,
  89. columns_list=["image"],
  90. shuffle=False)
  91. ds2 = ds2.map(operations=transforms2, input_columns=["image"])
  92. num_iter = 0
  93. for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
  94. ds2.create_dict_iterator(num_epochs=1)):
  95. num_iter += 1
  96. ori_img = data1["image"].asnumpy()
  97. cvt_img = data2["image"].asnumpy()
  98. cvt_img_gt = ori_img[::-1, :, :]
  99. assert_allclose(cvt_img_gt.flatten(),
  100. cvt_img.flatten(),
  101. rtol=1e-5,
  102. atol=0)
  103. assert ori_img.shape == cvt_img.shape
  104. def test_rgb_bgr_pipeline_c():
  105. # First dataset
  106. transforms1 = [
  107. vision.Decode(),
  108. vision.Resize([64, 64])
  109. ]
  110. transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
  111. transforms1)
  112. ds1 = ds.TFRecordDataset(DATA_DIR,
  113. SCHEMA_DIR,
  114. columns_list=["image"],
  115. shuffle=False)
  116. ds1 = ds1.map(operations=transforms1, input_columns=["image"])
  117. # Second dataset
  118. transforms2 = [
  119. vision.Decode(),
  120. vision.Resize([64, 64]),
  121. vision.RgbToBgr()
  122. ]
  123. transforms2 = mindspore.dataset.transforms.py_transforms.Compose(
  124. transforms2)
  125. ds2 = ds.TFRecordDataset(DATA_DIR,
  126. SCHEMA_DIR,
  127. columns_list=["image"],
  128. shuffle=False)
  129. ds2 = ds2.map(operations=transforms2, input_columns=["image"])
  130. num_iter = 0
  131. for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
  132. ds2.create_dict_iterator(num_epochs=1)):
  133. num_iter += 1
  134. ori_img = data1["image"].asnumpy()
  135. cvt_img = data2["image"].asnumpy()
  136. cvt_img_gt = ori_img[:, :, ::-1]
  137. assert_allclose(cvt_img_gt.flatten(),
  138. cvt_img.flatten(),
  139. rtol=1e-5,
  140. atol=0)
  141. assert ori_img.shape == cvt_img.shape
  142. if __name__ == "__main__":
  143. test_rgb_bgr_hwc_py()
  144. test_rgb_bgr_hwc_c()
  145. test_rgb_bgr_chw_py()
  146. test_rgb_bgr_pipeline_py()
  147. test_rgb_bgr_pipeline_c()