You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer_deserializer.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Functions to support dataset serialize and deserialize.
  17. """
  18. import json
  19. import os
  20. import sys
  21. from mindspore import log as logger
  22. from . import datasets as de
  23. from ..transforms.vision.utils import Inter, Border
  24. def serialize(dataset, json_filepath=None):
  25. """
  26. Serialize dataset pipeline into a json file.
  27. Args:
  28. dataset (Dataset): the starting node.
  29. json_filepath (string): a filepath where a serialized json file will be generated.
  30. Returns:
  31. dict containing the serialized dataset graph.
  32. Raises:
  33. OSError cannot open a file
  34. Examples:
  35. >>> import mindspore.dataset as ds
  36. >>> import mindspore.dataset.transforms.c_transforms as C
  37. >>> DATA_DIR = "../../data/testMnistData"
  38. >>> data = ds.MnistDataset(DATA_DIR, 100)
  39. >>> one_hot_encode = C.OneHot(10) # num_classes is input argument
  40. >>> data = data.map(input_column_names="label", operation=one_hot_encode)
  41. >>> data = data.batch(batch_size=10, drop_remainder=True)
  42. >>>
  43. >>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json") # serialize it to json file
  44. >>> serialized_data = ds.engine.serialize(data) # serialize it to python dict
  45. """
  46. serialized_pipeline = traverse(dataset)
  47. if json_filepath:
  48. with open(json_filepath, 'w') as json_file:
  49. json.dump(serialized_pipeline, json_file, indent=2)
  50. return serialized_pipeline
  51. def deserialize(input_dict=None, json_filepath=None):
  52. """
  53. Construct a de pipeline from a json file produced by de.serialize().
  54. Args:
  55. input_dict (dict): a python dictionary containing a serialized dataset graph
  56. json_filepath (string): a path to the json file.
  57. Returns:
  58. de.Dataset or None if error occurs.
  59. Raises:
  60. OSError cannot open a file.
  61. Examples:
  62. >>> import mindspore.dataset as ds
  63. >>> import mindspore.dataset.transforms.c_transforms as C
  64. >>> DATA_DIR = "../../data/testMnistData"
  65. >>> data = ds.MnistDataset(DATA_DIR, 100)
  66. >>> one_hot_encode = C.OneHot(10) # num_classes is input argument
  67. >>> data = data.map(input_column_names="label", operation=one_hot_encode)
  68. >>> data = data.batch(batch_size=10, drop_remainder=True)
  69. >>>
  70. >>> # Use case 1: to/from json file
  71. >>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json")
  72. >>> data = ds.engine.deserialize(json_filepath="mnist_dataset_pipeline.json")
  73. >>> # Use case 2: to/from python dictionary
  74. >>> serialized_data = ds.engine.serialize(data)
  75. >>> data = ds.engine.deserialize(input_dict=serialized_data)
  76. """
  77. data = None
  78. if input_dict:
  79. data = construct_pipeline(input_dict)
  80. if json_filepath:
  81. dict_pipeline = dict()
  82. with open(json_filepath, 'r') as json_file:
  83. dict_pipeline = json.load(json_file)
  84. data = construct_pipeline(dict_pipeline)
  85. return data
  86. def expand_path(node_repr, key, val):
  87. """Convert relative to absolute path."""
  88. if isinstance(val, list):
  89. node_repr[key] = [os.path.abspath(file) for file in val]
  90. else:
  91. node_repr[key] = os.path.abspath(val)
  92. def serialize_operations(node_repr, key, val):
  93. """Serialize tensor op (python object) to dictionary."""
  94. if isinstance(val, list):
  95. node_repr[key] = []
  96. for op in val:
  97. node_repr[key].append(op.__dict__)
  98. # Extracting module and name information from a python object
  99. # Example: tensor_op_module is 'minddata.transforms.c_transforms' and tensor_op_name is 'Decode'
  100. node_repr[key][-1]['tensor_op_name'] = type(op).__name__
  101. node_repr[key][-1]['tensor_op_module'] = type(op).__module__
  102. else:
  103. node_repr[key] = val.__dict__
  104. node_repr[key]['tensor_op_name'] = type(val).__name__
  105. node_repr[key]['tensor_op_module'] = type(val).__module__
  106. def serialize_sampler(node_repr, val):
  107. """Serialize sampler object to dictionary."""
  108. if val is None:
  109. node_repr['sampler'] = None
  110. else:
  111. node_repr['sampler'] = val.__dict__
  112. node_repr['sampler']['sampler_module'] = type(val).__module__
  113. node_repr['sampler']['sampler_name'] = type(val).__name__
  114. def traverse(node):
  115. """Pre-order traverse the pipeline and capture the information as we go."""
  116. # Node representation (node_repr) is a python dictionary that capture and store the
  117. # dataset pipeline information before dumping it to JSON or other format.
  118. node_repr = dict()
  119. node_repr['op_type'] = type(node).__name__
  120. node_repr['op_module'] = type(node).__module__
  121. # Start with an empty list of children, will be added later as we traverse this node.
  122. node_repr["children"] = []
  123. # Retrieve the information about the current node. It should include arguments
  124. # passed to the node during object construction.
  125. node_args = node.get_args()
  126. for k, v in node_args.items():
  127. # Store the information about this node into node_repr.
  128. # Further serialize the object in the arguments if needed.
  129. if k == 'operations':
  130. serialize_operations(node_repr, k, v)
  131. elif k == 'sampler':
  132. serialize_sampler(node_repr, v)
  133. elif k in set(['schema', 'dataset_files', 'dataset_dir', 'schema_file_path']):
  134. expand_path(node_repr, k, v)
  135. else:
  136. node_repr[k] = v
  137. # Leaf node doesn't have input attribute.
  138. if not node.input:
  139. return node_repr
  140. # Recursively traverse the child and assign it to the current node_repr['children'].
  141. for child in node.input:
  142. node_repr["children"].append(traverse(child))
  143. return node_repr
  144. def show(dataset, indentation=2):
  145. """
  146. Write the dataset pipeline graph onto logger.info.
  147. Args:
  148. dataset (Dataset): the starting node.
  149. indentation (int, optional): indentation used by the json print. Pass None to not indent.
  150. """
  151. pipeline = traverse(dataset)
  152. logger.info(json.dumps(pipeline, indent=indentation))
  153. def compare(pipeline1, pipeline2):
  154. """
  155. Compare if two dataset pipelines are the same.
  156. Args:
  157. pipeline1 (Dataset): a dataset pipeline.
  158. pipeline2 (Dataset): a dataset pipeline.
  159. """
  160. return traverse(pipeline1) == traverse(pipeline2)
  161. def construct_pipeline(node):
  162. """Construct the python Dataset objects by following the dictionary deserialized from json file."""
  163. op_type = node.get('op_type')
  164. if not op_type:
  165. raise ValueError("op_type field in the json file can't be None.")
  166. # Instantiate python Dataset object based on the current dictionary element
  167. dataset = create_node(node)
  168. # Initially it is not connected to any other object.
  169. dataset.input = []
  170. # Construct the children too and add edge between the children and parent.
  171. for child in node['children']:
  172. dataset.input.append(construct_pipeline(child))
  173. return dataset
  174. def create_node(node):
  175. """Parse the key, value in the node dictionary and instantiate the python Dataset object"""
  176. logger.info('creating node: %s', node['op_type'])
  177. dataset_op = node['op_type']
  178. op_module = node['op_module']
  179. # Get the python class to be instantiated.
  180. # Example:
  181. # "op_type": "MapDataset",
  182. # "op_module": "mindspore.dataset.datasets",
  183. pyclass = getattr(sys.modules[op_module], dataset_op)
  184. pyobj = None
  185. # Find a matching Dataset class and call the constructor with the corresponding args.
  186. # When a new Dataset class is introduced, another if clause and parsing code needs to be added.
  187. if dataset_op == 'ImageFolderDatasetV2':
  188. sampler = construct_sampler(node.get('sampler'))
  189. pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
  190. node.get('shuffle'), sampler, node.get('extensions'),
  191. node.get('class_indexing'), node.get('decode'), node.get('num_shards'),
  192. node.get('shard_id'))
  193. elif dataset_op == 'RangeDataset':
  194. pyobj = pyclass(node['start'], node['stop'], node['step'])
  195. elif dataset_op == 'ImageFolderDataset':
  196. pyobj = pyclass(node['dataset_dir'], node['schema'], node.get('distribution'),
  197. node.get('column_list'), node.get('num_parallel_workers'),
  198. node.get('deterministic_output'), node.get('prefetch_size'),
  199. node.get('labels_filename'), node.get('dataset_usage'))
  200. elif dataset_op == 'MnistDataset':
  201. sampler = construct_sampler(node.get('sampler'))
  202. pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
  203. node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
  204. elif dataset_op == 'MindDataset':
  205. sampler = construct_sampler(node.get('sampler'))
  206. pyobj = pyclass(node['dataset_file'], node.get('columns_list'),
  207. node.get('num_parallel_workers'), node.get('seed'), node.get('num_shards'),
  208. node.get('shard_id'), node.get('block_reader'), sampler)
  209. elif dataset_op == 'TFRecordDataset':
  210. pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'),
  211. node.get('num_samples'), node.get('num_parallel_workers'),
  212. de.Shuffle(node.get('shuffle')), node.get('num_shards'), node.get('shard_id'))
  213. elif dataset_op == 'ManifestDataset':
  214. sampler = construct_sampler(node.get('sampler'))
  215. pyobj = pyclass(node['dataset_file'], node['usage'], node.get('num_samples'),
  216. node.get('num_parallel_workers'), node.get('shuffle'), sampler,
  217. node.get('class_indexing'), node.get('decode'), node.get('num_shards'),
  218. node.get('shard_id'))
  219. elif dataset_op == 'Cifar10Dataset':
  220. sampler = construct_sampler(node.get('sampler'))
  221. pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
  222. node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
  223. elif dataset_op == 'Cifar100Dataset':
  224. sampler = construct_sampler(node.get('sampler'))
  225. pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
  226. node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
  227. elif dataset_op == 'VOCDataset':
  228. sampler = construct_sampler(node.get('sampler'))
  229. pyobj = pyclass(node['dataset_dir'], node.get('task'), node.get('mode'), node.get('class_indexing'),
  230. node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'),
  231. node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id'))
  232. elif dataset_op == 'CelebADataset':
  233. sampler = construct_sampler(node.get('sampler'))
  234. pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'),
  235. node.get('dataset_type'), sampler, node.get('decode'), node.get('extensions'),
  236. node.get('num_samples'), sampler, node.get('num_shards'), node.get('shard_id'))
  237. elif dataset_op == 'GeneratorDataset':
  238. # Serializing py function can be done using marshal library
  239. raise RuntimeError(dataset_op + " is not yet supported")
  240. elif dataset_op == 'RepeatDataset':
  241. pyobj = de.Dataset().repeat(node.get('count'))
  242. elif dataset_op == 'SkipDataset':
  243. pyobj = de.Dataset().skip(node.get('count'))
  244. elif dataset_op == 'TakeDataset':
  245. pyobj = de.Dataset().take(node.get('count'))
  246. elif dataset_op == 'MapDataset':
  247. tensor_ops = construct_tensor_ops(node.get('operations'))
  248. pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'),
  249. node.get('columns_order'), node.get('num_parallel_workers'))
  250. elif dataset_op == 'ShuffleDataset':
  251. pyobj = de.Dataset().shuffle(node.get('buffer_size'))
  252. elif dataset_op == 'BatchDataset':
  253. pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder'))
  254. elif dataset_op == 'CacheDataset':
  255. # Member function cache() is not defined in class Dataset yet.
  256. raise RuntimeError(dataset_op + " is not yet supported")
  257. elif dataset_op == 'FilterDataset':
  258. # Member function filter() is not defined in class Dataset yet.
  259. raise RuntimeError(dataset_op + " is not yet supported")
  260. elif dataset_op == 'TakeDataset':
  261. # Member function take() is not defined in class Dataset yet.
  262. raise RuntimeError(dataset_op + " is not yet supported")
  263. elif dataset_op == 'ZipDataset':
  264. # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller.
  265. pyobj = de.ZipDataset((de.Dataset(), de.Dataset()))
  266. elif dataset_op == 'ConcatDataset':
  267. # Create ConcatDataset instance, giving dummy input dataset that will be overrided in the caller.
  268. pyobj = de.ConcatDataset((de.Dataset(), de.Dataset()))
  269. elif dataset_op == 'RenameDataset':
  270. pyobj = de.Dataset().rename(node['input_columns'], node['output_columns'])
  271. elif dataset_op == 'ProjectDataset':
  272. pyobj = de.Dataset().project(node['columns'])
  273. elif dataset_op == 'TransferDataset':
  274. pyobj = de.Dataset().to_device()
  275. else:
  276. raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize()")
  277. return pyobj
  278. def construct_sampler(in_sampler):
  279. """Instantiate Sampler object based on the information from dictionary['sampler']"""
  280. sampler = None
  281. if in_sampler is not None:
  282. sampler_name = in_sampler['sampler_name']
  283. sampler_module = in_sampler['sampler_module']
  284. sampler_class = getattr(sys.modules[sampler_module], sampler_name)
  285. if sampler_name == 'DistributedSampler':
  286. sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle'))
  287. elif sampler_name == 'PKSampler':
  288. sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle'))
  289. elif sampler_name == 'RandomSampler':
  290. sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples'))
  291. elif sampler_name == 'SequentialSampler':
  292. sampler = sampler_class()
  293. elif sampler_name == 'SubsetRandomSampler':
  294. sampler = sampler_class(in_sampler['indices'])
  295. elif sampler_name == 'WeightedRandomSampler':
  296. sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement'))
  297. else:
  298. raise ValueError("Sampler type is unknown: " + sampler_name)
  299. return sampler
  300. def construct_tensor_ops(operations):
  301. """Instantiate tensor op object(s) based on the information from dictionary['operations']"""
  302. result = []
  303. for op in operations:
  304. op_module = op['tensor_op_module']
  305. op_name = op['tensor_op_name']
  306. op_class = getattr(sys.modules[op_module], op_name)
  307. if op_name == 'Decode':
  308. result.append(op_class(op.get('rgb')))
  309. elif op_name == 'Normalize':
  310. result.append(op_class(op['mean'], op['std']))
  311. elif op_name == 'RandomCrop':
  312. result.append(op_class(op['size'], op.get('padding'), op.get('pad_if_needed'),
  313. op.get('fill_value'), Border(op.get('padding_mode'))))
  314. elif op_name == 'RandomHorizontalFlip':
  315. result.append(op_class(op.get('prob')))
  316. elif op_name == 'RandomVerticalFlip':
  317. result.append(op_class(op.get('prob')))
  318. elif op_name == 'Resize':
  319. result.append(op_class(op['size'], Inter(op.get('interpolation'))))
  320. elif op_name == 'RandomResizedCrop':
  321. result.append(op_class(op['size'], op.get('scale'), op.get('ratio'),
  322. Inter(op.get('interpolation')), op.get('max_attempts')))
  323. elif op_name == 'CenterCrop':
  324. result.append(op_class(op['size']))
  325. elif op_name == 'RandomColorAdjust':
  326. result.append(op_class(op.get('brightness'), op.get('contrast'), op.get('saturation'),
  327. op.get('hue')))
  328. elif op_name == 'RandomRotation':
  329. result.append(op_class(op['degree'], op.get('resample'), op.get('expand'),
  330. op.get('center'), op.get('fill_value')))
  331. elif op_name == 'Rescale':
  332. result.append(op_class(op['rescale'], op['shift']))
  333. elif op_name == 'RandomResize':
  334. result.append(op_class(op['size']))
  335. elif op_name == 'TypeCast':
  336. result.append(op_class(op['data_type']))
  337. elif op_name == 'HWC2CHW':
  338. result.append(op_class())
  339. elif op_name == 'CHW2HWC':
  340. raise ValueError("Tensor op is not supported: " + op_name)
  341. elif op_name == 'OneHot':
  342. result.append(op_class(op['num_classes']))
  343. elif op_name == 'RandomCropDecodeResize':
  344. result.append(op_class(op['size'], op.get('scale'), op.get('ratio'),
  345. Inter(op.get('interpolation')), op.get('max_attempts')))
  346. elif op_name == 'Pad':
  347. result.append(op_class(op['padding'], op['fill_value'], Border(op['padding_mode'])))
  348. else:
  349. raise ValueError("Tensor op name is unknown: " + op_name)
  350. return result