You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer_deserializer.py 4.8 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright 2019-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Functions to support dataset serialize and deserialize.
  17. """
  18. import json
  19. import os
  20. from mindspore import log as logger
  21. from . import datasets as de
  22. def serialize(dataset, json_filepath=""):
  23. """
  24. Serialize dataset pipeline into a JSON file.
  25. Note:
  26. Currently some Python objects are not supported to be serialized.
  27. For Python function serialization of map operator, de.serialize will only return its function name.
  28. Args:
  29. dataset (Dataset): The starting node.
  30. json_filepath (str): The filepath where a serialized JSON file will be generated.
  31. Returns:
  32. Dict, The dictionary contains the serialized dataset graph.
  33. Raises:
  34. OSError: Can not open a file
  35. Examples:
  36. >>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
  37. >>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
  38. >>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
  39. >>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
  40. >>> # serialize it to JSON file
  41. >>> ds.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
  42. >>> serialized_data = ds.serialize(dataset) # serialize it to Python dict
  43. """
  44. return dataset.to_json(json_filepath)
  45. def deserialize(input_dict=None, json_filepath=None):
  46. """
  47. Construct dataset pipeline from a JSON file produced by de.serialize().
  48. Note:
  49. Currently Python function deserialization of map operator are not supported.
  50. Args:
  51. input_dict (dict): A Python dictionary containing a serialized dataset graph.
  52. json_filepath (str): A path to the JSON file.
  53. Returns:
  54. de.Dataset or None if error occurs.
  55. Raises:
  56. OSError: Can not open the JSON file.
  57. Examples:
  58. >>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
  59. >>> one_hot_encode = c_transforms.OneHot(10) # num_classes is input argument
  60. >>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
  61. >>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
  62. >>> # Use case 1: to/from JSON file
  63. >>> ds.serialize(dataset, json_filepath="/path/to/mnist_dataset_pipeline.json")
  64. >>> dataset = ds.deserialize(json_filepath="/path/to/mnist_dataset_pipeline.json")
  65. >>> # Use case 2: to/from Python dictionary
  66. >>> serialized_data = ds.serialize(dataset)
  67. >>> dataset = ds.deserialize(input_dict=serialized_data)
  68. """
  69. data = None
  70. if input_dict:
  71. data = de.DeserializedDataset(input_dict)
  72. if json_filepath:
  73. data = de.DeserializedDataset(json_filepath)
  74. return data
  75. def expand_path(node_repr, key, val):
  76. """Convert relative to absolute path."""
  77. if isinstance(val, list):
  78. node_repr[key] = [os.path.abspath(file) for file in val]
  79. else:
  80. node_repr[key] = os.path.abspath(val)
  81. def show(dataset, indentation=2):
  82. """
  83. Write the dataset pipeline graph to logger.info file.
  84. Args:
  85. dataset (Dataset): The starting node.
  86. indentation (int, optional): The indentation used by the JSON print.
  87. Do not indent if indentation is None.
  88. Examples:
  89. >>> dataset = ds.MnistDataset(mnist_dataset_dir, 100)
  90. >>> one_hot_encode = c_transforms.OneHot(10)
  91. >>> dataset = dataset.map(operation=one_hot_encode, input_column_names="label")
  92. >>> dataset = dataset.batch(batch_size=10, drop_remainder=True)
  93. >>> ds.show(dataset)
  94. """
  95. pipeline = dataset.to_json()
  96. logger.info(json.dumps(pipeline, indent=indentation))
  97. def compare(pipeline1, pipeline2):
  98. """
  99. Compare if two dataset pipelines are the same.
  100. Args:
  101. pipeline1 (Dataset): a dataset pipeline.
  102. pipeline2 (Dataset): a dataset pipeline.
  103. Returns:
  104. Whether pipeline1 is equal to pipeline2.
  105. Examples:
  106. >>> pipeline1 = ds.MnistDataset(mnist_dataset_dir, 100)
  107. >>> pipeline2 = ds.Cifar10Dataset(cifar_dataset_dir, 100)
  108. >>> ds.compare(pipeline1, pipeline2)
  109. """
  110. return pipeline1.to_json() == pipeline2.to_json()