You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Train utility."""
  16. import os
  17. from collections.abc import Iterable
  18. import numpy as np
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
  21. from mindspore.common import dtype as mstype
  22. from mindspore import log as logger
  23. from mindspore.common.api import _executor
  24. from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
  25. def _convert_type(types):
  26. """
  27. Convert from numpy type to tensor type.
  28. Args:
  29. types (list): Numpy type list of element in dataset.
  30. Returns:
  31. list, list of element in dataset.
  32. """
  33. ms_types = []
  34. for np_type in types:
  35. ms_type = pytype_to_dtype(np_type)
  36. ms_types.append(ms_type)
  37. return ms_types
  38. def _get_types_and_shapes(dataset):
  39. """Get dataset types and shapes."""
  40. dataset_types = _convert_type(dataset.output_types())
  41. dataset_shapes = dataset.output_shapes()
  42. return dataset_types, dataset_shapes
  43. def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
  44. """Initialize and execute the dataset graph."""
  45. batch_size = exec_dataset.get_batch_size()
  46. input_indexs = exec_dataset.input_indexs
  47. # transform data format
  48. dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
  49. send_epoch_end = bool(dataset_size == -1)
  50. exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end)
  51. _executor.init_dataset(exec_dataset.queue_name,
  52. dataset_size,
  53. batch_size,
  54. dataset_types,
  55. dataset_shapes,
  56. input_indexs,
  57. phase=phase)
  58. return exec_dataset
  59. def _make_directory(path: str):
  60. """Make directory."""
  61. real_path = None
  62. if path is None or not isinstance(path, str) or path.strip() == "":
  63. logger.error("The path(%r) is invalid type.", path)
  64. raise TypeError("Input path is invaild type")
  65. # convert the relative paths
  66. path = os.path.realpath(path)
  67. logger.debug("The abs path is %r", path)
  68. # check the path is exist and write permissions?
  69. if os.path.exists(path):
  70. real_path = path
  71. else:
  72. # All exceptions need to be caught because create directory maybe have some limit(permissions)
  73. logger.debug("The directory(%s) doesn't exist, will create it", path)
  74. try:
  75. os.makedirs(path, exist_ok=True)
  76. real_path = path
  77. except PermissionError as e:
  78. logger.error("No write permission on the directory(%r), error = %r", path, e)
  79. raise TypeError("No write permission on the directory.")
  80. return real_path
  81. def _construct_tensor_list(types, shapes, batch_expand_num=1):
  82. """
  83. Construct list of tensors with types and shapes, used to initialize the network.
  84. Args:
  85. types: List or Tuple. The output types of element in dataset.
  86. shapes: List or Tuple. The output shapes of element in dataset.
  87. batch_expand_num (int): Batch expand number.
  88. Returns:
  89. List, list of Tensors.
  90. """
  91. if len(types) != len(shapes):
  92. raise ValueError("The length of dataset types must equal to dataset shapes, "
  93. "but got dataset types={} and dataset shapes={}".format(types, shapes))
  94. tensor_list = []
  95. for type_, shape in zip(types, shapes):
  96. new_shape = ()
  97. for i, item in enumerate(shape):
  98. if i == 0:
  99. new_shape += (item * batch_expand_num,)
  100. else:
  101. new_shape += (item,)
  102. tensor = Tensor(np.zeros(new_shape, dtype_to_nptype(type_)))
  103. tensor.virtual_flag = True
  104. tensor_list.append(tensor)
  105. return tensor_list
  106. def _to_tensor(elem, scaling_sens=None):
  107. """Convert numpy to tensor, adapt to feed the data from host solution."""
  108. lst = []
  109. if not isinstance(elem, (tuple, list)):
  110. elem = [elem]
  111. for data in elem:
  112. if not isinstance(data, np.ndarray):
  113. if scaling_sens:
  114. elem_tuple = tuple(elem) + (Tensor(scaling_sens, mstype.float32),)
  115. else:
  116. elem_tuple = tuple(elem)
  117. return elem_tuple
  118. lst.append(Tensor(data))
  119. if scaling_sens:
  120. lst.append(Tensor(scaling_sens, mstype.float32))
  121. return lst[0] if len(lst) == 1 else tuple(lst)
  122. def _construct_input_tensors(dataset_types, dataset_shapes, device_number=1):
  123. """Construct tensor list to initialize the network which implemented in dataset sink."""
  124. tensor_list_run = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=1)
  125. tensor_list_compile = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=device_number)
  126. return tensor_list_run, tensor_list_compile
  127. def _check_to_numpy(plugin, tensor):
  128. """Check the tensor and return a numpy.ndarray."""
  129. np_value = tensor.asnumpy()
  130. if plugin == 'scalar':
  131. if np_value.size == 1:
  132. return np_value
  133. raise ValueError('The tensor holds more than one value, but the scalar plugin expects on value.')
  134. if plugin == 'image':
  135. if np_value.ndim == 4:
  136. return np_value
  137. raise ValueError('The tensor seems not to hold a valid image.')
  138. if plugin in ('tensor', 'histogram'):
  139. if np_value.ndim > 0:
  140. return np_value
  141. raise ValueError('The tensor should not be empty.')
  142. return np_value
  143. def _check_lineage_value(plugin, value):
  144. """Check the lineage value."""
  145. def raises(plugin, prototype):
  146. raise TypeError(f'Plugin {repr(plugin)} expects a {prototype.__name__} value.')
  147. if plugin == 'dataset_graph' and not isinstance(value, DatasetGraph):
  148. raises(plugin, DatasetGraph)
  149. if plugin == 'eval_lineage' and not isinstance(value, EvaluationLineage):
  150. raises(plugin, EvaluationLineage)
  151. if plugin == 'train_lineage' and not isinstance(value, TrainLineage):
  152. raises(plugin, TrainLineage)
  153. if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo):
  154. raises(plugin, UserDefinedInfo)
  155. def check_value_type(arg_name, arg_value, valid_types):
  156. """Checks whether a value is instance of some types."""
  157. valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,)
  158. is_valid = True
  159. # bool is subclass of int, so for a bool value, we need to extra check
  160. if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types:
  161. is_valid = False
  162. if not isinstance(arg_value, valid_types):
  163. is_valid = False
  164. if not is_valid:
  165. raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
  166. f'bug got {type(arg_value).__name__}.')