You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

read_imagenet.py 2.7 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. '''
  2. imagnet-1k 数据集已通过磁盘挂载的方式挂载在训练镜像中,
  3. 用户可参考下列示例代码读取直接使用。
  4. 挂载路径为
  5. .
  6. └── cache/
  7. ├── ascend
  8. ├── outputs
  9. ├── user-job-dir
  10. └── sfs/
  11. └── data/
  12. └── imagenet/
  13. ├── train/
  14. │ └── n01440764/
  15. │ └── n01440764_11063.JPEG
  16. └── val/
  17. └── n01440764/
  18. └── ILSVRC2012_val_00011993.JPEG
  19. mindspore.dataset.ImageFolderDataset
  20. - 读取imagenet-1k数据,同一文件夹下的数据为同一类class。
  21. mindspore.dataset.vision.c_transforms
  22. - 数据加载和预处理。
  23. mindspore.dataset.ImageFolderDataset
  24. - map:给定一组数据增强列表,按顺序将数据增强作用在数据集对象上。
  25. - batch:将数据集中连续 batch_size 条数据合并为一个批处理数据。
  26. - to_json:将数据处理管道序列化为JSON字符串,如果提供了文件名,则转储到文件中。
  27. '''
  28. import os
  29. import argparse
  30. import moxing as mox
  31. import mindspore as ms
  32. from mindspore.dataset import ImageFolderDataset
  33. import mindspore.dataset.vision.c_transforms as transforms
  34. from c2net.context import upload_output
  35. parser = argparse.ArgumentParser(description='Read big dataset ImageNet Example')
  36. parser.add_argument('--train_url',
  37. help='output folder to save/load',
  38. default= '/cache/output/')
  39. if __name__ == "__main__":
  40. args, unknown = parser.parse_known_args()
  41. #注意只有训练任务可用
  42. data_path = '/cache/sfs/data/imagenet/'
  43. modelart_output = '/cache/output'
  44. if not os.path.exists(modelart_output):
  45. os.makedirs(modelart_output)
  46. mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
  47. std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
  48. dataset_train = ImageFolderDataset(os.path.join(data_path, "train"),
  49. shuffle=True)
  50. trans_train = [
  51. transforms.RandomCropDecodeResize(size=224,
  52. scale=(0.08, 1.0),
  53. ratio=(0.75, 1.333)),
  54. transforms.RandomHorizontalFlip(prob=0.5),
  55. transforms.Normalize(mean=mean, std=std),
  56. transforms.HWC2CHW()
  57. ]
  58. dataset_train = dataset_train.map(operations=trans_train,
  59. input_columns=["image"])
  60. dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
  61. data_info = dataset_train.to_json(filename= modelart_output + '/data_info.json')
  62. print(data_info)
  63. upload_output()

No Description