You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_num_samples.py 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore.dataset as ds
  16. from mindspore import log as logger
  17. def test_num_samples():
  18. manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
  19. num_samples = 1
  20. # sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
  21. data1 = ds.ManifestDataset(
  22. manifest_file, num_samples=num_samples, num_shards=3, shard_id=1
  23. )
  24. row_count = 0
  25. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  26. row_count += 1
  27. assert row_count == 1
  28. def test_num_samples_tf():
  29. logger.info("test_tfrecord_read_all_dataset")
  30. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
  31. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  32. # here num samples indicate the rows per shard. Total rows in file = 12
  33. ds1 = ds.TFRecordDataset(files, schema_file, num_samples=2)
  34. count = 0
  35. for _ in ds1.create_tuple_iterator(num_epochs=1):
  36. count += 1
  37. assert count == 2
  38. def test_num_samples_image_folder():
  39. data_dir = "../data/dataset/testPK/data"
  40. ds1 = ds.ImageFolderDataset(data_dir, num_samples=2, num_shards=2, shard_id=0)
  41. count = 0
  42. for _ in ds1.create_tuple_iterator(num_epochs=1):
  43. count += 1
  44. assert count == 2
  45. if __name__ == "__main__":
  46. test_num_samples()
  47. test_num_samples_tf()
  48. test_num_samples_image_folder()