diff --git a/.gitignore b/.gitignore index 8ebaf8d..8ef59e4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ *.pyc /results -raw/ \ No newline at end of file +raw/ +*.jpg +*.png +*.pk \ No newline at end of file diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index 72884f9..c0fd102 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -26,9 +26,16 @@ import time class AbducerBase(abc.ABC): - def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True): + def __init__( + self, + kb, + dist_func="confidence", + zoopt=False, + multiple_predictions=False, + cache=True, + ): self.kb = kb - assert dist_func == 'hamming' or dist_func == 'confidence' + assert dist_func == "hamming" or dist_func == "confidence" self.dist_func = dist_func self.zoopt = zoopt self.multiple_predictions = multiple_predictions @@ -39,11 +46,18 @@ class AbducerBase(abc.ABC): self.cache_candidates = {} def _get_cost_list(self, pred_res, pred_res_prob, candidates): - if self.dist_func == 'hamming': + if self.dist_func == "hamming": return hamming_dist(pred_res, candidates) - elif self.dist_func == 'confidence': - mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) - return confidence_dist(pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates]) + elif self.dist_func == "confidence": + mapping = dict( + zip( + self.kb.pseudo_label_list, + list(range(len(self.kb.pseudo_label_list))), + ) + ) + return confidence_dist( + pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates] + ) def _get_one_candidate(self, pred_res, pred_res_prob, candidates): if len(candidates) == 0: @@ -60,9 +74,13 @@ class AbducerBase(abc.ABC): def _zoopt_score_multiple(self, pred_res, key, solution): all_address_flag = reform_idx(solution, pred_res) score = 0 - for idx in enumerate(len(pred_res)): - address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = self.kb.address_by_idx([pred_res[idx]], key[idx], address_idx, True) + for idx in range(len(pred_res)): + address_idx = [ + i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 + ] + candidate = self.kb.address_by_idx( + [pred_res[idx]], key[idx], address_idx, True + ) if len(candidate) > 0: score += 1 return score @@ -70,7 +88,9 @@ class AbducerBase(abc.ABC): def _zoopt_address_score(self, pred_res, key, sol): if not self.multiple_predictions: address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] - candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) + candidates = self.kb.address_by_idx( + pred_res, key, address_idx, self.multiple_predictions + ) return 1 if len(candidates) > 0 else 0 else: return self._zoopt_score_multiple(pred_res, key, sol.get_x()) @@ -98,7 +118,11 @@ class AbducerBase(abc.ABC): pred_res = flatten(pred_res) key = tuple(key) if (tuple(pred_res), key) in self.cache_min_address_num: - address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) + address_num = min( + max_address_num, + self.cache_min_address_num[(tuple(pred_res), key)] + + require_more_address, + ) if (tuple(pred_res), key, address_num) in self.cache_candidates: candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] if self.zoopt: @@ -127,12 +151,18 @@ class AbducerBase(abc.ABC): if self.zoopt: solution = self.zoopt_get_solution(pred_res, key, max_address_num) address_idx = [idx for idx, i in enumerate(solution) if i != 0] - candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) + candidates = self.kb.address_by_idx( + pred_res, key, address_idx, self.multiple_predictions + ) address_num = int(solution.sum()) min_address_num = address_num else: candidates, min_address_num, address_num = self.kb.abduce_candidates( - pred_res, key, max_address_num, require_more_address, self.multiple_predictions + pred_res, + key, + max_address_num, + require_more_address, + self.multiple_predictions, ) candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) @@ -147,20 +177,31 @@ class AbducerBase(abc.ABC): def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): if self.multiple_predictions: - return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) + return self.abduce( + (Z["cls"], Z["prob"], Y), max_address_num, require_more_address + ) else: - return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] + return [ + self.abduce((z, prob, y), max_address_num, require_more_address) + for z, prob, y in zip(Z["cls"], Z["prob"], Y) + ] def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): return self.batch_abduce(Z, Y, max_address_num, require_more_address) -if __name__ == '__main__': - prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] - prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] +if __name__ == "__main__": + prob1 = [ + [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ] + prob2 = [ + [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], + [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + ] kb = add_KB() - abd = AbducerBase(kb, 'confidence') + abd = AbducerBase(kb, "confidence") res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -174,7 +215,7 @@ if __name__ == '__main__': print() kb = add_prolog_KB() - abd = AbducerBase(kb, 'confidence') + abd = AbducerBase(kb, "confidence") res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -188,7 +229,7 @@ if __name__ == '__main__': print() kb = add_prolog_KB() - abd = AbducerBase(kb, 'confidence', zoopt=True) + abd = AbducerBase(kb, "confidence", zoopt=True) res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -202,24 +243,38 @@ if __name__ == '__main__': print() kb = HWF_KB(len_list=[1, 3, 5]) - abd = AbducerBase(kb, 'hamming') - res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) + abd = AbducerBase(kb, "hamming") + res = abd.abduce( + (["5", "+", "2"], None, 3), max_address_num=2, require_more_address=0 + ) print(res) - res = abd.abduce((['5', '+', '2'], None, 64), max_address_num=3, require_more_address=0) + res = abd.abduce( + (["5", "+", "2"], None, 64), max_address_num=3, require_more_address=0 + ) print(res) - res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) + res = abd.abduce( + (["5", "+", "2"], None, 1.67), max_address_num=3, require_more_address=0 + ) print(res) - res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) + res = abd.abduce( + (["5", "8", "8", "8", "8"], None, 3.17), + max_address_num=5, + require_more_address=3, + ) print(res) print() kb = HED_prolog_KB() abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) - consist_exs = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 0, '=', 1, 1]] - consist_exs2 = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 1, '=', 1, 1]] # not consistent with rules - inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] + consist_exs = [[1, "+", 0, "=", 0], [1, "+", 1, "=", 0], [0, "+", 0, "=", 1, 1]] + consist_exs2 = [ + [1, "+", 0, "=", 0], + [1, "+", 1, "=", 0], + [0, "+", 1, "=", 1, 1], + ] # not consistent with rules + inconsist_exs = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] # inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']] - rules = ['my_op([0], [0], [1, 1])', 'my_op([1], [1], [0])', 'my_op([1], [0], [0])'] + rules = ["my_op([0], [0], [1, 1])", "my_op([1], [1], [0])", "my_op([1], [0], [0])"] print(kb.logic_forward(consist_exs), kb.logic_forward(inconsist_exs)) print(kb.consist_rule(consist_exs, rules), kb.consist_rule(consist_exs2, rules)) diff --git a/abducer/kb.py b/abducer/kb.py index 6e5f0d3..04a9f68 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -43,16 +43,29 @@ class KBBase(ABC): def address(self, address_num, pred_res, key, multiple_predictions=False): new_candidates = [] if not multiple_predictions: - address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + address_idx_list = list( + combinations(list(range(len(pred_res))), address_num) + ) else: - address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) + address_idx_list = list( + combinations(list(range(len(flatten(pred_res)))), address_num) + ) for address_idx in address_idx_list: - candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) + candidates = self.address_by_idx( + pred_res, key, address_idx, multiple_predictions + ) new_candidates += candidates return new_candidates - def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions=False): + def abduction( + self, + pred_res, + key, + max_address_num, + require_more_address, + multiple_predictions=False, + ): candidates = [] for address_num in range(len(pred_res) + 1): @@ -60,7 +73,9 @@ class KBBase(ABC): if abs(self.logic_forward(pred_res) - key) <= 1e-3: candidates.append(pred_res) else: - new_candidates = self.address(address_num, pred_res, key, multiple_predictions) + new_candidates = self.address( + address_num, pred_res, key, multiple_predictions + ) candidates += new_candidates if len(candidates) > 0: @@ -70,10 +85,14 @@ class KBBase(ABC): if address_num >= max_address_num: return [], 0, 0 - for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): + for address_num in range( + min_address_num + 1, min_address_num + require_more_address + 1 + ): if address_num > max_address_num: return candidates, min_address_num, address_num - 1 - new_candidates = self.address(address_num, pred_res, key, multiple_predictions) + new_candidates = self.address( + address_num, pred_res, key, multiple_predictions + ) candidates += new_candidates return candidates, min_address_num, address_num @@ -98,7 +117,9 @@ class ClsKB(KBBase): else: self.all_address_candidate_dict = {} for address_num in range(max(self.len_list) + 1): - self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat=address_num)) + self.all_address_candidate_dict[address_num] = list( + product(self.pseudo_label_list, repeat=address_num) + ) # For parallel version of _get_GKB def _get_XY_list(self, args): @@ -143,11 +164,26 @@ class ClsKB(KBBase): def logic_forward(self): pass - def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): + def abduce_candidates( + self, + pred_res, + key, + max_address_num=-1, + require_more_address=0, + multiple_predictions=False, + ): if self.GKB_flag: - return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address) + return self.abduce_from_GKB( + pred_res, key, max_address_num, require_more_address + ) else: - return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) + return self.abduction( + pred_res, + key, + max_address_num, + require_more_address, + multiple_predictions, + ) def abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address): if self.base == {} or len(pred_res) not in self.len_list: @@ -211,7 +247,24 @@ class add_KB(ClsKB): class HWF_KB(ClsKB): def __init__( - self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7] + self, + GKB_flag=False, + pseudo_label_list=[ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "+", + "-", + "times", + "div", + ], + len_list=[1, 3, 5, 7], ): super().__init__(GKB_flag, pseudo_label_list, len_list) @@ -219,9 +272,19 @@ class HWF_KB(ClsKB): if len(formula) % 2 == 0: return False for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: + if i % 2 == 0 and formula[i] not in [ + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ]: return False - if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: + if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: return False return True @@ -229,22 +292,22 @@ class HWF_KB(ClsKB): if not self.valid_candidate(formula): return np.inf mapping = { - '1': '1', - '2': '2', - '3': '3', - '4': '4', - '5': '5', - '6': '6', - '7': '7', - '8': '8', - '9': '9', - '+': '+', - '-': '-', - 'times': '*', - 'div': '/', + "1": "1", + "2": "2", + "3": "3", + "4": "4", + "5": "5", + "6": "6", + "7": "7", + "8": "8", + "9": "9", + "+": "+", + "-": "-", + "times": "*", + "div": "/", } formula = [mapping[f] for f in formula] - return round(eval(''.join(formula)), 2) + return round(eval("".join(formula)), 2) class prolog_KB(KBBase): @@ -256,8 +319,12 @@ class prolog_KB(KBBase): def logic_forward(self): pass - def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): - return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) + def abduce_candidates( + self, pred_res, key, max_address_num, require_more_address, multiple_predictions + ): + return self.abduction( + pred_res, key, max_address_num, require_more_address, multiple_predictions + ) def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] @@ -265,7 +332,9 @@ class prolog_KB(KBBase): if not multiple_predictions: query_string = self.get_query_string(pred_res, key, address_idx) else: - query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) + query_string = self.get_query_string_need_flatten( + pred_res, key, address_idx + ) if multiple_predictions: save_pred_res = pred_res @@ -289,24 +358,28 @@ class add_prolog_KB(prolog_KB): super().__init__(pseudo_label_list) for i in self.pseudo_label_list: self.prolog.assertz("pseudo_label(%s)" % i) - self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") + self.prolog.assertz( + "addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2" + ) def logic_forward(self, nums): - return list(self.prolog.query("addition(%s, %s, Res)." % (nums[0], nums[1])))[0]['Res'] + return list(self.prolog.query("addition(%s, %s, Res)." % (nums[0], nums[1])))[ + 0 + ]["Res"] def get_query_string(self, pred_res, key, address_idx): query_string = "addition(" for idx, i in enumerate(pred_res): - tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' + tmp = "Z" + str(idx) + "," if idx in address_idx else str(i) + "," query_string += tmp query_string += "%s)." % key return query_string class HED_prolog_KB(prolog_KB): - def __init__(self, pseudo_label_list=[0, 1, '+', '=']): + def __init__(self, pseudo_label_list=[0, 1, "+", "="]): super().__init__(pseudo_label_list) - self.prolog.consult('./datasets/hed/learn_add.pl') + self.prolog.consult("./datasets/hed/learn_add.pl") # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` def logic_forward(self, exs): @@ -318,7 +391,7 @@ class HED_prolog_KB(prolog_KB): # add variables for prolog for idx in range(len(flatten_pred_res)): if idx in address_idx: - flatten_pred_res[idx] = 'X' + str(idx) + flatten_pred_res[idx] = "X" + str(idx) # unflatten new_pred_res = reform_idx(flatten_pred_res, pred_res) @@ -326,12 +399,30 @@ class HED_prolog_KB(prolog_KB): return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") def consist_rule(self, exs, rules): - rule_str = "%s" % rules - rule_str = rule_str.replace("'", "") - return len(list(self.prolog.query("consistent_inst_feature(%s, %s)." % (exs, rule_str)))) != 0 + consist = False + for rule in rules: + # print(rule) + if ( + len( + list( + self.prolog.query( + "consistent_inst_feature(%s, [%s])." % (exs, rule) + ) + ) + ) + != 0 + ): + consist = True + break + return consist def abduce_rules(self, pred_res): - prolog_rules = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))[0]['X'] + prolog_result = list( + self.prolog.query("consistent_inst_feature(%s, X)." % pred_res) + ) + if len(prolog_result) == 0: + return None + prolog_rules = prolog_result[0]["X"] rules = [] for rule in prolog_rules: rules.append(rule.value) diff --git a/datasets/hed/BK.pl b/datasets/hed/BK.pl new file mode 100644 index 0000000..441df4e --- /dev/null +++ b/datasets/hed/BK.pl @@ -0,0 +1,83 @@ +:- use_module(library(apply)). +:- use_module(library(lists)). +% :- use_module(library(tabling)). +% :- table valid_rules/2, op_rule/2. + +%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%% DCG parser for equations +%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%% symbols to be mapped +digit(1). +digit(0). + +% digits +digits([D]) --> [D], { digit(D) }. % empty list [] is not a digit +digits([D | T]) --> [D], !, digits(T), { digit(D) }. +digits(X):- + phrase(digits(X), X). +% More integrity constraints 1: +% This two clauses forbid the first digit to be 0. +% You may uncomment them to prune the search space +% length(X, L), +% (L > 1 -> X \= [0 | _]; true). + +% Equation definition +eq_arg([D]) --> [D], { \+ D == '+', \+ D == '=' }. +eq_arg([D | T]) --> [D], !, eq_arg(T), { \+ D == '+', \+ D == '=' }. +equation(eq(X, Y, Z)) --> + eq_arg(X), [+], eq_arg(Y), [=], eq_arg(Z). +% More integrity constraints 2: +% This clause restricts the length of arguments to be sane, +% You may uncomment them to prune the search space +% { length(X, LX), length(Y, LY), length(Z, LZ), +% LZ =< max(LX, LY) + 1, LZ >= max(LX, LY) }. +parse_eq(List_of_Terms, Eq) :- + phrase(equation(Eq), List_of_Terms). + +%%%%%%%%%%%%%%%%%%%%%% +%% Bit-wise operation +%%%%%%%%%%%%%%%%%%%%%% +% Abductive calculation with given pseudo-labels, abduces pseudo-labels as well as operation rules +calc(Rules, Pseudo) :- + calc([], Rules, Pseudo). +calc(Rules0, Rules1, Pseudo) :- + parse_eq(Pseudo, eq(X,Y,Z)), + bitwise_calc(Rules0, Rules1, X, Y, Z). + +% Bit-wise calculation that handles carrying +bitwise_calc(Rules, Rules1, X, Y, Z) :- + reverse(X, X1), reverse(Y, Y1), reverse(Z, Z1), + bitwise_calc_r(Rules, Rules1, X1, Y1, Z1), + maplist(digits, [X,Y,Z]). +bitwise_calc_r(Rs, Rs, [], Y, Y). +bitwise_calc_r(Rs, Rs, X, [], X). +bitwise_calc_r(Rules, Rules1, [D1 | X], [D2 | Y], [D3 | Z]) :- + abduce_op_rule(my_op([D1],[D2],Sum), Rules, Rules2), + ((Sum = [D3], Carry = []); (Sum = [C, D3], Carry = [C])), + bitwise_calc_r(Rules2, Rules3, X, Carry, X_carried), + bitwise_calc_r(Rules3, Rules1, X_carried, Y, Z). + +%%%%%%%%%%%%%%%%%%%%%%%%% +% Abduce operation rules +%%%%%%%%%%%%%%%%%%%%%%%%% +% Get an existed rule +abduce_op_rule(R, Rules, Rules) :- + member(R, Rules). +% Add a new rule +abduce_op_rule(R, Rules, [R|Rules]) :- + op_rule(R), + valid_rules(Rules, R). + +% Integrity Constraints +valid_rules([], _). +valid_rules([my_op([X1],[Y1],_)|Rs], my_op([X],[Y],Z)) :- + op_rule(my_op([X],[Y],Z)), + [X,Y] \= [X1,Y1], + [X,Y] \= [Y1,X1], + valid_rules(Rs, my_op([X],[Y],Z)). +valid_rules([my_op([Y],[X],Z)|Rs], my_op([X],[Y],Z)) :- + X \= Y, + valid_rules(Rs, my_op([X],[Y],Z)). + +op_rule(my_op([X],[Y],[Z])) :- digit(X), digit(Y), digit(Z). +op_rule(my_op([X],[Y],[Z1,Z2])) :- digit(X), digit(Y), digits([Z1,Z2]). diff --git a/datasets/hed/equation_generator.py b/datasets/hed/equation_generator.py new file mode 100644 index 0000000..3cf735b --- /dev/null +++ b/datasets/hed/equation_generator.py @@ -0,0 +1,266 @@ +import os +import itertools +import random +import numpy as np +from PIL import Image +import pickle + + +def get_sign_path_list(data_dir, sign_names): + sign_num = len(sign_names) + index_dict = dict(zip(sign_names, list(range(sign_num)))) + ret = [[] for _ in range(sign_num)] + for path in os.listdir(data_dir): + if path in sign_names: + index = index_dict[path] + sign_path = os.path.join(data_dir, path) + for p in os.listdir(sign_path): + ret[index].append(os.path.join(sign_path, p)) + return ret + + +def split_pool_by_rate(pools, rate, seed=None): + if seed is not None: + random.seed(seed) + ret1 = [] + ret2 = [] + for pool in pools: + random.shuffle(pool) + num = int(len(pool) * rate) + ret1.append(pool[:num]) + ret2.append(pool[num:]) + return ret1, ret2 + + +def int_to_system_form(num, system_num): + if num is 0: + return "0" + ret = "" + while num > 0: + ret += str(num % system_num) + num //= system_num + return ret[::-1] + + +def generator_equations( + left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type +): + expr_len = left_opt_len + right_opt_len + num_list = "".join([str(i) for i in range(system_num)]) + ret = [] + if generate_type == "all": + candidates = itertools.product(num_list, repeat=expr_len) + else: + candidates = ["".join(random.sample(["0", "1"] * expr_len, expr_len))] + random.shuffle(candidates) + for nums in candidates: + left_num = "".join(nums[:left_opt_len]) + right_num = "".join(nums[left_opt_len:]) + left_value = int(left_num, system_num) + right_value = int(right_num, system_num) + result_value = left_value + right_value + if label == "negative": + result_value += random.randint(-result_value, result_value) + if left_value + right_value == result_value: + continue + result_num = int_to_system_form(result_value, system_num) + # leading zeros + if res_opt_len != len(result_num): + continue + if (left_opt_len > 1 and left_num[0] == "0") or ( + right_opt_len > 1 and right_num[0] == "0" + ): + continue + + # add leading zeros + if res_opt_len < len(result_num): + continue + while len(result_num) < res_opt_len: + result_num = "0" + result_num + # continue + ret.append( + left_num + "+" + right_num + "=" + result_num + ) # current only consider '+' and '=' + # print(ret[-1]) + return ret + + +def generator_equation_by_len(equation_len, system_num=2, label=0, require_num=1): + generate_type = "one" + ret = [] + equation_sign_num = 2 # '+' and '=' + while len(ret) < require_num: + left_opt_len = random.randint(1, equation_len - 1 - equation_sign_num) + right_opt_len = random.randint( + 1, equation_len - left_opt_len - equation_sign_num + ) + res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num + ret.extend( + generator_equations( + left_opt_len, + right_opt_len, + res_opt_len, + system_num, + label, + generate_type, + ) + ) + return ret + + +def generator_equations_by_len( + equation_len, system_num=2, label=0, repeat_times=1, keep=1, generate_type="all" +): + ret = [] + equation_sign_num = 2 # '+' and '=' + for left_opt_len in range(1, equation_len - (2 + equation_sign_num) + 1): + for right_opt_len in range( + 1, equation_len - left_opt_len - (1 + equation_sign_num) + 1 + ): + res_opt_len = ( + equation_len - left_opt_len - right_opt_len - equation_sign_num + ) + for i in range(repeat_times): # generate more equations + if random.random() > keep ** (equation_len): + continue + ret.extend( + generator_equations( + left_opt_len, + right_opt_len, + res_opt_len, + system_num, + label, + generate_type, + ) + ) + return ret + + +def generator_equations_by_max_len( + max_equation_len, + system_num=2, + label=0, + repeat_times=1, + keep=1, + generate_type="all", + num_per_len=None, +): + ret = [] + equation_sign_num = 2 # '+' and '=' + for equation_len in range(3 + equation_sign_num, max_equation_len + 1): + if num_per_len is None: + ret.extend( + generator_equations_by_len( + equation_len, system_num, label, repeat_times, keep, generate_type + ) + ) + else: + ret.extend( + generator_equation_by_len( + equation_len, system_num, label, require_num=num_per_len + ) + ) + return ret + + +def generator_equation_images(image_pools, equations, signs, shape, seed, is_color): + if seed is not None: + random.seed(seed) + ret = [] + sign_num = len(signs) + sign_index_dict = dict(zip(signs, list(range(sign_num)))) + for equation in equations: + data = [] + for sign in equation: + index = sign_index_dict[sign] + pick = random.randint(0, len(image_pools[index]) - 1) + if is_color: + image = ( + Image.open(image_pools[index][pick]).convert("RGB").resize(shape) + ) + else: + image = Image.open(image_pools[index][pick]).convert("I").resize(shape) + image_array = np.array(image) + image_array = (image_array - 127) * (1.0 / 128) + data.append(image_array) + ret.append(np.array(data)) + return ret + + +def get_equation_std_data( + data_dir, + sign_dir_lists, + sign_output_lists, + shape=(28, 28), + train_max_equation_len=10, + test_max_equation_len=10, + system_num=2, + tmp_file_prev=None, + seed=None, + train_num_per_len=10, + test_num_per_len=10, + is_color=False, +): + tmp_file = "" + if tmp_file_prev is not None: + tmp_file = "%s_train_len_%d_test_len_%d_sys_%d_.pk" % ( + tmp_file_prev, + train_max_equation_len, + test_max_equation_len, + system_num, + ) + if os.path.exists(tmp_file): + return pickle.load(open(tmp_file, "rb")) + + image_pools = get_sign_path_list(data_dir, sign_dir_lists) + train_pool, test_pool = split_pool_by_rate(image_pools, 0.8, seed) + + ret = {} + for label in ["positive", "negative"]: + print("Generating equations.") + train_equations = generator_equations_by_max_len( + train_max_equation_len, system_num, label, num_per_len=train_num_per_len + ) + test_equations = generator_equations_by_max_len( + test_max_equation_len, system_num, label, num_per_len=test_num_per_len + ) + print(train_equations) + print(test_equations) + print("Generated equations.") + print("Generating equation image data.") + ret["train:%s" % (label)] = generator_equation_images( + train_pool, train_equations, sign_output_lists, shape, seed, is_color + ) + ret["test:%s" % (label)] = generator_equation_images( + test_pool, test_equations, sign_output_lists, shape, seed, is_color + ) + print("Generated equation image data.") + + if tmp_file_prev is not None: + pickle.dump(ret, open(tmp_file, "wb")) + return ret + + +if __name__ == "__main__": + data_dirs = [ + "./dataset/mnist_images", + "./dataset/random_images", + ] # , "../dataset/cifar10_images"] + tmp_file_prevs = [ + "mnist_equation_data", + "random_equation_data", + ] # , "cifar10_equation_data"] + for data_dir, tmp_file_prev in zip(data_dirs, tmp_file_prevs): + data = get_equation_std_data( + data_dir=data_dir, + sign_dir_lists=["0", "1", "10", "11"], + sign_output_lists=["0", "1", "+", "="], + shape=(28, 28), + train_max_equation_len=26, + test_max_equation_len=26, + system_num=2, + tmp_file_prev=tmp_file_prev, + train_num_per_len=300, + test_num_per_len=300, + is_color=False, + ) diff --git a/datasets/hed/get_hed.py b/datasets/hed/get_hed.py new file mode 100644 index 0000000..8b61b62 --- /dev/null +++ b/datasets/hed/get_hed.py @@ -0,0 +1,130 @@ +import os +import cv2 +import torch +import torchvision +import pickle +import numpy as np +import random +from collections import defaultdict +from torch.utils.data import Dataset +from torchvision.transforms import transforms + + +def get_data(img_dataset, train): + transform = transforms.Compose([transforms.ToTensor()]) + X = [] + Y = [] + if train: + positive = img_dataset["train:positive"] + negative = img_dataset["train:negative"] + else: + positive = img_dataset["test:positive"] + negative = img_dataset["test:negative"] + + for equation in positive: + equation = equation.astype(np.float32) + img_list = np.vsplit(equation, equation.shape[0]) + X.append(img_list) + Y.append(1) + + for equation in negative: + equation = equation.astype(np.float32) + img_list = np.vsplit(equation, equation.shape[0]) + X.append(img_list) + Y.append(0) + + return X, None, Y + + +def get_pretrain_data(labels, image_size=(28, 28, 1)): + transform = transforms.Compose([transforms.ToTensor()]) + X = [] + for label in labels: + label_path = os.path.join( + "./datasets/hed/dataset/mnist_images", label + ) + img_path_list = os.listdir(label_path) + for img_path in img_path_list: + img = cv2.imread( + os.path.join(label_path, img_path), cv2.IMREAD_GRAYSCALE + ) + img = cv2.resize(img, (image_size[1], image_size[0])) + X.append(np.array(img, dtype=np.float32)) + + X = [((img[:, :, np.newaxis] - 127) / 128.0) for img in X] + Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X] + + X = [transform(img) for img in X] + return X, Y + + +# def get_pretrain_data(train_data, image_size=(28, 28, 1)): +# X = [] +# for label in [0, 1]: +# for _, equation_list in train_data[label].items(): +# for equation in equation_list: +# X = X + equation + +# X = np.array(X) +# index = np.array(list(range(len(X)))) +# np.random.shuffle(index) +# X = X[index] + +# X = [img for img in X] +# Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X] + +# return X, Y + + +def divide_equations_by_len(equations, labels): + equations_by_len = {1: defaultdict(list), 0: defaultdict(list)} + for i, equation in enumerate(equations): + equations_by_len[labels[i]][len(equation)].append(equation) + return equations_by_len + + +def split_equation(equations_by_len, prop_train, prop_val): + """ + Split the equations in each length to training and validation data according to the proportion + """ + train_equations_by_len = {1: dict(), 0: dict()} + val_equations_by_len = {1: dict(), 0: dict()} + + for label in range(2): + for equation_len, equations in equations_by_len[label].items(): + random.shuffle(equations) + train_equations_by_len[label][equation_len] = equations[ + : len(equations) // (prop_train + prop_val) * prop_train + ] + val_equations_by_len[label][equation_len] = equations[ + len(equations) // (prop_train + prop_val) * prop_train : + ] + + return train_equations_by_len, val_equations_by_len + + +def get_hed(dataset="mnist", train=True): + + if dataset == "mnist": + with open( + "./datasets/hed/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk", + "rb", + ) as f: + img_dataset = pickle.load(f) + elif dataset == "random": + with open( + "./datasets/hed/random_equation_data_train_len_26_test_len_26_sys_2_.pk", + "rb", + ) as f: + img_dataset = pickle.load(f) + else: + raise Exception("Undefined dataset") + + X, _, Y = get_data(img_dataset, train) + equations_by_len = divide_equations_by_len(X, Y) + + return equations_by_len + + +if __name__ == "__main__": + get_hed() diff --git a/datasets/hed/learn_add.pl b/datasets/hed/learn_add.pl new file mode 100644 index 0000000..af1d6bb --- /dev/null +++ b/datasets/hed/learn_add.pl @@ -0,0 +1,81 @@ +:- ensure_loaded(['BK.pl']). +:- thread_setconcurrency(_, 8). +:- use_module(library(thread)). + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%% For propositionalisation +%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +eval_inst_feature(Ex, Feature):- + eval_eq(Ex, Feature). + +%% Evaluate instance given feature +eval_eq(Ex, Feature):- + parse_eq(Ex, eq(X,Y,Z)), + bitwise_calc(Feature,_,X,Y,Z), !. + +%%%%%%%%%%%%%% +%% Abduction +%%%%%%%%%%%%%% +% Make abduction when given examples that have been interpreted as pseudo-labels +abduce(Exs, Delta_C) :- + abduce(Exs, [], Delta_C). +abduce([], Delta_C, Delta_C). +abduce([E|Exs], Delta_C0, Delta_C1) :- + calc(Delta_C0, Delta_C2, E), + abduce(Exs, Delta_C2, Delta_C1). + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%% Abduce pseudo-labels only +%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +abduce_consistent_insts(Exs):- + abduce(Exs, _), !. +% (Experimental) Uncomment to use parallel abduction +% abduce_consistent_exs_concurrent(Exs), !. + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%% Abduce Delta_C given pseudo-labels +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +consistent_inst_feature(Exs, Delta_C):- + abduce(Exs, Delta_C), !. + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%% (Experimental) Parallel abduction +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +abduce_consistent_exs_concurrent(Exs) :- + % Split the current data batch into grounding examples and variable examples (which need to be revised) + split_exs(Exs, Ground_Exs, Var_Exs), + % Find the simplest Delta_C for grounding examples. + abduce(Ground_Exs, Ground_Delta_C), !, + % Extend Ground Delta_C into all possible variations + extend_op_rule(Ground_Delta_C, Possible_Deltas), + % Concurrently abduce the variable examples + maplist(append([abduce2, Var_Exs, Ground_Exs]), [[Possible_Deltas]], Call_List), + maplist(=.., Goals, Call_List), + % writeln(Goals), + first_solution(Var_Exs, Goals, [local(inf)]). + +split_exs([],[],[]). +split_exs([E | Exs], [E | G_Exs], V_Exs):- + ground(E), !, + split_exs(Exs, G_Exs, V_Exs). +split_exs([E | Exs], G_Exs, [E | V_Exs]):- + split_exs(Exs, G_Exs, V_Exs). + +:- table extend_op_rule/2. + +extend_op_rule(Rules, Rules) :- + length(Rules, 4). +extend_op_rule(Rules, Ext) :- + op_rule(R), + valid_rules(Rules, R), + extend_op_rule([R|Rules], Ext). + +% abduction without learning new Delta_C (Because they have been extended!) +abduce2([], _, _). +abduce2([E|Exs], Ground_Exs, Delta_C) :- + % abduce by finding ground examples + member(E, Ground_Exs), + abduce2(Exs, Ground_Exs, Delta_C). +abduce2([E|Exs], Ground_Exs, Delta_C) :- + eval_inst_feature(E, Delta_C), + abduce2(Exs, Ground_Exs, Delta_C). diff --git a/example.py b/example.py index 7ceaeff..795d455 100644 --- a/example.py +++ b/example.py @@ -1,60 +1,102 @@ # coding: utf-8 -#================================================================# +# ================================================================# # Copyright (C) 2021 Freecss All rights reserved. -# +# # File Name :share_example.py # Author :freecss # Email :karlfreecss@gmail.com # Created Date :2021/06/07 # Description : # -#================================================================# +# ================================================================# -from utils.plog import logger -import framework +from utils.plog import logger, INFO +import framework_hed import torch.nn as nn import torch -from models.lenet5 import LeNet5, SymbolNet -from models.basic_model import BasicModel + +from models.nn import LeNet5, SymbolNet, SymbolNetAutoencoder +from models.basic_model import BasicModel, BasicDataset from models.wabl_models import WABLBasicModel from multiprocessing import Pool import os from abducer.abducer_base import AbducerBase -from abducer.kb import add_KB, hwf_KB +from abducer.kb import add_KB, HWF_KB, HED_prolog_KB from datasets.mnist_add.get_mnist_add import get_mnist_add from datasets.hwf.get_hwf import get_hwf +from datasets.hed.get_hed import get_hed, get_pretrain_data, split_equation + def run_test(): # kb = add_KB(True) - kb = hwf_KB(True) - abducer = AbducerBase(kb) + + # kb = hwf_KB(True) + # abducer = AbducerBase(kb) + + kb = HED_prolog_KB() + abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) recorder = logger() - # train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True) - # test_X, test_Z, test_Y = get_mnist_add(train = False, get_pseudo_label = True) - - train_data = get_hwf(train = True, get_pseudo_label = True) - test_data = get_hwf(train = False, get_pseudo_label = True) + # train_X, train_Z, train_Y = get_mnist_add(train=True, get_pseudo_label=True) + # test_X, test_Z, test_Y = get_mnist_add(train=False, get_pseudo_label=True) + + # train_data = get_hwf(train=True, get_pseudo_label=True) + # test_data = get_hwf(train=False, get_pseudo_label=True) + + total_train_data = get_hed(train=True) + train_data, val_data = split_equation(total_train_data, 3, 1) + test_data = get_hed(train=False) # cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:])) + cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list)) cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) - + + if not os.path.exists("./weights/pretrain_weights.pth"): + pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"]) + pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y) + pretrain_data_loader = torch.utils.data.DataLoader( + pretrain_data, + batch_size=64, + shuffle=True, + ) + framework_hed.pretrain(cls_autoencoder, pretrain_data_loader, recorder) + torch.save( + cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth" + ) + cls.load_state_dict(torch.load("./weights/pretrain_weights.pth")) + criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) + optimizer = torch.optim.RMSprop( + cls.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6 + ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=1, recorder=recorder) + + base_model = BasicModel( + cls, + criterion, + optimizer, + device, + save_interval=1, + save_dir=recorder.save_dir, + batch_size=32, + num_epochs=10, + recorder=recorder, + ) + model = WABLBasicModel(base_model, kb.pseudo_label_list) - res = framework.train(model, abducer, train_data, test_data, sample_num = 10000, verbose = 1) - recorder.print(res) - + res = framework_hed.train_with_rule( + model, abducer, train_data, val_data, recorder=recorder + ) + INFO(res) + recorder.dump() return True + if __name__ == "__main__": run_test() diff --git a/framework_hed.py b/framework_hed.py index 1942e5c..04b41b8 100644 --- a/framework_hed.py +++ b/framework_hed.py @@ -1,261 +1,395 @@ +# coding: utf-8 +# ================================================================# +# Copyright (C) 2021 Freecss All rights reserved. +# +# File Name :framework.py +# Author :freecss +# Email :karlfreecss@gmail.com +# Created Date :2021/06/07 +# Description : +# +# ================================================================# + +import pickle as pk + import numpy as np -from utils.utils import flatten, reform_idx - - -def get_rules_from_data(equations_true): - SAMPLES_PER_RULE = 3 - - select_index = np.random.randint(len(equations_true), size=SAMPLES_PER_RULE) - select_equations = np.array(equations_true)[select_index] - - -def get_consist_idx(exs, abducer): - consistent_ex_idx = [] - label = [] - for idx, e in enumerate(exs): - if abducer.kb.logic_forward([e]): - consistent_ex_idx.append(idx) - label.append(e) - return consistent_ex_idx, label - -def get_label(exs, solution, abducer): - all_address_flag = reform_idx(solution, exs) - consistent_ex_idx = [] - label = [] - for idx, ex in enumerate(exs): - address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = abducer.kb.address_by_idx([ex], 1, address_idx, True) - if len(candidate) > 0: - consistent_ex_idx.append(idx) - label.append(candidate[0][0]) - return consistent_ex_idx, label - - -def get_percentage_precision(select_X, consistent_ex_idx, equation_label): - - images = [] - for idx in consistent_ex_idx: - images.append(select_X[idx]) - - ## TODO - model_labels = model.predict(images) - - assert(len(flatten(model_labels)) == len(flatten(equation_label))) - return (flatten(model_labels) == flatten(equation_label)).sum() / len(flatten(model_labels)) - - - - -def abduce_and_train(model, abducer, train_X_true, select_num): +import random + +random_seed = random.randint(0, 10000) +print("Selected random seed is : ", random_seed) +np.random.seed(random_seed) +random.seed(random_seed) + +from models.nn import MLP +from models.basic_model import BasicModel, BasicDataset +import torch.nn as nn +import torch + +from utils.plog import INFO, DEBUG, clocker +from utils.utils import flatten, reform_idx, block_sample + +from sklearn.tree import DecisionTreeClassifier + + +def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): + result = {} + if char_acc_flag: + char_acc_num = 0 + char_num = 0 + for pred_z, z in zip(pred_Z, Z): + char_num += len(z) + for zidx in range(len(z)): + if pred_z[zidx] == z[zidx]: + char_acc_num += 1 + char_acc = char_acc_num / char_num + result["Character level accuracy"] = char_acc + + abl_acc_num = 0 + for pred_z, y in zip(pred_Z, Y): + if logic_forward(pred_z) == y: + abl_acc_num += 1 + abl_acc = abl_acc_num / len(Y) + result["ABL accuracy"] = abl_acc + + return result + + +def filter_data(X, abduced_Z): + finetune_Z = [] + finetune_X = [] + for abduced_x, abduced_z in zip(X, abduced_Z): + if abduced_z is not []: + finetune_X.append(abduced_x) + finetune_Z.append(abduced_z) + return finetune_X, finetune_Z + + +def pretrain(net, pretrain_data_loader, recorder): + INFO("Pretrain Start") + + criterion = nn.MSELoss() + optimizer = torch.optim.RMSprop( + net.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6 + ) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + pretrain_model = BasicModel( + net, + criterion, + optimizer, + device, + save_interval=1, + save_dir=recorder.save_dir, + num_epochs=10, + recorder=recorder, + ) + + pretrain_model.fit(pretrain_data_loader) + + +def get_char_acc(model, X, consistent_pred_res): + pred_res = flatten(model.predict(X)["cls"]) + assert len(pred_res) == len(flatten(consistent_pred_res)) + return sum( + [ + pred_res[idx] == flatten(consistent_pred_res)[idx] + for idx in range(len(pred_res)) + ] + ) / len(pred_res) + + +def gen_mappings(chars, symbs): + n_char = len(chars) + n_symbs = len(symbs) + if n_char != n_symbs: + INFO("Characters and symbols size dosen't match.") + return + from itertools import permutations + + mappings = [] + perms = permutations(symbs) + for p in perms: + mappings.append(dict(zip(chars, list(p)))) + return mappings + + +def map_res(original_pred_res, m): + pred_res = [[m[symbol] for symbol in formula] for formula in original_pred_res] + return pred_res + - select_index = np.random.randint(len(train_X_true), size=select_num) - select_X = train_X_true[select_index] - - - - exs = select_X.predict() - # e.g. when select_num == 10, exs = [[1, '+', 0, '=', 1, 0], [1, '+', 0, '=', 1, 0], [1, '+', 0, '=', 1, 0], [0, '+', 0, '=', 0], [1, '+', 0, '=', 1, 0],\ - # [1, '+', 0, '=', 1, 0], [1, '+', 0, '=', 1, 0], [1, '+', 0, '=', 1, 0], [0, '+', 0, '=', 0], [1, '+', 0, '=', 1, 0]] - - print("This is the model's current label:", exs) - - # 1. Check if it can abduce rules without changing any labels - consistent_ex_idx, equation_label = get_consist_idx(exs) - - - max_abduce_num = 10 - if len(consistent_ex_idx) == 0: - - # 2. Find the possible wrong position in symbols and Abduce the right symbol through logic module - solution = abducer.zoopt_get_solution(exs, [1] * len(exs), max_abduce_num) - consistent_ex_idx, equation_label = get_label(exs, solution, abducer) - - # Still cannot find - if len(consistent_ex_idx) == 0: - return 0, 0 - - - ## TODO: train - # train_pool_X = np.concatenate(select_X[consistent_ex_idx]).reshape( - # -1, h, w, d) - # train_pool_Y = np_utils.to_categorical( - # flatten(exs[consistent_ex_idx]), - # num_classes=len(labels)) # Convert the symbol to network output - # assert (len(train_pool_X) == len(train_pool_Y)) - # print("\nTrain pool size is :", len(train_pool_X)) - # print("Training...") - # base_model.fit(train_pool_X, - # train_pool_Y, - # batch_size=BATCHSIZE, - # epochs=NN_EPOCHS, - # verbose=0) - - # consistent_percentage, batch_label_model_precision = get_percentage_precision( - # base_model, select_equations, consist_re, shape) - - consistent_percentage = len(consistent_ex_idx) / select_num - batch_label_model_precision = get_percentage_precision(exs, consistent_ex_idx, equation_label) - - return consistent_percentage, batch_label_model_precision - -def get_rules(exs): - consistent_ex_idx, equation_label = get_consist_idx(exs) - consist_exs = [] - for idx in consistent_ex_idx: - consist_exs.append(exs[idx]) - if len(consist_exs) == 0: - return None - else: - return abducer.abduce_rule(consist_exs) - - - -def get_rules_from_data(train_X_true, samples_per_rule, logic_output_dim): +def abduce_and_train(model, abducer, train_X_true, select_num): + select_idx = np.random.randint(len(train_X_true), size=select_num) + X = [] + for idx in select_idx: + X.append(train_X_true[idx]) + + pred_res = model.predict(X)["cls"] + + maps = gen_mappings(["+", "=", 0, 1], ["+", "=", 0, 1]) + + consistent_idx = [] + consistent_pred_res = [] + + import copy + + original_pred_res = copy.deepcopy(pred_res) + mapping = None + + for m in maps: + pred_res = map_res(original_pred_res, m) + remapping = {} + for key, value in m.items(): + remapping[value] = key + + max_abduce_num = 10 + solution = abducer.zoopt_get_solution( + pred_res, [1] * len(pred_res), max_abduce_num + ) + all_address_flag = reform_idx(solution, pred_res) + + consistent_idx_tmp = [] + consistent_pred_res_tmp = [] + + for idx in range(len(pred_res)): + address_idx = [ + i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 + ] + candidate = abducer.kb.address_by_idx([pred_res[idx]], 1, address_idx, True) + if len(candidate) > 0: + consistent_idx_tmp.append(idx) + consistent_pred_res_tmp.append( + [remapping[symbol] for symbol in candidate[0][0]] + ) + + if len(consistent_idx_tmp) > len(consistent_idx): + consistent_idx = consistent_idx_tmp + consistent_pred_res = consistent_pred_res_tmp + mapping = m + + if len(consistent_idx) == 0: + return 0, 0, None + + INFO("Consistent predict results are:", map_res(consistent_pred_res, mapping)) + INFO("Train pool size is:", len(flatten(consistent_pred_res))) + INFO("Start to use abduced pseudo label to train model...") + model.train([X[idx] for idx in consistent_idx], consistent_pred_res) + + consistent_acc = len(consistent_idx) / select_num + char_acc = get_char_acc( + model, [X[idx] for idx in consistent_idx], consistent_pred_res + ) + INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc)) + return consistent_acc, char_acc, mapping + + +def get_rules_from_data( + model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim +): rules = [] for _ in range(logic_output_dim): while True: - select_index = np.random.randint(len(train_X_true), size=samples_per_rule) - select_X = train_X_true[select_index] - - ## TODO - exs = select_X.predict() - rule = get_rules(exs) - if rule != None: - break + select_idx = np.random.randint(len(train_X_true), size=samples_per_rule) + X = [] + for idx in select_idx: + X.append(train_X_true[idx]) + pred_res = model.predict(X)["cls"] + pred_res = [[mapping[symbol] for symbol in formula] for formula in pred_res] + + consistent_idx = [] + consistent_pred_res = [] + for idx in range(len(pred_res)): + if abducer.kb.logic_forward([pred_res[idx]]): + consistent_idx.append(idx) + consistent_pred_res.append(pred_res[idx]) + + if len(consistent_pred_res) != 0: + rule = abducer.abduce_rules(consistent_pred_res) + if rule != None: + break + rules.append(rule) + INFO('Learned rules from data:') + for rule in rules: + INFO(rule) return rules -def get_mlp_vector(X, rules): - - ## TODO - exs = np.argmax(model.predict(X)) - +def get_mlp_vector(model, abducer, mapping, X, rules): + pred_res = model.predict([X])["cls"] + pred_res = [[mapping[symbol] for symbol in formula] for formula in pred_res] vector = [] for rule in rules: - if abducer.kb.consist_rule(exs, rule): + if abducer.kb.consist_rule(pred_res, rule): vector.append(1) else: vector.append(0) return vector -def get_mlp_data(X_true, X_false, rules): + +def get_mlp_data(model, abducer, mapping, X_true, X_false, rules): mlp_vectors = [] mlp_labels = [] for X in X_true: - mlp_vectors.append(get_mlp_vector(X, rules)) + mlp_vectors.append(get_mlp_vector(model, abducer, mapping, X, rules)) mlp_labels.append(1) for X in X_false: - mlp_vectors.append(get_mlp_vector(X, rules)) + mlp_vectors.append(get_mlp_vector(model, abducer, mapping, X, rules)) mlp_labels.append(0) - - return np.array(mlp_vectors), np.array(mlp_labels) + + return np.array(mlp_vectors, dtype=np.float32), np.array(mlp_labels, dtype=np.int64) -def validation(train_X_true, train_X_false, val_X_true, val_X_false): - print("Now checking if we can go to next course") +def validation( + model, + abducer, + mapping, + train_X_true, + train_X_false, + val_X_true, + val_X_false, + recorder, +): + INFO("Now checking if we can go to next course") samples_per_rule = 3 logic_output_dim = 50 - print("Now checking if we can go to next course") - rules = get_rules_from_data(train_X_true, samples_per_rule, logic_output_dim) - mlp_train_vectors, mlp_train_labels = get_mlp_data(train_X_true, train_X_false, rules) - - index = np.array(list(range(len(mlp_train_labels)))) - np.random.shuffle(index) - mlp_train_vectors = mlp_train_vectors[index] - mlp_train_labels = mlp_train_labels[index] - + rules = get_rules_from_data( + model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim + ) + + mlp_train_vectors, mlp_train_labels = get_mlp_data( + model, abducer, mapping, train_X_true, train_X_false, rules + ) + + idx = np.array(list(range(len(mlp_train_labels)))) + np.random.shuffle(idx) + mlp_train_vectors = mlp_train_vectors[idx] + mlp_train_labels = mlp_train_labels[idx] + best_accuracy = 0 - - #Try three times to find the best mlp + + # Try three times to find the best mlp for _ in range(3): - print("Training mlp...") - - ### TODO - # mlp_model = get_mlp_net(logic_output_dim) - # mlp_model.compile(loss='binary_crossentropy', - # optimizer='rmsprop', - # metrics=['accuracy']) - # mlp_model.fit(mlp_train_vectors, - # mlp_train_labels, - # epochs=MLP_EPOCHS, - # batch_size=MLP_BATCHSIZE, - # verbose=0) - #Prepare MLP validation data - - mlp_val_vectors, mlp_val_labels = get_mlp_data(val_X_true, val_X_false, rules) - - ## TODO - #Get MLP validation result - # result = mlp_model.evaluate(mlp_val_vectors, - # mlp_val_labels, - # batch_size=MLP_BATCHSIZE, - # verbose=0) - print("MLP validation result:", result) - accuracy = result[1] + INFO("Training mlp...") + mlp = MLP(input_dim=logic_output_dim) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999)) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + mlp_model = BasicModel( + mlp, + criterion, + optimizer, + device, + batch_size=128, + num_epochs=60, + recorder=recorder, + ) + mlp_train_data = BasicDataset(mlp_train_vectors, mlp_train_labels) + mlp_train_data_loader = torch.utils.data.DataLoader( + mlp_train_data, + batch_size=128, + shuffle=True, + ) + loss = mlp_model.fit(mlp_train_data_loader) + INFO("mlp training loss is %f" % loss) + + mlp_val_vectors, mlp_val_labels = get_mlp_data( + model, abducer, mapping, val_X_true, val_X_false, rules + ) + + # Get MLP validation result + mlp_val_data = BasicDataset(mlp_val_vectors, mlp_val_labels) + mlp_val_data_loader = torch.utils.data.DataLoader( + mlp_val_data, + batch_size=64, + shuffle=True, + ) + + accuracy = mlp_model.val(mlp_val_data_loader) if accuracy > best_accuracy: best_accuracy = accuracy return best_accuracy - -def train_HED(model, abducer, train_data, test_data, epochs=50, select_num=10, verbose=-1): - train_X, train_Z, train_Y = train_data - test_X, test_Z, test_Y = test_data +def train_with_rule( + model, + abducer, + train_data, + val_data, + select_num=10, + recorder=None +): + train_X = train_data + val_X = val_data min_len = 5 max_len = 8 - cp_threshold = 0.9 - blmp_threshold = 0.9 - - cnt_threshold = 5 - acc_threshold = 0.86 - # Start training / for each length of equations for equation_len in range(min_len, max_len): - - ### TODO: get_data, e.g. - # train_X_true = train_X['True'][equation_len] - # train_X_true.append(train_X['True'][equation_len + 1]) - + INFO("============== equation_len:%d ================" % (equation_len)) + train_X_true = train_X[1][equation_len] + train_X_false = train_X[0][equation_len] + val_X_true = val_X[1][equation_len] + val_X_false = val_X[0][equation_len] + train_X_true.extend(train_X[1][equation_len + 1]) + train_X_false.extend(train_X[0][equation_len + 1]) + val_X_true.extend(val_X[1][equation_len + 1]) + val_X_false.extend(val_X[0][equation_len + 1]) + condition_cnt = 0 while True: # Abduce and train NN - consistent_percentage, batch_label_model_precision = abduce_and_train(model, abducer, train_X_true, select_num) - if consistent_percentage == 0: + consistent_acc, char_acc, mapping = abduce_and_train( + model, abducer, train_X_true, select_num + ) + if consistent_acc == 0: continue # Test if we can use mlp to evaluate - if consistent_percentage >= cp_threshold and batch_label_model_precision >= blmp_threshold: + if consistent_acc >= 0.9 and char_acc >= 0.9: condition_cnt += 1 else: condition_cnt = 0 + # The condition has been satisfied continuously five times - if condition_cnt >= cnt_threshold: - best_accuracy = validation(train_X_true, train_X_false, val_X_true, val_X_false) + if condition_cnt >= 5: + best_accuracy = validation( + model, + abducer, + mapping, + train_X_true, + train_X_false, + val_X_true, + val_X_false, + recorder, + ) + + INFO("best_accuracy is %f" % (best_accuracy)) # decide next course or restart - if best_accuracy > acc_threshold: - # Save model and go to next course - ## TODO: model.save_weights() + if best_accuracy > 0.86: + torch.save( + model.cls_list[0].model.state_dict(), + "./weights/train_weights_%d.pth" % equation_len, + ) break - else: - # Restart current course: reload model if equation_len == min_len: - ## TODO: model.set_weights(pretrain_model.get_weights()) - model.set_weights() + model.cls_list[0].model.load_state_dict( + torch.load("./weights/pretrain_weights.pth") + ) else: - ## TODO: model.load_weights() - model.load_weights() - print("Failed! Reload model.") + model.cls_list[0].model.load_state_dict( + torch.load( + "./weights/train_weights_%d.pth" % (equation_len - 1) + ) + ) condition_cnt = 0 - - return model diff --git a/models/basic_model.py b/models/basic_model.py index 11063a8..8c137e2 100644 --- a/models/basic_model.py +++ b/models/basic_model.py @@ -1,16 +1,17 @@ # coding: utf-8 -#================================================================# +# ================================================================# # Copyright (C) 2020 Freecss All rights reserved. -# +# # File Name :basic_model.py # Author :freecss # Email :karlfreecss@gmail.com # Created Date :2020/11/21 # Description : # -#================================================================# +# ================================================================# import sys + sys.path.append("..") import torch @@ -19,6 +20,24 @@ from torch.utils.data import Dataset import os from multiprocessing import Pool + +class BasicDataset(Dataset): + def __init__(self, X, Y): + self.X = X + self.Y = Y + + def __len__(self): + return len(self.X) + + def __getitem__(self, index): + assert index < len(self), "index range error" + + img = self.X[index] + label = self.Y[index] + + return (img, label) + + class XYDataset(Dataset): def __init__(self, X, Y, transform=None): self.X = X @@ -31,8 +50,8 @@ class XYDataset(Dataset): return len(self.X) def __getitem__(self, index): - assert index < len(self), 'index range error' - + assert index < len(self), "index range error" + img = self.X[index] if self.transform is not None: img = self.transform(img) @@ -41,31 +60,35 @@ class XYDataset(Dataset): return (img, label) -class FakeRecorder(): + +class FakeRecorder: def __init__(self): pass def print(self, *x): pass -class BasicModel(): - def __init__(self, - model, - criterion, - optimizer, - device, - batch_size = 1, - num_epochs = 1, - stop_loss = 0.01, - num_workers = 0, - save_interval = None, - save_dir = None, - transform = None, - collate_fn = None, - recorder = None): + +class BasicModel: + def __init__( + self, + model, + criterion, + optimizer, + device, + batch_size=1, + num_epochs=1, + stop_loss=0.01, + num_workers=0, + save_interval=None, + save_dir=None, + transform=None, + collate_fn=None, + recorder=None, + ): self.model = model.to(device) - + self.batch_size = batch_size self.num_epochs = num_epochs self.stop_loss = stop_loss @@ -103,9 +126,7 @@ class BasicModel(): recorder.print("Model fitted, minimal loss is ", min_loss) return loss_value - def fit(self, data_loader = None, - X = None, - y = None): + def fit(self, data_loader=None, X=None, y=None): if data_loader is None: data_loader = self._data_loader(X, y) return self._fit(data_loader, self.num_epochs, self.stop_loss) @@ -115,7 +136,7 @@ class BasicModel(): criterion = self.criterion optimizer = self.optimizer device = self.device - + model.train() total_loss, total_num = 0.0, 0 @@ -136,7 +157,7 @@ class BasicModel(): def _predict(self, data_loader): model = self.model device = self.device - + model.eval() with torch.no_grad(): @@ -145,20 +166,20 @@ class BasicModel(): data = data.to(device) out = model(data) results.append(out) - + return torch.cat(results, axis=0) - def predict(self, data_loader = None, X = None, print_prefix = ""): + def predict(self, data_loader=None, X=None, print_prefix=""): recorder = self.recorder - recorder.print('Start Predict Class ', print_prefix) + recorder.print("Start Predict Class ", print_prefix) if data_loader is None: data_loader = self._data_loader(X) return self._predict(data_loader).argmax(axis=1).cpu().numpy() - def predict_proba(self, data_loader = None, X = None, print_prefix = ""): + def predict_proba(self, data_loader=None, X=None, print_prefix=""): recorder = self.recorder - recorder.print('Start Predict Probability ', print_prefix) + # recorder.print('Start Predict Probability ', print_prefix) if data_loader is None: data_loader = self._data_loader(X) @@ -168,7 +189,7 @@ class BasicModel(): model = self.model criterion = self.criterion device = self.device - + model.eval() total_correct_num, total_num, total_loss = 0, 0, 0.0 @@ -179,32 +200,37 @@ class BasicModel(): out = model(data) - correct_num = sum(target == out.argmax(axis=1)).item() + if len(out.shape) > 1: + correct_num = sum(target == out.argmax(axis=1)).item() + else: + correct_num = sum(target == (out > 0.5)).item() loss = criterion(out, target) total_loss += loss.item() * data.size(0) total_correct_num += correct_num total_num += data.size(0) - + mean_loss = total_loss / total_num accuracy = total_correct_num / total_num return mean_loss, accuracy - def val(self, data_loader = None, X = None, y = None, print_prefix = ""): + def val(self, data_loader=None, X=None, y=None, print_prefix=""): recorder = self.recorder - recorder.print('Start val ', print_prefix) + recorder.print("Start val ", print_prefix) if data_loader is None: data_loader = self._data_loader(X, y) mean_loss, accuracy = self._val(data_loader) - recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, mean_loss, accuracy)) + recorder.print( + "[%s] Val loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy) + ) return accuracy - def score(self, data_loader = None, X = None, y = None, print_prefix = ""): + def score(self, data_loader=None, X=None, y=None, print_prefix=""): return self.val(data_loader, X, y, print_prefix) - def _data_loader(self, X, y = None): + def _data_loader(self, X, y=None): collate_fn = self.collate_fn transform = self.transform @@ -212,9 +238,14 @@ class BasicModel(): y = [0] * len(X) dataset = XYDataset(X, y, transform=transform) sampler = None - data_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ - shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ - collate_fn=collate_fn) + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=False, + sampler=sampler, + num_workers=int(self.num_workers), + collate_fn=collate_fn, + ) return data_loader def save(self, epoch_id, save_dir): @@ -237,7 +268,6 @@ class BasicModel(): load_path = os.path.join(load_dir, str(epoch_id) + "_opt.pth") self.optimizer.load_state_dict(torch.load(load_path)) + if __name__ == "__main__": pass - - diff --git a/models/nn.py b/models/nn.py new file mode 100644 index 0000000..cecbeea --- /dev/null +++ b/models/nn.py @@ -0,0 +1,152 @@ +# coding: utf-8 +# ================================================================# +# Copyright (C) 2021 Freecss All rights reserved. +# +# File Name :lenet5.py +# Author :freecss +# Email :karlfreecss@gmail.com +# Created Date :2021/03/03 +# Description : +# +# ================================================================# + +import sys + +sys.path.append("..") + +import torchvision + +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Variable +import torchvision.transforms as transforms +import numpy as np + +from models.basic_model import BasicModel +import utils.plog as plog + + +class LeNet5(nn.Module): + def __init__(self, num_classes=10, image_size=(28, 28)): + super().__init__() + self.conv1 = nn.Conv2d(1, 6, 3, padding=1) + self.conv2 = nn.Conv2d(6, 16, 3) + self.conv3 = nn.Conv2d(16, 16, 3) + + feature_map_size = (np.array(image_size) // 2 - 2) // 2 - 2 + num_features = 16 * feature_map_size[0] * feature_map_size[1] + + self.fc1 = nn.Linear(num_features, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, num_classes) + + def forward(self, x): + """前向传播函数""" + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) + x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2)) + x = F.relu(self.conv3(x)) + x = x.view(-1, self.num_flat_features(x)) + # print(x.size()) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + def num_flat_features(self, x): + size = x.size()[1:] + num_features = 1 + for s in size: + num_features *= s + return num_features + + +# class SymbolNet(nn.Module): +# def __init__(self, num_classes=4, image_size=(28, 28, 1)): +# super(SymbolNet, self).__init__() +# self.conv1 = nn.Sequential( +# nn.Conv2d(1, 32, 3, stride=1, padding=1), +# nn.ReLU(inplace=True), +# nn.BatchNorm2d(32), +# ) +# self.conv2 = nn.Sequential( +# nn.Conv2d(32, 64, 3, stride=1, padding=1), +# nn.ReLU(inplace=True), +# nn.MaxPool2d(kernel_size=2, stride=2), +# nn.BatchNorm2d(64), +# nn.Dropout(0.25), +# ) +# num_features = 64 * (image_size[0] // 2) * (image_size[1] // 2) +# self.fc1 = nn.Sequential( +# nn.Linear(num_features, 128), nn.ReLU(inplace=True), nn.Dropout(0.5) +# ) +# self.fc2 = nn.Sequential(nn.Linear(128, num_classes), nn.Softmax(dim=1)) + +# def forward(self, x): +# x = self.conv1(x) +# x = self.conv2(x) +# x = torch.flatten(x, 1) +# x = self.fc1(x) +# x = self.fc2(x) +# return x + + +class SymbolNet(nn.Module): + def __init__(self, num_classes=4, image_size=(28, 28, 1)): + super(SymbolNet, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(1, 32, 5, stride=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(32, momentum=0.99, eps=0.001), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(32, 64, 5, padding=2, stride=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(64, momentum=0.99, eps=0.001), + ) + + num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) + self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) + self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) + self.fc3 = nn.Sequential(nn.Linear(84, num_classes), nn.Softmax(dim=1)) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + return x + + +class SymbolNetAutoencoder(nn.Module): + def __init__(self, num_classes=4, image_size=(28, 28, 1)): + super(SymbolNetAutoencoder, self).__init__() + self.base_model = SymbolNet(num_classes, image_size) + self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) + self.fc2 = nn.Sequential( + nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU() + ) + + def forward(self, x): + x = self.base_model(x) + x = self.fc1(x) + x = self.fc2(x) + return x + + +class MLP(nn.Module): + def __init__(self, input_dim=50, num_classes=2): + super(MLP, self).__init__() + assert input_dim > 0 + hidden_dim = int(np.sqrt(input_dim)) + self.fc1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU()) + self.fc2 = nn.Sequential(nn.Linear(hidden_dim, num_classes), nn.Softmax(dim=1)) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x diff --git a/models/wabl_models.py b/models/wabl_models.py index b854039..9c97bbc 100644 --- a/models/wabl_models.py +++ b/models/wabl_models.py @@ -1,14 +1,14 @@ # coding: utf-8 -#================================================================# +# ================================================================# # Copyright (C) 2020 Freecss All rights reserved. -# +# # File Name :models.py # Author :freecss # Email :karlfreecss@gmail.com # Created Date :2020/04/02 # Description : # -#================================================================# +# ================================================================# from itertools import chain from sklearn.tree import DecisionTreeClassifier @@ -29,14 +29,17 @@ import random from sklearn.neighbors import KNeighborsClassifier import numpy as np + def get_part_data(X, i): - return list(map(lambda x : x[i], X)) + return list(map(lambda x: x[i], X)) + def merge_data(X): - ret_mark = list(map(lambda x : len(x), X)) + ret_mark = list(map(lambda x: len(x), X)) ret_X = list(chain(*X)) return ret_X, ret_mark + def reshape_data(Y, marks): begin_mark = 0 ret_Y = [] @@ -48,42 +51,44 @@ def reshape_data(Y, marks): class WABLBasicModel: - def __init__(self, base_model, pseudo_label_list): self.cls_list = [] self.cls_list.append(base_model) self.pseudo_label_list = pseudo_label_list self.mapping = dict(zip(pseudo_label_list, list(range(len(pseudo_label_list))))) - self.remapping = dict(zip(list(range(len(pseudo_label_list))), pseudo_label_list)) + self.remapping = dict( + zip(list(range(len(pseudo_label_list))), pseudo_label_list) + ) def predict(self, X): data_X, marks = merge_data(X) - prob = self.cls_list[0].predict_proba(X = data_X) - _cls = prob.argmax(axis = 1) - cls = list(map(lambda x : self.remapping[x], _cls)) + prob = self.cls_list[0].predict_proba(X=data_X) + _cls = prob.argmax(axis=1) + cls = list(map(lambda x: self.remapping[x], _cls)) prob = reshape_data(prob, marks) cls = reshape_data(cls, marks) - return {"cls" : cls, "prob" : prob} + return {"cls": cls, "prob": prob} def valid(self, X, Y): data_X, _ = merge_data(X) _data_Y, _ = merge_data(Y) - data_Y = list(map(lambda y : self.mapping[y], _data_Y)) - score = self.cls_list[0].score(X = data_X, y = data_Y) + data_Y = list(map(lambda y: self.mapping[y], _data_Y)) + score = self.cls_list[0].score(X=data_X, y=data_Y) return score, [score] def train(self, X, Y): - #self.label_lists = [] + # self.label_lists = [] data_X, _ = merge_data(X) _data_Y, _ = merge_data(Y) - data_Y = list(map(lambda y : self.mapping[y], _data_Y)) - self.cls_list[0].fit(X = data_X, y = data_Y) + data_Y = list(map(lambda y: self.mapping[y], _data_Y)) + self.cls_list[0].fit(X=data_X, y=data_Y) + class DecisionTree(WABLBasicModel): - def __init__(self, code_len, label_lists, share = False): + def __init__(self, code_len, label_lists, share=False): self.code_len = code_len self._set_label_lists(label_lists) @@ -91,14 +96,19 @@ class DecisionTree(WABLBasicModel): self.share = share if share: # 本质上是同一个分类器 - self.cls_list.append(DecisionTreeClassifier(random_state = 0, min_samples_leaf = 3)) + self.cls_list.append( + DecisionTreeClassifier(random_state=0, min_samples_leaf=3) + ) self.cls_list = self.cls_list * self.code_len else: for _ in range(code_len): - self.cls_list.append(DecisionTreeClassifier(random_state = 0, min_samples_leaf = 3)) + self.cls_list.append( + DecisionTreeClassifier(random_state=0, min_samples_leaf=3) + ) + class KNN(WABLBasicModel): - def __init__(self, code_len, label_lists, share = False, k = 3): + def __init__(self, code_len, label_lists, share=False, k=3): self.code_len = code_len self._set_label_lists(label_lists) @@ -106,14 +116,15 @@ class KNN(WABLBasicModel): self.share = share if share: # 本质上是同一个分类器 - self.cls_list.append(KNeighborsClassifier(n_neighbors = k)) + self.cls_list.append(KNeighborsClassifier(n_neighbors=k)) self.cls_list = self.cls_list * self.code_len else: for _ in range(code_len): - self.cls_list.append(KNeighborsClassifier(n_neighbors = k)) + self.cls_list.append(KNeighborsClassifier(n_neighbors=k)) + class CNN(WABLBasicModel): - def __init__(self, base_model, code_len, label_lists, share = True): + def __init__(self, base_model, code_len, label_lists, share=True): assert share == True, "Not implemented" label_lists = [sorted(list(set(label_list))) for label_list in label_lists] @@ -126,27 +137,28 @@ class CNN(WABLBasicModel): if share: self.cls_list.append(base_model) - def train(self, X, Y, n_epoch = 100): - #self.label_lists = [] + def train(self, X, Y, n_epoch=100): + # self.label_lists = [] if self.share: # 因为是同一个分类器,所以只需要把数据放在一起,然后训练其中任意一个即可 data_X, _ = merge_data(X) data_Y, _ = merge_data(Y) - self.cls_list[0].fit(X = data_X, y = data_Y, n_epoch = n_epoch) - #self.label_lists = [sorted(list(set(data_Y)))] * self.code_len + self.cls_list[0].fit(X=data_X, y=data_Y, n_epoch=n_epoch) + # self.label_lists = [sorted(list(set(data_Y)))] * self.code_len else: for i in range(self.code_len): data_X = get_part_data(X, i) data_Y = get_part_data(Y, i) self.cls_list[i].fit(data_X, data_Y) - #self.label_lists.append(sorted(list(set(data_Y)))) + # self.label_lists.append(sorted(list(set(data_Y)))) + if __name__ == "__main__": - #data_path = "utils/hamming_data/generated_data/hamming_7_3_0.20.pk" + # data_path = "utils/hamming_data/generated_data/hamming_7_3_0.20.pk" data_path = "datasets/generated_data/0_code_7_2_0.00.pk" codes, data, labels = pk.load(open(data_path, "rb")) - - cls = KNN(7, False, k = 3) + + cls = KNN(7, False, k=3) cls.train(data, labels) print(cls.valid(data, labels)) for res in cls.predict_proba(data): @@ -157,4 +169,3 @@ if __name__ == "__main__": print(res) break print("Trained") - diff --git a/utils/plog.py b/utils/plog.py index 0b8f5b5..3f091cb 100644 --- a/utils/plog.py +++ b/utils/plog.py @@ -1,14 +1,14 @@ # coding: utf-8 -#================================================================# +# ================================================================# # Copyright (C) 2020 Freecss All rights reserved. -# +# # File Name :plog.py # Author :freecss # Email :karlfreecss@gmail.com # Created Date :2020/10/23 # Description : # -#================================================================# +# ================================================================# import time import logging @@ -19,13 +19,14 @@ import functools global recorder recorder = None + class ResultRecorder: def __init__(self): - logging.basicConfig(level=logging.DEBUG, filemode='a') + logging.basicConfig(level=logging.DEBUG, filemode="a") self.result = {} self.set_savefile() - + logging.info("===========================================================") logging.info("============= Result Recorder Version: 0.03 ===============") logging.info("===========================================================\n") @@ -33,25 +34,25 @@ class ResultRecorder: pass def set_savefile(self): - local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) + local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) save_dir = os.path.join("results", local_time) if not os.path.exists(save_dir): os.makedirs(save_dir) save_file_path = os.path.join(save_dir, "result.pk") save_file = open(save_file_path, "wb") - + self.save_dir = save_dir self.save_file = save_file filename = os.path.join(save_dir, "log.txt") file_handler = logging.FileHandler(filename) file_handler.setLevel(logging.DEBUG) - formatter = logging.Formatter('%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + formatter = logging.Formatter("%(asctime)s - %(levelname)s: %(message)s") file_handler.setFormatter(formatter) logging.getLogger().addHandler(file_handler) - def print(self, *argv, screen = False): + def print(self, *argv, screen=False): info = "" for data in argv: info += str(data) @@ -62,9 +63,8 @@ class ResultRecorder: def print_result(self, *argv): for data in argv: info = "#Result# %s" % str(data) - #print(info) logging.info(info) - + def store(self, *argv): for data in argv: if data.find(":") < 0: @@ -81,11 +81,11 @@ class ResultRecorder: self.result[label].append(data) def write_kv(self, label, data): - self.print_result({label : data}) - #self.print_result(label + ":" + str(data)) + self.print_result({label: data}) + # self.print_result(label + ":" + str(data)) self.store_kv(label, data) - def dump(self, save_file = None): + def dump(self, save_file=None): if save_file is None: save_file = self.save_file pk.dump(self.result, save_file) @@ -104,38 +104,43 @@ class ResultRecorder: self.write_kv("func:", context) return result + return clocked def __del__(self): self.dump() + def clocker(*argv): global recorder if recorder is None: recorder = ResultRecorder() return recorder.clock(*argv) -def INFO(*argv, screen = False): + +def INFO(*argv, screen=False): global recorder if recorder is None: recorder = ResultRecorder() - return recorder.print(*argv, screen = screen) + return recorder.print(*argv, screen=screen) + -def DEBUG(*argv, screen = False): +def DEBUG(*argv, screen=False): global recorder if recorder is None: recorder = ResultRecorder() - return recorder.print(*argv, screen = screen) + return recorder.print(*argv, screen=screen) + def logger(): global recorder if recorder is None: recorder = ResultRecorder() return recorder - + + if __name__ == "__main__": recorder = ResultRecorder() recorder.write_kv("test", 1) - recorder.set_savefile(pk_dir = "haha") + recorder.set_savefile(pk_dir="haha") recorder.write_kv("test", 1) - diff --git a/utils/utils.py b/utils/utils.py index 44ba6a1..57dea31 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,9 +1,17 @@ import numpy as np +from utils.plog import INFO +from collections import OrderedDict + # for multiple predictions, modify from `learn_add.py` def flatten(l): - return [item for sublist in l for item in flatten(sublist)] if isinstance(l, list) else [l] - + return ( + [item for sublist in l for item in flatten(sublist)] + if isinstance(l, list) + else [l] + ) + + # for multiple predictions, modify from `learn_add.py` def reform_idx(flatten_pred_res, save_pred_res): re = [] @@ -18,10 +26,25 @@ def reform_idx(flatten_pred_res, save_pred_res): i = i + j return re + +def block_sample(X, Z, Y, sample_num, epoch_idx): + part_num = len(X) // sample_num + if part_num == 0: + part_num = 1 + seg_idx = epoch_idx % part_num + INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X)) + X = X[sample_num * seg_idx : sample_num * (seg_idx + 1)] + Z = Z[sample_num * seg_idx : sample_num * (seg_idx + 1)] + Y = Y[sample_num * seg_idx : sample_num * (seg_idx + 1)] + + return X, Z, Y + + def hamming_dist(A, B): B = np.array(B) - A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) - return np.sum(A != B, axis = 1) + A = np.expand_dims(A, axis=0).repeat(axis=0, repeats=(len(B))) + return np.sum(A != B, axis=1) + def confidence_dist(A, B): B = np.array(B) @@ -29,7 +52,16 @@ def confidence_dist(A, B): A = np.expand_dims(A, axis=0) A = A.repeat(axis=0, repeats=(len(B))) rows = np.array(range(len(B))) - rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) + rows = np.expand_dims(rows, axis=1).repeat(axis=1, repeats=len(B[0])) cols = np.array(range(len(B[0]))) - cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) - return 1 - np.prod(A[rows, cols, B], axis = 1) + cols = np.expand_dims(cols, axis=0).repeat(axis=0, repeats=len(B)) + return 1 - np.prod(A[rows, cols, B], axis=1) + + +def copy_state_dict(state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('base_model'): + name = ".".join(k.split(".")[1:]) + new_state_dict[name] = v + return new_state_dict