You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mixup_label_smoothing.py 5.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.transforms.c_transforms as c
  18. import mindspore.dataset.transforms.py_transforms as f
  19. import mindspore.dataset.transforms.vision.c_transforms as c_vision
  20. import mindspore.dataset.transforms.vision.py_transforms as py_vision
  21. from mindspore import log as logger
  22. DATA_DIR = "../data/dataset/testImageNetData/train"
  23. DATA_DIR_2 = "../data/dataset/testImageNetData2/train"
  24. def test_one_hot_op():
  25. """
  26. Test one hot encoding op
  27. """
  28. logger.info("Test one hot encoding op")
  29. # define map operations
  30. # ds = de.ImageFolderDataset(DATA_DIR, schema=SCHEMA_DIR)
  31. dataset = ds.ImageFolderDatasetV2(DATA_DIR)
  32. num_classes = 2
  33. epsilon_para = 0.1
  34. transforms = [f.OneHotOp(num_classes=num_classes, smoothing_rate=epsilon_para),
  35. ]
  36. transform_label = py_vision.ComposeOp(transforms)
  37. dataset = dataset.map(input_columns=["label"], operations=transform_label())
  38. golden_label = np.ones(num_classes) * epsilon_para / num_classes
  39. golden_label[1] = 1 - epsilon_para / num_classes
  40. for data in dataset.create_dict_iterator():
  41. label = data["label"]
  42. logger.info("label is {}".format(label))
  43. logger.info("golden_label is {}".format(golden_label))
  44. assert label.all() == golden_label.all()
  45. logger.info("====test one hot op ok====")
  46. def test_mix_up_single():
  47. """
  48. Test single batch mix up op
  49. """
  50. logger.info("Test single batch mix up op")
  51. resize_height = 224
  52. resize_width = 224
  53. # Create dataset and define map operations
  54. ds1 = ds.ImageFolderDatasetV2(DATA_DIR_2)
  55. num_classes = 10
  56. decode_op = c_vision.Decode()
  57. resize_op = c_vision.Resize((resize_height, resize_width), c_vision.Inter.LINEAR)
  58. one_hot_encode = c.OneHot(num_classes) # num_classes is input argument
  59. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  60. ds1 = ds1.map(input_columns=["image"], operations=resize_op)
  61. ds1 = ds1.map(input_columns=["label"], operations=one_hot_encode)
  62. # apply batch operations
  63. batch_size = 3
  64. ds1 = ds1.batch(batch_size, drop_remainder=True)
  65. ds2 = ds1
  66. alpha = 0.2
  67. transforms = [py_vision.MixUp(batch_size=batch_size, alpha=alpha, is_single=True)
  68. ]
  69. ds1 = ds1.map(input_columns=["image", "label"], operations=transforms)
  70. for data1, data2 in zip(ds1.create_dict_iterator(), ds2.create_dict_iterator()):
  71. image1 = data1["image"]
  72. label = data1["label"]
  73. logger.info("label is {}".format(label))
  74. image2 = data2["image"]
  75. label2 = data2["label"]
  76. logger.info("label2 is {}".format(label2))
  77. lam = np.abs(label - label2)
  78. for index in range(batch_size - 1):
  79. if np.square(lam[index]).mean() != 0:
  80. lam_value = 1 - np.sum(lam[index]) / 2
  81. img_golden = lam_value * image2[index] + (1 - lam_value) * image2[index + 1]
  82. assert image1[index].all() == img_golden.all()
  83. logger.info("====test single batch mixup ok====")
  84. def test_mix_up_multi():
  85. """
  86. Test multi batch mix up op
  87. """
  88. logger.info("Test several batch mix up op")
  89. resize_height = 224
  90. resize_width = 224
  91. # Create dataset and define map operations
  92. ds1 = ds.ImageFolderDatasetV2(DATA_DIR_2)
  93. num_classes = 3
  94. decode_op = c_vision.Decode()
  95. resize_op = c_vision.Resize((resize_height, resize_width), c_vision.Inter.LINEAR)
  96. one_hot_encode = c.OneHot(num_classes) # num_classes is input argument
  97. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  98. ds1 = ds1.map(input_columns=["image"], operations=resize_op)
  99. ds1 = ds1.map(input_columns=["label"], operations=one_hot_encode)
  100. # apply batch operations
  101. batch_size = 3
  102. ds1 = ds1.batch(batch_size, drop_remainder=True)
  103. ds2 = ds1
  104. alpha = 0.2
  105. transforms = [py_vision.MixUp(batch_size=batch_size, alpha=alpha, is_single=False)
  106. ]
  107. ds1 = ds1.map(input_columns=["image", "label"], operations=transforms)
  108. num_iter = 0
  109. batch1_image1 = 0
  110. for data1, data2 in zip(ds1.create_dict_iterator(), ds2.create_dict_iterator()):
  111. image1 = data1["image"]
  112. label1 = data1["label"]
  113. logger.info("label: {}".format(label1))
  114. image2 = data2["image"]
  115. label2 = data2["label"]
  116. logger.info("label2: {}".format(label2))
  117. if num_iter == 0:
  118. batch1_image1 = image1
  119. if num_iter == 1:
  120. lam = np.abs(label2 - label1)
  121. logger.info("lam value in multi: {}".format(lam))
  122. for index in range(batch_size):
  123. if np.square(lam[index]).mean() != 0:
  124. lam_value = 1 - np.sum(lam[index]) / 2
  125. img_golden = lam_value * image2[index] + (1 - lam_value) * batch1_image1[index]
  126. assert image1[index].all() == img_golden.all()
  127. logger.info("====test several batch mixup ok====")
  128. break
  129. num_iter = num_iter + 1
  130. if __name__ == "__main__":
  131. test_one_hot_op()
  132. test_mix_up_single()
  133. test_mix_up_multi()