You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mixup_label_smoothing.py 6.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.transforms.c_transforms as c
  18. import mindspore.dataset.transforms.py_transforms as f
  19. import mindspore.dataset.vision.c_transforms as c_vision
  20. import mindspore.dataset.vision.py_transforms as py_vision
  21. from mindspore import log as logger
  22. DATA_DIR = "../data/dataset/testImageNetData/train"
  23. DATA_DIR_2 = "../data/dataset/testImageNetData2/train"
  24. def test_one_hot_op():
  25. """
  26. Test one hot encoding op
  27. """
  28. logger.info("Test one hot encoding op")
  29. # define map operations
  30. # ds = de.ImageFolderDataset(DATA_DIR, schema=SCHEMA_DIR)
  31. dataset = ds.ImageFolderDataset(DATA_DIR)
  32. num_classes = 2
  33. epsilon_para = 0.1
  34. transforms = [f.OneHotOp(num_classes=num_classes, smoothing_rate=epsilon_para)]
  35. transform_label = f.Compose(transforms)
  36. dataset = dataset.map(operations=transform_label, input_columns=["label"])
  37. golden_label = np.ones(num_classes) * epsilon_para / num_classes
  38. golden_label[1] = 1 - epsilon_para / num_classes
  39. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  40. label = data["label"]
  41. logger.info("label is {}".format(label))
  42. logger.info("golden_label is {}".format(golden_label))
  43. assert label.all() == golden_label.all()
  44. logger.info("====test one hot op ok====")
  45. def test_mix_up_single():
  46. """
  47. Test single batch mix up op
  48. """
  49. logger.info("Test single batch mix up op")
  50. resize_height = 224
  51. resize_width = 224
  52. # Create dataset and define map operations
  53. ds1 = ds.ImageFolderDataset(DATA_DIR_2)
  54. num_classes = 10
  55. decode_op = c_vision.Decode()
  56. resize_op = c_vision.Resize((resize_height, resize_width), c_vision.Inter.LINEAR)
  57. one_hot_encode = c.OneHot(num_classes) # num_classes is input argument
  58. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  59. ds1 = ds1.map(operations=resize_op, input_columns=["image"])
  60. ds1 = ds1.map(operations=one_hot_encode, input_columns=["label"])
  61. # apply batch operations
  62. batch_size = 3
  63. ds1 = ds1.batch(batch_size, drop_remainder=True)
  64. ds2 = ds1
  65. alpha = 0.2
  66. transforms = [py_vision.MixUp(batch_size=batch_size, alpha=alpha, is_single=True)
  67. ]
  68. ds1 = ds1.map(operations=transforms, input_columns=["image", "label"])
  69. for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1, output_numpy=True),
  70. ds2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  71. image1 = data1["image"]
  72. label = data1["label"]
  73. logger.info("label is {}".format(label))
  74. image2 = data2["image"]
  75. label2 = data2["label"]
  76. logger.info("label2 is {}".format(label2))
  77. lam = np.abs(label - label2)
  78. for index in range(batch_size - 1):
  79. if np.square(lam[index]).mean() != 0:
  80. lam_value = 1 - np.sum(lam[index]) / 2
  81. img_golden = lam_value * image2[index] + (1 - lam_value) * image2[index + 1]
  82. assert image1[index].all() == img_golden.all()
  83. logger.info("====test single batch mixup ok====")
  84. def test_mix_up_multi():
  85. """
  86. Test multi batch mix up op
  87. """
  88. logger.info("Test several batch mix up op")
  89. resize_height = 224
  90. resize_width = 224
  91. # Create dataset and define map operations
  92. ds1 = ds.ImageFolderDataset(DATA_DIR_2)
  93. num_classes = 3
  94. decode_op = c_vision.Decode()
  95. resize_op = c_vision.Resize((resize_height, resize_width), c_vision.Inter.LINEAR)
  96. one_hot_encode = c.OneHot(num_classes) # num_classes is input argument
  97. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  98. ds1 = ds1.map(operations=resize_op, input_columns=["image"])
  99. ds1 = ds1.map(operations=one_hot_encode, input_columns=["label"])
  100. # apply batch operations
  101. batch_size = 3
  102. ds1 = ds1.batch(batch_size, drop_remainder=True)
  103. ds2 = ds1
  104. alpha = 0.2
  105. transforms = [py_vision.MixUp(batch_size=batch_size, alpha=alpha, is_single=False)
  106. ]
  107. ds1 = ds1.map(operations=transforms, input_columns=["image", "label"])
  108. num_iter = 0
  109. batch1_image1 = 0
  110. for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1, output_numpy=True),
  111. ds2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  112. image1 = data1["image"]
  113. label1 = data1["label"]
  114. logger.info("label: {}".format(label1))
  115. image2 = data2["image"]
  116. label2 = data2["label"]
  117. logger.info("label2: {}".format(label2))
  118. if num_iter == 0:
  119. batch1_image1 = image1
  120. if num_iter == 1:
  121. lam = np.abs(label2 - label1)
  122. logger.info("lam value in multi: {}".format(lam))
  123. for index in range(batch_size):
  124. if np.square(lam[index]).mean() != 0:
  125. lam_value = 1 - np.sum(lam[index]) / 2
  126. img_golden = lam_value * image2[index] + (1 - lam_value) * batch1_image1[index]
  127. assert image1[index].all() == img_golden.all()
  128. logger.info("====test several batch mixup ok====")
  129. break
  130. num_iter = num_iter + 1
  131. if __name__ == "__main__":
  132. test_one_hot_op()
  133. test_mix_up_single()
  134. test_mix_up_multi()