You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Train utility."""
  16. import os
  17. from collections.abc import Iterable
  18. import numpy as np
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
  21. from mindspore.common import dtype as mstype
  22. from mindspore import log as logger
  23. from mindspore.common.api import _executor
  24. from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
  25. from mindspore.train.checkpoint_pb2 import Checkpoint
  26. from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strategy
  27. from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
  28. def _convert_type(types):
  29. """
  30. Convert from numpy type to tensor type.
  31. Args:
  32. types (list): Numpy type list of element in dataset.
  33. Returns:
  34. list, list of element in dataset.
  35. """
  36. ms_types = []
  37. for np_type in types:
  38. ms_type = pytype_to_dtype(np_type)
  39. ms_types.append(ms_type)
  40. return ms_types
  41. def _get_types_and_shapes(dataset):
  42. """Get dataset types and shapes."""
  43. dataset_types = _convert_type(dataset.output_types())
  44. dataset_shapes = dataset.output_shapes()
  45. return dataset_types, dataset_shapes
  46. def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_info_queue=False):
  47. """Initialize and execute the dataset graph."""
  48. batch_size = exec_dataset.get_batch_size()
  49. input_indexs = exec_dataset.input_indexs
  50. # transform data format
  51. dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
  52. send_epoch_end = bool(dataset_size == -1)
  53. exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
  54. _executor.init_dataset(exec_dataset.queue_name,
  55. dataset_size,
  56. batch_size,
  57. dataset_types,
  58. dataset_shapes,
  59. input_indexs,
  60. phase=phase)
  61. return exec_dataset
  62. def _make_directory(path: str):
  63. """Make directory."""
  64. if path is None or not isinstance(path, str) or path.strip() == "":
  65. logger.error("The path(%r) is invalid type.", path)
  66. raise TypeError("Input path is invalid type")
  67. path = os.path.realpath(path)
  68. logger.debug("The abs path is %r", path)
  69. if os.path.exists(path):
  70. real_path = path
  71. else:
  72. logger.debug("The directory(%s) doesn't exist, will create it", path)
  73. try:
  74. permissions = os.R_OK | os.W_OK | os.X_OK
  75. os.umask(permissions << 3 | permissions)
  76. mode = permissions << 6
  77. os.makedirs(path, mode=mode, exist_ok=True)
  78. real_path = path
  79. except PermissionError as e:
  80. logger.error("No write permission on the directory(%r), error = %r", path, e)
  81. raise TypeError("No write permission on the directory.")
  82. return real_path
  83. def _construct_tensor_list(types, shapes, batch_expand_num=1):
  84. """
  85. Construct list of tensors with types and shapes, used to initialize the network.
  86. Args:
  87. types: List or Tuple. The output types of element in dataset.
  88. shapes: List or Tuple. The output shapes of element in dataset.
  89. batch_expand_num (int): Batch expand number.
  90. Returns:
  91. List, list of Tensors.
  92. """
  93. if len(types) != len(shapes):
  94. raise ValueError("The length of dataset types must equal to dataset shapes, "
  95. "but got dataset types={} and dataset shapes={}".format(types, shapes))
  96. tensor_list = []
  97. for type_, shape in zip(types, shapes):
  98. new_shape = ()
  99. for i, item in enumerate(shape):
  100. if i == 0:
  101. new_shape += (item * batch_expand_num,)
  102. else:
  103. new_shape += (item,)
  104. tensor = Tensor(np.zeros(new_shape, dtype_to_nptype(type_)))
  105. tensor.virtual_flag = True
  106. tensor_list.append(tensor)
  107. return tensor_list
  108. def _to_tensor(elem, scaling_sens=None):
  109. """Convert numpy to tensor, adapt to feed the data from host solution."""
  110. lst = []
  111. if not isinstance(elem, (tuple, list)):
  112. elem = [elem]
  113. for data in elem:
  114. if not isinstance(data, np.ndarray):
  115. if scaling_sens:
  116. elem_tuple = tuple(elem) + (Tensor(scaling_sens, mstype.float32),)
  117. else:
  118. elem_tuple = tuple(elem)
  119. return elem_tuple
  120. lst.append(Tensor(data))
  121. if scaling_sens:
  122. lst.append(Tensor(scaling_sens, mstype.float32))
  123. return lst[0] if len(lst) == 1 else tuple(lst)
  124. def _construct_input_tensors(dataset_types, dataset_shapes, device_number=1):
  125. """Construct tensor list to initialize the network which implemented in dataset sink."""
  126. tensor_list_run = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=1)
  127. tensor_list_compile = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=device_number)
  128. return tensor_list_run, tensor_list_compile
  129. def _check_to_numpy(plugin, tensor):
  130. """Check the tensor and return a numpy.ndarray."""
  131. np_value = tensor.asnumpy()
  132. np_value = np_value.copy()
  133. if plugin == 'scalar':
  134. if np_value.size == 1:
  135. return np_value
  136. raise ValueError('The tensor holds more than one value, but the scalar plugin expects on value.')
  137. if plugin == 'image':
  138. if np_value.ndim == 4:
  139. return np_value
  140. raise ValueError('The tensor seems not to hold a valid image.')
  141. if plugin in ('tensor', 'histogram'):
  142. if np_value.ndim > 0:
  143. return np_value
  144. raise ValueError('The tensor should not be empty.')
  145. return np_value
  146. def _check_lineage_value(plugin, value):
  147. """Check the lineage value."""
  148. def raises(plugin, prototype):
  149. raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.')
  150. if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph):
  151. raises(plugin, DatasetGraph)
  152. if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage):
  153. raises(plugin, EvaluationLineage)
  154. if plugin == 'train_lineage' and not isinstance(value, TrainLineage):
  155. raises(plugin, TrainLineage)
  156. if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo):
  157. raises(plugin, UserDefinedInfo)
  158. def check_value_type(arg_name, arg_value, valid_types):
  159. """Checks whether a value is instance of some types."""
  160. valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,)
  161. is_valid = True
  162. # bool is subclass of int, so for a bool value, we need to extra check
  163. if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types:
  164. is_valid = False
  165. if not isinstance(arg_value, valid_types):
  166. is_valid = False
  167. if not is_valid:
  168. raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
  169. f'but got {type(arg_value).__name__}.')
  170. def read_proto(file_name, proto_format="MINDIR", display_data=False):
  171. """
  172. Read protobuf file.
  173. Args:
  174. file_name (str): File name.
  175. proto_format (str): Proto format {MINDIR, CKPT, CKPT_STRATEGY}. Default: MINDIR.
  176. display_data (bool): Whether display data. Default: False.
  177. Returns:
  178. Object, proto object.
  179. """
  180. if proto_format == "MINDIR":
  181. model = mindir_model()
  182. elif proto_format == "CKPT":
  183. model = Checkpoint()
  184. elif proto_format == "CKPT_STRATEGY":
  185. model = ckpt_strategy()
  186. else:
  187. raise ValueError("Unsupported proto format.")
  188. try:
  189. with open(file_name, "rb") as f:
  190. pb_content = f.read()
  191. model.ParseFromString(pb_content)
  192. except BaseException as e:
  193. logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name)
  194. raise ValueError(e.__str__())
  195. if proto_format == "MINDIR" and not display_data:
  196. for param_proto in model.graph.parameter:
  197. param_proto.raw_data = b'\0'
  198. if proto_format == "CKPT" and not display_data:
  199. for element in model.value:
  200. element.tensor.tensor_content = b'\0'
  201. return model