Browse Source

fix error examples in docs of dataset

tags/v1.2.0-rc1
Xiao Tianci 4 years ago
parent
commit
cd3206c26c
4 changed files with 94 additions and 65 deletions
  1. +68
    -38
      mindspore/dataset/engine/datasets.py
  2. +21
    -23
      mindspore/dataset/engine/graphdata.py
  3. +2
    -2
      mindspore/dataset/text/transforms.py
  4. +3
    -2
      mindspore/dataset/vision/c_transforms.py

+ 68
- 38
mindspore/dataset/engine/datasets.py View File

@@ -507,7 +507,7 @@ class Dataset:


Examples: Examples:
>>> # use NumpySliceDataset as an example >>> # use NumpySliceDataset as an example
>>> dataset = ds.NumpySliceDataset([[0, 1], [2, 3]])
>>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]])
>>> >>>
>>> def flat_map_func(array): >>> def flat_map_func(array):
... # create a NumpySliceDataset with the array ... # create a NumpySliceDataset with the array
@@ -638,7 +638,7 @@ class Dataset:
>>> >>>
>>> # Rename the column outputted by random_jitter_op to "image_mapped". >>> # Rename the column outputted by random_jitter_op to "image_mapped".
>>> # Specifying column order works in the same way as examples in 1). >>> # Specifying column order works in the same way as examples in 1).
>>> dataset = dataset.map(operation=[decode_op, random_jitter_op], input_columns=["image"],
>>> dataset = dataset.map(operations=[decode_op, random_jitter_op], input_columns=["image"],
... output_columns=["image_mapped"]) ... output_columns=["image_mapped"])
>>> >>>
>>> # Map with multiple operations using pyfunc. Renaming columns and specifying column order >>> # Map with multiple operations using pyfunc. Renaming columns and specifying column order
@@ -2878,16 +2878,18 @@ class ImageFolderDataset(MappableDataset):
ValueError: If shard_id is invalid (< 0 or >= num_shards). ValueError: If shard_id is invalid (< 0 or >= num_shards).


Examples: Examples:
>>> image_folder_dataset_dir = "/path/to/image_folder_dataset_directory"
>>>
>>> # 1) Read all samples (image files) in image_folder_dataset_dir with 8 threads >>> # 1) Read all samples (image files) in image_folder_dataset_dir with 8 threads
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
>>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir,
... num_parallel_workers=8) ... num_parallel_workers=8)
>>> >>>
>>> # 2) Read all samples (image files) from folder cat and folder dog with label 0 and 1 >>> # 2) Read all samples (image files) from folder cat and folder dog with label 0 and 1
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
>>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir,
... class_indexing={"cat":0, "dog":1}) ... class_indexing={"cat":0, "dog":1})
>>> >>>
>>> # 3) Read all samples (image files) in image_folder_dataset_dir with extensions .JPEG and .png (case sensitive) >>> # 3) Read all samples (image files) in image_folder_dataset_dir with extensions .JPEG and .png (case sensitive)
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
>>> dataset = ds.ImageFolderDataset(dataset_dir=image_folder_dataset_dir,
... extensions=[".JPEG", ".png"]) ... extensions=[".JPEG", ".png"])
""" """


@@ -2985,8 +2987,11 @@ class MnistDataset(MappableDataset):
ValueError: If shard_id is invalid (< 0 or >= num_shards). ValueError: If shard_id is invalid (< 0 or >= num_shards).


Examples: Examples:
>>> mnist_dataset_dir = "/path/to/mnist_dataset_directory"
>>>
>>> # Read 3 samples from MNIST dataset >>> # Read 3 samples from MNIST dataset
>>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, num_samples=3) >>> dataset = ds.MnistDataset(dataset_dir=mnist_dataset_dir, num_samples=3)
>>>
>>> # Note: In mnist_dataset dataset, each dictionary has keys "image" and "label" >>> # Note: In mnist_dataset dataset, each dictionary has keys "image" and "label"
""" """


@@ -3034,6 +3039,10 @@ class MindDataset(MappableDataset):
Raises: Raises:
ValueError: If num_shards is specified but shard_id is None. ValueError: If num_shards is specified but shard_id is None.
ValueError: If shard_id is specified but num_shards is None. ValueError: If shard_id is specified but num_shards is None.

Examples:
>>> mind_dataset_dir = ["/path/to/mind_dataset_file"] # contains 1 or multiple MindRecord files
>>> dataset = ds.MindDataset(dataset_file=mind_dataset_dir)
""" """


def parse(self, children=None): def parse(self, children=None):
@@ -3425,14 +3434,14 @@ class GeneratorDataset(MappableDataset):
... for i in range(64): ... for i in range(64):
... yield (np.array([[i, i + 1], [i + 2, i + 3]]),) ... yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
>>> >>>
>>> dataset = ds.GeneratorDataset(generator_multidimensional, ["multi_dimensional_data"])
>>> dataset = ds.GeneratorDataset(source=generator_multidimensional, column_names=["multi_dimensional_data"])
>>> >>>
>>> # 2) Multi-column generator function as callable input. >>> # 2) Multi-column generator function as callable input.
>>> def generator_multi_column(): >>> def generator_multi_column():
... for i in range(64): ... for i in range(64):
... yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])
... yield np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])
>>> >>>
>>> dataset = ds.GeneratorDataset(generator_multi_column, ["col1", "col2"])
>>> dataset = ds.GeneratorDataset(source=generator_multi_column, column_names=["col1", "col2"])
>>> >>>
>>> # 3) Iterable dataset as iterable input. >>> # 3) Iterable dataset as iterable input.
>>> class MyIterable: >>> class MyIterable:
@@ -3450,12 +3459,13 @@ class GeneratorDataset(MappableDataset):
... return item ... return item
... ...
... def __iter__(self): ... def __iter__(self):
... self._index = 0
... return self ... return self
... ...
... def __len__(self): ... def __len__(self):
... return len(self._data) ... return len(self._data)
>>> >>>
>>> dataset = ds.GeneratorDataset(MyIterable(), ["data", "label"])
>>> dataset = ds.GeneratorDataset(source=MyIterable(), column_names=["data", "label"])
>>> >>>
>>> # 4) Random accessible dataset as random accessible input. >>> # 4) Random accessible dataset as random accessible input.
>>> class MyAccessible: >>> class MyAccessible:
@@ -3469,10 +3479,10 @@ class GeneratorDataset(MappableDataset):
... def __len__(self): ... def __len__(self):
... return len(self._data) ... return len(self._data)
>>> >>>
>>> dataset = ds.GeneratorDataset(MyAccessible(), ["data", "label"])
>>> dataset = ds.GeneratorDataset(source=MyAccessible(), column_names=["data", "label"])
>>> >>>
>>> # list, dict, tuple of Python is also random accessible >>> # list, dict, tuple of Python is also random accessible
>>> dataset = ds.GeneratorDataset([(np.array(0),), (np.array(1),), (np.array(2),)], ["col"])
>>> dataset = ds.GeneratorDataset(source=[(np.array(0),), (np.array(1),), (np.array(2),)], column_names=["col"])
""" """


@check_generatordataset @check_generatordataset
@@ -3590,17 +3600,17 @@ class TFRecordDataset(SourceDataset):
Examples: Examples:
>>> import mindspore.common.dtype as mstype >>> import mindspore.common.dtype as mstype
>>> >>>
>>> tfrecord_dataset_dir = ["/path/to/tfrecord_dataset_file"] # contains 1 or multiple tf data files
>>> tfrecord_dataset_dir = ["/path/to/tfrecord_dataset_file"] # contains 1 or multiple TFRecord files
>>> tfrecord_schema_file = "/path/to/tfrecord_schema_file" >>> tfrecord_schema_file = "/path/to/tfrecord_schema_file"
>>> >>>
>>> # 1) Get all rows from tfrecord_dataset_dir with no explicit schema. >>> # 1) Get all rows from tfrecord_dataset_dir with no explicit schema.
>>> # The meta-data in the first row will be used as a schema. >>> # The meta-data in the first row will be used as a schema.
>>> dataset = ds.TFRecordDataset(tfrecord_dataset_dir)
>>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir)
>>> >>>
>>> # 2) Get all rows from tfrecord_dataset_dir with user-defined schema. >>> # 2) Get all rows from tfrecord_dataset_dir with user-defined schema.
>>> schema = ds.Schema() >>> schema = ds.Schema()
>>> schema.add_column('col_1d', de_type=mstype.int64, shape=[2])
>>> dataset = ds.TFRecordDataset(tfrecord_dataset_dir, schema=schema)
>>> schema.add_column(name='col_1d', de_type=mstype.int64, shape=[2])
>>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir, schema=schema)
>>> >>>
>>> # 3) Get all rows from tfrecord_dataset_dir with schema file. >>> # 3) Get all rows from tfrecord_dataset_dir with schema file.
>>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir, schema=tfrecord_schema_file) >>> dataset = ds.TFRecordDataset(dataset_files=tfrecord_dataset_dir, schema=tfrecord_schema_file)
@@ -3697,13 +3707,13 @@ class ManifestDataset(MappableDataset):
ValueError: If shard_id is invalid (< 0 or >= num_shards). ValueError: If shard_id is invalid (< 0 or >= num_shards).


Examples: Examples:
>>> manifest_dataset_dir = "/path/to/manifest_dataset_file"
>>>
>>> # 1) Read all samples specified in manifest_dataset_dir dataset with 8 threads for training >>> # 1) Read all samples specified in manifest_dataset_dir dataset with 8 threads for training
>>> dataset = ds.ManifestDataset(manifest_dataset_dir, usage="train", num_parallel_workers=8)
>>> dataset = ds.ManifestDataset(dataset_file=manifest_dataset_dir, usage="train", num_parallel_workers=8)
>>> >>>
>>> # 2) Read samples (specified in manifest_file.manifest) for shard 0
>>> # in a 2-way distributed training setup
>>> dataset = ds.ManifestDataset(manifest_dataset_dir, num_shards=2, shard_id=0)

>>> # 2) Read samples (specified in manifest_file.manifest) for shard 0 in a 2-way distributed training setup
>>> dataset = ds.ManifestDataset(dataset_file=manifest_dataset_dir, num_shards=2, shard_id=0)
""" """


@check_manifestdataset @check_manifestdataset
@@ -3815,6 +3825,8 @@ class Cifar10Dataset(MappableDataset):
ValueError: If shard_id is invalid (< 0 or >= num_shards). ValueError: If shard_id is invalid (< 0 or >= num_shards).


Examples: Examples:
>>> cifar10_dataset_dir = "/path/to/cifar10_dataset_directory"
>>>
>>> # 1) Get all samples from CIFAR10 dataset in sequence >>> # 1) Get all samples from CIFAR10 dataset in sequence
>>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, shuffle=False) >>> dataset = ds.Cifar10Dataset(dataset_dir=cifar10_dataset_dir, shuffle=False)
>>> >>>
@@ -3920,6 +3932,8 @@ class Cifar100Dataset(MappableDataset):
ValueError: If shard_id is invalid (< 0 or >= num_shards). ValueError: If shard_id is invalid (< 0 or >= num_shards).


Examples: Examples:
>>> cifar100_dataset_dir = "/path/to/cifar100_dataset_directory"
>>>
>>> # 1) Get all samples from CIFAR100 dataset in sequence >>> # 1) Get all samples from CIFAR100 dataset in sequence
>>> dataset = ds.Cifar100Dataset(dataset_dir=cifar100_dataset_dir, shuffle=False) >>> dataset = ds.Cifar100Dataset(dataset_dir=cifar100_dataset_dir, shuffle=False)
>>> >>>
@@ -3999,7 +4013,7 @@ class Schema:
>>> >>>
>>> # Create schema; specify column name, mindspore.dtype and shape of the column >>> # Create schema; specify column name, mindspore.dtype and shape of the column
>>> schema = ds.Schema() >>> schema = ds.Schema()
>>> schema.add_column('col1', de_type=mstype.int64, shape=[2])
>>> schema.add_column(name='col1', de_type=mstype.int64, shape=[2])
""" """


@check_schema @check_schema
@@ -4190,17 +4204,21 @@ class VOCDataset(MappableDataset):
ValueError: If shard_id is invalid (< 0 or >= num_shards). ValueError: If shard_id is invalid (< 0 or >= num_shards).


Examples: Examples:
>>> voc_dataset_dir = "/path/to/voc_dataset_directory"
>>>
>>> # 1) Read VOC data for segmentatation training >>> # 1) Read VOC data for segmentatation training
>>> dataset = ds.VOCDataset(voc_dataset_dir, task="Segmentation", usage="train")
>>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Segmentation", usage="train")
>>> >>>
>>> # 2) Read VOC data for detection training >>> # 2) Read VOC data for detection training
>>> dataset = ds.VOCDataset(voc_dataset_dir, task="Detection", usage="train")
>>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Detection", usage="train")
>>> >>>
>>> # 3) Read all VOC dataset samples in voc_dataset_dir with 8 threads in random order >>> # 3) Read all VOC dataset samples in voc_dataset_dir with 8 threads in random order
>>> dataset = ds.VOCDataset(voc_dataset_dir, task="Detection", usage="train", num_parallel_workers=8)
>>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Detection", usage="train",
... num_parallel_workers=8)
>>> >>>
>>> # 4) Read then decode all VOC dataset samples in voc_dataset_dir in sequence >>> # 4) Read then decode all VOC dataset samples in voc_dataset_dir in sequence
>>> dataset = ds.VOCDataset(voc_dataset_dir, task="Detection", usage="train", decode=True, shuffle=False)
>>> dataset = ds.VOCDataset(dataset_dir=voc_dataset_dir, task="Detection", usage="train",
... decode=True, shuffle=False)
>>> >>>
>>> # In VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target" >>> # In VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target"
>>> # In VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation" >>> # In VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation"
@@ -4344,17 +4362,28 @@ class CocoDataset(MappableDataset):
ValueError: If shard_id is invalid (< 0 or >= num_shards). ValueError: If shard_id is invalid (< 0 or >= num_shards).


Examples: Examples:
>>> coco_dataset_dir = "/path/to/coco_dataset_directory/images"
>>> coco_annotation_file = "/path/to/coco_dataset_directory/annotation_file"
>>>
>>> # 1) Read COCO data for Detection task >>> # 1) Read COCO data for Detection task
>>> dataset = ds.CocoDataset(coco_dataset_dir, annotation_file=coco_annotation_file, task='Detection')
>>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir,
... annotation_file=coco_annotation_file,
... task='Detection')
>>> >>>
>>> # 2) Read COCO data for Stuff task >>> # 2) Read COCO data for Stuff task
>>> dataset = ds.CocoDataset(coco_dataset_dir, annotation_file=coco_annotation_file, task='Stuff')
>>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir,
... annotation_file=coco_annotation_file,
... task='Stuff')
>>> >>>
>>> # 3) Read COCO data for Panoptic task >>> # 3) Read COCO data for Panoptic task
>>> dataset = ds.CocoDataset(coco_dataset_dir, annotation_file=coco_annotation_file, task='Panoptic')
>>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir,
... annotation_file=coco_annotation_file,
... task='Panoptic')
>>> >>>
>>> # 4) Read COCO data for Keypoint task >>> # 4) Read COCO data for Keypoint task
>>> dataset = ds.CocoDataset(coco_dataset_dir, annotation_file=coco_annotation_file, task='Keypoint')
>>> dataset = ds.CocoDataset(dataset_dir=coco_dataset_dir,
... annotation_file=coco_annotation_file,
... task='Keypoint')
>>> >>>
>>> # In COCO dataset, each dictionary has keys "image" and "annotation" >>> # In COCO dataset, each dictionary has keys "image" and "annotation"
""" """
@@ -4445,6 +4474,7 @@ class CelebADataset(MappableDataset):
(default=None, which means no cache is used). (default=None, which means no cache is used).


Examples: Examples:
>>> celeba_dataset_dir = "/path/to/celeba_dataset_directory"
>>> dataset = ds.CelebADataset(dataset_dir=celeba_dataset_dir, usage='train') >>> dataset = ds.CelebADataset(dataset_dir=celeba_dataset_dir, usage='train')
""" """


@@ -4519,7 +4549,7 @@ class CLUEDataset(SourceDataset):
(default=None, which means no cache is used). (default=None, which means no cache is used).


Examples: Examples:
>>> clue_dataset_dir = ["/path/to/clue_dataset_file"] # contains 1 or multiple text files
>>> clue_dataset_dir = ["/path/to/clue_dataset_file"] # contains 1 or multiple clue files
>>> dataset = ds.CLUEDataset(dataset_files=clue_dataset_dir, task='AFQMC', usage='train') >>> dataset = ds.CLUEDataset(dataset_files=clue_dataset_dir, task='AFQMC', usage='train')
""" """


@@ -4647,7 +4677,7 @@ class CSVDataset(SourceDataset):




Examples: Examples:
>>> csv_dataset_dir = ["/path/to/csv_dataset_file"]
>>> csv_dataset_dir = ["/path/to/csv_dataset_file"] # contains 1 or multiple csv files
>>> dataset = ds.CSVDataset(dataset_files=csv_dataset_dir, column_names=['col1', 'col2', 'col3', 'col4']) >>> dataset = ds.CSVDataset(dataset_files=csv_dataset_dir, column_names=['col1', 'col2', 'col3', 'col4'])
""" """


@@ -4696,7 +4726,7 @@ class TextFileDataset(SourceDataset):
(default=None, which means no cache is used). (default=None, which means no cache is used).


Examples: Examples:
>>> # contains 1 or multiple text files
>>> text_file_dataset_dir = ["/path/to/text_file_dataset_file"] # contains 1 or multiple text files
>>> dataset = ds.TextFileDataset(dataset_files=text_file_dataset_dir) >>> dataset = ds.TextFileDataset(dataset_files=text_file_dataset_dir)
""" """


@@ -4833,20 +4863,20 @@ class NumpySlicesDataset(GeneratorDataset):
Examples: Examples:
>>> # 1) Input data can be a list >>> # 1) Input data can be a list
>>> data = [1, 2, 3] >>> data = [1, 2, 3]
>>> dataset = ds.NumpySlicesDataset(data, column_names=["column_1"])
>>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1"])
>>> >>>
>>> # 2) Input data can be a dictionary, and column_names will be its keys >>> # 2) Input data can be a dictionary, and column_names will be its keys
>>> data = {"a": [1, 2], "b": [3, 4]} >>> data = {"a": [1, 2], "b": [3, 4]}
>>> dataset = ds.NumpySlicesDataset(data)
>>> dataset = ds.NumpySlicesDataset(data=data)
>>> >>>
>>> # 3) Input data can be a tuple of lists (or NumPy arrays), each tuple element refers to data in each column >>> # 3) Input data can be a tuple of lists (or NumPy arrays), each tuple element refers to data in each column
>>> data = ([1, 2], [3, 4], [5, 6]) >>> data = ([1, 2], [3, 4], [5, 6])
>>> dataset = ds.NumpySlicesDataset(data, column_names=["column_1", "column_2", "column_3"])
>>> dataset = ds.NumpySlicesDataset(data=data, column_names=["column_1", "column_2", "column_3"])
>>> >>>
>>> # 4) Load data from CSV file >>> # 4) Load data from CSV file
>>> import pandas as pd >>> import pandas as pd
>>> df = pd.read_csv(csv_dataset_dir[0])
>>> dataset = ds.NumpySlicesDataset(dict(df), shuffle=False)
>>> df = pd.read_csv(filepath_or_buffer=csv_dataset_dir[0])
>>> dataset = ds.NumpySlicesDataset(data=dict(df), shuffle=False)
""" """


@check_numpyslicesdataset @check_numpyslicesdataset
@@ -4893,7 +4923,7 @@ class PaddedDataset(GeneratorDataset):
Examples: Examples:
>>> import numpy as np >>> import numpy as np
>>> data = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}] >>> data = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}]
>>> dataset = ds.PaddedDataset(data)
>>> dataset = ds.PaddedDataset(padded_samples=data)
""" """


@check_paddeddataset @check_paddeddataset


+ 21
- 23
mindspore/dataset/engine/graphdata.py View File

@@ -72,9 +72,10 @@ class GraphData:
the server automatically exits (default=True). the server automatically exits (default=True).


Examples: Examples:
>>> graph_dataset = ds.GraphData(graph_dataset_dir, 2)
>>> nodes = graph_dataset.get_all_nodes(1)
>>> features = graph_dataset.get_node_feature(nodes, [1])
>>> graph_dataset_dir = "/path/to/graph_dataset_file"
>>> graph_dataset = ds.GraphData(dataset_file=graph_dataset_dir, num_parallel_workers=2)
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[1])
""" """


@check_gnn_graphdata @check_gnn_graphdata
@@ -114,7 +115,7 @@ class GraphData:
numpy.ndarray, array of nodes. numpy.ndarray, array of nodes.


Examples: Examples:
>>> nodes = graph_dataset.get_all_nodes(1)
>>> nodes = graph_dataset.get_all_nodes(node_type=1)


Raises: Raises:
TypeError: If `node_type` is not integer. TypeError: If `node_type` is not integer.
@@ -135,7 +136,7 @@ class GraphData:
numpy.ndarray, array of edges. numpy.ndarray, array of edges.


Examples: Examples:
>>> edges = graph_dataset.get_all_edges(0)
>>> edges = graph_dataset.get_all_edges(edge_type=0)


Raises: Raises:
TypeError: If `edge_type` is not integer. TypeError: If `edge_type` is not integer.
@@ -175,8 +176,8 @@ class GraphData:
numpy.ndarray, array of neighbors. numpy.ndarray, array of neighbors.


Examples: Examples:
>>> nodes = graph_dataset.get_all_nodes(1)
>>> neighbors = graph_dataset.get_all_neighbors(nodes, 2)
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2)


Raises: Raises:
TypeError: If `node_list` is not list or ndarray. TypeError: If `node_list` is not list or ndarray.
@@ -211,8 +212,9 @@ class GraphData:
numpy.ndarray, array of neighbors. numpy.ndarray, array of neighbors.


Examples: Examples:
>>> nodes = graph_dataset.get_all_nodes(1)
>>> neighbors = graph_dataset.get_sampled_neighbors(nodes, [2, 2], [2, 1])
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neighbors = graph_dataset.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2],
... neighbor_types=[2, 1])


Raises: Raises:
TypeError: If `node_list` is not list or ndarray. TypeError: If `node_list` is not list or ndarray.
@@ -240,8 +242,9 @@ class GraphData:
numpy.ndarray, array of neighbors. numpy.ndarray, array of neighbors.


Examples: Examples:
>>> nodes = graph_dataset.get_all_nodes(1)
>>> neg_neighbors = graph_dataset.get_neg_sampled_neighbors(nodes, 5, 2)
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neg_neighbors = graph_dataset.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5,
... neg_neighbor_type=2)


Raises: Raises:
TypeError: If `node_list` is not list or ndarray. TypeError: If `node_list` is not list or ndarray.
@@ -266,8 +269,8 @@ class GraphData:
numpy.ndarray, array of features. numpy.ndarray, array of features.


Examples: Examples:
>>> nodes = graph_dataset.get_all_nodes(1)
>>> features = graph_dataset.get_node_feature(nodes, [2, 3])
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[2, 3])


Raises: Raises:
TypeError: If `node_list` is not list or ndarray. TypeError: If `node_list` is not list or ndarray.
@@ -295,8 +298,8 @@ class GraphData:
numpy.ndarray, array of features. numpy.ndarray, array of features.


Examples: Examples:
>>> edges = graph_dataset.get_all_edges(0)
>>> features = graph_dataset.get_edge_feature(edges, [1])
>>> edges = graph_dataset.get_all_edges(edge_type=0)
>>> features = graph_dataset.get_edge_feature(edge_list=edges, feature_types=[1])


Raises: Raises:
TypeError: If `edge_list` is not list or ndarray. TypeError: If `edge_list` is not list or ndarray.
@@ -325,13 +328,7 @@ class GraphData:
return self._graph_data.graph_info() return self._graph_data.graph_info()


@check_gnn_random_walk @check_gnn_random_walk
def random_walk(
self,
target_nodes,
meta_path,
step_home_param=1.0,
step_away_param=1.0,
default_node=-1):
def random_walk(self, target_nodes, meta_path, step_home_param=1.0, step_away_param=1.0, default_node=-1):
""" """
Random walk in nodes. Random walk in nodes.


@@ -347,7 +344,8 @@ class GraphData:
numpy.ndarray, array of nodes. numpy.ndarray, array of nodes.


Examples: Examples:
>>> nodes = graph_dataset.random_walk([1, 2], [1, 2, 1, 2, 1])
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> walks = graph_dataset.random_walk(target_nodes=nodes, meta_path=[2, 1, 2])


Raises: Raises:
TypeError: If `target_nodes` is not list or ndarray. TypeError: If `target_nodes` is not list or ndarray.


+ 2
- 2
mindspore/dataset/text/transforms.py View File

@@ -459,8 +459,8 @@ class UnicodeCharTokenizer(TextTensorOperation):
>>> # ["offsets_limit", dtype=uint32]} >>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.UnicodeCharTokenizer(with_offsets=True) >>> tokenizer_op = text.UnicodeCharTokenizer(with_offsets=True)
>>> text_file_dataset = text_file_dataset.map(operations=tokenizer_op, input_columns=["text"], >>> text_file_dataset = text_file_dataset.map(operations=tokenizer_op, input_columns=["text"],
>>> output_columns=["token", "offsets_start", "offsets_limit"],
>>> column_order=["token", "offsets_start", "offsets_limit"])
... output_columns=["token", "offsets_start", "offsets_limit"],
... column_order=["token", "offsets_start", "offsets_limit"])
""" """


@check_with_offsets @check_with_offsets


+ 3
- 2
mindspore/dataset/vision/c_transforms.py View File

@@ -62,6 +62,7 @@ class ImageTensorOperation(TensorOperation):
""" """
Base class of Image Tensor Ops Base class of Image Tensor Ops
""" """

def __call__(self, *input_tensor_list): def __call__(self, *input_tensor_list):
for tensor in input_tensor_list: for tensor in input_tensor_list:
if not isinstance(tensor, (np.ndarray, Image.Image)): if not isinstance(tensor, (np.ndarray, Image.Image)):
@@ -1142,8 +1143,8 @@ class RandomSelectSubpolicy(ImageTensorOperation):
... (c_vision.RandomColorAdjust(), 0.8)], ... (c_vision.RandomColorAdjust(), 0.8)],
... [(c_vision.RandomRotation((90, 90)), 1), ... [(c_vision.RandomRotation((90, 90)), 1),
... (c_vision.RandomColorAdjust(), 0.2)]] ... (c_vision.RandomColorAdjust(), 0.2)]]
>>> image_folder_dataset_1 = image_folder_dataset.map(operations=c_vision.RandomSelectSubpolicy(policy),
... input_columns=["image"])
>>> image_folder_dataset = image_folder_dataset.map(operations=c_vision.RandomSelectSubpolicy(policy),
... input_columns=["image"])
""" """


@check_random_select_subpolicy_op @check_random_select_subpolicy_op


Loading…
Cancel
Save