You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

process_data.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # coding:utf-8
  2. import os
  3. import pickle
  4. import collections
  5. import argparse
  6. import numpy as np
  7. import pandas as pd
  8. TRAIN_LINE_COUNT = 45840617
  9. TEST_LINE_COUNT = 6042135
  10. class DataStatsDict():
  11. def __init__(self):
  12. self.field_size = 39 # value_1-13; cat_1-26;
  13. self.val_cols = ["val_{}".format(i + 1) for i in range(13)]
  14. self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)]
  15. #
  16. self.val_min_dict = {col: 0 for col in self.val_cols}
  17. self.val_max_dict = {col: 0 for col in self.val_cols}
  18. self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols}
  19. #
  20. self.oov_prefix = "OOV_"
  21. self.cat2id_dict = {}
  22. self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)})
  23. self.cat2id_dict.update({self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)})
  24. # { "val_1": , ..., "val_13": , "OOV_cat_1": , ..., "OOV_cat_26": }
  25. def stats_vals(self, val_list):
  26. assert len(val_list) == len(self.val_cols)
  27. def map_max_min(i, val):
  28. key = self.val_cols[i]
  29. if val != "":
  30. if float(val) > self.val_max_dict[key]:
  31. self.val_max_dict[key] = float(val)
  32. if float(val) < self.val_min_dict[key]:
  33. self.val_min_dict[key] = float(val)
  34. for i, val in enumerate(val_list):
  35. map_max_min(i, val)
  36. def stats_cats(self, cat_list):
  37. assert len(cat_list) == len(self.cat_cols)
  38. def map_cat_count(i, cat):
  39. key = self.cat_cols[i]
  40. self.cat_count_dict[key][cat] += 1
  41. for i, cat in enumerate(cat_list):
  42. map_cat_count(i, cat)
  43. #
  44. def save_dict(self, output_path, prefix=""):
  45. with open(os.path.join(output_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt:
  46. pickle.dump(self.val_max_dict, file_wrt)
  47. with open(os.path.join(output_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt:
  48. pickle.dump(self.val_min_dict, file_wrt)
  49. with open(os.path.join(output_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt:
  50. pickle.dump(self.cat_count_dict, file_wrt)
  51. def load_dict(self, dict_path, prefix=""):
  52. with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt:
  53. self.val_max_dict = pickle.load(file_wrt)
  54. with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt:
  55. self.val_min_dict = pickle.load(file_wrt)
  56. with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt:
  57. self.cat_count_dict = pickle.load(file_wrt)
  58. print("val_max_dict.items()[:50]: {}".format(list(self.val_max_dict.items())))
  59. print("val_min_dict.items()[:50]: {}".format(list(self.val_min_dict.items())))
  60. def get_cat2id(self, threshold=100):
  61. for key, cat_count_d in self.cat_count_dict.items():
  62. new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items()))
  63. for cat_str, _ in new_cat_count_d.items():
  64. self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict)
  65. # print("before_all_count: {}".format( before_all_count )) # before_all_count: 33762577
  66. # print("after_all_count: {}".format( after_all_count )) # after_all_count: 184926
  67. print("cat2id_dict.size: {}".format(len(self.cat2id_dict)))
  68. print("cat2id_dict.items()[:50]: {}".format(list(self.cat2id_dict.items())[:50]))
  69. def map_cat2id(self, values, cats):
  70. def minmax_scale_value(i, val):
  71. # min_v = float(self.val_min_dict[ "val_{}".format(i+1) ])
  72. max_v = float(self.val_max_dict["val_{}".format(i + 1)])
  73. # return ( float(val) - min_v ) * 1.0 / (max_v - min_v)
  74. return float(val) * 1.0 / max_v
  75. id_list = []
  76. weight_list = []
  77. for i, val in enumerate(values):
  78. if val == "":
  79. id_list.append(i)
  80. weight_list.append(0)
  81. else:
  82. key = "val_{}".format(i + 1)
  83. id_list.append(self.cat2id_dict[key])
  84. weight_list.append(minmax_scale_value(i, float(val)))
  85. for i, cat_str in enumerate(cats):
  86. key = "cat_{}".format(i + 1) + "_" + cat_str
  87. if key in self.cat2id_dict:
  88. id_list.append(self.cat2id_dict[key])
  89. else:
  90. id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)])
  91. weight_list.append(1.0)
  92. return id_list, weight_list
  93. def mkdir_path(file_path):
  94. if not os.path.exists(file_path):
  95. os.makedirs(file_path)
  96. def statsdata(data_source_path, output_path, data_stats1):
  97. with open(data_source_path, encoding="utf-8") as file_in:
  98. errorline_list = []
  99. count = 0
  100. for line in file_in:
  101. count += 1
  102. line = line.strip("\n")
  103. items = line.split("\t")
  104. if len(items) != 40:
  105. errorline_list.append(count)
  106. print("line: {}".format(line))
  107. continue
  108. if count % 1000000 == 0:
  109. print("Have handle {}w lines.".format(count // 10000))
  110. values = items[1:14]
  111. cats = items[14:]
  112. assert len(values) == 13, "values.size: {}".format(len(values))
  113. assert len(cats) == 26, "cats.size: {}".format(len(cats))
  114. data_stats1.stats_vals(values)
  115. data_stats1.stats_cats(cats)
  116. data_stats1.save_dict(output_path)
  117. def add_write(file_path, wrt_str):
  118. with open(file_path, 'a', encoding="utf-8") as file_out:
  119. file_out.write(wrt_str + "\n")
  120. def random_split_trans2h5(input_file_path, output_path, data_stats2, part_rows=2000000, test_size=0.1, seed=2020):
  121. test_size = int(TRAIN_LINE_COUNT * test_size)
  122. all_indices = [i for i in range(TRAIN_LINE_COUNT)]
  123. np.random.seed(seed)
  124. np.random.shuffle(all_indices)
  125. print("all_indices.size: {}".format(len(all_indices)))
  126. test_indices_set = set(all_indices[: test_size])
  127. print("test_indices_set.size: {}".format(len(test_indices_set)))
  128. print("----------" * 10 + "\n" * 2)
  129. train_feature_file_name = os.path.join(output_path, "train_input_part_{}.h5")
  130. train_label_file_name = os.path.join(output_path, "train_output_part_{}.h5")
  131. test_feature_file_name = os.path.join(output_path, "test_input_part_{}.h5")
  132. test_label_file_name = os.path.join(output_path, "test_output_part_{}.h5")
  133. train_feature_list = []
  134. train_label_list = []
  135. test_feature_list = []
  136. test_label_list = []
  137. with open(input_file_path, encoding="utf-8") as file_in:
  138. count = 0
  139. train_part_number = 0
  140. test_part_number = 0
  141. for i, line in enumerate(file_in):
  142. count += 1
  143. if count % 1000000 == 0:
  144. print("Have handle {}w lines.".format(count // 10000))
  145. line = line.strip("\n")
  146. items = line.split("\t")
  147. if len(items) != 40:
  148. continue
  149. label = float(items[0])
  150. values = items[1:14]
  151. cats = items[14:]
  152. assert len(values) == 13, "values.size: {}".format(len(values))
  153. assert len(cats) == 26, "cats.size: {}".format(len(cats))
  154. ids, wts = data_stats2.map_cat2id(values, cats)
  155. if i not in test_indices_set:
  156. train_feature_list.append(ids + wts)
  157. train_label_list.append(label)
  158. else:
  159. test_feature_list.append(ids + wts)
  160. test_label_list.append(label)
  161. if train_label_list and (len(train_label_list) % part_rows == 0):
  162. pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number),
  163. key="fixed")
  164. pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number),
  165. key="fixed")
  166. train_feature_list = []
  167. train_label_list = []
  168. train_part_number += 1
  169. if test_label_list and (len(test_label_list) % part_rows == 0):
  170. pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number),
  171. key="fixed")
  172. pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number),
  173. key="fixed")
  174. test_feature_list = []
  175. test_label_list = []
  176. test_part_number += 1
  177. if train_label_list:
  178. pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number),
  179. key="fixed")
  180. pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number),
  181. key="fixed")
  182. if test_label_list:
  183. pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number),
  184. key="fixed")
  185. pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number), key="fixed")
  186. if __name__ == "__main__":
  187. parser = argparse.ArgumentParser(description='Get and Process datasets')
  188. parser.add_argument('--base_path', default="/home/wushuquan/tmp/", help='The path to save dataset')
  189. parser.add_argument('--output_path', default="/home/wushuquan/tmp/h5dataset/",
  190. help='The path to save h5 dataset')
  191. args, _ = parser.parse_known_args()
  192. base_path = args.base_path
  193. data_path = base_path + ""
  194. # mkdir_path(data_path)
  195. # if not os.path.exists(base_path + "dac.tar.gz"):
  196. # os.system(
  197. # "wget -P {} -c https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz --no-check-certificate".format(
  198. # base_path))
  199. os.system("tar -zxvf {}dac.tar.gz".format(data_path))
  200. print("********tar end***********")
  201. data_stats = DataStatsDict()
  202. # step 1, stats the vocab and normalize value
  203. data_file_path = "./train.txt"
  204. stats_output_path = base_path + "stats_dict/"
  205. mkdir_path(stats_output_path)
  206. statsdata(data_file_path, stats_output_path, data_stats)
  207. print("----------" * 10)
  208. data_stats.load_dict(dict_path=stats_output_path, prefix="")
  209. data_stats.get_cat2id(threshold=100)
  210. # step 2, transform data trans2h5; version 2: np.random.shuffle
  211. in_file_path = "./train.txt"
  212. mkdir_path(args.output_path)
  213. random_split_trans2h5(in_file_path, args.output_path, data_stats, part_rows=2000000, test_size=0.1, seed=2020)