You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_distributed.py 2.5 kB

3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. """
  2. Produce the dataset:
  3. 与单机不同的是,在数据集接口需要传入num_shards和shard_id参数,分别对应卡的数量和逻辑序号,建议通过HCCL接口获取:
  4. get_rank:获取当前设备在集群中的ID。
  5. get_group_size:获取集群数量。
  6. """
  7. import mindspore.dataset as ds
  8. import mindspore.dataset.vision.c_transforms as CV
  9. import mindspore.dataset.transforms.c_transforms as C
  10. from mindspore.dataset.vision import Inter
  11. from mindspore.common import dtype as mstype
  12. from mindspore.communication.management import init, get_rank, get_group_size
  13. def create_dataset_parallel(data_path, batch_size=32, repeat_size=1,
  14. num_parallel_workers=1, shard_id=0, num_shards=8):
  15. """
  16. create dataset for train or test
  17. """
  18. resize_height, resize_width = 32, 32
  19. rescale = 1.0 / 255.0
  20. shift = 0.0
  21. rescale_nml = 1 / 0.3081
  22. shift_nml = -1 * 0.1307 / 0.3081
  23. # get shard_id and num_shards.Get the ID of the current device in the cluster And Get the number of clusters.
  24. shard_id = get_rank()
  25. num_shards = get_group_size()
  26. # define dataset
  27. mnist_ds = ds.MnistDataset(data_path, num_shards=num_shards, shard_id=shard_id)
  28. # define map operations
  29. resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
  30. rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
  31. rescale_op = CV.Rescale(rescale, shift)
  32. hwc2chw_op = CV.HWC2CHW()
  33. type_cast_op = C.TypeCast(mstype.int32)
  34. # apply map operations on images
  35. mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
  36. mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  37. mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  38. mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  39. mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
  40. # apply DatasetOps
  41. buffer_size = 10000
  42. mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
  43. mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
  44. mnist_ds = mnist_ds.repeat(repeat_size)
  45. return mnist_ds