You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer_deserializer.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Functions to support dataset serialize and deserialize.
  17. """
  18. import json
  19. import os
  20. import sys
  21. from mindspore import log as logger
  22. from . import datasets as de
  23. from ..vision.utils import Inter, Border
  24. from ..core import config
  25. def serialize(dataset, json_filepath=None):
  26. """
  27. Serialize dataset pipeline into a json file.
  28. Args:
  29. dataset (Dataset): the starting node.
  30. json_filepath (str): a filepath where a serialized json file will be generated.
  31. Returns:
  32. dict containing the serialized dataset graph.
  33. Raises:
  34. OSError cannot open a file
  35. Examples:
  36. >>> import mindspore.dataset as ds
  37. >>> import mindspore.dataset.transforms.c_transforms as C
  38. >>> DATA_DIR = "../../data/testMnistData"
  39. >>> data = ds.MnistDataset(DATA_DIR, 100)
  40. >>> one_hot_encode = C.OneHot(10) # num_classes is input argument
  41. >>> data = data.map(operation=one_hot_encode, input_column_names="label")
  42. >>> data = data.batch(batch_size=10, drop_remainder=True)
  43. >>>
  44. >>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json") # serialize it to json file
  45. >>> serialized_data = ds.engine.serialize(data) # serialize it to Python dict
  46. """
  47. serialized_pipeline = traverse(dataset)
  48. if json_filepath:
  49. with open(json_filepath, 'w') as json_file:
  50. json.dump(serialized_pipeline, json_file, indent=2)
  51. return serialized_pipeline
  52. def deserialize(input_dict=None, json_filepath=None):
  53. """
  54. Construct a de pipeline from a json file produced by de.serialize().
  55. Args:
  56. input_dict (dict): a Python dictionary containing a serialized dataset graph
  57. json_filepath (str): a path to the json file.
  58. Returns:
  59. de.Dataset or None if error occurs.
  60. Raises:
  61. OSError cannot open a file.
  62. Examples:
  63. >>> import mindspore.dataset as ds
  64. >>> import mindspore.dataset.transforms.c_transforms as C
  65. >>> DATA_DIR = "../../data/testMnistData"
  66. >>> data = ds.MnistDataset(DATA_DIR, 100)
  67. >>> one_hot_encode = C.OneHot(10) # num_classes is input argument
  68. >>> data = data.map(operation=one_hot_encode, input_column_names="label")
  69. >>> data = data.batch(batch_size=10, drop_remainder=True)
  70. >>>
  71. >>> # Use case 1: to/from json file
  72. >>> ds.engine.serialize(data, json_filepath="mnist_dataset_pipeline.json")
  73. >>> data = ds.engine.deserialize(json_filepath="mnist_dataset_pipeline.json")
  74. >>> # Use case 2: to/from Python dictionary
  75. >>> serialized_data = ds.engine.serialize(data)
  76. >>> data = ds.engine.deserialize(input_dict=serialized_data)
  77. """
  78. data = None
  79. if input_dict:
  80. data = construct_pipeline(input_dict)
  81. if json_filepath:
  82. dict_pipeline = dict()
  83. with open(json_filepath, 'r') as json_file:
  84. dict_pipeline = json.load(json_file)
  85. data = construct_pipeline(dict_pipeline)
  86. return data
  87. def expand_path(node_repr, key, val):
  88. """Convert relative to absolute path."""
  89. if isinstance(val, list):
  90. node_repr[key] = [os.path.abspath(file) for file in val]
  91. else:
  92. node_repr[key] = os.path.abspath(val)
  93. def serialize_operations(node_repr, key, val):
  94. """Serialize tensor op (Python object) to dictionary."""
  95. if isinstance(val, list):
  96. node_repr[key] = []
  97. for op in val:
  98. node_repr[key].append(op.__dict__)
  99. # Extracting module and name information from a Python object
  100. # Example: tensor_op_module is 'minddata.transforms.c_transforms' and tensor_op_name is 'Decode'
  101. node_repr[key][-1]['tensor_op_name'] = type(op).__name__
  102. node_repr[key][-1]['tensor_op_module'] = type(op).__module__
  103. else:
  104. node_repr[key] = val.__dict__
  105. node_repr[key]['tensor_op_name'] = type(val).__name__
  106. node_repr[key]['tensor_op_module'] = type(val).__module__
  107. def serialize_sampler(node_repr, val):
  108. """Serialize sampler object to dictionary."""
  109. if val is None:
  110. node_repr['sampler'] = None
  111. else:
  112. node_repr['sampler'] = val.__dict__
  113. node_repr['sampler']['sampler_module'] = type(val).__module__
  114. node_repr['sampler']['sampler_name'] = type(val).__name__
  115. def traverse(node):
  116. """Pre-order traverse the pipeline and capture the information as we go."""
  117. # Node representation (node_repr) is a Python dictionary that capture and store the
  118. # dataset pipeline information before dumping it to JSON or other format.
  119. node_repr = dict()
  120. node_repr['op_type'] = type(node).__name__
  121. node_repr['op_module'] = type(node).__module__
  122. # Start with an empty list of children, will be added later as we traverse this node.
  123. node_repr["children"] = []
  124. # Retrieve the information about the current node. It should include arguments
  125. # passed to the node during object construction.
  126. node_args = node.get_args()
  127. for k, v in node_args.items():
  128. # Store the information about this node into node_repr.
  129. # Further serialize the object in the arguments if needed.
  130. if k == 'operations':
  131. serialize_operations(node_repr, k, v)
  132. elif k == 'sampler':
  133. serialize_sampler(node_repr, v)
  134. elif k == 'padded_sample' and v:
  135. v1 = {key: value for key, value in v.items() if not isinstance(value, bytes)}
  136. node_repr[k] = json.dumps(v1, indent=2)
  137. # return schema json str if its type is mindspore.dataset.Schema
  138. elif k == 'schema' and isinstance(v, de.Schema):
  139. node_repr[k] = v.to_json()
  140. elif k in set(['schema', 'dataset_files', 'dataset_dir', 'schema_file_path']):
  141. expand_path(node_repr, k, v)
  142. elif k == "num_parallel_workers" and v is None:
  143. node_repr[k] = config.get_num_parallel_workers()
  144. else:
  145. node_repr[k] = v
  146. # If a sampler exists in this node, then the following 4 arguments must be set to None:
  147. # num_samples, shard_id, num_shards, shuffle
  148. # These arguments get moved into the sampler itself, so they are no longer needed to
  149. # be set at the dataset level.
  150. # TF Record is a special case because it uses both the dataset and sampler arguments
  151. # which is not decided until later during tree preparation phase.
  152. if node_repr['op_type'] != 'TFRecordDataset' and 'sampler' in node_args.keys():
  153. if 'num_samples' in node_repr.keys():
  154. node_repr['num_samples'] = None
  155. if 'shuffle' in node_repr.keys():
  156. node_repr['shuffle'] = None
  157. if 'num_shards' in node_repr.keys():
  158. node_repr['num_shards'] = None
  159. if 'shard_id' in node_repr.keys():
  160. node_repr['shard_id'] = None
  161. # Leaf node doesn't have input attribute.
  162. if not node.children:
  163. return node_repr
  164. # Recursively traverse the child and assign it to the current node_repr['children'].
  165. for child in node.children:
  166. node_repr["children"].append(traverse(child))
  167. return node_repr
  168. def show(dataset, indentation=2):
  169. """
  170. Write the dataset pipeline graph onto logger.info.
  171. Args:
  172. dataset (Dataset): the starting node.
  173. indentation (int, optional): indentation used by the json print. Pass None to not indent.
  174. """
  175. pipeline = traverse(dataset)
  176. logger.info(json.dumps(pipeline, indent=indentation))
  177. def compare(pipeline1, pipeline2):
  178. """
  179. Compare if two dataset pipelines are the same.
  180. Args:
  181. pipeline1 (Dataset): a dataset pipeline.
  182. pipeline2 (Dataset): a dataset pipeline.
  183. """
  184. return traverse(pipeline1) == traverse(pipeline2)
  185. def construct_pipeline(node):
  186. """Construct the Python Dataset objects by following the dictionary deserialized from json file."""
  187. op_type = node.get('op_type')
  188. if not op_type:
  189. raise ValueError("op_type field in the json file can't be None.")
  190. # Instantiate Python Dataset object based on the current dictionary element
  191. dataset = create_node(node)
  192. # Initially it is not connected to any other object.
  193. dataset.children = []
  194. # Construct the children too and add edge between the children and parent.
  195. for child in node['children']:
  196. dataset.children.append(construct_pipeline(child))
  197. return dataset
  198. def create_node(node):
  199. """Parse the key, value in the node dictionary and instantiate the Python Dataset object"""
  200. logger.info('creating node: %s', node['op_type'])
  201. dataset_op = node['op_type']
  202. op_module = node['op_module']
  203. # Get the Python class to be instantiated.
  204. # Example:
  205. # "op_type": "MapDataset",
  206. # "op_module": "mindspore.dataset.datasets",
  207. pyclass = getattr(sys.modules[op_module], dataset_op)
  208. pyobj = None
  209. # Find a matching Dataset class and call the constructor with the corresponding args.
  210. # When a new Dataset class is introduced, another if clause and parsing code needs to be added.
  211. if dataset_op == 'ImageFolderDataset':
  212. sampler = construct_sampler(node.get('sampler'))
  213. pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
  214. node.get('shuffle'), sampler, node.get('extensions'),
  215. node.get('class_indexing'), node.get('decode'), node.get('num_shards'),
  216. node.get('shard_id'))
  217. elif dataset_op == 'RangeDataset':
  218. pyobj = pyclass(node['start'], node['stop'], node['step'])
  219. elif dataset_op == 'ImageFolderDataset':
  220. pyobj = pyclass(node['dataset_dir'], node['schema'], node.get('distribution'),
  221. node.get('column_list'), node.get('num_parallel_workers'),
  222. node.get('deterministic_output'), node.get('prefetch_size'),
  223. node.get('labels_filename'), node.get('dataset_usage'))
  224. elif dataset_op == 'MnistDataset':
  225. sampler = construct_sampler(node.get('sampler'))
  226. pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'),
  227. node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
  228. elif dataset_op == 'MindDataset':
  229. sampler = construct_sampler(node.get('sampler'))
  230. pyobj = pyclass(node['dataset_file'], node.get('columns_list'),
  231. node.get('num_parallel_workers'), node.get('seed'), node.get('num_shards'),
  232. node.get('shard_id'), sampler)
  233. elif dataset_op == 'TFRecordDataset':
  234. pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'),
  235. node.get('num_samples'), node.get('num_parallel_workers'),
  236. de.Shuffle(node.get('shuffle')), node.get('num_shards'), node.get('shard_id'))
  237. elif dataset_op == 'ManifestDataset':
  238. sampler = construct_sampler(node.get('sampler'))
  239. pyobj = pyclass(node['dataset_file'], node['usage'], node.get('num_samples'),
  240. node.get('num_parallel_workers'), node.get('shuffle'), sampler,
  241. node.get('class_indexing'), node.get('decode'), node.get('num_shards'),
  242. node.get('shard_id'))
  243. elif dataset_op == 'Cifar10Dataset':
  244. sampler = construct_sampler(node.get('sampler'))
  245. pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'),
  246. node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
  247. elif dataset_op == 'Cifar100Dataset':
  248. sampler = construct_sampler(node.get('sampler'))
  249. pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'),
  250. node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'))
  251. elif dataset_op == 'VOCDataset':
  252. sampler = construct_sampler(node.get('sampler'))
  253. pyobj = pyclass(node['dataset_dir'], node.get('task'), node.get('mode'), node.get('class_indexing'),
  254. node.get('num_samples'), node.get('num_parallel_workers'), node.get('shuffle'),
  255. node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id'))
  256. elif dataset_op == 'CocoDataset':
  257. sampler = construct_sampler(node.get('sampler'))
  258. pyobj = pyclass(node['dataset_dir'], node.get('annotation_file'), node.get('task'), node.get('num_samples'),
  259. node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler,
  260. node.get('num_shards'), node.get('shard_id'))
  261. elif dataset_op == 'CelebADataset':
  262. sampler = construct_sampler(node.get('sampler'))
  263. pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'),
  264. node.get('dataset_type'), sampler, node.get('decode'), node.get('extensions'),
  265. node.get('num_samples'), sampler, node.get('num_shards'), node.get('shard_id'))
  266. elif dataset_op == 'GeneratorDataset':
  267. # Serializing py function can be done using marshal library
  268. raise RuntimeError(dataset_op + " is not yet supported")
  269. elif dataset_op == 'RepeatDataset':
  270. pyobj = de.Dataset().repeat(node.get('count'))
  271. elif dataset_op == 'SkipDataset':
  272. pyobj = de.Dataset().skip(node.get('count'))
  273. elif dataset_op == 'TakeDataset':
  274. pyobj = de.Dataset().take(node.get('count'))
  275. elif dataset_op == 'MapDataset':
  276. tensor_ops = construct_tensor_ops(node.get('operations'))
  277. pyobj = de.Dataset().map(tensor_ops, node.get('input_columns'), node.get('output_columns'),
  278. node.get('column_order'), node.get('num_parallel_workers'))
  279. elif dataset_op == 'ShuffleDataset':
  280. pyobj = de.Dataset().shuffle(node.get('buffer_size'))
  281. elif dataset_op == 'BatchDataset':
  282. pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder'))
  283. elif dataset_op == 'CacheDataset':
  284. # Member function cache() is not defined in class Dataset yet.
  285. raise RuntimeError(dataset_op + " is not yet supported")
  286. elif dataset_op == 'FilterDataset':
  287. # Member function filter() is not defined in class Dataset yet.
  288. raise RuntimeError(dataset_op + " is not yet supported")
  289. elif dataset_op == 'TakeDataset':
  290. # Member function take() is not defined in class Dataset yet.
  291. raise RuntimeError(dataset_op + " is not yet supported")
  292. elif dataset_op == 'ZipDataset':
  293. # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller.
  294. pyobj = de.ZipDataset((de.Dataset(), de.Dataset()))
  295. elif dataset_op == 'ConcatDataset':
  296. # Create ConcatDataset instance, giving dummy input dataset that will be overrided in the caller.
  297. pyobj = de.ConcatDataset((de.Dataset(), de.Dataset()))
  298. elif dataset_op == 'RenameDataset':
  299. pyobj = de.Dataset().rename(node['input_columns'], node['output_columns'])
  300. elif dataset_op == 'ProjectDataset':
  301. pyobj = de.Dataset().project(node['columns'])
  302. elif dataset_op == 'TransferDataset':
  303. pyobj = de.Dataset().to_device()
  304. else:
  305. raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize()")
  306. return pyobj
  307. def construct_sampler(in_sampler):
  308. """Instantiate Sampler object based on the information from dictionary['sampler']"""
  309. sampler = None
  310. if in_sampler is not None:
  311. sampler_name = in_sampler['sampler_name']
  312. sampler_module = in_sampler['sampler_module']
  313. sampler_class = getattr(sys.modules[sampler_module], sampler_name)
  314. if sampler_name == 'DistributedSampler':
  315. sampler = sampler_class(in_sampler['num_shards'], in_sampler['shard_id'], in_sampler.get('shuffle'))
  316. elif sampler_name == 'PKSampler':
  317. sampler = sampler_class(in_sampler['num_val'], in_sampler.get('num_class'), in_sampler('shuffle'))
  318. elif sampler_name == 'RandomSampler':
  319. sampler = sampler_class(in_sampler.get('replacement'), in_sampler.get('num_samples'))
  320. elif sampler_name == 'SequentialSampler':
  321. sampler = sampler_class()
  322. elif sampler_name == 'SubsetRandomSampler':
  323. sampler = sampler_class(in_sampler['indices'])
  324. elif sampler_name == 'WeightedRandomSampler':
  325. sampler = sampler_class(in_sampler['weights'], in_sampler['num_samples'], in_sampler.get('replacement'))
  326. else:
  327. raise ValueError("Sampler type is unknown: " + sampler_name)
  328. return sampler
  329. def construct_tensor_ops(operations):
  330. """Instantiate tensor op object(s) based on the information from dictionary['operations']"""
  331. result = []
  332. for op in operations:
  333. op_module = op['tensor_op_module']
  334. op_name = op['tensor_op_name']
  335. op_class = getattr(sys.modules[op_module], op_name)
  336. if op_name == 'Decode':
  337. result.append(op_class(op.get('rgb')))
  338. elif op_name == 'Normalize':
  339. result.append(op_class(op['mean'], op['std']))
  340. elif op_name == 'RandomCrop':
  341. result.append(op_class(op['size'], op.get('padding'), op.get('pad_if_needed'),
  342. op.get('fill_value'), Border(op.get('padding_mode'))))
  343. elif op_name == 'RandomHorizontalFlip':
  344. result.append(op_class(op.get('prob')))
  345. elif op_name == 'RandomVerticalFlip':
  346. result.append(op_class(op.get('prob')))
  347. elif op_name == 'Resize':
  348. result.append(op_class(op['size'], Inter(op.get('interpolation'))))
  349. elif op_name == 'RandomResizedCrop':
  350. result.append(op_class(op['size'], op.get('scale'), op.get('ratio'),
  351. Inter(op.get('interpolation')), op.get('max_attempts')))
  352. elif op_name == 'CenterCrop':
  353. result.append(op_class(op['size']))
  354. elif op_name == 'RandomColorAdjust':
  355. result.append(op_class(op.get('brightness'), op.get('contrast'), op.get('saturation'),
  356. op.get('hue')))
  357. elif op_name == 'RandomRotation':
  358. result.append(op_class(op['degree'], op.get('resample'), op.get('expand'),
  359. op.get('center'), op.get('fill_value')))
  360. elif op_name == 'Rescale':
  361. result.append(op_class(op['rescale'], op['shift']))
  362. elif op_name == 'RandomResize':
  363. result.append(op_class(op['size']))
  364. elif op_name == 'TypeCast':
  365. result.append(op_class(op['data_type']))
  366. elif op_name == 'HWC2CHW':
  367. result.append(op_class())
  368. elif op_name == 'CHW2HWC':
  369. raise ValueError("Tensor op is not supported: " + op_name)
  370. elif op_name == 'OneHot':
  371. result.append(op_class(op['num_classes']))
  372. elif op_name == 'RandomCropDecodeResize':
  373. result.append(op_class(op['size'], op.get('scale'), op.get('ratio'),
  374. Inter(op.get('interpolation')), op.get('max_attempts')))
  375. elif op_name == 'Pad':
  376. result.append(op_class(op['padding'], op['fill_value'], Border(op['padding_mode'])))
  377. else:
  378. raise ValueError("Tensor op name is unknown: " + op_name)
  379. return result