Browse Source

update API of sampler

tags/v1.3.0
shenwei41 4 years ago
parent
commit
19436375d8
6 changed files with 52 additions and 43 deletions
  1. +3
    -2
      mindspore/dataset/callback/ds_callback.py
  2. +2
    -1
      mindspore/dataset/engine/cache_client.py
  3. +2
    -2
      mindspore/dataset/engine/datasets.py
  4. +16
    -15
      mindspore/dataset/engine/samplers.py
  5. +14
    -11
      mindspore/dataset/engine/serializer_deserializer.py
  6. +15
    -12
      mindspore/dataset/text/transforms.py

+ 3
- 2
mindspore/dataset/callback/ds_callback.py View File

@@ -27,7 +27,7 @@ class DSCallback:
Abstract base class used to build a dataset callback class.

Args:
step_size (int, optional): The number of steps before the step_begin and step_end are called (Default=1).
step_size (int, optional): The number of steps between the step_begin and step_end are called (Default=1).

Examples:
>>> class PrintInfo(DSCallback):
@@ -123,7 +123,8 @@ class WaitedDSCallback(Callback, DSCallback):
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.

Args:
step_size: the number of rows in each step. Usually the step size will be equal to the batch size (Default=1).
step_size (int, optional): The number of rows in each step. Usually the step size
will be equal to the batch size (Default=1).

Examples:
>>> my_cb = MyWaitedCallback(32)


+ 2
- 1
mindspore/dataset/engine/cache_client.py View File

@@ -37,7 +37,8 @@ class DatasetCache:
hostname (str, optional): Host name (default=None, use default hostname '127.0.0.1').
port (int, optional): Port to connect to server (default=None, use default port 50052).
num_connections (int, optional): Number of tcp/ip connections (default=None, use default value 12).
prefetch_size (int, optional): Prefetch size (default=None, use default value 20).
prefetch_size (int, optional): The size of the cache queue between operations
(default=None, use default value 20).

Examples:
>>> import mindspore.dataset as ds


+ 2
- 2
mindspore/dataset/engine/datasets.py View File

@@ -4497,7 +4497,7 @@ class Schema:
Class to represent a schema of a dataset.

Args:
schema_file(str): Path of schema file (default=None).
schema_file(str): Path of the schema file (default=None).

Returns:
Schema object, schema info about dataset.
@@ -4524,7 +4524,7 @@ class Schema:
Add new column to the schema.

Args:
name (str): Name of the column.
name (str): The new name of the column.
de_type (str): Data type of the column.
shape (list[int], optional): Shape of the column
(default=None, [-1] which is an unknown shape of rank 1).


+ 16
- 15
mindspore/dataset/engine/samplers.py View File

@@ -121,7 +121,7 @@ class BuiltinSampler:
self.child_sampler = sampler

def get_child(self):
""" add a child sampler """
""" add a child sampler. """
return self.child_sampler

def parse_child(self):
@@ -188,7 +188,7 @@ class BuiltinSampler:
- None

Returns:
int, the number of samples, or None
int, the number of samples, or None.
"""
if self.child_sampler is not None:
child_samples = self.child_sampler.get_num_samples()
@@ -310,9 +310,9 @@ class DistributedSampler(BuiltinSampler):

Args:
num_shards (int): Number of shards to divide the dataset into.
shard_id (int): Shard ID of the current shard within num_shards.
shuffle (bool, optional): If True, the indices are shuffled (default=True).
num_samples (int, optional): The number of samples to draw (default=None, all elements).
shard_id (int): Shard ID of the current shard, which should within the range of [0, num_shards-1].
shuffle (bool, optional): If True, the indices are shuffled, otherwise it will not be shuffled(default=True).
num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements).
offset(int, optional): The starting shard ID where the elements in the dataset are sent to (default=-1), which
should be no more than num_shards.

@@ -408,11 +408,12 @@ class PKSampler(BuiltinSampler):

Args:
num_val (int): Number of elements to sample for each class.
num_class (int, optional): Number of classes to sample (default=None, all classes).
num_class (int, optional): Number of classes to sample (default=None, sample all classes).
The parameter does not supported to specify currently.
shuffle (bool, optional): If True, the class IDs are shuffled (default=False).
shuffle (bool, optional): If True, the class IDs are shuffled, otherwise it will not be
shuffled(default=False).
class_column (str, optional): Name of column with class labels for MindDataset (default='label').
num_samples (int, optional): The number of samples to draw (default=None, all elements).
num_samples (int, optional): The number of samples to draw (default=None, which means sample all elements).

Examples:
>>> # creates a PKSampler that will get 3 samples from every class.
@@ -495,7 +496,7 @@ class RandomSampler(BuiltinSampler):

Args:
replacement (bool, optional): If True, put the sample ID back for the next draw (default=False).
num_samples (int, optional): Number of elements to sample (default=None, all elements).
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).

Examples:
>>> # creates a RandomSampler
@@ -555,11 +556,11 @@ class RandomSampler(BuiltinSampler):

class SequentialSampler(BuiltinSampler):
"""
Samples the dataset elements sequentially, same as not having a sampler.
Samples the dataset elements sequentially that is equivalent to not using a sampler.

Args:
start_index (int, optional): Index to start sampling at. (default=None, start at first ID)
num_samples (int, optional): Number of elements to sample (default=None, all elements).
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).

Examples:
>>> # creates a SequentialSampler
@@ -626,7 +627,7 @@ class SubsetSampler(BuiltinSampler):

Args:
indices (Any iterable Python object but string): A sequence of indices.
num_samples (int, optional): Number of elements to sample (default=None, all elements).
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).

Examples:
>>> indices = [0, 1, 2, 3, 4, 5]
@@ -713,7 +714,7 @@ class SubsetRandomSampler(SubsetSampler):

Args:
indices (Any iterable python object but string): A sequence of indices.
num_samples (int, optional): Number of elements to sample (default=None, all elements).
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).

Examples:
>>> indices = [0, 1, 2, 3, 7, 88, 119]
@@ -757,7 +758,7 @@ class IterSampler(Sampler):

Args:
sampler (iterable object): an user defined iterable object.
num_samples (int, optional): Number of elements to sample (default=None, all elements).
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).

Examples:
>>> class MySampler:
@@ -788,7 +789,7 @@ class WeightedRandomSampler(BuiltinSampler):

Args:
weights (list[float, int]): A sequence of weights, not necessarily summing up to 1.
num_samples (int, optional): Number of elements to sample (default=None, all elements).
num_samples (int, optional): Number of elements to sample (default=None, which means sample all elements).
replacement (bool): If True, put the sample ID back for the next draw (default=True).

Examples:


+ 14
- 11
mindspore/dataset/engine/serializer_deserializer.py View File

@@ -29,15 +29,16 @@ def serialize(dataset, json_filepath=""):
"""
Serialize dataset pipeline into a json file.

Currently some python objects are not supported to be serialized.
For python function serialization of map operator, de.serialize will only return its function name.
Note:
Currently some python objects are not supported to be serialized.
For python function serialization of map operator, de.serialize will only return its function name.

Args:
dataset (Dataset): the starting node.
json_filepath (str): a filepath where a serialized json file will be generated.
dataset (Dataset): The starting node.
json_filepath (str): The filepath where a serialized json file will be generated.

Returns:
dict containing the serialized dataset graph.
Dict, The dictionary contains the serialized dataset graph.

Raises:
OSError: Can not open a file
@@ -58,11 +59,12 @@ def deserialize(input_dict=None, json_filepath=None):
"""
Construct a de pipeline from a json file produced by de.serialize().

Currently python function deserialization of map operator are not supported.
Note:
Currently python function deserialization of map operator are not supported.

Args:
input_dict (dict): a Python dictionary containing a serialized dataset graph
json_filepath (str): a path to the json file.
input_dict (dict): A Python dictionary containing a serialized dataset graph.
json_filepath (str): A path to the json file.

Returns:
de.Dataset or None if error occurs.
@@ -107,11 +109,12 @@ def expand_path(node_repr, key, val):

def show(dataset, indentation=2):
"""
Write the dataset pipeline graph onto logger.info.
Write the dataset pipeline graph to logger.info file.

Args:
dataset (Dataset): the starting node.
indentation (int, optional): indentation used by the json print. Pass None to not indent.
dataset (Dataset): The starting node.
indentation (int, optional): The indentation used by the json print.
Do not indent if indentation is None.
"""

pipeline = dataset.to_json()


+ 15
- 12
mindspore/dataset/text/transforms.py View File

@@ -143,7 +143,7 @@ class JiebaTokenizer(TextTensorOperation):
@check_jieba_add_word
def add_word(self, word, freq=None):
"""
Add user defined word to JiebaTokenizer's dictionary.
Add a user defined word to JiebaTokenizer's dictionary.

Args:
word (str): The word to be added to the JiebaTokenizer instance.
@@ -172,7 +172,7 @@ class JiebaTokenizer(TextTensorOperation):
@check_jieba_add_dict
def add_dict(self, user_dict):
"""
Add user defined word to JiebaTokenizer's dictionary.
Add a user defined word to JiebaTokenizer's dictionary.

Args:
user_dict (Union[str, dict]): One of the two loading methods is file path(str) loading
@@ -259,9 +259,11 @@ class Lookup(TextTensorOperation):

Args:
vocab (Vocab): A vocabulary object.
unknown_token (str, optional): Word used for lookup if the word being looked up is out-of-vocabulary (OOV).
If unknown_token is OOV, a runtime error will be thrown (default=None).
data_type (mindspore.dtype, optional): mindspore.dtype that lookup maps string to (default=mindspore.int32)
unknown_token (str, optional): Word is used for lookup. In case of the word is out of vocabulary (OOV),
the result of lookup will be replaced with unknown_token. If the unknown_token is not specified or
it is OOV, runtime error will be thrown (default={}, means no unknown_token is specified).
data_type (mindspore.dtype, optional): The data type that lookup operation maps
string to(default=mindspore.int32).

Examples:
>>> # Load vocabulary from list
@@ -587,7 +589,7 @@ if platform.system().lower() != 'windows':
... preserve_unused_token=True,
... with_offsets=False)
>>> text_file_dataset = text_file_dataset.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # If with_offsets=True, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.BasicTokenizer(lower_case=False,
@@ -630,14 +632,15 @@ if platform.system().lower() != 'windows':
Args:
vocab (Vocab): A vocabulary object.
suffix_indicator (str, optional): Used to show that the subword is the last part of a word (default='##').
max_bytes_per_token (int, optional): Tokens exceeding this length will not be further split (default=100).
max_bytes_per_token (int, optional): If Tokens exceeding this length, it will not be further
split (default=100).
unknown_token (str, optional): When an unknown token is found, return the token directly if `unknown_token`
is an empty string, else return `unknown_token` instead (default='[UNK]').
lower_case (bool, optional): If True, apply CaseFold, NormalizeUTF8 with `NFD` mode, RegexReplace operation
on input text to fold the text to lower case and strip accented characters. If False, only apply
NormalizeUTF8 operation with the specified mode on input text (default=False).
keep_whitespace (bool, optional): If True, the whitespace will be kept in out tokens (default=False).
normalization_form (NormalizeForm, optional): Used to specify a specific normalize mode,
normalization_form (NormalizeForm, optional): This parameter is used to specify a specific normalize mode,
only effective when `lower_case` is False. See NormalizeUTF8 for details (default=NormalizeForm.NONE).
preserve_unused_token (bool, optional): If True, do not split special tokens like
'[CLS]', '[SEP]', '[UNK]', '[PAD]', '[MASK]' (default=True).
@@ -658,7 +661,7 @@ if platform.system().lower() != 'windows':
... normalization_form=NormalizeForm.NONE, preserve_unused_token=True,
... with_offsets=False)
>>> text_file_dataset = text_file_dataset.map(operations=tokenizer_op)
>>> # If with_offsets=False, then output three columns {["token", dtype=str],
>>> # If with_offsets=True, then output three columns {["token", dtype=str],
>>> # ["offsets_start", dtype=uint32],
>>> # ["offsets_limit", dtype=uint32]}
>>> tokenizer_op = text.BertTokenizer(vocab=vocab, suffix_indicator='##', max_bytes_per_token=100,
@@ -721,9 +724,9 @@ if platform.system().lower() != 'windows':
NormalizeUTF8 is not supported on Windows platform yet.

Args:
normalize_form (NormalizeForm, optional): Valid values can be any of [NormalizeForm.NONE,
NormalizeForm.NFC, NormalizeForm.NFKC, NormalizeForm.NFD,
NormalizeForm.NFKD](default=NormalizeForm.NFKC).
normalize_form (NormalizeForm, optional): Valid values can be [NormalizeForm.NONE, NormalizeForm.NFC,
NormalizeForm.NFKC, NormalizeForm.NFD, NormalizeForm.NFKD] any of the four unicode
normalized forms(default=NormalizeForm.NFKC).
See http://unicode.org/reports/tr15/ for details.

- NormalizeForm.NONE, do nothing for input string tensor.


Loading…
Cancel
Save