Browse Source

fix missing return descriptions

tags/v1.2.0-rc1
Xiao Tianci 4 years ago
parent
commit
8e332a7494
4 changed files with 69 additions and 34 deletions
  1. +23
    -7
      mindspore/dataset/engine/serializer_deserializer.py
  2. +15
    -4
      mindspore/dataset/text/utils.py
  3. +2
    -2
      mindspore/mindrecord/filereader.py
  4. +29
    -21
      mindspore/mindrecord/tools/tfrecord_to_mr.py

+ 23
- 7
mindspore/dataset/engine/serializer_deserializer.py View File

@@ -168,6 +168,17 @@ def create_node(node):
# Find a matching Dataset class and call the constructor with the corresponding args. # Find a matching Dataset class and call the constructor with the corresponding args.
# When a new Dataset class is introduced, another if clause and parsing code needs to be added. # When a new Dataset class is introduced, another if clause and parsing code needs to be added.
# Dataset Source Ops (in alphabetical order) # Dataset Source Ops (in alphabetical order)
pyobj = create_dataset_node(pyclass, node, dataset_op)
if not pyobj:
# Dataset Ops (in alphabetical order)
pyobj = create_dataset_operation_node(node, dataset_op)

return pyobj


def create_dataset_node(pyclass, node, dataset_op):
"""Parse the key, value in the dataset node dictionary and instantiate the Python Dataset object"""
pyobj = None
if dataset_op == 'CelebADataset': if dataset_op == 'CelebADataset':
sampler = construct_sampler(node.get('sampler')) sampler = construct_sampler(node.get('sampler'))
num_samples = check_and_replace_input(node.get('num_samples'), 0, None) num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
@@ -189,7 +200,7 @@ def create_node(node):


elif dataset_op == 'ClueDataset': elif dataset_op == 'ClueDataset':
shuffle = to_shuffle_mode(node.get('shuffle')) shuffle = to_shuffle_mode(node.get('shuffle'))
if shuffle is not None and isinstance(shuffle, str):
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle) shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None) num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], node.get('task'), pyobj = pyclass(node['dataset_files'], node.get('task'),
@@ -205,7 +216,7 @@ def create_node(node):


elif dataset_op == 'CSVDataset': elif dataset_op == 'CSVDataset':
shuffle = to_shuffle_mode(node.get('shuffle')) shuffle = to_shuffle_mode(node.get('shuffle'))
if shuffle is not None and isinstance(shuffle, str):
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle) shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None) num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], node.get('field_delim'), pyobj = pyclass(node['dataset_files'], node.get('field_delim'),
@@ -237,7 +248,7 @@ def create_node(node):


elif dataset_op == 'TextFileDataset': elif dataset_op == 'TextFileDataset':
shuffle = to_shuffle_mode(node.get('shuffle')) shuffle = to_shuffle_mode(node.get('shuffle'))
if shuffle is not None and isinstance(shuffle, str):
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle) shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None) num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], num_samples, pyobj = pyclass(node['dataset_files'], num_samples,
@@ -246,7 +257,7 @@ def create_node(node):


elif dataset_op == 'TFRecordDataset': elif dataset_op == 'TFRecordDataset':
shuffle = to_shuffle_mode(node.get('shuffle')) shuffle = to_shuffle_mode(node.get('shuffle'))
if shuffle is not None and isinstance(shuffle, str):
if isinstance(shuffle, str):
shuffle = de.Shuffle(shuffle) shuffle = de.Shuffle(shuffle)
num_samples = check_and_replace_input(node.get('num_samples'), 0, None) num_samples = check_and_replace_input(node.get('num_samples'), 0, None)
pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'), pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'),
@@ -260,8 +271,13 @@ def create_node(node):
num_samples, node.get('num_parallel_workers'), node.get('shuffle'), num_samples, node.get('num_parallel_workers'), node.get('shuffle'),
node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id')) node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id'))


# Dataset Ops (in alphabetical order)
elif dataset_op == 'Batch':
return pyobj


def create_dataset_operation_node(node, dataset_op):
"""Parse the key, value in the dataset operation node dictionary and instantiate the Python Dataset object"""
pyobj = None
if dataset_op == 'Batch':
pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder')) pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder'))


elif dataset_op == 'Map': elif dataset_op == 'Map':
@@ -292,7 +308,7 @@ def create_node(node):
pyobj = de.Dataset().to_device(node.get('send_epoch_end'), node.get('create_data_info_queue')) pyobj = de.Dataset().to_device(node.get('send_epoch_end'), node.get('create_data_info_queue'))


elif dataset_op == 'Zip': elif dataset_op == 'Zip':
# Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller.
# Create ZipDataset instance, giving dummy input dataset that will be overrode in the caller.
pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) pyobj = de.ZipDataset((de.Dataset(), de.Dataset()))


else: else:


+ 15
- 4
mindspore/dataset/text/utils.py View File

@@ -24,7 +24,6 @@ import mindspore._c_dataengine as cde
from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \ from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \
check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model



__all__ = [ __all__ = [
"Vocab", "SentencePieceVocab", "to_str", "to_bytes" "Vocab", "SentencePieceVocab", "to_str", "to_bytes"
] ]
@@ -66,7 +65,7 @@ class Vocab(cde.Vocab):
is specified and special_first is set to True, special_tokens will be prepended (default=True). is specified and special_first is set to True, special_tokens will be prepended (default=True).


Returns: Returns:
Vocab, Vocab object built from dataset.
Vocab, vocab built from the dataset.
""" """
return dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first) return dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first)


@@ -82,6 +81,9 @@ class Vocab(cde.Vocab):
special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added).
special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens
is specified and special_first is set to True, special_tokens will be prepended (default=True). is specified and special_first is set to True, special_tokens will be prepended (default=True).

Returns:
Vocab, vocab built from the `list`.
""" """
if special_tokens is None: if special_tokens is None:
special_tokens = [] special_tokens = []
@@ -103,6 +105,9 @@ class Vocab(cde.Vocab):
special_first (bool, optional): whether special_tokens will be prepended/appended to vocab, special_first (bool, optional): whether special_tokens will be prepended/appended to vocab,
If special_tokens is specified and special_first is set to True, If special_tokens is specified and special_first is set to True,
special_tokens will be prepended (default=True). special_tokens will be prepended (default=True).

Returns:
Vocab, vocab built from the file.
""" """
if vocab_size is None: if vocab_size is None:
vocab_size = -1 vocab_size = -1
@@ -119,6 +124,9 @@ class Vocab(cde.Vocab):
Args: Args:
word_dict (dict): dict contains word and id pairs, where word should be str and id be int. id is recommended word_dict (dict): dict contains word and id pairs, where word should be str and id be int. id is recommended
to start from 0 and be continuous. ValueError will be raised if id is negative. to start from 0 and be continuous. ValueError will be raised if id is negative.

Returns:
Vocab, vocab built from the `dict`.
""" """


return super().from_dict(word_dict) return super().from_dict(word_dict)
@@ -147,7 +155,7 @@ class SentencePieceVocab(cde.SentencePieceVocab):
params(dict): A dictionary with no incoming parameters. params(dict): A dictionary with no incoming parameters.


Returns: Returns:
SentencePiece, SentencePiece object from dataset.
SentencePieceVocab, vocab built from the dataset.
""" """


return dataset.build_sentencepiece_vocab(col_names, vocab_size, character_coverage, return dataset.build_sentencepiece_vocab(col_names, vocab_size, character_coverage,
@@ -174,6 +182,9 @@ class SentencePieceVocab(cde.SentencePieceVocab):


input_sentence_size 0 input_sentence_size 0
max_sentencepiece_length 16 max_sentencepiece_length 16

Returns:
SentencePieceVocab, vocab built from the file.
""" """
return super().from_file(file_path, vocab_size, character_coverage, return super().from_file(file_path, vocab_size, character_coverage,
DE_C_INTER_SENTENCEPIECE_MODE[model_type], params) DE_C_INTER_SENTENCEPIECE_MODE[model_type], params)
@@ -189,7 +200,7 @@ class SentencePieceVocab(cde.SentencePieceVocab):
path(str): Path to store model. path(str): Path to store model.
filename(str): The name of the file. filename(str): The name of the file.
""" """
return super().save_model(vocab, path, filename)
super().save_model(vocab, path, filename)




def to_str(array, encoding='utf8'): def to_str(array, encoding='utf8'):


+ 2
- 2
mindspore/mindrecord/filereader.py View File

@@ -38,6 +38,7 @@ class FileReader:
Raises: Raises:
ParamValueError: If file_name, num_consumer or columns is invalid. ParamValueError: If file_name, num_consumer or columns is invalid.
""" """

def __init__(self, file_name, num_consumer=4, columns=None, operator=None): def __init__(self, file_name, num_consumer=4, columns=None, operator=None):
if isinstance(file_name, list): if isinstance(file_name, list):
for f in file_name: for f in file_name:
@@ -66,7 +67,6 @@ class FileReader:
self._header = ShardHeader(self._reader.get_header()) self._header = ShardHeader(self._reader.get_header())
self._reader.launch() self._reader.launch()



def get_next(self): def get_next(self):
""" """
Yield a batch of data according to columns at a time. Yield a batch of data according to columns at a time.
@@ -85,4 +85,4 @@ class FileReader:


def close(self): def close(self):
"""Stop reader worker and close File.""" """Stop reader worker and close File."""
return self._reader.close()
self._reader.close()

+ 29
- 21
mindspore/mindrecord/tools/tfrecord_to_mr.py View File

@@ -69,7 +69,7 @@ class TFRecordToMR:


Args: Args:
source (str): the TFRecord file to be transformed. source (str): the TFRecord file to be transformed.
destination (str): the MindRecord file path to tranform into.
destination (str): the MindRecord file path to transform into.
feature_dict (dict): a dictionary that states the feature type, e.g. feature_dict (dict): a dictionary that states the feature type, e.g.
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \ feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \
"yyyy": tf.io.FixedLenFeature([], tf.int64)} "yyyy": tf.io.FixedLenFeature([], tf.int64)}
@@ -90,31 +90,14 @@ class TFRecordToMR:
try: try:
self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
except ModuleNotFoundError: except ModuleNotFoundError:
self.tf = None
if not self.tf:
raise Exception("Module tensorflow is not found, please use pip install it.") raise Exception("Module tensorflow is not found, please use pip install it.")


if self.tf.__version__ < SupportedTensorFlowVersion: if self.tf.__version__ < SupportedTensorFlowVersion:
raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion)) raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion))


if not isinstance(source, str):
raise ValueError("Parameter source must be string.")
check_filename(source)

if not isinstance(destination, str):
raise ValueError("Parameter destination must be string.")
check_filename(destination)

self._check_input(source, destination, feature_dict)
self.source = source self.source = source
self.destination = destination self.destination = destination

if feature_dict is None or not isinstance(feature_dict, dict):
raise ValueError("Parameter feature_dict is None or not dict.")

for key, val in feature_dict.items():
if not isinstance(val, self.tf.io.FixedLenFeature):
raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict))

self.feature_dict = feature_dict self.feature_dict = feature_dict


bytes_fields_list = [] bytes_fields_list = []
@@ -162,6 +145,23 @@ class TFRecordToMR:
mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype), "shape": [val.shape[0]]} mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype), "shape": [val.shape[0]]}
self.mindrecord_schema = mindrecord_schema self.mindrecord_schema = mindrecord_schema


def _check_input(self, source, destination, feature_dict):
"""Validation check for inputs of init method"""
if not isinstance(source, str):
raise ValueError("Parameter source must be string.")
check_filename(source)

if not isinstance(destination, str):
raise ValueError("Parameter destination must be string.")
check_filename(destination)

if feature_dict is None or not isinstance(feature_dict, dict):
raise ValueError("Parameter feature_dict is None or not dict.")

for _, val in feature_dict.items():
if not isinstance(val, self.tf.io.FixedLenFeature):
raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict))

def _parse_record(self, example): def _parse_record(self, example):
"""Returns features for a single example""" """Returns features for a single example"""
features = self.tf.io.parse_single_example(example, features=self.feature_dict) features = self.tf.io.parse_single_example(example, features=self.feature_dict)
@@ -206,6 +206,9 @@ class TFRecordToMR:
""" """
Yield a dict with key to be fields in schema, and value to be data. Yield a dict with key to be fields in schema, and value to be data.
This function is for old version tensorflow whose version number < 2.1.0 This function is for old version tensorflow whose version number < 2.1.0

Yields:
dict, data dictionary whose keys are the same as columns.
""" """
dataset = self.tf.data.TFRecordDataset(self.source) dataset = self.tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record) dataset = dataset.map(self._parse_record)
@@ -235,7 +238,12 @@ class TFRecordToMR:
raise ValueError("TFRecord feature_dict parameter error.") raise ValueError("TFRecord feature_dict parameter error.")


def tfrecord_iterator(self): def tfrecord_iterator(self):
"""Yield a dictionary whose keys are fields in schema."""
"""
Yield a dictionary whose keys are fields in schema.

Yields:
dict, data dictionary whose keys are the same as columns.
"""
dataset = self.tf.data.TFRecordDataset(self.source) dataset = self.tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record) dataset = dataset.map(self._parse_record)
iterator = dataset.__iter__() iterator = dataset.__iter__()
@@ -265,7 +273,7 @@ class TFRecordToMR:
Execute transformation from TFRecord to MindRecord. Execute transformation from TFRecord to MindRecord.


Returns: Returns:
MSRStatus, whether TFRecord is successfuly transformed to MindRecord.
MSRStatus, whether TFRecord is successfully transformed to MindRecord.
""" """
writer = FileWriter(self.destination) writer = FileWriter(self.destination)
logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}" logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}"


Loading…
Cancel
Save