You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_deviceop_cpu.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore.dataset.transforms.vision.c_transforms as vision
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  19. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  20. TF_FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
  21. TF_SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  22. def test_case_0():
  23. """
  24. Test Repeat
  25. """
  26. # apply dataset operations
  27. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  28. # define parameters
  29. repeat_count = 2
  30. data = data.repeat(repeat_count)
  31. data = data.device_que()
  32. data.send()
  33. def test_case_1():
  34. """
  35. Test Batch
  36. """
  37. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  38. # define data augmentation parameters
  39. resize_height, resize_width = 224, 224
  40. # define map operations
  41. decode_op = vision.Decode()
  42. resize_op = vision.Resize((resize_height, resize_width))
  43. # apply map operations on images
  44. data = data.map(input_columns=["image"], operations=decode_op)
  45. data = data.map(input_columns=["image"], operations=resize_op)
  46. batch_size = 3
  47. data = data.batch(batch_size, drop_remainder=True)
  48. data = data.device_que()
  49. data.send()
  50. def test_case_2():
  51. """
  52. Test Batch & Repeat
  53. """
  54. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  55. # define data augmentation parameters
  56. resize_height, resize_width = 224, 224
  57. # define map operations
  58. decode_op = vision.Decode()
  59. resize_op = vision.Resize((resize_height, resize_width))
  60. # apply map operations on images
  61. data = data.map(input_columns=["image"], operations=decode_op)
  62. data = data.map(input_columns=["image"], operations=resize_op)
  63. batch_size = 2
  64. data = data.batch(batch_size, drop_remainder=True)
  65. data = data.repeat(2)
  66. data = data.device_que()
  67. assert data.get_repeat_count() == 2
  68. data.send()
  69. def test_case_3():
  70. """
  71. Test Repeat & Batch
  72. """
  73. data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  74. # define data augmentation parameters
  75. resize_height, resize_width = 224, 224
  76. # define map operations
  77. decode_op = vision.Decode()
  78. resize_op = vision.Resize((resize_height, resize_width))
  79. # apply map operations on images
  80. data = data.map(input_columns=["image"], operations=decode_op)
  81. data = data.map(input_columns=["image"], operations=resize_op)
  82. data = data.repeat(2)
  83. batch_size = 2
  84. data = data.batch(batch_size, drop_remainder=True)
  85. data = data.device_que()
  86. data.send()
  87. def test_case_tf_file():
  88. data = ds.TFRecordDataset(TF_FILES, TF_SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
  89. data = data.to_device(num_batch=10)
  90. data.send()
  91. if __name__ == '__main__':
  92. logger.info('===========now test Repeat============')
  93. test_case_0()
  94. logger.info('===========now test Batch============')
  95. test_case_1()
  96. logger.info('===========now test Batch & Repeat============')
  97. test_case_2()
  98. logger.info('===========now test Repeat & Batch============')
  99. test_case_3()
  100. logger.info('===========now test tf file============')
  101. test_case_tf_file()