You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Train utility."""
  16. import os
  17. import numpy as np
  18. from mindspore.common.tensor import Tensor
  19. from mindspore.common.dtype import dtype_to_nptype
  20. from mindspore.common import dtype as mstype
  21. from mindspore import log as logger
  22. from mindspore.common.api import _executor
  23. from mindspore.common.dtype import pytype_to_dtype
  24. def _convert_type(types):
  25. """
  26. Convert from numpy type to tensor type.
  27. Args:
  28. types (list): Numpy type list of element in dataset.
  29. Returns:
  30. list, list of element in dataset.
  31. """
  32. ms_types = []
  33. for np_type in types:
  34. ms_type = pytype_to_dtype(np_type)
  35. ms_types.append(ms_type)
  36. return ms_types
  37. def _get_types_and_shapes(dataset):
  38. """Get dataset types and shapes."""
  39. dataset_types = _convert_type(dataset.output_types())
  40. dataset_shapes = dataset.output_shapes()
  41. return dataset_types, dataset_shapes
  42. def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
  43. """Initialize and execute the dataset graph."""
  44. batch_size = exec_dataset.get_batch_size()
  45. input_indexs = exec_dataset.input_indexs
  46. # transform data format
  47. dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
  48. exec_dataset = exec_dataset.device_que()
  49. _executor.init_dataset(exec_dataset.queue_name,
  50. dataset_size,
  51. batch_size,
  52. dataset_types,
  53. dataset_shapes,
  54. input_indexs,
  55. phase=phase)
  56. # engine dataset to write data to tdt queue
  57. exec_dataset.send()
  58. return exec_dataset
  59. def _make_directory(path: str):
  60. """Make directory."""
  61. real_path = None
  62. if path is None or not isinstance(path, str) or path.strip() == "":
  63. logger.error("The path(%r) is invalid type.", path)
  64. raise TypeError("Input path is invaild type")
  65. # convert the relative paths
  66. path = os.path.realpath(path)
  67. logger.debug("The abs path is %r", path)
  68. # check the path is exist and write permissions?
  69. if os.path.exists(path):
  70. real_path = path
  71. else:
  72. # All exceptions need to be caught because create directory maybe have some limit(permissions)
  73. logger.debug("The directory(%s) doesn't exist, will create it", path)
  74. try:
  75. os.makedirs(path, exist_ok=True)
  76. real_path = path
  77. except PermissionError as e:
  78. logger.error("No write permission on the directory(%r), error = %r", path, e)
  79. raise TypeError("No write permission on the directory.")
  80. return real_path
  81. def _construct_tensor_list(types, shapes, batch_expand_num=1):
  82. """
  83. Construct list of tensors with types and shapes, used to initialize the network.
  84. Args:
  85. types: List or Tuple. The output types of element in dataset.
  86. shapes: List or Tuple. The output shapes of element in dataset.
  87. batch_expand_num (int): Batch expand number.
  88. Returns:
  89. List, list of Tensors.
  90. """
  91. if len(types) != len(shapes):
  92. raise ValueError("The length of dataset types must equal to dataset shapes, "
  93. "but got dataset types={} and dataset shapes={}".format(types, shapes))
  94. tensor_list = []
  95. for type_, shape in zip(types, shapes):
  96. new_shape = ()
  97. for i, item in enumerate(shape):
  98. if i == 0:
  99. new_shape += (item * batch_expand_num,)
  100. else:
  101. new_shape += (item,)
  102. tensor = Tensor(np.zeros(new_shape, dtype_to_nptype(type_)))
  103. tensor.virtual_flag = True
  104. tensor_list.append(tensor)
  105. return tensor_list
  106. def _to_tensor(elem, scaling_sens=None):
  107. """Conver numpy to tensor, adapt to minddata feed solution."""
  108. lst = []
  109. if not isinstance(elem, (tuple, list)):
  110. elem = [elem]
  111. for data in elem:
  112. if not isinstance(data, np.ndarray):
  113. if scaling_sens:
  114. elem_tuple = tuple(elem) + (Tensor(scaling_sens, mstype.float32),)
  115. else:
  116. elem_tuple = tuple(elem)
  117. return elem_tuple
  118. lst.append(Tensor(data))
  119. if scaling_sens:
  120. lst.append(Tensor(scaling_sens, mstype.float32))
  121. return lst[0] if len(lst) == 1 else tuple(lst)
  122. def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
  123. """Conver numpy to tensor, expanding batch dimension according to device_num, adapt to minddata feed solution."""
  124. lst = []
  125. if not isinstance(elem, (tuple, list)):
  126. elem = [elem]
  127. if global_rank >= device_num:
  128. raise ValueError("The global rank must be smaller than device number, the global rank is {}, "
  129. "the device num is {}".format(global_rank, device_num))
  130. for data in elem:
  131. if isinstance(data, np.ndarray):
  132. data = Tensor(data)
  133. if not isinstance(data, Tensor):
  134. raise ValueError("elements in tensors must be Tensor")
  135. shape_ = data.shape()
  136. type_ = data.dtype()
  137. new_shape = ()
  138. batchsize_per_device = 1
  139. for i, item in enumerate(shape_):
  140. if i == 0:
  141. new_shape += (item * device_num,)
  142. batchsize_per_device = item
  143. else:
  144. new_shape += (item,)
  145. new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
  146. start = global_rank * batchsize_per_device
  147. new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy()
  148. new_tensor = Tensor(new_tensor_numpy)
  149. lst.append(new_tensor)
  150. if scaling_sens:
  151. lst.append(Tensor(scaling_sens, mstype.float32))
  152. return tuple(lst)
  153. def _construct_input_tensors(dataset_types, dataset_shapes, device_number=1):
  154. """Construct tensor list to initialize the network which implemented in dataset sink."""
  155. tensor_list_run = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=1)
  156. tensor_list_compile = _construct_tensor_list(dataset_types, dataset_shapes, batch_expand_num=device_number)
  157. return tensor_list_run, tensor_list_compile
  158. def _to_full_shapes(shapes, device_num):
  159. """Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
  160. new_shapes = []
  161. for shape in shapes:
  162. new_shape = ()
  163. for i, item in enumerate(shape):
  164. if i == 0:
  165. new_shape += (item * device_num,)
  166. else:
  167. new_shape += (item,)
  168. new_shapes.append(new_shape)
  169. return new_shapes