You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ create train dataset. """
  16. import os
  17. import re
  18. import numpy as np
  19. from mindspore.communication.management import init
  20. from mindspore.communication.management import get_rank
  21. from mindspore.communication.management import get_group_size
  22. from mindspore import Tensor
  23. def _count_unequal_element(data_expected, data_me, rtol, atol):
  24. assert data_expected.shape == data_me.shape
  25. total_count = len(data_expected.flatten())
  26. error = np.abs(data_expected - data_me)
  27. greater = np.greater(error, atol + np.abs(data_me) * rtol)
  28. loss_count = np.count_nonzero(greater)
  29. assert (loss_count / total_count) < rtol, \
  30. "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
  31. format(data_expected[greater], data_me[greater], error[greater])
  32. def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
  33. if np.any(np.isnan(data_expected)):
  34. assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
  35. elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
  36. _count_unequal_element(data_expected, data_me, rtol, atol)
  37. else:
  38. assert True
  39. def clean_all_ir_files(folder_path):
  40. if os.path.exists(folder_path):
  41. for file_name in os.listdir(folder_path):
  42. if file_name.endswith('.ir') or file_name.endswith('.dat') or file_name.endswith('.dot'):
  43. os.remove(os.path.join(folder_path, file_name))
  44. def find_newest_validateir_file(folder_path):
  45. validate_files = map(lambda f: os.path.join(folder_path, f),
  46. filter(lambda f: re.match(r'\d+_validate_\d+.ir', f), os.listdir(folder_path)))
  47. return max(validate_files, key=os.path.getctime)
  48. class FakeDataInitMode:
  49. RandomInit = 0
  50. OnesInit = 1
  51. UniqueInit = 2
  52. ZerosInit = 3
  53. class FakeData:
  54. def __init__(self, size=1024, batch_size=32, image_size=(3, 224, 224),
  55. num_classes=10, random_offset=0, use_parallel=False,
  56. fakedata_mode=FakeDataInitMode.RandomInit):
  57. self.size = size
  58. self.rank_batch_size = batch_size
  59. self.total_batch_size = self.rank_batch_size
  60. self.random_offset = random_offset
  61. self.image_size = image_size
  62. self.num_classes = num_classes
  63. self.rank_size = 1
  64. self.rank_id = 0
  65. self.batch_index = 0
  66. self.image_data_type = np.float32
  67. self.label_data_type = np.float32
  68. self.is_onehot = True
  69. self.fakedata_mode = fakedata_mode
  70. if use_parallel is True:
  71. init()
  72. self.rank_size = get_group_size()
  73. self.rank_id = get_rank()
  74. self.total_batch_size = self.rank_batch_size * self.rank_size
  75. assert (self.size % self.total_batch_size) == 0
  76. self.total_batch_data_size = (self.rank_size, self.rank_batch_size) + image_size
  77. def get_dataset_size(self):
  78. return int(self.size / self.total_batch_size)
  79. def get_repeat_count(self):
  80. return 1
  81. def set_image_data_type(self, data_type):
  82. self.image_data_type = data_type
  83. def set_label_data_type(self, data_type):
  84. self.label_data_type = data_type
  85. def set_label_onehot(self, is_onehot=True):
  86. self.is_onehot = is_onehot
  87. def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
  88. _ = num_epochs
  89. return self
  90. def __getitem__(self, batch_index):
  91. if batch_index * self.total_batch_size >= len(self):
  92. raise IndexError("{} index out of range".format(self.__class__.__name__))
  93. rng_state = np.random.get_state()
  94. np.random.seed(batch_index + self.random_offset)
  95. if self.fakedata_mode == FakeDataInitMode.OnesInit:
  96. img = np.ones(self.total_batch_data_size)
  97. elif self.fakedata_mode == FakeDataInitMode.ZerosInit:
  98. img = np.zeros(self.total_batch_data_size)
  99. elif self.fakedata_mode == FakeDataInitMode.UniqueInit:
  100. total_size = 1
  101. for i in self.total_batch_data_size:
  102. total_size = total_size * i
  103. img = np.reshape(np.arange(total_size) * 0.0001, self.total_batch_data_size)
  104. else:
  105. img = np.random.randn(*self.total_batch_data_size)
  106. target = np.random.randint(0, self.num_classes, size=(self.rank_size, self.rank_batch_size))
  107. np.random.set_state(rng_state)
  108. img = img[self.rank_id]
  109. target = target[self.rank_id]
  110. img_ret = img.astype(self.image_data_type)
  111. target_ret = target.astype(self.label_data_type)
  112. if self.is_onehot:
  113. target_onehot = np.zeros(shape=(self.rank_batch_size, self.num_classes))
  114. target_onehot[np.arange(self.rank_batch_size), target] = 1
  115. target_ret = target_onehot.astype(self.label_data_type)
  116. return Tensor(img_ret), Tensor(target_ret)
  117. def __len__(self):
  118. return self.size
  119. def __iter__(self):
  120. self.batch_index = 0
  121. return self
  122. def reset(self):
  123. self.batch_index = 0
  124. def __next__(self):
  125. if self.batch_index * self.total_batch_size < len(self):
  126. data = self[self.batch_index]
  127. self.batch_index += 1
  128. return data
  129. raise StopIteration