You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 12 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Create train or eval dataset.
  17. """
  18. import os
  19. import math
  20. from enum import Enum
  21. import pandas as pd
  22. import numpy as np
  23. import mindspore.dataset.engine as de
  24. import mindspore.common.dtype as mstype
  25. from .config import DataConfig
  26. class DataType(Enum):
  27. """
  28. Enumerate supported dataset format.
  29. """
  30. MINDRECORD = 1
  31. TFRECORD = 2
  32. H5 = 3
  33. class H5Dataset():
  34. """
  35. Create dataset with H5 format.
  36. Args:
  37. data_path (str): Dataset directory.
  38. train_mode (bool): Whether dataset is used for train or eval (default=True).
  39. train_num_of_parts (int): The number of train data file (default=21).
  40. test_num_of_parts (int): The number of test data file (default=3).
  41. """
  42. max_length = 39
  43. def __init__(self, data_path, train_mode=True,
  44. train_num_of_parts=DataConfig.train_num_of_parts,
  45. test_num_of_parts=DataConfig.test_num_of_parts):
  46. self._hdf_data_dir = data_path
  47. self._is_training = train_mode
  48. if self._is_training:
  49. self._file_prefix = 'train'
  50. self._num_of_parts = train_num_of_parts
  51. else:
  52. self._file_prefix = 'test'
  53. self._num_of_parts = test_num_of_parts
  54. self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix, self._num_of_parts)
  55. print("data_size: {}".format(self.data_size))
  56. def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts):
  57. size = 0
  58. for part in range(num_of_parts):
  59. _y = pd.read_hdf(os.path.join(hdf_data_dir, f'{file_prefix}_output_part_{str(part)}.h5'))
  60. size += _y.shape[0]
  61. return size
  62. def _iterate_hdf_files_(self, num_of_parts=None,
  63. shuffle_block=False):
  64. """
  65. iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts
  66. from the beginning, thus the data stream will never stop
  67. :param train_mode: True or false,false is eval_mode,
  68. this file iterator will go through the train set
  69. :param num_of_parts: number of files
  70. :param shuffle_block: shuffle block files at every round
  71. :return: input_hdf_file_name, output_hdf_file_name, finish_flag
  72. """
  73. parts = np.arange(num_of_parts)
  74. while True:
  75. if shuffle_block:
  76. for _ in range(int(shuffle_block)):
  77. np.random.shuffle(parts)
  78. for i, p in enumerate(parts):
  79. yield os.path.join(self._hdf_data_dir, f'{self._file_prefix}_input_part_{str(p)}.h5'), \
  80. os.path.join(self._hdf_data_dir, f'{self._file_prefix}_output_part_{str(p)}.h5'), \
  81. i + 1 == len(parts)
  82. def _generator(self, X, y, batch_size, shuffle=True):
  83. """
  84. should be accessed only in private
  85. :param X:
  86. :param y:
  87. :param batch_size:
  88. :param shuffle:
  89. :return:
  90. """
  91. number_of_batches = np.ceil(1. * X.shape[0] / batch_size)
  92. counter = 0
  93. finished = False
  94. sample_index = np.arange(X.shape[0])
  95. if shuffle:
  96. for _ in range(int(shuffle)):
  97. np.random.shuffle(sample_index)
  98. assert X.shape[0] > 0
  99. while True:
  100. batch_index = sample_index[batch_size * counter: batch_size * (counter + 1)]
  101. X_batch = X[batch_index]
  102. y_batch = y[batch_index]
  103. counter += 1
  104. yield X_batch, y_batch, finished
  105. if counter == number_of_batches:
  106. counter = 0
  107. finished = True
  108. def batch_generator(self, batch_size=1000,
  109. random_sample=False, shuffle_block=False):
  110. """
  111. :param train_mode: True or false,false is eval_mode,
  112. :param batch_size
  113. :param num_of_parts: number of files
  114. :param random_sample: if True, will shuffle
  115. :param shuffle_block: shuffle file blocks at every round
  116. :return:
  117. """
  118. for hdf_in, hdf_out, _ in self._iterate_hdf_files_(self._num_of_parts,
  119. shuffle_block):
  120. start = stop = None
  121. X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values
  122. y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values
  123. data_gen = self._generator(X_all, y_all, batch_size,
  124. shuffle=random_sample)
  125. finished = False
  126. while not finished:
  127. X, y, finished = data_gen.__next__()
  128. X_id = X[:, 0:self.max_length]
  129. X_va = X[:, self.max_length:]
  130. yield np.array(X_id.astype(dtype=np.int32)), \
  131. np.array(X_va.astype(dtype=np.float32)), \
  132. np.array(y.astype(dtype=np.float32))
  133. def _get_h5_dataset(directory, train_mode=True, epochs=1, batch_size=1000):
  134. """
  135. Get dataset with h5 format.
  136. Args:
  137. directory (str): Dataset directory.
  138. train_mode (bool): Whether dataset is use for train or eval (default=True).
  139. epochs (int): Dataset epoch size (default=1).
  140. batch_size (int): Dataset batch size (default=1000)
  141. Returns:
  142. Dataset.
  143. """
  144. data_para = {'batch_size': batch_size}
  145. if train_mode:
  146. data_para['random_sample'] = True
  147. data_para['shuffle_block'] = True
  148. h5_dataset = H5Dataset(data_path=directory, train_mode=train_mode)
  149. numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size)
  150. def _iter_h5_data():
  151. train_eval_gen = h5_dataset.batch_generator(**data_para)
  152. for _ in range(0, numbers_of_batch, 1):
  153. yield train_eval_gen.__next__()
  154. ds = de.GeneratorDataset(_iter_h5_data, ["ids", "weights", "labels"])
  155. ds.set_dataset_size(numbers_of_batch)
  156. ds = ds.repeat(epochs)
  157. return ds
  158. def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
  159. line_per_sample=1000, rank_size=None, rank_id=None):
  160. """
  161. Get dataset with mindrecord format.
  162. Args:
  163. directory (str): Dataset directory.
  164. train_mode (bool): Whether dataset is use for train or eval (default=True).
  165. epochs (int): Dataset epoch size (default=1).
  166. batch_size (int): Dataset batch size (default=1000).
  167. line_per_sample (int): The number of sample per line (default=1000).
  168. rank_size (int): The number of device, not necessary for single device (default=None).
  169. rank_id (int): Id of device, not necessary for single device (default=None).
  170. Returns:
  171. Dataset.
  172. """
  173. file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord'
  174. file_suffix_name = '00' if train_mode else '0'
  175. shuffle = train_mode
  176. if rank_size is not None and rank_id is not None:
  177. ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
  178. columns_list=['feat_ids', 'feat_vals', 'label'],
  179. num_shards=rank_size, shard_id=rank_id, shuffle=shuffle,
  180. num_parallel_workers=8)
  181. else:
  182. ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name),
  183. columns_list=['feat_ids', 'feat_vals', 'label'],
  184. shuffle=shuffle, num_parallel_workers=8)
  185. ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
  186. ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39),
  187. np.array(y).flatten().reshape(batch_size, 39),
  188. np.array(z).flatten().reshape(batch_size, 1))),
  189. input_columns=['feat_ids', 'feat_vals', 'label'],
  190. columns_order=['feat_ids', 'feat_vals', 'label'],
  191. num_parallel_workers=8)
  192. ds = ds.repeat(epochs)
  193. return ds
  194. def _get_tf_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
  195. line_per_sample=1000, rank_size=None, rank_id=None):
  196. """
  197. Get dataset with tfrecord format.
  198. Args:
  199. directory (str): Dataset directory.
  200. train_mode (bool): Whether dataset is use for train or eval (default=True).
  201. epochs (int): Dataset epoch size (default=1).
  202. batch_size (int): Dataset batch size (default=1000).
  203. line_per_sample (int): The number of sample per line (default=1000).
  204. rank_size (int): The number of device, not necessary for single device (default=None).
  205. rank_id (int): Id of device, not necessary for single device (default=None).
  206. Returns:
  207. Dataset.
  208. """
  209. dataset_files = []
  210. file_prefixt_name = 'train' if train_mode else 'test'
  211. shuffle = train_mode
  212. for (dir_path, _, filenames) in os.walk(directory):
  213. for filename in filenames:
  214. if file_prefixt_name in filename and 'tfrecord' in filename:
  215. dataset_files.append(os.path.join(dir_path, filename))
  216. schema = de.Schema()
  217. schema.add_column('feat_ids', de_type=mstype.int32)
  218. schema.add_column('feat_vals', de_type=mstype.float32)
  219. schema.add_column('label', de_type=mstype.float32)
  220. if rank_size is not None and rank_id is not None:
  221. ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle,
  222. schema=schema, num_parallel_workers=8,
  223. num_shards=rank_size, shard_id=rank_id,
  224. shard_equal_rows=True)
  225. else:
  226. ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle,
  227. schema=schema, num_parallel_workers=8)
  228. ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
  229. ds = ds.map(operations=(lambda x, y, z: (
  230. np.array(x).flatten().reshape(batch_size, 39),
  231. np.array(y).flatten().reshape(batch_size, 39),
  232. np.array(z).flatten().reshape(batch_size, 1))),
  233. input_columns=['feat_ids', 'feat_vals', 'label'],
  234. columns_order=['feat_ids', 'feat_vals', 'label'],
  235. num_parallel_workers=8)
  236. ds = ds.repeat(epochs)
  237. return ds
  238. def create_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
  239. data_type=DataType.TFRECORD, line_per_sample=1000,
  240. rank_size=None, rank_id=None):
  241. """
  242. Get dataset.
  243. Args:
  244. directory (str): Dataset directory.
  245. train_mode (bool): Whether dataset is use for train or eval (default=True).
  246. epochs (int): Dataset epoch size (default=1).
  247. batch_size (int): Dataset batch size (default=1000).
  248. data_type (DataType): The type of dataset which is one of H5, TFRECORE, MINDRECORD (default=TFRECORD).
  249. line_per_sample (int): The number of sample per line (default=1000).
  250. rank_size (int): The number of device, not necessary for single device (default=None).
  251. rank_id (int): Id of device, not necessary for single device (default=None).
  252. Returns:
  253. Dataset.
  254. """
  255. if data_type == DataType.MINDRECORD:
  256. return _get_mindrecord_dataset(directory, train_mode, epochs,
  257. batch_size, line_per_sample,
  258. rank_size, rank_id)
  259. if data_type == DataType.TFRECORD:
  260. return _get_tf_dataset(directory, train_mode, epochs, batch_size,
  261. line_per_sample, rank_size=rank_size, rank_id=rank_id)
  262. if rank_size is not None and rank_size > 1:
  263. raise ValueError('Please use mindrecord dataset.')
  264. return _get_h5_dataset(directory, train_mode, epochs, batch_size)