|
|
|
@@ -43,11 +43,10 @@ def serialize(dataset, json_filepath=""): |
|
|
|
Examples: |
|
|
|
>>> dataset = ds.MnistDataset(mnist_dataset_dir, num_samples=100) |
|
|
|
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument |
|
|
|
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label") |
|
|
|
>>> dataset = dataset.map(operations=one_hot_encode, input_columns="label") |
|
|
|
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True) |
|
|
|
>>> # serialize it to JSON file |
|
|
|
>>> ds.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json") |
|
|
|
>>> serialized_data = ds.serialize(dataset) # serialize it to Python dict |
|
|
|
>>> serialized_data = ds.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json") |
|
|
|
""" |
|
|
|
return dataset.to_json(json_filepath) |
|
|
|
|
|
|
|
@@ -72,16 +71,16 @@ def deserialize(input_dict=None, json_filepath=None): |
|
|
|
Examples: |
|
|
|
>>> dataset = ds.MnistDataset(mnist_dataset_dir, num_samples=100) |
|
|
|
>>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument |
|
|
|
>>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label") |
|
|
|
>>> dataset = dataset.map(operations=one_hot_encode, input_columns="label") |
|
|
|
>>> dataset = dataset.batch(batch_size=10, drop_remainder=True) |
|
|
|
>>> # Use case 1: to/from JSON file |
|
|
|
>>> ds.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json") |
|
|
|
>>> dataset = ds.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json") |
|
|
|
>>> # Use case 2: to/from Python dictionary |
|
|
|
>>> # Case 1: to/from JSON file |
|
|
|
>>> serialized_data = ds.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json") |
|
|
|
>>> deserialized_dataset = ds.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json") |
|
|
|
>>> # Case 2: to/from Python dictionary |
|
|
|
>>> serialized_data = ds.serialize(dataset) |
|
|
|
>>> dataset = ds.deserialize(input_dict=serialized_data) |
|
|
|
|
|
|
|
>>> deserialized_dataset = ds.deserialize(input_dict=serialized_data) |
|
|
|
""" |
|
|
|
|
|
|
|
data = None |
|
|
|
if input_dict: |
|
|
|
data = de.DeserializedDataset(input_dict) |
|
|
|
@@ -133,7 +132,7 @@ def compare(pipeline1, pipeline2): |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> pipeline1 = ds.MnistDataset(mnist_dataset_dir, num_samples=100) |
|
|
|
>>> pipeline2 = ds.Cifar10Dataset(cifar_dataset_dir, num_samples=100) |
|
|
|
>>> pipeline2 = ds.Cifar10Dataset(cifar10_dataset_dir, num_samples=100) |
|
|
|
>>> res = ds.compare(pipeline1, pipeline2) |
|
|
|
""" |
|
|
|
|
|
|
|
|