You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_dataset_graph.py 6.1 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define dataset graph related operations."""
  16. import json
  17. from importlib import import_module
  18. from mindspore import log as logger
  19. from mindspore.train import lineage_pb2
  20. class DatasetGraph:
  21. """Handle the data graph and packages it into binary data."""
  22. def package_dataset_graph(self, dataset):
  23. """
  24. packages dataset graph into binary data
  25. Args:
  26. dataset (MindDataset): Refer to MindDataset.
  27. Returns:
  28. DatasetGraph, a object of lineage_pb2.DatasetGraph.
  29. """
  30. dataset_package = import_module('mindspore.dataset')
  31. try:
  32. dataset_dict = dataset_package.serialize(dataset)
  33. except (TypeError, OSError) as exc:
  34. logger.warning("Summary can not collect dataset graph, there is an error in dataset internal, "
  35. "detail: %s.", str(exc))
  36. return None
  37. dataset_graph_proto = lineage_pb2.DatasetGraph()
  38. if not isinstance(dataset_dict, dict):
  39. logger.warning("The dataset graph serialized from dataset object is not a dict. "
  40. "Its type is %r.", type(dataset_dict).__name__)
  41. return dataset_graph_proto
  42. if "children" in dataset_dict:
  43. children = dataset_dict.pop("children")
  44. if children:
  45. self._package_children(children=children, message=dataset_graph_proto)
  46. self._package_current_dataset(operation=dataset_dict, message=dataset_graph_proto)
  47. return dataset_graph_proto
  48. def _package_children(self, children, message):
  49. """
  50. Package children in dataset operation.
  51. Args:
  52. children (list[dict]): Child operations.
  53. message (DatasetGraph): Children proto message.
  54. """
  55. for child in children:
  56. if child:
  57. child_graph_message = getattr(message, "children").add()
  58. grandson = child.pop("children")
  59. if grandson:
  60. self._package_children(children=grandson, message=child_graph_message)
  61. # package other parameters
  62. self._package_current_dataset(operation=child, message=child_graph_message)
  63. def _package_current_dataset(self, operation, message):
  64. """
  65. Package operation parameters in event message.
  66. Args:
  67. operation (dict): Operation dict.
  68. message (Operation): Operation proto message.
  69. """
  70. for key, value in operation.items():
  71. if value and key == "operations":
  72. for operator in value:
  73. self._package_enhancement_operation(
  74. operator,
  75. message.operations.add()
  76. )
  77. elif value and key == "sampler":
  78. self._package_enhancement_operation(
  79. value,
  80. message.sampler
  81. )
  82. else:
  83. self._package_parameter(key, value, message.parameter)
  84. def _package_enhancement_operation(self, operation, message):
  85. """
  86. Package enhancement operation in MapDataset.
  87. Args:
  88. operation (dict): Enhancement operation.
  89. message (Operation): Enhancement operation proto message.
  90. """
  91. if operation is None:
  92. logger.warning("Summary cannot collect the operation for dataset graph as the operation is none."
  93. "it may due to the custom operation cannot be pickled.")
  94. return
  95. for key, value in operation.items():
  96. if isinstance(value, (list, tuple)):
  97. if all(isinstance(ele, int) for ele in value):
  98. message.size.extend(value)
  99. else:
  100. message.weights.extend(value)
  101. else:
  102. self._package_parameter(key, value, message.operationParam)
  103. @staticmethod
  104. def _package_parameter(key, value, message):
  105. """
  106. Package parameters in operation.
  107. Args:
  108. key (str): Operation name.
  109. value (Union[str, bool, int, float, list, None]): Operation args.
  110. message (OperationParameter): Operation proto message.
  111. """
  112. if isinstance(value, str):
  113. message.mapStr[key] = value
  114. elif isinstance(value, bool):
  115. message.mapBool[key] = value
  116. elif isinstance(value, int):
  117. message.mapInt[key] = value
  118. elif isinstance(value, float):
  119. message.mapDouble[key] = value
  120. elif isinstance(value, (list, tuple)) and key != "operations":
  121. if value:
  122. replace_value_list = list(map(lambda x: "" if x is None else json.dumps(x), value))
  123. message.mapStrList[key].strValue.extend(replace_value_list)
  124. elif isinstance(value, dict):
  125. try:
  126. message.mapStr[key] = json.dumps(value)
  127. except TypeError as exo:
  128. logger.warning("Transform the value of parameter %r to string failed. Detail: %s.", key, str(exo))
  129. elif value is None:
  130. message.mapStr[key] = "None"
  131. else:
  132. logger.warning("The parameter %r is not recorded, because its type is not supported in event package. "
  133. "Its type is %r.", key, type(value).__name__)