|
- Class mindspore.dataset.NumpySlicesDataset(data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None)
-
- 由Python数据构建源数据集。
- 生成的数据集的列名和列类型取决于用户传入的Python数据。
-
- **参数:**
- - **data** (Union[list, tuple, dict]):输入的Python数据。支持的数据类型包括:list、tuple、dict和其他NumPy格式。
- 输入数据将沿着第一个维度切片,并生成额外的行。如果输入是单个list,则将生成一个数据列,若是嵌套多个list,则生成多个数据列。
- 不建议通过这种方式加载大量的数据,因为可能会在数据加载到内存时等待较长时间。
- - **column_names** (list[str], 可选): 指定数据集生成的列名(默认值为None)。
- 如果未指定列名称,且当输入数据的类型是dict时,输出列名称将被命名为dict的键名,否则它们将被命名为column_0,column_1...。
- - **num_samples** (int, 可选): 指定从数据集中读取的样本数(默认值为None,所有样本)。
- - **num_parallel_workers** (int, 可选): 指定读取数据的工作线程数(默认值为1)。
- - **shuffle** (bool, 可选): 是否混洗数据集。只有输入的`data`参数带有可随机访问属性(__getitem__)时,才可以指定该参数。(默认值为None,下表中会展示不同配置的预期行为)。
- - **sampler** (Union[Sampler, Iterable], 可选): 指定从数据集中选取样本的采样器。只有输入的`data`参数带有可随机访问属性(__getitem__)时,才可以指定该参数(默认值为None,下表中会展示不同配置的预期行为)。
- - **num_shards** (int, 可选): 分布式训练时,将数据集划分成指定的分片数(默认值None)。指定此参数后,`num_samples` 表示每个分片的最大样本数。需要输入`data`支持可随机访问才能指定该参数。
- - **shard_id** (int, 可选): 分布式训练时,指定使用的分片ID号(默认值None)。只有当指定了 `num_shards` 时才能指定此参数。
-
- **注:**
- - 此数据集可以指定`sampler`参数,但`sampler` 和 `shuffle` 是互斥的。下表展示了几种合法的输入参数及预期的行为。
-
- .. list-table:: 配置`sampler`和`shuffle`的不同组合得到的预期排序结果
- :widths: 25 25 50
- :header-rows: 1
-
- * - 参数`sampler`
- - 参数`shuffle`
- - 预期数据顺序
- * - None
- - None
- - 随机排列
- * - None
- - True
- - 随机排列
- * - None
- - False
- - 顺序排列
- * - 参数`sampler`
- - None
- - 由`sampler`行为定义的顺序
- * - 参数`sampler`
- - True
- - 不允许
- * - 参数`sampler`
- - False
- - 不允许
-
- **异常:**
- - **RuntimeError**: column_names列表的长度与数据的输出列表长度不匹配。
- - **RuntimeError**: num_parallel_workers超过系统最大线程数。
- - **RuntimeError**: 同时指定了sampler和shuffle。
- - **RuntimeError**: 同时指定了sampler和num_shards。
- - **RuntimeError**: 指定了`num_shards`参数,但是未指定`shard_id`参数。
- - **RuntimeError**: 指定了`shard_id`参数,但是未指定`num_shards`参数。
- - **ValueError**: `shard_id`参数错误(小于0或者大于等于 `num_shards`)。
-
- **示例:**
- >>> # 1) 输入的`data`参数类型为list
- >>> data = [1, 2, 3]
- >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1"])
- >>>
- >>> # 2) 输入的`data`参数类型为dict,并且使用column_names的默认行为,即采用键名作为生成列名。
- >>> data = {"a": [1, 2], "b": [3, 4]}
- >>> dataset = ds.NumpySlicesDataset(data=data)
- >>>
- >>> # 3) 输入的`data`参数类型是由list组成的tuple(或NumPy数组),每个元组分别生成一个输出列,共三个输出列
- >>> data = ([1, 2], [3, 4], [5, 6])
- >>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1", "column_2", "column_3"])
- >>>
- >>> # 4) 从CSV文件加载数据
- >>> import pandas as pd
- >>> df = pd.read_csv(filepath_or_buffer=csv_dataset_dir[0])
- >>> dataset = ds.NumpySlicesDataset(data=dict(df), shuffle=False)
|