You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Train utility."""
  16. import os
  17. from collections.abc import Iterable
  18. import numpy as np
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
  21. from mindspore.common import dtype as mstype
  22. from mindspore import log as logger
  23. from mindspore.common.api import _executor
  24. from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
  25. from mindspore.train.anf_ir_pb2 import ModelProto as anf_model
  26. from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
  27. def _convert_type(types):
  28. """
  29. Convert from numpy type to tensor type.
  30. Args:
  31. types (list): Numpy type list of element in dataset.
  32. Returns:
  33. list, list of element in dataset.
  34. """
  35. ms_types = []
  36. for np_type in types:
  37. ms_type = pytype_to_dtype(np_type)
  38. ms_types.append(ms_type)
  39. return ms_types
  40. def _get_types_and_shapes(dataset):
  41. """Get dataset types and shapes."""
  42. dataset_types = _convert_type(dataset.output_types())
  43. dataset_shapes = dataset.output_shapes()
  44. return dataset_types, dataset_shapes
  45. def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_info_queue=False):
  46. """Initialize and execute the dataset graph."""
  47. batch_size = exec_dataset.get_batch_size()
  48. input_indexs = exec_dataset.input_indexs
  49. # transform data format
  50. dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
  51. send_epoch_end = bool(dataset_size == -1)
  52. exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
  53. _executor.init_dataset(exec_dataset.queue_name,
  54. dataset_size,
  55. batch_size,
  56. dataset_types,
  57. dataset_shapes,
  58. input_indexs,
  59. phase=phase)
  60. return exec_dataset
  61. def _make_directory(path: str):
  62. """Make directory."""
  63. if path is None or not isinstance(path, str) or path.strip() == "":
  64. logger.error("The path(%r) is invalid type.", path)
  65. raise TypeError("Input path is invalid type")
  66. path = os.path.realpath(path)
  67. logger.debug("The abs path is %r", path)
  68. if os.path.exists(path):
  69. real_path = path
  70. else:
  71. logger.debug("The directory(%s) doesn't exist, will create it", path)
  72. try:
  73. permissions = os.R_OK | os.W_OK | os.X_OK
  74. os.umask(permissions << 3 | permissions)
  75. mode = permissions << 6
  76. os.makedirs(path, mode=mode, exist_ok=True)
  77. real_path = path
  78. except PermissionError as e:
  79. logger.error("No write permission on the directory(%r), error = %r", path, e)
  80. raise TypeError("No write permission on the directory.")
  81. return real_path
  82. def _construct_tensor_list(types, shapes, batch_expand_num=1):
  83. """
  84. Construct list of tensors with types and shapes, used to initialize the network.
  85. Args:
  86. types: List or Tuple. The output types of element in dataset.
  87. shapes: List or Tuple. The output shapes of element in dataset.
  88. batch_expand_num (int): Batch expand number.
  89. Returns:
  90. List, list of Tensors.
  91. """
  92. if len(types) != len(shapes):
  93. raise ValueError("The length of dataset types must equal to dataset shapes, "
  94. "but got dataset types={} and dataset shapes={}".format(types, shapes))
  95. tensor_list = []
  96. for type_, shape in zip(types, shapes):
  97. new_shape = ()
  98. for i, item in enumerate(shape):
  99. if i == 0:
  100. new_shape += (item * batch_expand_num,)
  101. else:
  102. new_shape += (item,)
  103. tensor = Tensor(np.zeros(new_shape, dtype_to_nptype(type_)))
  104. tensor.virtual_flag = True
  105. tensor_list.append(tensor)
  106. return tensor_list
  107. def _to_tensor(elem, scaling_sens=None):
  108. """Convert numpy to tensor, adapt to feed the data from host solution."""
  109. lst = []
  110. if not isinstance(elem, (tuple, list)):
  111. elem = [elem]
  112. for data in elem:
  113. if not isinstance(data, np.ndarray):
  114. if scaling_sens:
  115. elem_tuple = tuple(elem) + (Tensor(scaling_sens, mstype.float32),)
  116. else:
  117. elem_tuple = tuple(elem)
  118. return elem_tuple
  119. lst.append(Tensor(data))
  120. if scaling_sens:
  121. lst.append(Tensor(scaling_sens, mstype.float32))
  122. return lst[0] if len(lst) == 1 else tuple(lst)
  123. def _construct_input_tensors(dataset_types, dataset_shapes, device_number=1):
  124. """Construct tensor list to initialize the network which implemented in dataset sink."""
  125. tensor_list_run = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=1)
  126. tensor_list_compile = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=device_number)
  127. return tensor_list_run, tensor_list_compile
  128. def _check_to_numpy(plugin, tensor):
  129. """Check the tensor and return a numpy.ndarray."""
  130. np_value = tensor.asnumpy()
  131. np_value = np_value.copy()
  132. if plugin == 'scalar':
  133. if np_value.size == 1:
  134. return np_value
  135. raise ValueError('The tensor holds more than one value, but the scalar plugin expects on value.')
  136. if plugin == 'image':
  137. if np_value.ndim == 4:
  138. return np_value
  139. raise ValueError('The tensor seems not to hold a valid image.')
  140. if plugin in ('tensor', 'histogram'):
  141. if np_value.ndim > 0:
  142. return np_value
  143. raise ValueError('The tensor should not be empty.')
  144. return np_value
  145. def _check_lineage_value(plugin, value):
  146. """Check the lineage value."""
  147. def raises(plugin, prototype):
  148. raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.')
  149. if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph):
  150. raises(plugin, DatasetGraph)
  151. if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage):
  152. raises(plugin, EvaluationLineage)
  153. if plugin == 'train_lineage' and not isinstance(value, TrainLineage):
  154. raises(plugin, TrainLineage)
  155. if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo):
  156. raises(plugin, UserDefinedInfo)
  157. def check_value_type(arg_name, arg_value, valid_types):
  158. """Checks whether a value is instance of some types."""
  159. valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,)
  160. is_valid = True
  161. # bool is subclass of int, so for a bool value, we need to extra check
  162. if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types:
  163. is_valid = False
  164. if not isinstance(arg_value, valid_types):
  165. is_valid = False
  166. if not is_valid:
  167. raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
  168. f'but got {type(arg_value).__name__}.')
  169. def read_proto(file_name, proto_format="MINDIR"):
  170. """
  171. Read protobuf file.
  172. Args:
  173. file_name (str): File name.
  174. proto_format (str): Proto format.
  175. Returns:
  176. Object, proto object.
  177. """
  178. if proto_format == "MINDIR":
  179. model = mindir_model()
  180. elif model_format == "ANF":
  181. model = anf_model()
  182. else:
  183. raise ValueError("Unsupported proto format.")
  184. try:
  185. with open(file_name, "rb") as f:
  186. pb_content = f.read()
  187. model.ParseFromString(pb_content)
  188. except BaseException as e:
  189. logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name)
  190. raise ValueError(e.__str__())
  191. return model