You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_generator.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :data_generator.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/04/02
  9. # Description :
  10. #
  11. #================================================================#
  12. from itertools import product
  13. import math
  14. import numpy as np
  15. import random
  16. import pickle as pk
  17. import random
  18. from multiprocessing import Pool
  19. import copy
  20. #def hamming_code_generator(data_len, p_len):
  21. # ret = []
  22. # for data in product((0, 1), repeat=data_len):
  23. # p_idxs = [2 ** i for i in range(p_len)]
  24. # total_len = data_len + p_len
  25. # data_idx = 0
  26. # hamming_code = []
  27. # for idx in range(total_len):
  28. # if idx + 1 in p_idxs:
  29. # hamming_code.append(0)
  30. # else:
  31. # hamming_code.append(data[data_idx])
  32. # data_idx += 1
  33. #
  34. # for idx in range(total_len):
  35. # if idx + 1 in p_idxs:
  36. # for i in range(total_len):
  37. # if (i + 1) & (idx + 1) != 0:
  38. # hamming_code[idx] ^= hamming_code[i]
  39. # #hamming_code = "".join([str(x) for x in hamming_code])
  40. # ret.append(hamming_code)
  41. # return ret
  42. def code_generator(code_len, code_num, letter_num = 2):
  43. codes = list(product(list(range(letter_num)), repeat = code_len))
  44. random.shuffle(codes)
  45. return codes[:code_num]
  46. def hamming_distance_static(codes):
  47. min_dist = len(codes)
  48. avg_dist = 0.
  49. avg_min_dist = 0.
  50. relation_num = 0.
  51. for code1 in codes:
  52. tmp_min_dist = len(codes)
  53. for code2 in codes:
  54. if code1 == code2:
  55. continue
  56. dist = 0
  57. relation_num += 1
  58. for c1, c2 in zip(code1, code2):
  59. if c1 != c2:
  60. dist += 1
  61. avg_dist += dist
  62. if tmp_min_dist > dist:
  63. tmp_min_dist = dist
  64. avg_min_dist += tmp_min_dist
  65. if min_dist > tmp_min_dist:
  66. min_dist = tmp_min_dist
  67. return avg_dist / relation_num, avg_min_dist / len(codes)
  68. def generate_cosin_data(codes, err, repeat, letter_num):
  69. Y = np.random.random(100000) * letter_num * 3 - 3
  70. X = np.random.random(100000) * 20 - 10
  71. data_X = np.concatenate((X.reshape(-1, 1), Y.reshape(-1, 1)), axis = 1)
  72. samples = {}
  73. all_sign = list(set(sum([[c for c in code] for code in codes], [])))
  74. for d, sign in enumerate(all_sign):
  75. labels = np.logical_and(Y < np.cos(X) + 2 * d, Y > np.cos(X) + 2 * d - 2)
  76. samples[sign] = data_X[labels]
  77. data = []
  78. labels = []
  79. count = 0
  80. for _ in range(repeat):
  81. if (count > 100000):
  82. break
  83. for code in codes:
  84. tmp = []
  85. count += 1
  86. for d in code:
  87. if random.random() < err:
  88. candidates = copy.deepcopy(all_sign)
  89. candidates.remove(d)
  90. d = candidates[random.randint(0, letter_num - 2)]
  91. idx = random.randint(0, len(samples[d]) - 1)
  92. tmp.append(samples[d][idx])
  93. data.append(tmp)
  94. labels.append(code)
  95. data = np.array(data)
  96. labels = np.array(labels)
  97. return data, labels
  98. #codes = """110011001
  99. #100011001
  100. #101101101
  101. #011111001
  102. #100100001
  103. #111111101
  104. #101110001
  105. #111100101
  106. #101000101
  107. #001001101
  108. #111110101
  109. #100101001
  110. #010010101
  111. #110100101
  112. #001111101
  113. #111111001"""
  114. #codes = codes.split()
  115. def generate_data_via_codes(codes, err, letter_num):
  116. #codes = code_generator(code_len, code_num)
  117. data, labels = generate_cosin_data(codes, err, 100000, letter_num)
  118. return data, labels
  119. def generate_data(params):
  120. code_len = params["code_len"]
  121. times = params["times"]
  122. p = params["p"]
  123. code_num = params["code_num"]
  124. err = p / 20.
  125. codes = code_generator(code_len, code_num)
  126. data, labels = generate_cosin_data(codes, err)
  127. data_name = "code_%d_%d" % (code_len, code_num)
  128. pk.dump((codes, data, labels), open("generated_data/%d_%s_%.2f.pk" % (times, data_name, err), "wb"))
  129. return True
  130. def generate_multi_data():
  131. pool = Pool(64)
  132. params_list = []
  133. #for code_len in [7, 9, 11, 13, 15]:
  134. for code_len in [7, 11, 15]:
  135. for times in range(20):
  136. for p in range(0, 11):
  137. for code_num_power in range(1, code_len):
  138. code_num = 2 ** code_num_power
  139. params_list.append({"code_len" : code_len, "times" : times, "p" : p, "code_num" : code_num})
  140. return list(pool.map(generate_data, params_list))
  141. def read_lexicon(file_path):
  142. ret = []
  143. with open(file_path) as fin:
  144. ret = [s.strip() for s in fin]
  145. all_sign = list(set(sum([[c for c in s] for s in ret], [])))
  146. #ret = ["".join(str(all_sign.index(t)) for t in tmp) for tmp in ret]
  147. return ret, len(all_sign)
  148. import os
  149. if __name__ == "__main__":
  150. for root, dirs, files in os.walk("lexicons"):
  151. if root != "lexicons":
  152. continue
  153. for file_name in files:
  154. file_path = os.path.join(root, file_name)
  155. codes, letter_num = read_lexicon(file_path)
  156. data, labels = generate_data_via_codes(codes, 0, letter_num)
  157. save_path = os.path.join("dataset", file_name.split(".")[0] + ".pk")
  158. pk.dump((data, labels, codes), open(save_path, "wb"))
  159. #res = read_lexicon("add2.txt")
  160. #print(res)
  161. exit(0)
  162. generate_multi_data()
  163. exit()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.