You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rgb_bgr.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing RgbToBgr op in DE
  17. """
  18. import numpy as np
  19. from numpy.testing import assert_allclose
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.transforms.py_transforms
  22. import mindspore.dataset.vision.c_transforms as vision
  23. import mindspore.dataset.vision.py_transforms as py_vision
  24. import mindspore.dataset.vision.py_transforms_util as util
  25. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  26. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  27. def generate_numpy_random_rgb(shape):
  28. # Only generate floating points that are fractions like n / 256, since they
  29. # are RGB pixels. Some low-precision floating point types in this test can't
  30. # handle arbitrary precision floating points well.
  31. return np.random.randint(0, 256, shape) / 255.
  32. def test_rgb_bgr_hwc_py():
  33. # Eager
  34. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  35. rgb_np = rgb_flat.reshape((8, 8, 3))
  36. bgr_np_pred = util.rgb_to_bgrs(rgb_np, True)
  37. r, g, b = rgb_np[:, :, 0], rgb_np[:, :, 1], rgb_np[:, :, 2]
  38. bgr_np_gt = np.stack((b, g, r), axis=2)
  39. assert bgr_np_pred.shape == rgb_np.shape
  40. assert_allclose(bgr_np_pred.flatten(),
  41. bgr_np_gt.flatten(),
  42. rtol=1e-5,
  43. atol=0)
  44. def test_rgb_bgr_hwc_c():
  45. # Eager
  46. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  47. rgb_np = rgb_flat.reshape((8, 8, 3))
  48. rgb2bgr_op = vision.RgbToBgr()
  49. bgr_np_pred = rgb2bgr_op(rgb_np)
  50. r, g, b = rgb_np[:, :, 0], rgb_np[:, :, 1], rgb_np[:, :, 2]
  51. bgr_np_gt = np.stack((b, g, r), axis=2)
  52. assert bgr_np_pred.shape == rgb_np.shape
  53. assert_allclose(bgr_np_pred.flatten(),
  54. bgr_np_gt.flatten(),
  55. rtol=1e-5,
  56. atol=0)
  57. def test_rgb_bgr_chw_py():
  58. rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
  59. rgb_np = rgb_flat.reshape((3, 8, 8))
  60. rgb_np_pred = util.rgb_to_bgrs(rgb_np, False)
  61. rgb_np_gt = rgb_np[::-1, :, :]
  62. assert rgb_np_pred.shape == rgb_np.shape
  63. assert_allclose(rgb_np_pred.flatten(),
  64. rgb_np_gt.flatten(),
  65. rtol=1e-5,
  66. atol=0)
  67. def test_rgb_bgr_pipeline_py():
  68. # First dataset
  69. transforms1 = [py_vision.Decode(), py_vision.Resize([64, 64]), py_vision.ToTensor()]
  70. transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
  71. transforms1)
  72. ds1 = ds.TFRecordDataset(DATA_DIR,
  73. SCHEMA_DIR,
  74. columns_list=["image"],
  75. shuffle=False)
  76. ds1 = ds1.map(operations=transforms1, input_columns=["image"])
  77. # Second dataset
  78. transforms2 = [
  79. py_vision.Decode(),
  80. py_vision.Resize([64, 64]),
  81. py_vision.ToTensor(),
  82. py_vision.RgbToBgr()
  83. ]
  84. transforms2 = mindspore.dataset.transforms.py_transforms.Compose(
  85. transforms2)
  86. ds2 = ds.TFRecordDataset(DATA_DIR,
  87. SCHEMA_DIR,
  88. columns_list=["image"],
  89. shuffle=False)
  90. ds2 = ds2.map(operations=transforms2, input_columns=["image"])
  91. num_iter = 0
  92. for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
  93. ds2.create_dict_iterator(num_epochs=1)):
  94. num_iter += 1
  95. ori_img = data1["image"].asnumpy()
  96. cvt_img = data2["image"].asnumpy()
  97. cvt_img_gt = ori_img[::-1, :, :]
  98. assert_allclose(cvt_img_gt.flatten(),
  99. cvt_img.flatten(),
  100. rtol=1e-5,
  101. atol=0)
  102. assert ori_img.shape == cvt_img.shape
  103. def test_rgb_bgr_pipeline_c():
  104. # First dataset
  105. transforms1 = [
  106. vision.Decode(),
  107. vision.Resize([64, 64])
  108. ]
  109. transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
  110. transforms1)
  111. ds1 = ds.TFRecordDataset(DATA_DIR,
  112. SCHEMA_DIR,
  113. columns_list=["image"],
  114. shuffle=False)
  115. ds1 = ds1.map(operations=transforms1, input_columns=["image"])
  116. # Second dataset
  117. transforms2 = [
  118. vision.Decode(),
  119. vision.Resize([64, 64]),
  120. vision.RgbToBgr()
  121. ]
  122. transforms2 = mindspore.dataset.transforms.py_transforms.Compose(
  123. transforms2)
  124. ds2 = ds.TFRecordDataset(DATA_DIR,
  125. SCHEMA_DIR,
  126. columns_list=["image"],
  127. shuffle=False)
  128. ds2 = ds2.map(operations=transforms2, input_columns=["image"])
  129. num_iter = 0
  130. for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
  131. ds2.create_dict_iterator(num_epochs=1)):
  132. num_iter += 1
  133. ori_img = data1["image"].asnumpy()
  134. cvt_img = data2["image"].asnumpy()
  135. cvt_img_gt = ori_img[:, :, ::-1]
  136. assert_allclose(cvt_img_gt.flatten(),
  137. cvt_img.flatten(),
  138. rtol=1e-5,
  139. atol=0)
  140. assert ori_img.shape == cvt_img.shape
  141. if __name__ == "__main__":
  142. test_rgb_bgr_hwc_py()
  143. test_rgb_bgr_hwc_c()
  144. test_rgb_bgr_chw_py()
  145. test_rgb_bgr_pipeline_py()
  146. test_rgb_bgr_pipeline_c()