| @@ -168,6 +168,17 @@ def create_node(node): | |||
| # Find a matching Dataset class and call the constructor with the corresponding args. | |||
| # When a new Dataset class is introduced, another if clause and parsing code needs to be added. | |||
| # Dataset Source Ops (in alphabetical order) | |||
| pyobj = create_dataset_node(pyclass, node, dataset_op) | |||
| if not pyobj: | |||
| # Dataset Ops (in alphabetical order) | |||
| pyobj = create_dataset_operation_node(node, dataset_op) | |||
| return pyobj | |||
| def create_dataset_node(pyclass, node, dataset_op): | |||
| """Parse the key, value in the dataset node dictionary and instantiate the Python Dataset object""" | |||
| pyobj = None | |||
| if dataset_op == 'CelebADataset': | |||
| sampler = construct_sampler(node.get('sampler')) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| @@ -189,7 +200,7 @@ def create_node(node): | |||
| elif dataset_op == 'ClueDataset': | |||
| shuffle = to_shuffle_mode(node.get('shuffle')) | |||
| if shuffle is not None and isinstance(shuffle, str): | |||
| if isinstance(shuffle, str): | |||
| shuffle = de.Shuffle(shuffle) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_files'], node.get('task'), | |||
| @@ -205,7 +216,7 @@ def create_node(node): | |||
| elif dataset_op == 'CSVDataset': | |||
| shuffle = to_shuffle_mode(node.get('shuffle')) | |||
| if shuffle is not None and isinstance(shuffle, str): | |||
| if isinstance(shuffle, str): | |||
| shuffle = de.Shuffle(shuffle) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_files'], node.get('field_delim'), | |||
| @@ -237,7 +248,7 @@ def create_node(node): | |||
| elif dataset_op == 'TextFileDataset': | |||
| shuffle = to_shuffle_mode(node.get('shuffle')) | |||
| if shuffle is not None and isinstance(shuffle, str): | |||
| if isinstance(shuffle, str): | |||
| shuffle = de.Shuffle(shuffle) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_files'], num_samples, | |||
| @@ -246,7 +257,7 @@ def create_node(node): | |||
| elif dataset_op == 'TFRecordDataset': | |||
| shuffle = to_shuffle_mode(node.get('shuffle')) | |||
| if shuffle is not None and isinstance(shuffle, str): | |||
| if isinstance(shuffle, str): | |||
| shuffle = de.Shuffle(shuffle) | |||
| num_samples = check_and_replace_input(node.get('num_samples'), 0, None) | |||
| pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'), | |||
| @@ -260,8 +271,13 @@ def create_node(node): | |||
| num_samples, node.get('num_parallel_workers'), node.get('shuffle'), | |||
| node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id')) | |||
| # Dataset Ops (in alphabetical order) | |||
| elif dataset_op == 'Batch': | |||
| return pyobj | |||
| def create_dataset_operation_node(node, dataset_op): | |||
| """Parse the key, value in the dataset operation node dictionary and instantiate the Python Dataset object""" | |||
| pyobj = None | |||
| if dataset_op == 'Batch': | |||
| pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder')) | |||
| elif dataset_op == 'Map': | |||
| @@ -292,7 +308,7 @@ def create_node(node): | |||
| pyobj = de.Dataset().to_device(node.get('send_epoch_end'), node.get('create_data_info_queue')) | |||
| elif dataset_op == 'Zip': | |||
| # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. | |||
| # Create ZipDataset instance, giving dummy input dataset that will be overrode in the caller. | |||
| pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) | |||
| else: | |||
| @@ -24,7 +24,6 @@ import mindspore._c_dataengine as cde | |||
| from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \ | |||
| check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model | |||
| __all__ = [ | |||
| "Vocab", "SentencePieceVocab", "to_str", "to_bytes" | |||
| ] | |||
| @@ -66,7 +65,7 @@ class Vocab(cde.Vocab): | |||
| is specified and special_first is set to True, special_tokens will be prepended (default=True). | |||
| Returns: | |||
| Vocab, Vocab object built from dataset. | |||
| Vocab, vocab built from the dataset. | |||
| """ | |||
| return dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first) | |||
| @@ -82,6 +81,9 @@ class Vocab(cde.Vocab): | |||
| special_tokens=["<pad>","<unk>"] (default=None, no special tokens will be added). | |||
| special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens | |||
| is specified and special_first is set to True, special_tokens will be prepended (default=True). | |||
| Returns: | |||
| Vocab, vocab built from the `list`. | |||
| """ | |||
| if special_tokens is None: | |||
| special_tokens = [] | |||
| @@ -103,6 +105,9 @@ class Vocab(cde.Vocab): | |||
| special_first (bool, optional): whether special_tokens will be prepended/appended to vocab, | |||
| If special_tokens is specified and special_first is set to True, | |||
| special_tokens will be prepended (default=True). | |||
| Returns: | |||
| Vocab, vocab built from the file. | |||
| """ | |||
| if vocab_size is None: | |||
| vocab_size = -1 | |||
| @@ -119,6 +124,9 @@ class Vocab(cde.Vocab): | |||
| Args: | |||
| word_dict (dict): dict contains word and id pairs, where word should be str and id be int. id is recommended | |||
| to start from 0 and be continuous. ValueError will be raised if id is negative. | |||
| Returns: | |||
| Vocab, vocab built from the `dict`. | |||
| """ | |||
| return super().from_dict(word_dict) | |||
| @@ -147,7 +155,7 @@ class SentencePieceVocab(cde.SentencePieceVocab): | |||
| params(dict): A dictionary with no incoming parameters. | |||
| Returns: | |||
| SentencePiece, SentencePiece object from dataset. | |||
| SentencePieceVocab, vocab built from the dataset. | |||
| """ | |||
| return dataset.build_sentencepiece_vocab(col_names, vocab_size, character_coverage, | |||
| @@ -174,6 +182,9 @@ class SentencePieceVocab(cde.SentencePieceVocab): | |||
| input_sentence_size 0 | |||
| max_sentencepiece_length 16 | |||
| Returns: | |||
| SentencePieceVocab, vocab built from the file. | |||
| """ | |||
| return super().from_file(file_path, vocab_size, character_coverage, | |||
| DE_C_INTER_SENTENCEPIECE_MODE[model_type], params) | |||
| @@ -189,7 +200,7 @@ class SentencePieceVocab(cde.SentencePieceVocab): | |||
| path(str): Path to store model. | |||
| filename(str): The name of the file. | |||
| """ | |||
| return super().save_model(vocab, path, filename) | |||
| super().save_model(vocab, path, filename) | |||
| def to_str(array, encoding='utf8'): | |||
| @@ -38,6 +38,7 @@ class FileReader: | |||
| Raises: | |||
| ParamValueError: If file_name, num_consumer or columns is invalid. | |||
| """ | |||
| def __init__(self, file_name, num_consumer=4, columns=None, operator=None): | |||
| if isinstance(file_name, list): | |||
| for f in file_name: | |||
| @@ -66,7 +67,6 @@ class FileReader: | |||
| self._header = ShardHeader(self._reader.get_header()) | |||
| self._reader.launch() | |||
| def get_next(self): | |||
| """ | |||
| Yield a batch of data according to columns at a time. | |||
| @@ -85,4 +85,4 @@ class FileReader: | |||
| def close(self): | |||
| """Stop reader worker and close File.""" | |||
| return self._reader.close() | |||
| self._reader.close() | |||
| @@ -69,7 +69,7 @@ class TFRecordToMR: | |||
| Args: | |||
| source (str): the TFRecord file to be transformed. | |||
| destination (str): the MindRecord file path to tranform into. | |||
| destination (str): the MindRecord file path to transform into. | |||
| feature_dict (dict): a dictionary that states the feature type, e.g. | |||
| feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \ | |||
| "yyyy": tf.io.FixedLenFeature([], tf.int64)} | |||
| @@ -90,31 +90,14 @@ class TFRecordToMR: | |||
| try: | |||
| self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord | |||
| except ModuleNotFoundError: | |||
| self.tf = None | |||
| if not self.tf: | |||
| raise Exception("Module tensorflow is not found, please use pip install it.") | |||
| if self.tf.__version__ < SupportedTensorFlowVersion: | |||
| raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion)) | |||
| if not isinstance(source, str): | |||
| raise ValueError("Parameter source must be string.") | |||
| check_filename(source) | |||
| if not isinstance(destination, str): | |||
| raise ValueError("Parameter destination must be string.") | |||
| check_filename(destination) | |||
| self._check_input(source, destination, feature_dict) | |||
| self.source = source | |||
| self.destination = destination | |||
| if feature_dict is None or not isinstance(feature_dict, dict): | |||
| raise ValueError("Parameter feature_dict is None or not dict.") | |||
| for key, val in feature_dict.items(): | |||
| if not isinstance(val, self.tf.io.FixedLenFeature): | |||
| raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict)) | |||
| self.feature_dict = feature_dict | |||
| bytes_fields_list = [] | |||
| @@ -162,6 +145,23 @@ class TFRecordToMR: | |||
| mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype), "shape": [val.shape[0]]} | |||
| self.mindrecord_schema = mindrecord_schema | |||
| def _check_input(self, source, destination, feature_dict): | |||
| """Validation check for inputs of init method""" | |||
| if not isinstance(source, str): | |||
| raise ValueError("Parameter source must be string.") | |||
| check_filename(source) | |||
| if not isinstance(destination, str): | |||
| raise ValueError("Parameter destination must be string.") | |||
| check_filename(destination) | |||
| if feature_dict is None or not isinstance(feature_dict, dict): | |||
| raise ValueError("Parameter feature_dict is None or not dict.") | |||
| for _, val in feature_dict.items(): | |||
| if not isinstance(val, self.tf.io.FixedLenFeature): | |||
| raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict)) | |||
| def _parse_record(self, example): | |||
| """Returns features for a single example""" | |||
| features = self.tf.io.parse_single_example(example, features=self.feature_dict) | |||
| @@ -206,6 +206,9 @@ class TFRecordToMR: | |||
| """ | |||
| Yield a dict with key to be fields in schema, and value to be data. | |||
| This function is for old version tensorflow whose version number < 2.1.0 | |||
| Yields: | |||
| dict, data dictionary whose keys are the same as columns. | |||
| """ | |||
| dataset = self.tf.data.TFRecordDataset(self.source) | |||
| dataset = dataset.map(self._parse_record) | |||
| @@ -235,7 +238,12 @@ class TFRecordToMR: | |||
| raise ValueError("TFRecord feature_dict parameter error.") | |||
| def tfrecord_iterator(self): | |||
| """Yield a dictionary whose keys are fields in schema.""" | |||
| """ | |||
| Yield a dictionary whose keys are fields in schema. | |||
| Yields: | |||
| dict, data dictionary whose keys are the same as columns. | |||
| """ | |||
| dataset = self.tf.data.TFRecordDataset(self.source) | |||
| dataset = dataset.map(self._parse_record) | |||
| iterator = dataset.__iter__() | |||
| @@ -265,7 +273,7 @@ class TFRecordToMR: | |||
| Execute transformation from TFRecord to MindRecord. | |||
| Returns: | |||
| MSRStatus, whether TFRecord is successfuly transformed to MindRecord. | |||
| MSRStatus, whether TFRecord is successfully transformed to MindRecord. | |||
| """ | |||
| writer = FileWriter(self.destination) | |||
| logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}" | |||