You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.dataset.NumpySlicesDataset.rst 4.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. Class mindspore.dataset.NumpySlicesDataset(data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None)
  2. 由Python数据构建源数据集。
  3. 生成的数据集的列名和列类型取决于用户传入的Python数据。
  4. **参数:**
  5. - **data** (Union[list, tuple, dict]):输入的Python数据。支持的数据类型包括:list、tuple、dict和其他NumPy格式。
  6. 输入数据将沿着第一个维度切片,并生成额外的行。如果输入是单个list,则将生成一个数据列,若是嵌套多个list,则生成多个数据列。
  7. 不建议通过这种方式加载大量的数据,因为可能会在数据加载到内存时等待较长时间。
  8. - **column_names** (list[str], 可选): 指定数据集生成的列名(默认值为None)。
  9. 如果未指定列名称,且当输入数据的类型是dict时,输出列名称将被命名为dict的键名,否则它们将被命名为column_0,column_1...。
  10. - **num_samples** (int, 可选): 指定从数据集中读取的样本数(默认值为None,所有样本)。
  11. - **num_parallel_workers** (int, 可选): 指定读取数据的工作线程数(默认值为1)。
  12. - **shuffle** (bool, 可选): 是否混洗数据集。只有输入的`data`参数带有可随机访问属性(__getitem__)时,才可以指定该参数。(默认值为None,下表中会展示不同配置的预期行为)。
  13. - **sampler** (Union[Sampler, Iterable], 可选): 指定从数据集中选取样本的采样器。只有输入的`data`参数带有可随机访问属性(__getitem__)时,才可以指定该参数(默认值为None,下表中会展示不同配置的预期行为)。
  14. - **num_shards** (int, 可选): 分布式训练时,将数据集划分成指定的分片数(默认值None)。指定此参数后,`num_samples` 表示每个分片的最大样本数。需要输入`data`支持可随机访问才能指定该参数。
  15. - **shard_id** (int, 可选): 分布式训练时,指定使用的分片ID号(默认值None)。只有当指定了 `num_shards` 时才能指定此参数。
  16. **注:**
  17. - 此数据集可以指定`sampler`参数,但`sampler` 和 `shuffle` 是互斥的。下表展示了几种合法的输入参数及预期的行为。
  18. .. list-table:: 配置`sampler`和`shuffle`的不同组合得到的预期排序结果
  19. :widths: 25 25 50
  20. :header-rows: 1
  21. * - 参数`sampler`
  22. - 参数`shuffle`
  23. - 预期数据顺序
  24. * - None
  25. - None
  26. - 随机排列
  27. * - None
  28. - True
  29. - 随机排列
  30. * - None
  31. - False
  32. - 顺序排列
  33. * - 参数`sampler`
  34. - None
  35. - 由`sampler`行为定义的顺序
  36. * - 参数`sampler`
  37. - True
  38. - 不允许
  39. * - 参数`sampler`
  40. - False
  41. - 不允许
  42. **异常:**
  43. - **RuntimeError**: column_names列表的长度与数据的输出列表长度不匹配。
  44. - **RuntimeError**: num_parallel_workers超过系统最大线程数。
  45. - **RuntimeError**: 同时指定了sampler和shuffle。
  46. - **RuntimeError**: 同时指定了sampler和num_shards。
  47. - **RuntimeError**: 指定了`num_shards`参数,但是未指定`shard_id`参数。
  48. - **RuntimeError**: 指定了`shard_id`参数,但是未指定`num_shards`参数。
  49. - **ValueError**: `shard_id`参数错误(小于0或者大于等于 `num_shards`)。
  50. **示例:**
  51. >>> # 1) 输入的`data`参数类型为list
  52. >>> data = [1, 2, 3]
  53. >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1"])
  54. >>>
  55. >>> # 2) 输入的`data`参数类型为dict,并且使用column_names的默认行为,即采用键名作为生成列名。
  56. >>> data = {"a": [1, 2], "b": [3, 4]}
  57. >>> dataset = ds.NumpySlicesDataset(data=data)
  58. >>>
  59. >>> # 3) 输入的`data`参数类型是由list组成的tuple(或NumPy数组),每个元组分别生成一个输出列,共三个输出列
  60. >>> data = ([1, 2], [3, 4], [5, 6])
  61. >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1", "column_2", "column_3"])
  62. >>>
  63. >>> # 4) 从CSV文件加载数据
  64. >>> import pandas as pd
  65. >>> df = pd.read_csv(filepath_or_buffer=csv_dataset_dir[0])
  66. >>> dataset = ds.NumpySlicesDataset(data=dict(df), shuffle=False)