You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer_deserializer.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Functions to support dataset serialize and deserialize.
  17. """
  18. import json
  19. import os
  20. import sys
  21. from mindspore import log as logger
  22. from . import datasets as de
  23. from ..transforms.vision.utils import Inter, Border
  24. from ..core.configuration import config
  25. def serialize(dataset, json_filepath=None):
  26. """
  27. Serialize dataset pipeline into a json file.
  28. Args:
  29. dataset (Dataset): the starting node.
  30. json_filepath (string): a filepath where a serialized json file will be generated.
  31. Returns:
  32. dict containing the serialized dataset graph.
  33. Raises:
  34. OSError cannot open a file
  35. Examples:
  36. >>> import mindspore.dataset as ds
  37. >>> import mindspore.dataset.transforms.c_transforms as C
  38. >>> DATA_DIR = "../../data/testMnistData"
  39. >>> data = ds.MnistDataset(DATA_DIR, 100)
  40. >>> one_hot_encode = C.OneHot(10) # num_classes is input argument
  41. >>> data = data.map(input_column_names="label", operation=one_hot_encode)
  42. >>> data = data.batch(batch_size=10, drop_remainder=True)
  43. >>>
  44. >>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json") # serialize it to json file
  45. >>> serialized_data = ds.engine.serialize(data) # serialize it to python dict
  46. """
  47. serialized_pipeline = traverse(dataset)
  48. if json_filepath:
  49. with open(json_filepath, 'w') as json_file:
  50. json.dump(serialized_pipeline, json_file, indent=2)
  51. return serialized_pipeline
  52. def deserialize(input_dict=None, json_filepath=None):
  53. """
  54. Construct a de pipeline from a json file produced by de.serialize().
  55. Args:
  56. input_dict (dict): a python dictionary containing a serialized dataset graph
  57. json_filepath (string): a path to the json file.
  58. Returns:
  59. de.Dataset or None if error occurs.
  60. Raises:
  61. OSError cannot open a file.
  62. Examples:
  63. >>> import mindspore.dataset as ds
  64. >>> import mindspore.dataset.transforms.c_transforms as C
  65. >>> DATA_DIR = "../../data/testMnistData"
  66. >>> data = ds.MnistDataset(DATA_DIR, 100)
  67. >>> one_hot_encode = C.OneHot(10) # num_classes is input argument
  68. >>> data = data.map(input_column_names="label", operation=one_hot_encode)
  69. >>> data = data.batch(batch_size=10, drop_remainder=True)
  70. >>>
  71. >>> # Use case 1: to/from json file
  72. >>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json")
  73. >>> data = ds.engine.deserialize(json_filepath="mnist_dataset_pipeline.json")
  74. >>> # Use case 2: to/from python dictionary
  75. >>> serialized_data = ds.engine.serialize(data)
  76. >>> data = ds.engine.deserialize(input_dict=serialized_data)
  77. """
  78. data = None
  79. if input_dict:
  80. data = construct_pipeline(input_dict)
  81. if json_filepath:
  82. dict_pipeline = dict()
  83. with open(json_filepath, 'r') as json_file:
  84. dict_pipeline = json.load(json_file)
  85. data = construct_pipeline(dict_pipeline)
  86. return data
  87. def expand_path(node_repr, key, val):
  88. """Convert relative to absolute path."""
  89. if isinstance(val, list):
  90. node_repr[key] = [os.path.abspath(file) for file in val]
  91. else:
  92. node_repr[key] = os.path.abspath(val)
  93. def serialize_operations(node_repr, key, val):
  94. """Serialize tensor op (python object) to dictionary."""
  95. if isinstance(val, list):
  96. node_repr[key] = []
  97. for op in val:
  98. node_repr[key].append(op.__dict__)
  99. # Extracting module and name information from a python object
  100. # Example: tensor_op_module is 'minddata.transforms.c_transforms' and tensor_op_name is 'Decode'
  101. node_repr[key][-1]['tensor_op_name'] = type(op).__name__
  102. node_repr[key][-1]['tensor_op_module'] = type(op).__module__
  103. else:
  104. node_repr[key] = val.__dict__
  105. node_repr[key]['tensor_op_name'] = type(val).__name__
  106. node_repr[key]['tensor_op_module'] = type(val).__module__
  107. def serialize_sampler(node_repr, val):
  108. """Serialize sampler object to dictionary."""
  109. if val is None:
  110. node_repr['sampler'] = None
  111. else:
  112. node_repr['sampler'] = val.__dict__
  113. node_repr['sampler']['sampler_module'] = type(val).__module__
  114. node_repr['sampler']['sampler_name'] = type(val).__name__
  115. def traverse(node):
  116. """Pre-order traverse the pipeline and capture the information as we go."""
  117. # Node representation (node_repr) is a python dictionary that capture and store the
  118. # dataset pipeline information before dumping it to JSON or other format.
  119. node_repr = dict()
  120. node_repr['op_type'] = type(node).__name__
  121. node_repr['op_module'] = type(node).__module__
  122. # Start with an empty list of children, will be added later as we traverse this node.
  123. node_repr["children"] = []
  124. # Retrieve the information about the current node. It should include arguments
  125. # passed to the node during object construction.
  126. node_args = node.get_args()
  127. for k, v in node_args.items():
  128. # Store the information about this node into node_repr.
  129. # Further serialize the object in the arguments if needed.
  130. if k == 'operations':
  131. serialize_operations(node_repr, k, v)
  132. elif k == 'sampler':
  133. serialize_sampler(node_repr, v)
  134. elif k == 'padded_sample' and v:
  135. v1 = {key: value for key, value in v.items() if not isinstance(value, bytes)}
  136. node_repr[k] = json.dumps(v1, indent=2)
  137. # return schema json str if its type is mindspore.dataset.Schema
  138. elif k == 'schema' and isinstance(v, de.Schema):
  139. node_repr[k] = v.to_json()
  140. elif k in set(['schema', 'dataset_files', 'dataset_dir', 'schema_file_path']):
  141. expand_path(node_repr, k, v)
  142. elif k == "num_parallel_workers" and v is None:
  143. node_repr[k] = config.get_num_parallel_workers()
  144. else:
  145. node_repr[k] = v
  146. # If a sampler exists in this node, then the following 4 arguments must be set to None:
  147. # num_samples, shard_id, num_shards, shuffle
  148. # These arguments get moved into the sampler itself, so they are no longer needed to
  149. # be set at the dataset level.
  150. if 'sampler' in node_args.keys():
  151. if 'num_samples' in node_repr.keys():
  152. node_repr['num_samples'] = None
  153. if 'shuffle' in node_repr.keys():
  154. node_repr['shuffle'] = None
  155. if 'num_shards' in node_repr.keys():
  156. node_repr['num_shards'] = None
  157. if 'shard_id' in node_repr.keys():
  158. node_repr['shard_id'] = None
  159. # Leaf node doesn't have input attribute.
  160. if not node.children:
  161. return node_repr
  162. # Recursively traverse the child and assign it to the current node_repr['children'].
  163. for child in node.children:
  164. node_repr["children"].append(traverse(child))
  165. return node_repr
  166. def show(dataset, indentation=2):
  167. """
  168. Write the dataset pipeline graph onto logger.info.
  169. Args:
  170. dataset (Dataset): the starting node.
  171. indentation (int, optional): indentation used by the json print. Pass None to not indent.
  172. """
  173. pipeline = traverse(dataset)
  174. logger.info(json.dumps(pipeline, indent=indentation))
  175. def compare(pipeline1, pipeline2):
  176. """
  177. Compare if two dataset pipelines are the same.
  178. Args:
  179. pipeline1 (Dataset): a dataset pipeline.
  180. pipeline2 (Dataset): a dataset pipeline.
  181. """
  182. return traverse(pipeline1) == traverse(pipeline2)
  183. def construct_pipeline(node):
  184. """Construct the python Dataset objects by following the dictionary deserialized from json file."""
  185. op_type = node.get('op_type')
  186. if not op_type:
  187. raise ValueError("op_type field in the json file can't be None.")
  188. # Instantiate python Dataset object based on the current dictionary element
  189. dataset = create_node(node)
  190. # Initially it is not connected to any other object.
  191. dataset.children = []
  192. # Construct the children too and add edge between the children and parent.
  193. for child in node['children']:
  194. dataset.children.append(construct_pipeline(child))
  195. return dataset
  196. def create_node(node):
  197. """Parse the key, value in the node dictionary and instantiate the python Dataset object"""
  198. logger.info('creating node: %s', node['op_type'])
  199. dataset_op = node['op_type']
  200. op_module = node['op_module']
  201. # Get the python class to be instantiated.
  202. # Example:
  203. # "op_type": "MapDataset",
  204. # "op_module": "mindspore.dataset.datasets",
  205. pyclass = getattr(sys.modules[op_module], dataset_op)
  206. pyobj = None
  207. # Find a matching Dataset class and call the constructor with the corresponding args.
  208. # When a new Dataset class is introduced, another if clause and parsing code needs to be added.
  209. if dataset_op == 'ImageFolderDatasetV2':
  210. sampler = construct_sampler(node.get('sampler'))
  211. pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
  212. node.get('shuffle'), sampler, node.get('extensions'),
  213. node.get('class_indexing'), node.get('decode'), node.get('num_shards'),
  214. node.get('shard_id'))
  215. elif dataset_op == 'RangeDataset':
  216. pyobj = pyclass(node['start'], node['stop'], node['step'])
  217. elif dataset_op == 'ImageFolderDataset':
  218. pyobj = pyclass(node['dataset_dir'], node['schema'], node.get('distribution'),
  219. node.get('column_list'), node.get('num_parallel_workers'),
  220. node.get('deterministic_output'), node.get('prefetch_size'),
  221. node.get('labels_filename'), node.get('dataset_usage'))
  222. elif dataset_op == 'MnistDataset':
  223. sampler = construct_sampler(node.get('sampler'))
  224. pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
  225. node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
  226. elif dataset_op == 'MindDataset':
  227. sampler = construct_sampler(node.get('sampler'))
  228. pyobj = pyclass(node['dataset_file'], node.get('columns_list'),
  229. node.get('num_parallel_workers'), node.get('seed'), node.get('num_shards'),
  230. node.get('shard_id'), node.get('block_reader'), sampler)
  231. elif dataset_op == 'TFRecordDataset':
  232. pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'),
  233. node.get('num_samples'), node.get('num_parallel_workers'),
  234. de.Shuffle(node.get('shuffle')), node.get('num_shards'), node.get('shard_id'))
  235. elif dataset_op == 'ManifestDataset':
  236. sampler = construct_sampler(node.get('sampler'))
  237. pyobj = pyclass(node['dataset_file'], node['usage'], node.get('num_samples'),
  238. node.get('num_parallel_workers'), node.get('shuffle'), sampler,
  239. node.get('class_indexing'), node.get('decode'), node.get('num_shards'),
  240. node.get('shard_id'))
  241. elif dataset_op == 'Cifar10Dataset':
  242. sampler = construct_sampler(node.get('sampler'))
  243. pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
  244. node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
  245. elif dataset_op == 'Cifar100Dataset':
  246. sampler = construct_sampler(node.get('sampler'))
  247. pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
  248. node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
  249. elif dataset_op == 'VOCDataset':
  250. sampler = construct_sampler(node.get('sampler'))
  251. pyobj = pyclass(node['dataset_dir'], node.get('task'), node.get('mode'), node.get('class_indexing'),
  252. node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'),
  253. node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id'))
  254. elif dataset_op == 'CocoDataset':
  255. sampler = construct_sampler(node.get('sampler'))
  256. pyobj = pyclass(node['dataset_dir'], node.get('annotation_file'), node.get('task'), node.get('num_samples'),
  257. node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler,
  258. node.get('num_shards'), node.get('shard_id'))
  259. elif dataset_op == 'CelebADataset':
  260. sampler = construct_sampler(node.get('sampler'))
  261. pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'),
  262. node.get('dataset_type'), sampler, node.get('decode'), node.get('extensions'),
  263. node.get('num_samples'), sampler, node.get('num_shards'), node.get('shard_id'))
  264. elif dataset_op == 'GeneratorDataset':
  265. # Serializing py function can be done using marshal library
  266. raise RuntimeError(dataset_op + " is not yet supported")
  267. elif dataset_op == 'RepeatDataset':
  268. pyobj = de.Dataset().repeat(node.get('count'))
  269. elif dataset_op == 'SkipDataset':
  270. pyobj = de.Dataset().skip(node.get('count'))
  271. elif dataset_op == 'TakeDataset':
  272. pyobj = de.Dataset().take(node.get('count'))
  273. elif dataset_op == 'MapDataset':
  274. tensor_ops = construct_tensor_ops(node.get('operations'))
  275. pyobj = de.Dataset().map(node.get('input_columns'), tensor_ops, node.get('output_columns'),
  276. node.get('columns_order'), node.get('num_parallel_workers'))
  277. elif dataset_op == 'ShuffleDataset':
  278. pyobj = de.Dataset().shuffle(node.get('buffer_size'))
  279. elif dataset_op == 'BatchDataset':
  280. pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder'))
  281. elif dataset_op == 'CacheDataset':
  282. # Member function cache() is not defined in class Dataset yet.
  283. raise RuntimeError(dataset_op + " is not yet supported")
  284. elif dataset_op == 'FilterDataset':
  285. # Member function filter() is not defined in class Dataset yet.
  286. raise RuntimeError(dataset_op + " is not yet supported")
  287. elif dataset_op == 'TakeDataset':
  288. # Member function take() is not defined in class Dataset yet.
  289. raise RuntimeError(dataset_op + " is not yet supported")
  290. elif dataset_op == 'ZipDataset':
  291. # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller.
  292. pyobj = de.ZipDataset((de.Dataset(), de.Dataset()))
  293. elif dataset_op == 'ConcatDataset':
  294. # Create ConcatDataset instance, giving dummy input dataset that will be overrided in the caller.
  295. pyobj = de.ConcatDataset((de.Dataset(), de.Dataset()))
  296. elif dataset_op == 'RenameDataset':
  297. pyobj = de.Dataset().rename(node['input_columns'], node['output_columns'])
  298. elif dataset_op == 'ProjectDataset':
  299. pyobj = de.Dataset().project(node['columns'])
  300. elif dataset_op == 'TransferDataset':
  301. pyobj = de.Dataset().to_device()
  302. else:
  303. raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize()")
  304. return pyobj
  305. def construct_sampler(in_sampler):
  306. """Instantiate Sampler object based on the information from dictionary['sampler']"""
  307. sampler = None
  308. if in_sampler is not None:
  309. sampler_name = in_sampler['sampler_name']
  310. sampler_module = in_sampler['sampler_module']
  311. sampler_class = getattr(sys.modules[sampler_module], sampler_name)
  312. if sampler_name == 'DistributedSampler':
  313. sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle'))
  314. elif sampler_name == 'PKSampler':
  315. sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle'))
  316. elif sampler_name == 'RandomSampler':
  317. sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples'))
  318. elif sampler_name == 'SequentialSampler':
  319. sampler = sampler_class()
  320. elif sampler_name == 'SubsetRandomSampler':
  321. sampler = sampler_class(in_sampler['indices'])
  322. elif sampler_name == 'WeightedRandomSampler':
  323. sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement'))
  324. else:
  325. raise ValueError("Sampler type is unknown: " + sampler_name)
  326. return sampler
  327. def construct_tensor_ops(operations):
  328. """Instantiate tensor op object(s) based on the information from dictionary['operations']"""
  329. result = []
  330. for op in operations:
  331. op_module = op['tensor_op_module']
  332. op_name = op['tensor_op_name']
  333. op_class = getattr(sys.modules[op_module], op_name)
  334. if op_name == 'Decode':
  335. result.append(op_class(op.get('rgb')))
  336. elif op_name == 'Normalize':
  337. result.append(op_class(op['mean'], op['std']))
  338. elif op_name == 'RandomCrop':
  339. result.append(op_class(op['size'], op.get('padding'), op.get('pad_if_needed'),
  340. op.get('fill_value'), Border(op.get('padding_mode'))))
  341. elif op_name == 'RandomHorizontalFlip':
  342. result.append(op_class(op.get('prob')))
  343. elif op_name == 'RandomVerticalFlip':
  344. result.append(op_class(op.get('prob')))
  345. elif op_name == 'Resize':
  346. result.append(op_class(op['size'], Inter(op.get('interpolation'))))
  347. elif op_name == 'RandomResizedCrop':
  348. result.append(op_class(op['size'], op.get('scale'), op.get('ratio'),
  349. Inter(op.get('interpolation')), op.get('max_attempts')))
  350. elif op_name == 'CenterCrop':
  351. result.append(op_class(op['size']))
  352. elif op_name == 'RandomColorAdjust':
  353. result.append(op_class(op.get('brightness'), op.get('contrast'), op.get('saturation'),
  354. op.get('hue')))
  355. elif op_name == 'RandomRotation':
  356. result.append(op_class(op['degree'], op.get('resample'), op.get('expand'),
  357. op.get('center'), op.get('fill_value')))
  358. elif op_name == 'Rescale':
  359. result.append(op_class(op['rescale'], op['shift']))
  360. elif op_name == 'RandomResize':
  361. result.append(op_class(op['size']))
  362. elif op_name == 'TypeCast':
  363. result.append(op_class(op['data_type']))
  364. elif op_name == 'HWC2CHW':
  365. result.append(op_class())
  366. elif op_name == 'CHW2HWC':
  367. raise ValueError("Tensor op is not supported: " + op_name)
  368. elif op_name == 'OneHot':
  369. result.append(op_class(op['num_classes']))
  370. elif op_name == 'RandomCropDecodeResize':
  371. result.append(op_class(op['size'], op.get('scale'), op.get('ratio'),
  372. Inter(op.get('interpolation')), op.get('max_attempts')))
  373. elif op_name == 'Pad':
  374. result.append(op_class(op['padding'], op['fill_value'], Border(op['padding_mode'])))
  375. else:
  376. raise ValueError("Tensor op name is unknown: " + op_name)
  377. return result