You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

read_imagenet.py 2.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. '''
  2. imagnet-1k 数据集已通过磁盘挂载的方式挂载在训练镜像中,
  3. 用户可参考下列示例代码读取直接使用。
  4. 挂载路径为
  5. .
  6. └── cache/
  7. ├── ascend
  8. ├── outputs
  9. ├── user-job-dir
  10. └── sfs/
  11. └── data/
  12. └── imagenet/
  13. ├── train/
  14. │ └── n01440764/
  15. │ └── n01440764_11063.JPEG
  16. └── val/
  17. └── n01440764/
  18. └── ILSVRC2012_val_00011993.JPEG
  19. mindspore.dataset.ImageFolderDataset
  20. - 读取imagenet-1k数据,同一文件夹下的数据为同一类class。
  21. mindspore.dataset.vision.c_transforms
  22. - 数据加载和预处理。
  23. mindspore.dataset.ImageFolderDataset
  24. - map:给定一组数据增强列表,按顺序将数据增强作用在数据集对象上。
  25. - batch:将数据集中连续 batch_size 条数据合并为一个批处理数据。
  26. - to_json:将数据处理管道序列化为JSON字符串,如果提供了文件名,则转储到文件中。
  27. '''
  28. import os
  29. import argparse
  30. import moxing as mox
  31. import mindspore as ms
  32. from mindspore.dataset import ImageFolderDataset
  33. import mindspore.dataset.vision.c_transforms as transforms
  34. from openi.context import upload_openi
  35. parser = argparse.ArgumentParser(description='Read big dataset ImageNet Example')
  36. parser.add_argument('--train_url',
  37. help='output folder to save/load',
  38. default= '/cache/output/')
  39. if __name__ == "__main__":
  40. args, unknown = parser.parse_known_args()
  41. data_path = '/cache/sfs/data/imagenet/'
  42. modelart_output = '/cache/output'
  43. if not os.path.exists(modelart_output):
  44. os.makedirs(modelart_output)
  45. mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
  46. std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
  47. dataset_train = ImageFolderDataset(os.path.join(data_path, "train"),
  48. shuffle=True)
  49. trans_train = [
  50. transforms.RandomCropDecodeResize(size=224,
  51. scale=(0.08, 1.0),
  52. ratio=(0.75, 1.333)),
  53. transforms.RandomHorizontalFlip(prob=0.5),
  54. transforms.Normalize(mean=mean, std=std),
  55. transforms.HWC2CHW()
  56. ]
  57. dataset_train = dataset_train.map(operations=trans_train,
  58. input_columns=["image"])
  59. dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
  60. data_info = dataset_train.to_json(filename= modelart_output + '/data_info.json')
  61. print(data_info)
  62. upload_openi()

No Description