| @@ -1,266 +0,0 @@ | |||
| import os | |||
| import itertools | |||
| import random | |||
| import numpy as np | |||
| from PIL import Image | |||
| import pickle | |||
| def get_sign_path_list(data_dir, sign_names): | |||
| sign_num = len(sign_names) | |||
| index_dict = dict(zip(sign_names, list(range(sign_num)))) | |||
| ret = [[] for _ in range(sign_num)] | |||
| for path in os.listdir(data_dir): | |||
| if path in sign_names: | |||
| index = index_dict[path] | |||
| sign_path = os.path.join(data_dir, path) | |||
| for p in os.listdir(sign_path): | |||
| ret[index].append(os.path.join(sign_path, p)) | |||
| return ret | |||
| def split_pool_by_rate(pools, rate, seed=None): | |||
| if seed is not None: | |||
| random.seed(seed) | |||
| ret1 = [] | |||
| ret2 = [] | |||
| for pool in pools: | |||
| random.shuffle(pool) | |||
| num = int(len(pool) * rate) | |||
| ret1.append(pool[:num]) | |||
| ret2.append(pool[num:]) | |||
| return ret1, ret2 | |||
| def int_to_system_form(num, system_num): | |||
| if num is 0: | |||
| return "0" | |||
| ret = "" | |||
| while num > 0: | |||
| ret += str(num % system_num) | |||
| num //= system_num | |||
| return ret[::-1] | |||
| def generator_equations( | |||
| left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type | |||
| ): | |||
| expr_len = left_opt_len + right_opt_len | |||
| num_list = "".join([str(i) for i in range(system_num)]) | |||
| ret = [] | |||
| if generate_type == "all": | |||
| candidates = itertools.product(num_list, repeat=expr_len) | |||
| else: | |||
| candidates = ["".join(random.sample(["0", "1"] * expr_len, expr_len))] | |||
| random.shuffle(candidates) | |||
| for nums in candidates: | |||
| left_num = "".join(nums[:left_opt_len]) | |||
| right_num = "".join(nums[left_opt_len:]) | |||
| left_value = int(left_num, system_num) | |||
| right_value = int(right_num, system_num) | |||
| result_value = left_value + right_value | |||
| if label == "negative": | |||
| result_value += random.randint(-result_value, result_value) | |||
| if left_value + right_value == result_value: | |||
| continue | |||
| result_num = int_to_system_form(result_value, system_num) | |||
| # leading zeros | |||
| if res_opt_len != len(result_num): | |||
| continue | |||
| if (left_opt_len > 1 and left_num[0] == "0") or ( | |||
| right_opt_len > 1 and right_num[0] == "0" | |||
| ): | |||
| continue | |||
| # add leading zeros | |||
| if res_opt_len < len(result_num): | |||
| continue | |||
| while len(result_num) < res_opt_len: | |||
| result_num = "0" + result_num | |||
| # continue | |||
| ret.append( | |||
| left_num + "+" + right_num + "=" + result_num | |||
| ) # current only consider '+' and '=' | |||
| # print(ret[-1]) | |||
| return ret | |||
| def generator_equation_by_len(equation_len, system_num=2, label=0, require_num=1): | |||
| generate_type = "one" | |||
| ret = [] | |||
| equation_sign_num = 2 # '+' and '=' | |||
| while len(ret) < require_num: | |||
| left_opt_len = random.randint(1, equation_len - 1 - equation_sign_num) | |||
| right_opt_len = random.randint( | |||
| 1, equation_len - left_opt_len - equation_sign_num | |||
| ) | |||
| res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num | |||
| ret.extend( | |||
| generator_equations( | |||
| left_opt_len, | |||
| right_opt_len, | |||
| res_opt_len, | |||
| system_num, | |||
| label, | |||
| generate_type, | |||
| ) | |||
| ) | |||
| return ret | |||
| def generator_equations_by_len( | |||
| equation_len, system_num=2, label=0, repeat_times=1, keep=1, generate_type="all" | |||
| ): | |||
| ret = [] | |||
| equation_sign_num = 2 # '+' and '=' | |||
| for left_opt_len in range(1, equation_len - (2 + equation_sign_num) + 1): | |||
| for right_opt_len in range( | |||
| 1, equation_len - left_opt_len - (1 + equation_sign_num) + 1 | |||
| ): | |||
| res_opt_len = ( | |||
| equation_len - left_opt_len - right_opt_len - equation_sign_num | |||
| ) | |||
| for i in range(repeat_times): # generate more equations | |||
| if random.random() > keep ** (equation_len): | |||
| continue | |||
| ret.extend( | |||
| generator_equations( | |||
| left_opt_len, | |||
| right_opt_len, | |||
| res_opt_len, | |||
| system_num, | |||
| label, | |||
| generate_type, | |||
| ) | |||
| ) | |||
| return ret | |||
| def generator_equations_by_max_len( | |||
| max_equation_len, | |||
| system_num=2, | |||
| label=0, | |||
| repeat_times=1, | |||
| keep=1, | |||
| generate_type="all", | |||
| num_per_len=None, | |||
| ): | |||
| ret = [] | |||
| equation_sign_num = 2 # '+' and '=' | |||
| for equation_len in range(3 + equation_sign_num, max_equation_len + 1): | |||
| if num_per_len is None: | |||
| ret.extend( | |||
| generator_equations_by_len( | |||
| equation_len, system_num, label, repeat_times, keep, generate_type | |||
| ) | |||
| ) | |||
| else: | |||
| ret.extend( | |||
| generator_equation_by_len( | |||
| equation_len, system_num, label, require_num=num_per_len | |||
| ) | |||
| ) | |||
| return ret | |||
| def generator_equation_images(image_pools, equations, signs, shape, seed, is_color): | |||
| if seed is not None: | |||
| random.seed(seed) | |||
| ret = [] | |||
| sign_num = len(signs) | |||
| sign_index_dict = dict(zip(signs, list(range(sign_num)))) | |||
| for equation in equations: | |||
| data = [] | |||
| for sign in equation: | |||
| index = sign_index_dict[sign] | |||
| pick = random.randint(0, len(image_pools[index]) - 1) | |||
| if is_color: | |||
| image = ( | |||
| Image.open(image_pools[index][pick]).convert("RGB").resize(shape) | |||
| ) | |||
| else: | |||
| image = Image.open(image_pools[index][pick]).convert("I").resize(shape) | |||
| image_array = np.array(image) | |||
| image_array = (image_array - 127) * (1.0 / 128) | |||
| data.append(image_array) | |||
| ret.append(np.array(data)) | |||
| return ret | |||
| def get_equation_std_data( | |||
| data_dir, | |||
| sign_dir_lists, | |||
| sign_output_lists, | |||
| shape=(28, 28), | |||
| train_max_equation_len=10, | |||
| test_max_equation_len=10, | |||
| system_num=2, | |||
| tmp_file_prev=None, | |||
| seed=None, | |||
| train_num_per_len=10, | |||
| test_num_per_len=10, | |||
| is_color=False, | |||
| ): | |||
| tmp_file = "" | |||
| if tmp_file_prev is not None: | |||
| tmp_file = "%s_train_len_%d_test_len_%d_sys_%d_.pk" % ( | |||
| tmp_file_prev, | |||
| train_max_equation_len, | |||
| test_max_equation_len, | |||
| system_num, | |||
| ) | |||
| if os.path.exists(tmp_file): | |||
| return pickle.load(open(tmp_file, "rb")) | |||
| image_pools = get_sign_path_list(data_dir, sign_dir_lists) | |||
| train_pool, test_pool = split_pool_by_rate(image_pools, 0.8, seed) | |||
| ret = {} | |||
| for label in ["positive", "negative"]: | |||
| print("Generating equations.") | |||
| train_equations = generator_equations_by_max_len( | |||
| train_max_equation_len, system_num, label, num_per_len=train_num_per_len | |||
| ) | |||
| test_equations = generator_equations_by_max_len( | |||
| test_max_equation_len, system_num, label, num_per_len=test_num_per_len | |||
| ) | |||
| print(train_equations) | |||
| print(test_equations) | |||
| print("Generated equations.") | |||
| print("Generating equation image data.") | |||
| ret["train:%s" % (label)] = generator_equation_images( | |||
| train_pool, train_equations, sign_output_lists, shape, seed, is_color | |||
| ) | |||
| ret["test:%s" % (label)] = generator_equation_images( | |||
| test_pool, test_equations, sign_output_lists, shape, seed, is_color | |||
| ) | |||
| print("Generated equation image data.") | |||
| if tmp_file_prev is not None: | |||
| pickle.dump(ret, open(tmp_file, "wb")) | |||
| return ret | |||
| if __name__ == "__main__": | |||
| data_dirs = [ | |||
| "./dataset/mnist_images", | |||
| "./dataset/random_images", | |||
| ] # , "../dataset/cifar10_images"] | |||
| tmp_file_prevs = [ | |||
| "mnist_equation_data", | |||
| "random_equation_data", | |||
| ] # , "cifar10_equation_data"] | |||
| for data_dir, tmp_file_prev in zip(data_dirs, tmp_file_prevs): | |||
| data = get_equation_std_data( | |||
| data_dir=data_dir, | |||
| sign_dir_lists=["0", "1", "10", "11"], | |||
| sign_output_lists=["0", "1", "+", "="], | |||
| shape=(28, 28), | |||
| train_max_equation_len=26, | |||
| test_max_equation_len=26, | |||
| system_num=2, | |||
| tmp_file_prev=tmp_file_prev, | |||
| train_num_per_len=300, | |||
| test_num_per_len=300, | |||
| is_color=False, | |||
| ) | |||