| @@ -1,3 +1,6 @@ | |||
| *.pyc | |||
| /results | |||
| raw/ | |||
| raw/ | |||
| *.jpg | |||
| *.png | |||
| *.pk | |||
| @@ -26,9 +26,16 @@ import time | |||
| class AbducerBase(abc.ABC): | |||
| def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True): | |||
| def __init__( | |||
| self, | |||
| kb, | |||
| dist_func="confidence", | |||
| zoopt=False, | |||
| multiple_predictions=False, | |||
| cache=True, | |||
| ): | |||
| self.kb = kb | |||
| assert dist_func == 'hamming' or dist_func == 'confidence' | |||
| assert dist_func == "hamming" or dist_func == "confidence" | |||
| self.dist_func = dist_func | |||
| self.zoopt = zoopt | |||
| self.multiple_predictions = multiple_predictions | |||
| @@ -39,11 +46,18 @@ class AbducerBase(abc.ABC): | |||
| self.cache_candidates = {} | |||
| def _get_cost_list(self, pred_res, pred_res_prob, candidates): | |||
| if self.dist_func == 'hamming': | |||
| if self.dist_func == "hamming": | |||
| return hamming_dist(pred_res, candidates) | |||
| elif self.dist_func == 'confidence': | |||
| mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) | |||
| return confidence_dist(pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates]) | |||
| elif self.dist_func == "confidence": | |||
| mapping = dict( | |||
| zip( | |||
| self.kb.pseudo_label_list, | |||
| list(range(len(self.kb.pseudo_label_list))), | |||
| ) | |||
| ) | |||
| return confidence_dist( | |||
| pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates] | |||
| ) | |||
| def _get_one_candidate(self, pred_res, pred_res_prob, candidates): | |||
| if len(candidates) == 0: | |||
| @@ -60,9 +74,13 @@ class AbducerBase(abc.ABC): | |||
| def _zoopt_score_multiple(self, pred_res, key, solution): | |||
| all_address_flag = reform_idx(solution, pred_res) | |||
| score = 0 | |||
| for idx in enumerate(len(pred_res)): | |||
| address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] | |||
| candidate = self.kb.address_by_idx([pred_res[idx]], key[idx], address_idx, True) | |||
| for idx in range(len(pred_res)): | |||
| address_idx = [ | |||
| i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 | |||
| ] | |||
| candidate = self.kb.address_by_idx( | |||
| [pred_res[idx]], key[idx], address_idx, True | |||
| ) | |||
| if len(candidate) > 0: | |||
| score += 1 | |||
| return score | |||
| @@ -70,7 +88,9 @@ class AbducerBase(abc.ABC): | |||
| def _zoopt_address_score(self, pred_res, key, sol): | |||
| if not self.multiple_predictions: | |||
| address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] | |||
| candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) | |||
| candidates = self.kb.address_by_idx( | |||
| pred_res, key, address_idx, self.multiple_predictions | |||
| ) | |||
| return 1 if len(candidates) > 0 else 0 | |||
| else: | |||
| return self._zoopt_score_multiple(pred_res, key, sol.get_x()) | |||
| @@ -98,7 +118,11 @@ class AbducerBase(abc.ABC): | |||
| pred_res = flatten(pred_res) | |||
| key = tuple(key) | |||
| if (tuple(pred_res), key) in self.cache_min_address_num: | |||
| address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) | |||
| address_num = min( | |||
| max_address_num, | |||
| self.cache_min_address_num[(tuple(pred_res), key)] | |||
| + require_more_address, | |||
| ) | |||
| if (tuple(pred_res), key, address_num) in self.cache_candidates: | |||
| candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] | |||
| if self.zoopt: | |||
| @@ -127,12 +151,18 @@ class AbducerBase(abc.ABC): | |||
| if self.zoopt: | |||
| solution = self.zoopt_get_solution(pred_res, key, max_address_num) | |||
| address_idx = [idx for idx, i in enumerate(solution) if i != 0] | |||
| candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) | |||
| candidates = self.kb.address_by_idx( | |||
| pred_res, key, address_idx, self.multiple_predictions | |||
| ) | |||
| address_num = int(solution.sum()) | |||
| min_address_num = address_num | |||
| else: | |||
| candidates, min_address_num, address_num = self.kb.abduce_candidates( | |||
| pred_res, key, max_address_num, require_more_address, self.multiple_predictions | |||
| pred_res, | |||
| key, | |||
| max_address_num, | |||
| require_more_address, | |||
| self.multiple_predictions, | |||
| ) | |||
| candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) | |||
| @@ -147,20 +177,31 @@ class AbducerBase(abc.ABC): | |||
| def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): | |||
| if self.multiple_predictions: | |||
| return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) | |||
| return self.abduce( | |||
| (Z["cls"], Z["prob"], Y), max_address_num, require_more_address | |||
| ) | |||
| else: | |||
| return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] | |||
| return [ | |||
| self.abduce((z, prob, y), max_address_num, require_more_address) | |||
| for z, prob, y in zip(Z["cls"], Z["prob"], Y) | |||
| ] | |||
| def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): | |||
| return self.batch_abduce(Z, Y, max_address_num, require_more_address) | |||
| if __name__ == '__main__': | |||
| prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
| prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
| if __name__ == "__main__": | |||
| prob1 = [ | |||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ] | |||
| prob2 = [ | |||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ] | |||
| kb = add_KB() | |||
| abd = AbducerBase(kb, 'confidence') | |||
| abd = AbducerBase(kb, "confidence") | |||
| res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | |||
| @@ -174,7 +215,7 @@ if __name__ == '__main__': | |||
| print() | |||
| kb = add_prolog_KB() | |||
| abd = AbducerBase(kb, 'confidence') | |||
| abd = AbducerBase(kb, "confidence") | |||
| res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | |||
| @@ -188,7 +229,7 @@ if __name__ == '__main__': | |||
| print() | |||
| kb = add_prolog_KB() | |||
| abd = AbducerBase(kb, 'confidence', zoopt=True) | |||
| abd = AbducerBase(kb, "confidence", zoopt=True) | |||
| res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | |||
| @@ -202,24 +243,38 @@ if __name__ == '__main__': | |||
| print() | |||
| kb = HWF_KB(len_list=[1, 3, 5]) | |||
| abd = AbducerBase(kb, 'hamming') | |||
| res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) | |||
| abd = AbducerBase(kb, "hamming") | |||
| res = abd.abduce( | |||
| (["5", "+", "2"], None, 3), max_address_num=2, require_more_address=0 | |||
| ) | |||
| print(res) | |||
| res = abd.abduce((['5', '+', '2'], None, 64), max_address_num=3, require_more_address=0) | |||
| res = abd.abduce( | |||
| (["5", "+", "2"], None, 64), max_address_num=3, require_more_address=0 | |||
| ) | |||
| print(res) | |||
| res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) | |||
| res = abd.abduce( | |||
| (["5", "+", "2"], None, 1.67), max_address_num=3, require_more_address=0 | |||
| ) | |||
| print(res) | |||
| res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) | |||
| res = abd.abduce( | |||
| (["5", "8", "8", "8", "8"], None, 3.17), | |||
| max_address_num=5, | |||
| require_more_address=3, | |||
| ) | |||
| print(res) | |||
| print() | |||
| kb = HED_prolog_KB() | |||
| abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) | |||
| consist_exs = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 0, '=', 1, 1]] | |||
| consist_exs2 = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 1, '=', 1, 1]] # not consistent with rules | |||
| inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] | |||
| consist_exs = [[1, "+", 0, "=", 0], [1, "+", 1, "=", 0], [0, "+", 0, "=", 1, 1]] | |||
| consist_exs2 = [ | |||
| [1, "+", 0, "=", 0], | |||
| [1, "+", 1, "=", 0], | |||
| [0, "+", 1, "=", 1, 1], | |||
| ] # not consistent with rules | |||
| inconsist_exs = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | |||
| # inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']] | |||
| rules = ['my_op([0], [0], [1, 1])', 'my_op([1], [1], [0])', 'my_op([1], [0], [0])'] | |||
| rules = ["my_op([0], [0], [1, 1])", "my_op([1], [1], [0])", "my_op([1], [0], [0])"] | |||
| print(kb.logic_forward(consist_exs), kb.logic_forward(inconsist_exs)) | |||
| print(kb.consist_rule(consist_exs, rules), kb.consist_rule(consist_exs2, rules)) | |||
| @@ -43,16 +43,29 @@ class KBBase(ABC): | |||
| def address(self, address_num, pred_res, key, multiple_predictions=False): | |||
| new_candidates = [] | |||
| if not multiple_predictions: | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| address_idx_list = list( | |||
| combinations(list(range(len(pred_res))), address_num) | |||
| ) | |||
| else: | |||
| address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) | |||
| address_idx_list = list( | |||
| combinations(list(range(len(flatten(pred_res)))), address_num) | |||
| ) | |||
| for address_idx in address_idx_list: | |||
| candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | |||
| candidates = self.address_by_idx( | |||
| pred_res, key, address_idx, multiple_predictions | |||
| ) | |||
| new_candidates += candidates | |||
| return new_candidates | |||
| def abduction(self, pred_res, key, max_address_num, require_more_address, multiple_predictions=False): | |||
| def abduction( | |||
| self, | |||
| pred_res, | |||
| key, | |||
| max_address_num, | |||
| require_more_address, | |||
| multiple_predictions=False, | |||
| ): | |||
| candidates = [] | |||
| for address_num in range(len(pred_res) + 1): | |||
| @@ -60,7 +73,9 @@ class KBBase(ABC): | |||
| if abs(self.logic_forward(pred_res) - key) <= 1e-3: | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self.address(address_num, pred_res, key, multiple_predictions) | |||
| new_candidates = self.address( | |||
| address_num, pred_res, key, multiple_predictions | |||
| ) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| @@ -70,10 +85,14 @@ class KBBase(ABC): | |||
| if address_num >= max_address_num: | |||
| return [], 0, 0 | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| for address_num in range( | |||
| min_address_num + 1, min_address_num + require_more_address + 1 | |||
| ): | |||
| if address_num > max_address_num: | |||
| return candidates, min_address_num, address_num - 1 | |||
| new_candidates = self.address(address_num, pred_res, key, multiple_predictions) | |||
| new_candidates = self.address( | |||
| address_num, pred_res, key, multiple_predictions | |||
| ) | |||
| candidates += new_candidates | |||
| return candidates, min_address_num, address_num | |||
| @@ -98,7 +117,9 @@ class ClsKB(KBBase): | |||
| else: | |||
| self.all_address_candidate_dict = {} | |||
| for address_num in range(max(self.len_list) + 1): | |||
| self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat=address_num)) | |||
| self.all_address_candidate_dict[address_num] = list( | |||
| product(self.pseudo_label_list, repeat=address_num) | |||
| ) | |||
| # For parallel version of _get_GKB | |||
| def _get_XY_list(self, args): | |||
| @@ -143,11 +164,26 @@ class ClsKB(KBBase): | |||
| def logic_forward(self): | |||
| pass | |||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | |||
| def abduce_candidates( | |||
| self, | |||
| pred_res, | |||
| key, | |||
| max_address_num=-1, | |||
| require_more_address=0, | |||
| multiple_predictions=False, | |||
| ): | |||
| if self.GKB_flag: | |||
| return self.abduce_from_GKB(pred_res, key, max_address_num, require_more_address) | |||
| return self.abduce_from_GKB( | |||
| pred_res, key, max_address_num, require_more_address | |||
| ) | |||
| else: | |||
| return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| return self.abduction( | |||
| pred_res, | |||
| key, | |||
| max_address_num, | |||
| require_more_address, | |||
| multiple_predictions, | |||
| ) | |||
| def abduce_from_GKB(self, pred_res, key, max_address_num, require_more_address): | |||
| if self.base == {} or len(pred_res) not in self.len_list: | |||
| @@ -211,7 +247,24 @@ class add_KB(ClsKB): | |||
| class HWF_KB(ClsKB): | |||
| def __init__( | |||
| self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7] | |||
| self, | |||
| GKB_flag=False, | |||
| pseudo_label_list=[ | |||
| "1", | |||
| "2", | |||
| "3", | |||
| "4", | |||
| "5", | |||
| "6", | |||
| "7", | |||
| "8", | |||
| "9", | |||
| "+", | |||
| "-", | |||
| "times", | |||
| "div", | |||
| ], | |||
| len_list=[1, 3, 5, 7], | |||
| ): | |||
| super().__init__(GKB_flag, pseudo_label_list, len_list) | |||
| @@ -219,9 +272,19 @@ class HWF_KB(ClsKB): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||
| if i % 2 == 0 and formula[i] not in [ | |||
| "1", | |||
| "2", | |||
| "3", | |||
| "4", | |||
| "5", | |||
| "6", | |||
| "7", | |||
| "8", | |||
| "9", | |||
| ]: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | |||
| if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
| return False | |||
| return True | |||
| @@ -229,22 +292,22 @@ class HWF_KB(ClsKB): | |||
| if not self.valid_candidate(formula): | |||
| return np.inf | |||
| mapping = { | |||
| '1': '1', | |||
| '2': '2', | |||
| '3': '3', | |||
| '4': '4', | |||
| '5': '5', | |||
| '6': '6', | |||
| '7': '7', | |||
| '8': '8', | |||
| '9': '9', | |||
| '+': '+', | |||
| '-': '-', | |||
| 'times': '*', | |||
| 'div': '/', | |||
| "1": "1", | |||
| "2": "2", | |||
| "3": "3", | |||
| "4": "4", | |||
| "5": "5", | |||
| "6": "6", | |||
| "7": "7", | |||
| "8": "8", | |||
| "9": "9", | |||
| "+": "+", | |||
| "-": "-", | |||
| "times": "*", | |||
| "div": "/", | |||
| } | |||
| formula = [mapping[f] for f in formula] | |||
| return round(eval(''.join(formula)), 2) | |||
| return round(eval("".join(formula)), 2) | |||
| class prolog_KB(KBBase): | |||
| @@ -256,8 +319,12 @@ class prolog_KB(KBBase): | |||
| def logic_forward(self): | |||
| pass | |||
| def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| return self.abduction(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| def abduce_candidates( | |||
| self, pred_res, key, max_address_num, require_more_address, multiple_predictions | |||
| ): | |||
| return self.abduction( | |||
| pred_res, key, max_address_num, require_more_address, multiple_predictions | |||
| ) | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| candidates = [] | |||
| @@ -265,7 +332,9 @@ class prolog_KB(KBBase): | |||
| if not multiple_predictions: | |||
| query_string = self.get_query_string(pred_res, key, address_idx) | |||
| else: | |||
| query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) | |||
| query_string = self.get_query_string_need_flatten( | |||
| pred_res, key, address_idx | |||
| ) | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| @@ -289,24 +358,28 @@ class add_prolog_KB(prolog_KB): | |||
| super().__init__(pseudo_label_list) | |||
| for i in self.pseudo_label_list: | |||
| self.prolog.assertz("pseudo_label(%s)" % i) | |||
| self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") | |||
| self.prolog.assertz( | |||
| "addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2" | |||
| ) | |||
| def logic_forward(self, nums): | |||
| return list(self.prolog.query("addition(%s, %s, Res)." % (nums[0], nums[1])))[0]['Res'] | |||
| return list(self.prolog.query("addition(%s, %s, Res)." % (nums[0], nums[1])))[ | |||
| 0 | |||
| ]["Res"] | |||
| def get_query_string(self, pred_res, key, address_idx): | |||
| query_string = "addition(" | |||
| for idx, i in enumerate(pred_res): | |||
| tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' | |||
| tmp = "Z" + str(idx) + "," if idx in address_idx else str(i) + "," | |||
| query_string += tmp | |||
| query_string += "%s)." % key | |||
| return query_string | |||
| class HED_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list=[0, 1, '+', '=']): | |||
| def __init__(self, pseudo_label_list=[0, 1, "+", "="]): | |||
| super().__init__(pseudo_label_list) | |||
| self.prolog.consult('./datasets/hed/learn_add.pl') | |||
| self.prolog.consult("./datasets/hed/learn_add.pl") | |||
| # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` | |||
| def logic_forward(self, exs): | |||
| @@ -318,7 +391,7 @@ class HED_prolog_KB(prolog_KB): | |||
| # add variables for prolog | |||
| for idx in range(len(flatten_pred_res)): | |||
| if idx in address_idx: | |||
| flatten_pred_res[idx] = 'X' + str(idx) | |||
| flatten_pred_res[idx] = "X" + str(idx) | |||
| # unflatten | |||
| new_pred_res = reform_idx(flatten_pred_res, pred_res) | |||
| @@ -326,12 +399,30 @@ class HED_prolog_KB(prolog_KB): | |||
| return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") | |||
| def consist_rule(self, exs, rules): | |||
| rule_str = "%s" % rules | |||
| rule_str = rule_str.replace("'", "") | |||
| return len(list(self.prolog.query("consistent_inst_feature(%s, %s)." % (exs, rule_str)))) != 0 | |||
| consist = False | |||
| for rule in rules: | |||
| # print(rule) | |||
| if ( | |||
| len( | |||
| list( | |||
| self.prolog.query( | |||
| "consistent_inst_feature(%s, [%s])." % (exs, rule) | |||
| ) | |||
| ) | |||
| ) | |||
| != 0 | |||
| ): | |||
| consist = True | |||
| break | |||
| return consist | |||
| def abduce_rules(self, pred_res): | |||
| prolog_rules = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))[0]['X'] | |||
| prolog_result = list( | |||
| self.prolog.query("consistent_inst_feature(%s, X)." % pred_res) | |||
| ) | |||
| if len(prolog_result) == 0: | |||
| return None | |||
| prolog_rules = prolog_result[0]["X"] | |||
| rules = [] | |||
| for rule in prolog_rules: | |||
| rules.append(rule.value) | |||
| @@ -0,0 +1,83 @@ | |||
| :- use_module(library(apply)). | |||
| :- use_module(library(lists)). | |||
| % :- use_module(library(tabling)). | |||
| % :- table valid_rules/2, op_rule/2. | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% DCG parser for equations | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% symbols to be mapped | |||
| digit(1). | |||
| digit(0). | |||
| % digits | |||
| digits([D]) --> [D], { digit(D) }. % empty list [] is not a digit | |||
| digits([D | T]) --> [D], !, digits(T), { digit(D) }. | |||
| digits(X):- | |||
| phrase(digits(X), X). | |||
| % More integrity constraints 1: | |||
| % This two clauses forbid the first digit to be 0. | |||
| % You may uncomment them to prune the search space | |||
| % length(X, L), | |||
| % (L > 1 -> X \= [0 | _]; true). | |||
| % Equation definition | |||
| eq_arg([D]) --> [D], { \+ D == '+', \+ D == '=' }. | |||
| eq_arg([D | T]) --> [D], !, eq_arg(T), { \+ D == '+', \+ D == '=' }. | |||
| equation(eq(X, Y, Z)) --> | |||
| eq_arg(X), [+], eq_arg(Y), [=], eq_arg(Z). | |||
| % More integrity constraints 2: | |||
| % This clause restricts the length of arguments to be sane, | |||
| % You may uncomment them to prune the search space | |||
| % { length(X, LX), length(Y, LY), length(Z, LZ), | |||
| % LZ =< max(LX, LY) + 1, LZ >= max(LX, LY) }. | |||
| parse_eq(List_of_Terms, Eq) :- | |||
| phrase(equation(Eq), List_of_Terms). | |||
| %%%%%%%%%%%%%%%%%%%%%% | |||
| %% Bit-wise operation | |||
| %%%%%%%%%%%%%%%%%%%%%% | |||
| % Abductive calculation with given pseudo-labels, abduces pseudo-labels as well as operation rules | |||
| calc(Rules, Pseudo) :- | |||
| calc([], Rules, Pseudo). | |||
| calc(Rules0, Rules1, Pseudo) :- | |||
| parse_eq(Pseudo, eq(X,Y,Z)), | |||
| bitwise_calc(Rules0, Rules1, X, Y, Z). | |||
| % Bit-wise calculation that handles carrying | |||
| bitwise_calc(Rules, Rules1, X, Y, Z) :- | |||
| reverse(X, X1), reverse(Y, Y1), reverse(Z, Z1), | |||
| bitwise_calc_r(Rules, Rules1, X1, Y1, Z1), | |||
| maplist(digits, [X,Y,Z]). | |||
| bitwise_calc_r(Rs, Rs, [], Y, Y). | |||
| bitwise_calc_r(Rs, Rs, X, [], X). | |||
| bitwise_calc_r(Rules, Rules1, [D1 | X], [D2 | Y], [D3 | Z]) :- | |||
| abduce_op_rule(my_op([D1],[D2],Sum), Rules, Rules2), | |||
| ((Sum = [D3], Carry = []); (Sum = [C, D3], Carry = [C])), | |||
| bitwise_calc_r(Rules2, Rules3, X, Carry, X_carried), | |||
| bitwise_calc_r(Rules3, Rules1, X_carried, Y, Z). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%% | |||
| % Abduce operation rules | |||
| %%%%%%%%%%%%%%%%%%%%%%%%% | |||
| % Get an existed rule | |||
| abduce_op_rule(R, Rules, Rules) :- | |||
| member(R, Rules). | |||
| % Add a new rule | |||
| abduce_op_rule(R, Rules, [R|Rules]) :- | |||
| op_rule(R), | |||
| valid_rules(Rules, R). | |||
| % Integrity Constraints | |||
| valid_rules([], _). | |||
| valid_rules([my_op([X1],[Y1],_)|Rs], my_op([X],[Y],Z)) :- | |||
| op_rule(my_op([X],[Y],Z)), | |||
| [X,Y] \= [X1,Y1], | |||
| [X,Y] \= [Y1,X1], | |||
| valid_rules(Rs, my_op([X],[Y],Z)). | |||
| valid_rules([my_op([Y],[X],Z)|Rs], my_op([X],[Y],Z)) :- | |||
| X \= Y, | |||
| valid_rules(Rs, my_op([X],[Y],Z)). | |||
| op_rule(my_op([X],[Y],[Z])) :- digit(X), digit(Y), digit(Z). | |||
| op_rule(my_op([X],[Y],[Z1,Z2])) :- digit(X), digit(Y), digits([Z1,Z2]). | |||
| @@ -0,0 +1,266 @@ | |||
| import os | |||
| import itertools | |||
| import random | |||
| import numpy as np | |||
| from PIL import Image | |||
| import pickle | |||
| def get_sign_path_list(data_dir, sign_names): | |||
| sign_num = len(sign_names) | |||
| index_dict = dict(zip(sign_names, list(range(sign_num)))) | |||
| ret = [[] for _ in range(sign_num)] | |||
| for path in os.listdir(data_dir): | |||
| if path in sign_names: | |||
| index = index_dict[path] | |||
| sign_path = os.path.join(data_dir, path) | |||
| for p in os.listdir(sign_path): | |||
| ret[index].append(os.path.join(sign_path, p)) | |||
| return ret | |||
| def split_pool_by_rate(pools, rate, seed=None): | |||
| if seed is not None: | |||
| random.seed(seed) | |||
| ret1 = [] | |||
| ret2 = [] | |||
| for pool in pools: | |||
| random.shuffle(pool) | |||
| num = int(len(pool) * rate) | |||
| ret1.append(pool[:num]) | |||
| ret2.append(pool[num:]) | |||
| return ret1, ret2 | |||
| def int_to_system_form(num, system_num): | |||
| if num is 0: | |||
| return "0" | |||
| ret = "" | |||
| while num > 0: | |||
| ret += str(num % system_num) | |||
| num //= system_num | |||
| return ret[::-1] | |||
| def generator_equations( | |||
| left_opt_len, right_opt_len, res_opt_len, system_num, label, generate_type | |||
| ): | |||
| expr_len = left_opt_len + right_opt_len | |||
| num_list = "".join([str(i) for i in range(system_num)]) | |||
| ret = [] | |||
| if generate_type == "all": | |||
| candidates = itertools.product(num_list, repeat=expr_len) | |||
| else: | |||
| candidates = ["".join(random.sample(["0", "1"] * expr_len, expr_len))] | |||
| random.shuffle(candidates) | |||
| for nums in candidates: | |||
| left_num = "".join(nums[:left_opt_len]) | |||
| right_num = "".join(nums[left_opt_len:]) | |||
| left_value = int(left_num, system_num) | |||
| right_value = int(right_num, system_num) | |||
| result_value = left_value + right_value | |||
| if label == "negative": | |||
| result_value += random.randint(-result_value, result_value) | |||
| if left_value + right_value == result_value: | |||
| continue | |||
| result_num = int_to_system_form(result_value, system_num) | |||
| # leading zeros | |||
| if res_opt_len != len(result_num): | |||
| continue | |||
| if (left_opt_len > 1 and left_num[0] == "0") or ( | |||
| right_opt_len > 1 and right_num[0] == "0" | |||
| ): | |||
| continue | |||
| # add leading zeros | |||
| if res_opt_len < len(result_num): | |||
| continue | |||
| while len(result_num) < res_opt_len: | |||
| result_num = "0" + result_num | |||
| # continue | |||
| ret.append( | |||
| left_num + "+" + right_num + "=" + result_num | |||
| ) # current only consider '+' and '=' | |||
| # print(ret[-1]) | |||
| return ret | |||
| def generator_equation_by_len(equation_len, system_num=2, label=0, require_num=1): | |||
| generate_type = "one" | |||
| ret = [] | |||
| equation_sign_num = 2 # '+' and '=' | |||
| while len(ret) < require_num: | |||
| left_opt_len = random.randint(1, equation_len - 1 - equation_sign_num) | |||
| right_opt_len = random.randint( | |||
| 1, equation_len - left_opt_len - equation_sign_num | |||
| ) | |||
| res_opt_len = equation_len - left_opt_len - right_opt_len - equation_sign_num | |||
| ret.extend( | |||
| generator_equations( | |||
| left_opt_len, | |||
| right_opt_len, | |||
| res_opt_len, | |||
| system_num, | |||
| label, | |||
| generate_type, | |||
| ) | |||
| ) | |||
| return ret | |||
| def generator_equations_by_len( | |||
| equation_len, system_num=2, label=0, repeat_times=1, keep=1, generate_type="all" | |||
| ): | |||
| ret = [] | |||
| equation_sign_num = 2 # '+' and '=' | |||
| for left_opt_len in range(1, equation_len - (2 + equation_sign_num) + 1): | |||
| for right_opt_len in range( | |||
| 1, equation_len - left_opt_len - (1 + equation_sign_num) + 1 | |||
| ): | |||
| res_opt_len = ( | |||
| equation_len - left_opt_len - right_opt_len - equation_sign_num | |||
| ) | |||
| for i in range(repeat_times): # generate more equations | |||
| if random.random() > keep ** (equation_len): | |||
| continue | |||
| ret.extend( | |||
| generator_equations( | |||
| left_opt_len, | |||
| right_opt_len, | |||
| res_opt_len, | |||
| system_num, | |||
| label, | |||
| generate_type, | |||
| ) | |||
| ) | |||
| return ret | |||
| def generator_equations_by_max_len( | |||
| max_equation_len, | |||
| system_num=2, | |||
| label=0, | |||
| repeat_times=1, | |||
| keep=1, | |||
| generate_type="all", | |||
| num_per_len=None, | |||
| ): | |||
| ret = [] | |||
| equation_sign_num = 2 # '+' and '=' | |||
| for equation_len in range(3 + equation_sign_num, max_equation_len + 1): | |||
| if num_per_len is None: | |||
| ret.extend( | |||
| generator_equations_by_len( | |||
| equation_len, system_num, label, repeat_times, keep, generate_type | |||
| ) | |||
| ) | |||
| else: | |||
| ret.extend( | |||
| generator_equation_by_len( | |||
| equation_len, system_num, label, require_num=num_per_len | |||
| ) | |||
| ) | |||
| return ret | |||
| def generator_equation_images(image_pools, equations, signs, shape, seed, is_color): | |||
| if seed is not None: | |||
| random.seed(seed) | |||
| ret = [] | |||
| sign_num = len(signs) | |||
| sign_index_dict = dict(zip(signs, list(range(sign_num)))) | |||
| for equation in equations: | |||
| data = [] | |||
| for sign in equation: | |||
| index = sign_index_dict[sign] | |||
| pick = random.randint(0, len(image_pools[index]) - 1) | |||
| if is_color: | |||
| image = ( | |||
| Image.open(image_pools[index][pick]).convert("RGB").resize(shape) | |||
| ) | |||
| else: | |||
| image = Image.open(image_pools[index][pick]).convert("I").resize(shape) | |||
| image_array = np.array(image) | |||
| image_array = (image_array - 127) * (1.0 / 128) | |||
| data.append(image_array) | |||
| ret.append(np.array(data)) | |||
| return ret | |||
| def get_equation_std_data( | |||
| data_dir, | |||
| sign_dir_lists, | |||
| sign_output_lists, | |||
| shape=(28, 28), | |||
| train_max_equation_len=10, | |||
| test_max_equation_len=10, | |||
| system_num=2, | |||
| tmp_file_prev=None, | |||
| seed=None, | |||
| train_num_per_len=10, | |||
| test_num_per_len=10, | |||
| is_color=False, | |||
| ): | |||
| tmp_file = "" | |||
| if tmp_file_prev is not None: | |||
| tmp_file = "%s_train_len_%d_test_len_%d_sys_%d_.pk" % ( | |||
| tmp_file_prev, | |||
| train_max_equation_len, | |||
| test_max_equation_len, | |||
| system_num, | |||
| ) | |||
| if os.path.exists(tmp_file): | |||
| return pickle.load(open(tmp_file, "rb")) | |||
| image_pools = get_sign_path_list(data_dir, sign_dir_lists) | |||
| train_pool, test_pool = split_pool_by_rate(image_pools, 0.8, seed) | |||
| ret = {} | |||
| for label in ["positive", "negative"]: | |||
| print("Generating equations.") | |||
| train_equations = generator_equations_by_max_len( | |||
| train_max_equation_len, system_num, label, num_per_len=train_num_per_len | |||
| ) | |||
| test_equations = generator_equations_by_max_len( | |||
| test_max_equation_len, system_num, label, num_per_len=test_num_per_len | |||
| ) | |||
| print(train_equations) | |||
| print(test_equations) | |||
| print("Generated equations.") | |||
| print("Generating equation image data.") | |||
| ret["train:%s" % (label)] = generator_equation_images( | |||
| train_pool, train_equations, sign_output_lists, shape, seed, is_color | |||
| ) | |||
| ret["test:%s" % (label)] = generator_equation_images( | |||
| test_pool, test_equations, sign_output_lists, shape, seed, is_color | |||
| ) | |||
| print("Generated equation image data.") | |||
| if tmp_file_prev is not None: | |||
| pickle.dump(ret, open(tmp_file, "wb")) | |||
| return ret | |||
| if __name__ == "__main__": | |||
| data_dirs = [ | |||
| "./dataset/mnist_images", | |||
| "./dataset/random_images", | |||
| ] # , "../dataset/cifar10_images"] | |||
| tmp_file_prevs = [ | |||
| "mnist_equation_data", | |||
| "random_equation_data", | |||
| ] # , "cifar10_equation_data"] | |||
| for data_dir, tmp_file_prev in zip(data_dirs, tmp_file_prevs): | |||
| data = get_equation_std_data( | |||
| data_dir=data_dir, | |||
| sign_dir_lists=["0", "1", "10", "11"], | |||
| sign_output_lists=["0", "1", "+", "="], | |||
| shape=(28, 28), | |||
| train_max_equation_len=26, | |||
| test_max_equation_len=26, | |||
| system_num=2, | |||
| tmp_file_prev=tmp_file_prev, | |||
| train_num_per_len=300, | |||
| test_num_per_len=300, | |||
| is_color=False, | |||
| ) | |||
| @@ -0,0 +1,130 @@ | |||
| import os | |||
| import cv2 | |||
| import torch | |||
| import torchvision | |||
| import pickle | |||
| import numpy as np | |||
| import random | |||
| from collections import defaultdict | |||
| from torch.utils.data import Dataset | |||
| from torchvision.transforms import transforms | |||
| def get_data(img_dataset, train): | |||
| transform = transforms.Compose([transforms.ToTensor()]) | |||
| X = [] | |||
| Y = [] | |||
| if train: | |||
| positive = img_dataset["train:positive"] | |||
| negative = img_dataset["train:negative"] | |||
| else: | |||
| positive = img_dataset["test:positive"] | |||
| negative = img_dataset["test:negative"] | |||
| for equation in positive: | |||
| equation = equation.astype(np.float32) | |||
| img_list = np.vsplit(equation, equation.shape[0]) | |||
| X.append(img_list) | |||
| Y.append(1) | |||
| for equation in negative: | |||
| equation = equation.astype(np.float32) | |||
| img_list = np.vsplit(equation, equation.shape[0]) | |||
| X.append(img_list) | |||
| Y.append(0) | |||
| return X, None, Y | |||
| def get_pretrain_data(labels, image_size=(28, 28, 1)): | |||
| transform = transforms.Compose([transforms.ToTensor()]) | |||
| X = [] | |||
| for label in labels: | |||
| label_path = os.path.join( | |||
| "./datasets/hed/dataset/mnist_images", label | |||
| ) | |||
| img_path_list = os.listdir(label_path) | |||
| for img_path in img_path_list: | |||
| img = cv2.imread( | |||
| os.path.join(label_path, img_path), cv2.IMREAD_GRAYSCALE | |||
| ) | |||
| img = cv2.resize(img, (image_size[1], image_size[0])) | |||
| X.append(np.array(img, dtype=np.float32)) | |||
| X = [((img[:, :, np.newaxis] - 127) / 128.0) for img in X] | |||
| Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X] | |||
| X = [transform(img) for img in X] | |||
| return X, Y | |||
| # def get_pretrain_data(train_data, image_size=(28, 28, 1)): | |||
| # X = [] | |||
| # for label in [0, 1]: | |||
| # for _, equation_list in train_data[label].items(): | |||
| # for equation in equation_list: | |||
| # X = X + equation | |||
| # X = np.array(X) | |||
| # index = np.array(list(range(len(X)))) | |||
| # np.random.shuffle(index) | |||
| # X = X[index] | |||
| # X = [img for img in X] | |||
| # Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X] | |||
| # return X, Y | |||
| def divide_equations_by_len(equations, labels): | |||
| equations_by_len = {1: defaultdict(list), 0: defaultdict(list)} | |||
| for i, equation in enumerate(equations): | |||
| equations_by_len[labels[i]][len(equation)].append(equation) | |||
| return equations_by_len | |||
| def split_equation(equations_by_len, prop_train, prop_val): | |||
| """ | |||
| Split the equations in each length to training and validation data according to the proportion | |||
| """ | |||
| train_equations_by_len = {1: dict(), 0: dict()} | |||
| val_equations_by_len = {1: dict(), 0: dict()} | |||
| for label in range(2): | |||
| for equation_len, equations in equations_by_len[label].items(): | |||
| random.shuffle(equations) | |||
| train_equations_by_len[label][equation_len] = equations[ | |||
| : len(equations) // (prop_train + prop_val) * prop_train | |||
| ] | |||
| val_equations_by_len[label][equation_len] = equations[ | |||
| len(equations) // (prop_train + prop_val) * prop_train : | |||
| ] | |||
| return train_equations_by_len, val_equations_by_len | |||
| def get_hed(dataset="mnist", train=True): | |||
| if dataset == "mnist": | |||
| with open( | |||
| "./datasets/hed/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk", | |||
| "rb", | |||
| ) as f: | |||
| img_dataset = pickle.load(f) | |||
| elif dataset == "random": | |||
| with open( | |||
| "./datasets/hed/random_equation_data_train_len_26_test_len_26_sys_2_.pk", | |||
| "rb", | |||
| ) as f: | |||
| img_dataset = pickle.load(f) | |||
| else: | |||
| raise Exception("Undefined dataset") | |||
| X, _, Y = get_data(img_dataset, train) | |||
| equations_by_len = divide_equations_by_len(X, Y) | |||
| return equations_by_len | |||
| if __name__ == "__main__": | |||
| get_hed() | |||
| @@ -0,0 +1,81 @@ | |||
| :- ensure_loaded(['BK.pl']). | |||
| :- thread_setconcurrency(_, 8). | |||
| :- use_module(library(thread)). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% For propositionalisation | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| eval_inst_feature(Ex, Feature):- | |||
| eval_eq(Ex, Feature). | |||
| %% Evaluate instance given feature | |||
| eval_eq(Ex, Feature):- | |||
| parse_eq(Ex, eq(X,Y,Z)), | |||
| bitwise_calc(Feature,_,X,Y,Z), !. | |||
| %%%%%%%%%%%%%% | |||
| %% Abduction | |||
| %%%%%%%%%%%%%% | |||
| % Make abduction when given examples that have been interpreted as pseudo-labels | |||
| abduce(Exs, Delta_C) :- | |||
| abduce(Exs, [], Delta_C). | |||
| abduce([], Delta_C, Delta_C). | |||
| abduce([E|Exs], Delta_C0, Delta_C1) :- | |||
| calc(Delta_C0, Delta_C2, E), | |||
| abduce(Exs, Delta_C2, Delta_C1). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% Abduce pseudo-labels only | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| abduce_consistent_insts(Exs):- | |||
| abduce(Exs, _), !. | |||
| % (Experimental) Uncomment to use parallel abduction | |||
| % abduce_consistent_exs_concurrent(Exs), !. | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% Abduce Delta_C given pseudo-labels | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| consistent_inst_feature(Exs, Delta_C):- | |||
| abduce(Exs, Delta_C), !. | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% (Experimental) Parallel abduction | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| abduce_consistent_exs_concurrent(Exs) :- | |||
| % Split the current data batch into grounding examples and variable examples (which need to be revised) | |||
| split_exs(Exs, Ground_Exs, Var_Exs), | |||
| % Find the simplest Delta_C for grounding examples. | |||
| abduce(Ground_Exs, Ground_Delta_C), !, | |||
| % Extend Ground Delta_C into all possible variations | |||
| extend_op_rule(Ground_Delta_C, Possible_Deltas), | |||
| % Concurrently abduce the variable examples | |||
| maplist(append([abduce2, Var_Exs, Ground_Exs]), [[Possible_Deltas]], Call_List), | |||
| maplist(=.., Goals, Call_List), | |||
| % writeln(Goals), | |||
| first_solution(Var_Exs, Goals, [local(inf)]). | |||
| split_exs([],[],[]). | |||
| split_exs([E | Exs], [E | G_Exs], V_Exs):- | |||
| ground(E), !, | |||
| split_exs(Exs, G_Exs, V_Exs). | |||
| split_exs([E | Exs], G_Exs, [E | V_Exs]):- | |||
| split_exs(Exs, G_Exs, V_Exs). | |||
| :- table extend_op_rule/2. | |||
| extend_op_rule(Rules, Rules) :- | |||
| length(Rules, 4). | |||
| extend_op_rule(Rules, Ext) :- | |||
| op_rule(R), | |||
| valid_rules(Rules, R), | |||
| extend_op_rule([R|Rules], Ext). | |||
| % abduction without learning new Delta_C (Because they have been extended!) | |||
| abduce2([], _, _). | |||
| abduce2([E|Exs], Ground_Exs, Delta_C) :- | |||
| % abduce by finding ground examples | |||
| member(E, Ground_Exs), | |||
| abduce2(Exs, Ground_Exs, Delta_C). | |||
| abduce2([E|Exs], Ground_Exs, Delta_C) :- | |||
| eval_inst_feature(E, Delta_C), | |||
| abduce2(Exs, Ground_Exs, Delta_C). | |||
| @@ -1,60 +1,102 @@ | |||
| # coding: utf-8 | |||
| #================================================================# | |||
| # ================================================================# | |||
| # Copyright (C) 2021 Freecss All rights reserved. | |||
| # | |||
| # | |||
| # File Name :share_example.py | |||
| # Author :freecss | |||
| # Email :karlfreecss@gmail.com | |||
| # Created Date :2021/06/07 | |||
| # Description : | |||
| # | |||
| #================================================================# | |||
| # ================================================================# | |||
| from utils.plog import logger | |||
| import framework | |||
| from utils.plog import logger, INFO | |||
| import framework_hed | |||
| import torch.nn as nn | |||
| import torch | |||
| from models.lenet5 import LeNet5, SymbolNet | |||
| from models.basic_model import BasicModel | |||
| from models.nn import LeNet5, SymbolNet, SymbolNetAutoencoder | |||
| from models.basic_model import BasicModel, BasicDataset | |||
| from models.wabl_models import WABLBasicModel | |||
| from multiprocessing import Pool | |||
| import os | |||
| from abducer.abducer_base import AbducerBase | |||
| from abducer.kb import add_KB, hwf_KB | |||
| from abducer.kb import add_KB, HWF_KB, HED_prolog_KB | |||
| from datasets.mnist_add.get_mnist_add import get_mnist_add | |||
| from datasets.hwf.get_hwf import get_hwf | |||
| from datasets.hed.get_hed import get_hed, get_pretrain_data, split_equation | |||
| def run_test(): | |||
| # kb = add_KB(True) | |||
| kb = hwf_KB(True) | |||
| abducer = AbducerBase(kb) | |||
| # kb = hwf_KB(True) | |||
| # abducer = AbducerBase(kb) | |||
| kb = HED_prolog_KB() | |||
| abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) | |||
| recorder = logger() | |||
| # train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True) | |||
| # test_X, test_Z, test_Y = get_mnist_add(train = False, get_pseudo_label = True) | |||
| train_data = get_hwf(train = True, get_pseudo_label = True) | |||
| test_data = get_hwf(train = False, get_pseudo_label = True) | |||
| # train_X, train_Z, train_Y = get_mnist_add(train=True, get_pseudo_label=True) | |||
| # test_X, test_Z, test_Y = get_mnist_add(train=False, get_pseudo_label=True) | |||
| # train_data = get_hwf(train=True, get_pseudo_label=True) | |||
| # test_data = get_hwf(train=False, get_pseudo_label=True) | |||
| total_train_data = get_hed(train=True) | |||
| train_data, val_data = split_equation(total_train_data, 3, 1) | |||
| test_data = get_hed(train=False) | |||
| # cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:])) | |||
| cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list)) | |||
| cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) | |||
| if not os.path.exists("./weights/pretrain_weights.pth"): | |||
| pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"]) | |||
| pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y) | |||
| pretrain_data_loader = torch.utils.data.DataLoader( | |||
| pretrain_data, | |||
| batch_size=64, | |||
| shuffle=True, | |||
| ) | |||
| framework_hed.pretrain(cls_autoencoder, pretrain_data_loader, recorder) | |||
| torch.save( | |||
| cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth" | |||
| ) | |||
| cls.load_state_dict(torch.load("./weights/pretrain_weights.pth")) | |||
| criterion = nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) | |||
| optimizer = torch.optim.RMSprop( | |||
| cls.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6 | |||
| ) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=1, recorder=recorder) | |||
| base_model = BasicModel( | |||
| cls, | |||
| criterion, | |||
| optimizer, | |||
| device, | |||
| save_interval=1, | |||
| save_dir=recorder.save_dir, | |||
| batch_size=32, | |||
| num_epochs=10, | |||
| recorder=recorder, | |||
| ) | |||
| model = WABLBasicModel(base_model, kb.pseudo_label_list) | |||
| res = framework.train(model, abducer, train_data, test_data, sample_num = 10000, verbose = 1) | |||
| recorder.print(res) | |||
| res = framework_hed.train_with_rule( | |||
| model, abducer, train_data, val_data, recorder=recorder | |||
| ) | |||
| INFO(res) | |||
| recorder.dump() | |||
| return True | |||
| if __name__ == "__main__": | |||
| run_test() | |||
| @@ -1,261 +1,395 @@ | |||
| # coding: utf-8 | |||
| # ================================================================# | |||
| # Copyright (C) 2021 Freecss All rights reserved. | |||
| # | |||
| # File Name :framework.py | |||
| # Author :freecss | |||
| # Email :karlfreecss@gmail.com | |||
| # Created Date :2021/06/07 | |||
| # Description : | |||
| # | |||
| # ================================================================# | |||
| import pickle as pk | |||
| import numpy as np | |||
| from utils.utils import flatten, reform_idx | |||
| def get_rules_from_data(equations_true): | |||
| SAMPLES_PER_RULE = 3 | |||
| select_index = np.random.randint(len(equations_true), size=SAMPLES_PER_RULE) | |||
| select_equations = np.array(equations_true)[select_index] | |||
| def get_consist_idx(exs, abducer): | |||
| consistent_ex_idx = [] | |||
| label = [] | |||
| for idx, e in enumerate(exs): | |||
| if abducer.kb.logic_forward([e]): | |||
| consistent_ex_idx.append(idx) | |||
| label.append(e) | |||
| return consistent_ex_idx, label | |||
| def get_label(exs, solution, abducer): | |||
| all_address_flag = reform_idx(solution, exs) | |||
| consistent_ex_idx = [] | |||
| label = [] | |||
| for idx, ex in enumerate(exs): | |||
| address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] | |||
| candidate = abducer.kb.address_by_idx([ex], 1, address_idx, True) | |||
| if len(candidate) > 0: | |||
| consistent_ex_idx.append(idx) | |||
| label.append(candidate[0][0]) | |||
| return consistent_ex_idx, label | |||
| def get_percentage_precision(select_X, consistent_ex_idx, equation_label): | |||
| images = [] | |||
| for idx in consistent_ex_idx: | |||
| images.append(select_X[idx]) | |||
| ## TODO | |||
| model_labels = model.predict(images) | |||
| assert(len(flatten(model_labels)) == len(flatten(equation_label))) | |||
| return (flatten(model_labels) == flatten(equation_label)).sum() / len(flatten(model_labels)) | |||
| def abduce_and_train(model, abducer, train_X_true, select_num): | |||
| import random | |||
| random_seed = random.randint(0, 10000) | |||
| print("Selected random seed is : ", random_seed) | |||
| np.random.seed(random_seed) | |||
| random.seed(random_seed) | |||
| from models.nn import MLP | |||
| from models.basic_model import BasicModel, BasicDataset | |||
| import torch.nn as nn | |||
| import torch | |||
| from utils.plog import INFO, DEBUG, clocker | |||
| from utils.utils import flatten, reform_idx, block_sample | |||
| from sklearn.tree import DecisionTreeClassifier | |||
| def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): | |||
| result = {} | |||
| if char_acc_flag: | |||
| char_acc_num = 0 | |||
| char_num = 0 | |||
| for pred_z, z in zip(pred_Z, Z): | |||
| char_num += len(z) | |||
| for zidx in range(len(z)): | |||
| if pred_z[zidx] == z[zidx]: | |||
| char_acc_num += 1 | |||
| char_acc = char_acc_num / char_num | |||
| result["Character level accuracy"] = char_acc | |||
| abl_acc_num = 0 | |||
| for pred_z, y in zip(pred_Z, Y): | |||
| if logic_forward(pred_z) == y: | |||
| abl_acc_num += 1 | |||
| abl_acc = abl_acc_num / len(Y) | |||
| result["ABL accuracy"] = abl_acc | |||
| return result | |||
| def filter_data(X, abduced_Z): | |||
| finetune_Z = [] | |||
| finetune_X = [] | |||
| for abduced_x, abduced_z in zip(X, abduced_Z): | |||
| if abduced_z is not []: | |||
| finetune_X.append(abduced_x) | |||
| finetune_Z.append(abduced_z) | |||
| return finetune_X, finetune_Z | |||
| def pretrain(net, pretrain_data_loader, recorder): | |||
| INFO("Pretrain Start") | |||
| criterion = nn.MSELoss() | |||
| optimizer = torch.optim.RMSprop( | |||
| net.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6 | |||
| ) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| pretrain_model = BasicModel( | |||
| net, | |||
| criterion, | |||
| optimizer, | |||
| device, | |||
| save_interval=1, | |||
| save_dir=recorder.save_dir, | |||
| num_epochs=10, | |||
| recorder=recorder, | |||
| ) | |||
| pretrain_model.fit(pretrain_data_loader) | |||
| def get_char_acc(model, X, consistent_pred_res): | |||
| pred_res = flatten(model.predict(X)["cls"]) | |||
| assert len(pred_res) == len(flatten(consistent_pred_res)) | |||
| return sum( | |||
| [ | |||
| pred_res[idx] == flatten(consistent_pred_res)[idx] | |||
| for idx in range(len(pred_res)) | |||
| ] | |||
| ) / len(pred_res) | |||
| def gen_mappings(chars, symbs): | |||
| n_char = len(chars) | |||
| n_symbs = len(symbs) | |||
| if n_char != n_symbs: | |||
| INFO("Characters and symbols size dosen't match.") | |||
| return | |||
| from itertools import permutations | |||
| mappings = [] | |||
| perms = permutations(symbs) | |||
| for p in perms: | |||
| mappings.append(dict(zip(chars, list(p)))) | |||
| return mappings | |||
| def map_res(original_pred_res, m): | |||
| pred_res = [[m[symbol] for symbol in formula] for formula in original_pred_res] | |||
| return pred_res | |||
| select_index = np.random.randint(len(train_X_true), size=select_num) | |||
| select_X = train_X_true[select_index] | |||
| exs = select_X.predict() | |||
| # e.g. when select_num == 10, exs = [[1, '+', 0, '=', 1, 0], [1, '+', 0, '=', 1, 0], [1, '+', 0, '=', 1, 0], [0, '+', 0, '=', 0], [1, '+', 0, '=', 1, 0],\ | |||
| # [1, '+', 0, '=', 1, 0], [1, '+', 0, '=', 1, 0], [1, '+', 0, '=', 1, 0], [0, '+', 0, '=', 0], [1, '+', 0, '=', 1, 0]] | |||
| print("This is the model's current label:", exs) | |||
| # 1. Check if it can abduce rules without changing any labels | |||
| consistent_ex_idx, equation_label = get_consist_idx(exs) | |||
| max_abduce_num = 10 | |||
| if len(consistent_ex_idx) == 0: | |||
| # 2. Find the possible wrong position in symbols and Abduce the right symbol through logic module | |||
| solution = abducer.zoopt_get_solution(exs, [1] * len(exs), max_abduce_num) | |||
| consistent_ex_idx, equation_label = get_label(exs, solution, abducer) | |||
| # Still cannot find | |||
| if len(consistent_ex_idx) == 0: | |||
| return 0, 0 | |||
| ## TODO: train | |||
| # train_pool_X = np.concatenate(select_X[consistent_ex_idx]).reshape( | |||
| # -1, h, w, d) | |||
| # train_pool_Y = np_utils.to_categorical( | |||
| # flatten(exs[consistent_ex_idx]), | |||
| # num_classes=len(labels)) # Convert the symbol to network output | |||
| # assert (len(train_pool_X) == len(train_pool_Y)) | |||
| # print("\nTrain pool size is :", len(train_pool_X)) | |||
| # print("Training...") | |||
| # base_model.fit(train_pool_X, | |||
| # train_pool_Y, | |||
| # batch_size=BATCHSIZE, | |||
| # epochs=NN_EPOCHS, | |||
| # verbose=0) | |||
| # consistent_percentage, batch_label_model_precision = get_percentage_precision( | |||
| # base_model, select_equations, consist_re, shape) | |||
| consistent_percentage = len(consistent_ex_idx) / select_num | |||
| batch_label_model_precision = get_percentage_precision(exs, consistent_ex_idx, equation_label) | |||
| return consistent_percentage, batch_label_model_precision | |||
| def get_rules(exs): | |||
| consistent_ex_idx, equation_label = get_consist_idx(exs) | |||
| consist_exs = [] | |||
| for idx in consistent_ex_idx: | |||
| consist_exs.append(exs[idx]) | |||
| if len(consist_exs) == 0: | |||
| return None | |||
| else: | |||
| return abducer.abduce_rule(consist_exs) | |||
| def get_rules_from_data(train_X_true, samples_per_rule, logic_output_dim): | |||
| def abduce_and_train(model, abducer, train_X_true, select_num): | |||
| select_idx = np.random.randint(len(train_X_true), size=select_num) | |||
| X = [] | |||
| for idx in select_idx: | |||
| X.append(train_X_true[idx]) | |||
| pred_res = model.predict(X)["cls"] | |||
| maps = gen_mappings(["+", "=", 0, 1], ["+", "=", 0, 1]) | |||
| consistent_idx = [] | |||
| consistent_pred_res = [] | |||
| import copy | |||
| original_pred_res = copy.deepcopy(pred_res) | |||
| mapping = None | |||
| for m in maps: | |||
| pred_res = map_res(original_pred_res, m) | |||
| remapping = {} | |||
| for key, value in m.items(): | |||
| remapping[value] = key | |||
| max_abduce_num = 10 | |||
| solution = abducer.zoopt_get_solution( | |||
| pred_res, [1] * len(pred_res), max_abduce_num | |||
| ) | |||
| all_address_flag = reform_idx(solution, pred_res) | |||
| consistent_idx_tmp = [] | |||
| consistent_pred_res_tmp = [] | |||
| for idx in range(len(pred_res)): | |||
| address_idx = [ | |||
| i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 | |||
| ] | |||
| candidate = abducer.kb.address_by_idx([pred_res[idx]], 1, address_idx, True) | |||
| if len(candidate) > 0: | |||
| consistent_idx_tmp.append(idx) | |||
| consistent_pred_res_tmp.append( | |||
| [remapping[symbol] for symbol in candidate[0][0]] | |||
| ) | |||
| if len(consistent_idx_tmp) > len(consistent_idx): | |||
| consistent_idx = consistent_idx_tmp | |||
| consistent_pred_res = consistent_pred_res_tmp | |||
| mapping = m | |||
| if len(consistent_idx) == 0: | |||
| return 0, 0, None | |||
| INFO("Consistent predict results are:", map_res(consistent_pred_res, mapping)) | |||
| INFO("Train pool size is:", len(flatten(consistent_pred_res))) | |||
| INFO("Start to use abduced pseudo label to train model...") | |||
| model.train([X[idx] for idx in consistent_idx], consistent_pred_res) | |||
| consistent_acc = len(consistent_idx) / select_num | |||
| char_acc = get_char_acc( | |||
| model, [X[idx] for idx in consistent_idx], consistent_pred_res | |||
| ) | |||
| INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc)) | |||
| return consistent_acc, char_acc, mapping | |||
| def get_rules_from_data( | |||
| model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim | |||
| ): | |||
| rules = [] | |||
| for _ in range(logic_output_dim): | |||
| while True: | |||
| select_index = np.random.randint(len(train_X_true), size=samples_per_rule) | |||
| select_X = train_X_true[select_index] | |||
| ## TODO | |||
| exs = select_X.predict() | |||
| rule = get_rules(exs) | |||
| if rule != None: | |||
| break | |||
| select_idx = np.random.randint(len(train_X_true), size=samples_per_rule) | |||
| X = [] | |||
| for idx in select_idx: | |||
| X.append(train_X_true[idx]) | |||
| pred_res = model.predict(X)["cls"] | |||
| pred_res = [[mapping[symbol] for symbol in formula] for formula in pred_res] | |||
| consistent_idx = [] | |||
| consistent_pred_res = [] | |||
| for idx in range(len(pred_res)): | |||
| if abducer.kb.logic_forward([pred_res[idx]]): | |||
| consistent_idx.append(idx) | |||
| consistent_pred_res.append(pred_res[idx]) | |||
| if len(consistent_pred_res) != 0: | |||
| rule = abducer.abduce_rules(consistent_pred_res) | |||
| if rule != None: | |||
| break | |||
| rules.append(rule) | |||
| INFO('Learned rules from data:') | |||
| for rule in rules: | |||
| INFO(rule) | |||
| return rules | |||
| def get_mlp_vector(X, rules): | |||
| ## TODO | |||
| exs = np.argmax(model.predict(X)) | |||
| def get_mlp_vector(model, abducer, mapping, X, rules): | |||
| pred_res = model.predict([X])["cls"] | |||
| pred_res = [[mapping[symbol] for symbol in formula] for formula in pred_res] | |||
| vector = [] | |||
| for rule in rules: | |||
| if abducer.kb.consist_rule(exs, rule): | |||
| if abducer.kb.consist_rule(pred_res, rule): | |||
| vector.append(1) | |||
| else: | |||
| vector.append(0) | |||
| return vector | |||
| def get_mlp_data(X_true, X_false, rules): | |||
| def get_mlp_data(model, abducer, mapping, X_true, X_false, rules): | |||
| mlp_vectors = [] | |||
| mlp_labels = [] | |||
| for X in X_true: | |||
| mlp_vectors.append(get_mlp_vector(X, rules)) | |||
| mlp_vectors.append(get_mlp_vector(model, abducer, mapping, X, rules)) | |||
| mlp_labels.append(1) | |||
| for X in X_false: | |||
| mlp_vectors.append(get_mlp_vector(X, rules)) | |||
| mlp_vectors.append(get_mlp_vector(model, abducer, mapping, X, rules)) | |||
| mlp_labels.append(0) | |||
| return np.array(mlp_vectors), np.array(mlp_labels) | |||
| return np.array(mlp_vectors, dtype=np.float32), np.array(mlp_labels, dtype=np.int64) | |||
| def validation(train_X_true, train_X_false, val_X_true, val_X_false): | |||
| print("Now checking if we can go to next course") | |||
| def validation( | |||
| model, | |||
| abducer, | |||
| mapping, | |||
| train_X_true, | |||
| train_X_false, | |||
| val_X_true, | |||
| val_X_false, | |||
| recorder, | |||
| ): | |||
| INFO("Now checking if we can go to next course") | |||
| samples_per_rule = 3 | |||
| logic_output_dim = 50 | |||
| print("Now checking if we can go to next course") | |||
| rules = get_rules_from_data(train_X_true, samples_per_rule, logic_output_dim) | |||
| mlp_train_vectors, mlp_train_labels = get_mlp_data(train_X_true, train_X_false, rules) | |||
| index = np.array(list(range(len(mlp_train_labels)))) | |||
| np.random.shuffle(index) | |||
| mlp_train_vectors = mlp_train_vectors[index] | |||
| mlp_train_labels = mlp_train_labels[index] | |||
| rules = get_rules_from_data( | |||
| model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim | |||
| ) | |||
| mlp_train_vectors, mlp_train_labels = get_mlp_data( | |||
| model, abducer, mapping, train_X_true, train_X_false, rules | |||
| ) | |||
| idx = np.array(list(range(len(mlp_train_labels)))) | |||
| np.random.shuffle(idx) | |||
| mlp_train_vectors = mlp_train_vectors[idx] | |||
| mlp_train_labels = mlp_train_labels[idx] | |||
| best_accuracy = 0 | |||
| #Try three times to find the best mlp | |||
| # Try three times to find the best mlp | |||
| for _ in range(3): | |||
| print("Training mlp...") | |||
| ### TODO | |||
| # mlp_model = get_mlp_net(logic_output_dim) | |||
| # mlp_model.compile(loss='binary_crossentropy', | |||
| # optimizer='rmsprop', | |||
| # metrics=['accuracy']) | |||
| # mlp_model.fit(mlp_train_vectors, | |||
| # mlp_train_labels, | |||
| # epochs=MLP_EPOCHS, | |||
| # batch_size=MLP_BATCHSIZE, | |||
| # verbose=0) | |||
| #Prepare MLP validation data | |||
| mlp_val_vectors, mlp_val_labels = get_mlp_data(val_X_true, val_X_false, rules) | |||
| ## TODO | |||
| #Get MLP validation result | |||
| # result = mlp_model.evaluate(mlp_val_vectors, | |||
| # mlp_val_labels, | |||
| # batch_size=MLP_BATCHSIZE, | |||
| # verbose=0) | |||
| print("MLP validation result:", result) | |||
| accuracy = result[1] | |||
| INFO("Training mlp...") | |||
| mlp = MLP(input_dim=logic_output_dim) | |||
| criterion = nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999)) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| mlp_model = BasicModel( | |||
| mlp, | |||
| criterion, | |||
| optimizer, | |||
| device, | |||
| batch_size=128, | |||
| num_epochs=60, | |||
| recorder=recorder, | |||
| ) | |||
| mlp_train_data = BasicDataset(mlp_train_vectors, mlp_train_labels) | |||
| mlp_train_data_loader = torch.utils.data.DataLoader( | |||
| mlp_train_data, | |||
| batch_size=128, | |||
| shuffle=True, | |||
| ) | |||
| loss = mlp_model.fit(mlp_train_data_loader) | |||
| INFO("mlp training loss is %f" % loss) | |||
| mlp_val_vectors, mlp_val_labels = get_mlp_data( | |||
| model, abducer, mapping, val_X_true, val_X_false, rules | |||
| ) | |||
| # Get MLP validation result | |||
| mlp_val_data = BasicDataset(mlp_val_vectors, mlp_val_labels) | |||
| mlp_val_data_loader = torch.utils.data.DataLoader( | |||
| mlp_val_data, | |||
| batch_size=64, | |||
| shuffle=True, | |||
| ) | |||
| accuracy = mlp_model.val(mlp_val_data_loader) | |||
| if accuracy > best_accuracy: | |||
| best_accuracy = accuracy | |||
| return best_accuracy | |||
| def train_HED(model, abducer, train_data, test_data, epochs=50, select_num=10, verbose=-1): | |||
| train_X, train_Z, train_Y = train_data | |||
| test_X, test_Z, test_Y = test_data | |||
| def train_with_rule( | |||
| model, | |||
| abducer, | |||
| train_data, | |||
| val_data, | |||
| select_num=10, | |||
| recorder=None | |||
| ): | |||
| train_X = train_data | |||
| val_X = val_data | |||
| min_len = 5 | |||
| max_len = 8 | |||
| cp_threshold = 0.9 | |||
| blmp_threshold = 0.9 | |||
| cnt_threshold = 5 | |||
| acc_threshold = 0.86 | |||
| # Start training / for each length of equations | |||
| for equation_len in range(min_len, max_len): | |||
| ### TODO: get_data, e.g. | |||
| # train_X_true = train_X['True'][equation_len] | |||
| # train_X_true.append(train_X['True'][equation_len + 1]) | |||
| INFO("============== equation_len:%d ================" % (equation_len)) | |||
| train_X_true = train_X[1][equation_len] | |||
| train_X_false = train_X[0][equation_len] | |||
| val_X_true = val_X[1][equation_len] | |||
| val_X_false = val_X[0][equation_len] | |||
| train_X_true.extend(train_X[1][equation_len + 1]) | |||
| train_X_false.extend(train_X[0][equation_len + 1]) | |||
| val_X_true.extend(val_X[1][equation_len + 1]) | |||
| val_X_false.extend(val_X[0][equation_len + 1]) | |||
| condition_cnt = 0 | |||
| while True: | |||
| # Abduce and train NN | |||
| consistent_percentage, batch_label_model_precision = abduce_and_train(model, abducer, train_X_true, select_num) | |||
| if consistent_percentage == 0: | |||
| consistent_acc, char_acc, mapping = abduce_and_train( | |||
| model, abducer, train_X_true, select_num | |||
| ) | |||
| if consistent_acc == 0: | |||
| continue | |||
| # Test if we can use mlp to evaluate | |||
| if consistent_percentage >= cp_threshold and batch_label_model_precision >= blmp_threshold: | |||
| if consistent_acc >= 0.9 and char_acc >= 0.9: | |||
| condition_cnt += 1 | |||
| else: | |||
| condition_cnt = 0 | |||
| # The condition has been satisfied continuously five times | |||
| if condition_cnt >= cnt_threshold: | |||
| best_accuracy = validation(train_X_true, train_X_false, val_X_true, val_X_false) | |||
| if condition_cnt >= 5: | |||
| best_accuracy = validation( | |||
| model, | |||
| abducer, | |||
| mapping, | |||
| train_X_true, | |||
| train_X_false, | |||
| val_X_true, | |||
| val_X_false, | |||
| recorder, | |||
| ) | |||
| INFO("best_accuracy is %f" % (best_accuracy)) | |||
| # decide next course or restart | |||
| if best_accuracy > acc_threshold: | |||
| # Save model and go to next course | |||
| ## TODO: model.save_weights() | |||
| if best_accuracy > 0.86: | |||
| torch.save( | |||
| model.cls_list[0].model.state_dict(), | |||
| "./weights/train_weights_%d.pth" % equation_len, | |||
| ) | |||
| break | |||
| else: | |||
| # Restart current course: reload model | |||
| if equation_len == min_len: | |||
| ## TODO: model.set_weights(pretrain_model.get_weights()) | |||
| model.set_weights() | |||
| model.cls_list[0].model.load_state_dict( | |||
| torch.load("./weights/pretrain_weights.pth") | |||
| ) | |||
| else: | |||
| ## TODO: model.load_weights() | |||
| model.load_weights() | |||
| print("Failed! Reload model.") | |||
| model.cls_list[0].model.load_state_dict( | |||
| torch.load( | |||
| "./weights/train_weights_%d.pth" % (equation_len - 1) | |||
| ) | |||
| ) | |||
| condition_cnt = 0 | |||
| return model | |||
| @@ -1,16 +1,17 @@ | |||
| # coding: utf-8 | |||
| #================================================================# | |||
| # ================================================================# | |||
| # Copyright (C) 2020 Freecss All rights reserved. | |||
| # | |||
| # | |||
| # File Name :basic_model.py | |||
| # Author :freecss | |||
| # Email :karlfreecss@gmail.com | |||
| # Created Date :2020/11/21 | |||
| # Description : | |||
| # | |||
| #================================================================# | |||
| # ================================================================# | |||
| import sys | |||
| sys.path.append("..") | |||
| import torch | |||
| @@ -19,6 +20,24 @@ from torch.utils.data import Dataset | |||
| import os | |||
| from multiprocessing import Pool | |||
| class BasicDataset(Dataset): | |||
| def __init__(self, X, Y): | |||
| self.X = X | |||
| self.Y = Y | |||
| def __len__(self): | |||
| return len(self.X) | |||
| def __getitem__(self, index): | |||
| assert index < len(self), "index range error" | |||
| img = self.X[index] | |||
| label = self.Y[index] | |||
| return (img, label) | |||
| class XYDataset(Dataset): | |||
| def __init__(self, X, Y, transform=None): | |||
| self.X = X | |||
| @@ -31,8 +50,8 @@ class XYDataset(Dataset): | |||
| return len(self.X) | |||
| def __getitem__(self, index): | |||
| assert index < len(self), 'index range error' | |||
| assert index < len(self), "index range error" | |||
| img = self.X[index] | |||
| if self.transform is not None: | |||
| img = self.transform(img) | |||
| @@ -41,31 +60,35 @@ class XYDataset(Dataset): | |||
| return (img, label) | |||
| class FakeRecorder(): | |||
| class FakeRecorder: | |||
| def __init__(self): | |||
| pass | |||
| def print(self, *x): | |||
| pass | |||
| class BasicModel(): | |||
| def __init__(self, | |||
| model, | |||
| criterion, | |||
| optimizer, | |||
| device, | |||
| batch_size = 1, | |||
| num_epochs = 1, | |||
| stop_loss = 0.01, | |||
| num_workers = 0, | |||
| save_interval = None, | |||
| save_dir = None, | |||
| transform = None, | |||
| collate_fn = None, | |||
| recorder = None): | |||
| class BasicModel: | |||
| def __init__( | |||
| self, | |||
| model, | |||
| criterion, | |||
| optimizer, | |||
| device, | |||
| batch_size=1, | |||
| num_epochs=1, | |||
| stop_loss=0.01, | |||
| num_workers=0, | |||
| save_interval=None, | |||
| save_dir=None, | |||
| transform=None, | |||
| collate_fn=None, | |||
| recorder=None, | |||
| ): | |||
| self.model = model.to(device) | |||
| self.batch_size = batch_size | |||
| self.num_epochs = num_epochs | |||
| self.stop_loss = stop_loss | |||
| @@ -103,9 +126,7 @@ class BasicModel(): | |||
| recorder.print("Model fitted, minimal loss is ", min_loss) | |||
| return loss_value | |||
| def fit(self, data_loader = None, | |||
| X = None, | |||
| y = None): | |||
| def fit(self, data_loader=None, X=None, y=None): | |||
| if data_loader is None: | |||
| data_loader = self._data_loader(X, y) | |||
| return self._fit(data_loader, self.num_epochs, self.stop_loss) | |||
| @@ -115,7 +136,7 @@ class BasicModel(): | |||
| criterion = self.criterion | |||
| optimizer = self.optimizer | |||
| device = self.device | |||
| model.train() | |||
| total_loss, total_num = 0.0, 0 | |||
| @@ -136,7 +157,7 @@ class BasicModel(): | |||
| def _predict(self, data_loader): | |||
| model = self.model | |||
| device = self.device | |||
| model.eval() | |||
| with torch.no_grad(): | |||
| @@ -145,20 +166,20 @@ class BasicModel(): | |||
| data = data.to(device) | |||
| out = model(data) | |||
| results.append(out) | |||
| return torch.cat(results, axis=0) | |||
| def predict(self, data_loader = None, X = None, print_prefix = ""): | |||
| def predict(self, data_loader=None, X=None, print_prefix=""): | |||
| recorder = self.recorder | |||
| recorder.print('Start Predict Class ', print_prefix) | |||
| recorder.print("Start Predict Class ", print_prefix) | |||
| if data_loader is None: | |||
| data_loader = self._data_loader(X) | |||
| return self._predict(data_loader).argmax(axis=1).cpu().numpy() | |||
| def predict_proba(self, data_loader = None, X = None, print_prefix = ""): | |||
| def predict_proba(self, data_loader=None, X=None, print_prefix=""): | |||
| recorder = self.recorder | |||
| recorder.print('Start Predict Probability ', print_prefix) | |||
| # recorder.print('Start Predict Probability ', print_prefix) | |||
| if data_loader is None: | |||
| data_loader = self._data_loader(X) | |||
| @@ -168,7 +189,7 @@ class BasicModel(): | |||
| model = self.model | |||
| criterion = self.criterion | |||
| device = self.device | |||
| model.eval() | |||
| total_correct_num, total_num, total_loss = 0, 0, 0.0 | |||
| @@ -179,32 +200,37 @@ class BasicModel(): | |||
| out = model(data) | |||
| correct_num = sum(target == out.argmax(axis=1)).item() | |||
| if len(out.shape) > 1: | |||
| correct_num = sum(target == out.argmax(axis=1)).item() | |||
| else: | |||
| correct_num = sum(target == (out > 0.5)).item() | |||
| loss = criterion(out, target) | |||
| total_loss += loss.item() * data.size(0) | |||
| total_correct_num += correct_num | |||
| total_num += data.size(0) | |||
| mean_loss = total_loss / total_num | |||
| accuracy = total_correct_num / total_num | |||
| return mean_loss, accuracy | |||
| def val(self, data_loader = None, X = None, y = None, print_prefix = ""): | |||
| def val(self, data_loader=None, X=None, y=None, print_prefix=""): | |||
| recorder = self.recorder | |||
| recorder.print('Start val ', print_prefix) | |||
| recorder.print("Start val ", print_prefix) | |||
| if data_loader is None: | |||
| data_loader = self._data_loader(X, y) | |||
| mean_loss, accuracy = self._val(data_loader) | |||
| recorder.print('[%s] Val loss: %f, accuray: %f' % (print_prefix, mean_loss, accuracy)) | |||
| recorder.print( | |||
| "[%s] Val loss: %f, accuray: %f" % (print_prefix, mean_loss, accuracy) | |||
| ) | |||
| return accuracy | |||
| def score(self, data_loader = None, X = None, y = None, print_prefix = ""): | |||
| def score(self, data_loader=None, X=None, y=None, print_prefix=""): | |||
| return self.val(data_loader, X, y, print_prefix) | |||
| def _data_loader(self, X, y = None): | |||
| def _data_loader(self, X, y=None): | |||
| collate_fn = self.collate_fn | |||
| transform = self.transform | |||
| @@ -212,9 +238,14 @@ class BasicModel(): | |||
| y = [0] * len(X) | |||
| dataset = XYDataset(X, y, transform=transform) | |||
| sampler = None | |||
| data_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, \ | |||
| shuffle=False, sampler=sampler, num_workers=int(self.num_workers), \ | |||
| collate_fn=collate_fn) | |||
| data_loader = torch.utils.data.DataLoader( | |||
| dataset, | |||
| batch_size=self.batch_size, | |||
| shuffle=False, | |||
| sampler=sampler, | |||
| num_workers=int(self.num_workers), | |||
| collate_fn=collate_fn, | |||
| ) | |||
| return data_loader | |||
| def save(self, epoch_id, save_dir): | |||
| @@ -237,7 +268,6 @@ class BasicModel(): | |||
| load_path = os.path.join(load_dir, str(epoch_id) + "_opt.pth") | |||
| self.optimizer.load_state_dict(torch.load(load_path)) | |||
| if __name__ == "__main__": | |||
| pass | |||
| @@ -0,0 +1,152 @@ | |||
| # coding: utf-8 | |||
| # ================================================================# | |||
| # Copyright (C) 2021 Freecss All rights reserved. | |||
| # | |||
| # File Name :lenet5.py | |||
| # Author :freecss | |||
| # Email :karlfreecss@gmail.com | |||
| # Created Date :2021/03/03 | |||
| # Description : | |||
| # | |||
| # ================================================================# | |||
| import sys | |||
| sys.path.append("..") | |||
| import torchvision | |||
| import torch | |||
| from torch import nn | |||
| from torch.nn import functional as F | |||
| from torch.autograd import Variable | |||
| import torchvision.transforms as transforms | |||
| import numpy as np | |||
| from models.basic_model import BasicModel | |||
| import utils.plog as plog | |||
| class LeNet5(nn.Module): | |||
| def __init__(self, num_classes=10, image_size=(28, 28)): | |||
| super().__init__() | |||
| self.conv1 = nn.Conv2d(1, 6, 3, padding=1) | |||
| self.conv2 = nn.Conv2d(6, 16, 3) | |||
| self.conv3 = nn.Conv2d(16, 16, 3) | |||
| feature_map_size = (np.array(image_size) // 2 - 2) // 2 - 2 | |||
| num_features = 16 * feature_map_size[0] * feature_map_size[1] | |||
| self.fc1 = nn.Linear(num_features, 120) | |||
| self.fc2 = nn.Linear(120, 84) | |||
| self.fc3 = nn.Linear(84, num_classes) | |||
| def forward(self, x): | |||
| """前向传播函数""" | |||
| x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) | |||
| x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2)) | |||
| x = F.relu(self.conv3(x)) | |||
| x = x.view(-1, self.num_flat_features(x)) | |||
| # print(x.size()) | |||
| x = F.relu(self.fc1(x)) | |||
| x = F.relu(self.fc2(x)) | |||
| x = self.fc3(x) | |||
| return x | |||
| def num_flat_features(self, x): | |||
| size = x.size()[1:] | |||
| num_features = 1 | |||
| for s in size: | |||
| num_features *= s | |||
| return num_features | |||
| # class SymbolNet(nn.Module): | |||
| # def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
| # super(SymbolNet, self).__init__() | |||
| # self.conv1 = nn.Sequential( | |||
| # nn.Conv2d(1, 32, 3, stride=1, padding=1), | |||
| # nn.ReLU(inplace=True), | |||
| # nn.BatchNorm2d(32), | |||
| # ) | |||
| # self.conv2 = nn.Sequential( | |||
| # nn.Conv2d(32, 64, 3, stride=1, padding=1), | |||
| # nn.ReLU(inplace=True), | |||
| # nn.MaxPool2d(kernel_size=2, stride=2), | |||
| # nn.BatchNorm2d(64), | |||
| # nn.Dropout(0.25), | |||
| # ) | |||
| # num_features = 64 * (image_size[0] // 2) * (image_size[1] // 2) | |||
| # self.fc1 = nn.Sequential( | |||
| # nn.Linear(num_features, 128), nn.ReLU(inplace=True), nn.Dropout(0.5) | |||
| # ) | |||
| # self.fc2 = nn.Sequential(nn.Linear(128, num_classes), nn.Softmax(dim=1)) | |||
| # def forward(self, x): | |||
| # x = self.conv1(x) | |||
| # x = self.conv2(x) | |||
| # x = torch.flatten(x, 1) | |||
| # x = self.fc1(x) | |||
| # x = self.fc2(x) | |||
| # return x | |||
| class SymbolNet(nn.Module): | |||
| def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
| super(SymbolNet, self).__init__() | |||
| self.conv1 = nn.Sequential( | |||
| nn.Conv2d(1, 32, 5, stride=1), | |||
| nn.ReLU(), | |||
| nn.MaxPool2d(kernel_size=2, stride=2), | |||
| nn.BatchNorm2d(32, momentum=0.99, eps=0.001), | |||
| ) | |||
| self.conv2 = nn.Sequential( | |||
| nn.Conv2d(32, 64, 5, padding=2, stride=1), | |||
| nn.ReLU(), | |||
| nn.MaxPool2d(kernel_size=2, stride=2), | |||
| nn.BatchNorm2d(64, momentum=0.99, eps=0.001), | |||
| ) | |||
| num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) | |||
| self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) | |||
| self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) | |||
| self.fc3 = nn.Sequential(nn.Linear(84, num_classes), nn.Softmax(dim=1)) | |||
| def forward(self, x): | |||
| x = self.conv1(x) | |||
| x = self.conv2(x) | |||
| x = torch.flatten(x, 1) | |||
| x = self.fc1(x) | |||
| x = self.fc2(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| class SymbolNetAutoencoder(nn.Module): | |||
| def __init__(self, num_classes=4, image_size=(28, 28, 1)): | |||
| super(SymbolNetAutoencoder, self).__init__() | |||
| self.base_model = SymbolNet(num_classes, image_size) | |||
| self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) | |||
| self.fc2 = nn.Sequential( | |||
| nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU() | |||
| ) | |||
| def forward(self, x): | |||
| x = self.base_model(x) | |||
| x = self.fc1(x) | |||
| x = self.fc2(x) | |||
| return x | |||
| class MLP(nn.Module): | |||
| def __init__(self, input_dim=50, num_classes=2): | |||
| super(MLP, self).__init__() | |||
| assert input_dim > 0 | |||
| hidden_dim = int(np.sqrt(input_dim)) | |||
| self.fc1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU()) | |||
| self.fc2 = nn.Sequential(nn.Linear(hidden_dim, num_classes), nn.Softmax(dim=1)) | |||
| def forward(self, x): | |||
| x = self.fc1(x) | |||
| x = self.fc2(x) | |||
| return x | |||
| @@ -1,14 +1,14 @@ | |||
| # coding: utf-8 | |||
| #================================================================# | |||
| # ================================================================# | |||
| # Copyright (C) 2020 Freecss All rights reserved. | |||
| # | |||
| # | |||
| # File Name :models.py | |||
| # Author :freecss | |||
| # Email :karlfreecss@gmail.com | |||
| # Created Date :2020/04/02 | |||
| # Description : | |||
| # | |||
| #================================================================# | |||
| # ================================================================# | |||
| from itertools import chain | |||
| from sklearn.tree import DecisionTreeClassifier | |||
| @@ -29,14 +29,17 @@ import random | |||
| from sklearn.neighbors import KNeighborsClassifier | |||
| import numpy as np | |||
| def get_part_data(X, i): | |||
| return list(map(lambda x : x[i], X)) | |||
| return list(map(lambda x: x[i], X)) | |||
| def merge_data(X): | |||
| ret_mark = list(map(lambda x : len(x), X)) | |||
| ret_mark = list(map(lambda x: len(x), X)) | |||
| ret_X = list(chain(*X)) | |||
| return ret_X, ret_mark | |||
| def reshape_data(Y, marks): | |||
| begin_mark = 0 | |||
| ret_Y = [] | |||
| @@ -48,42 +51,44 @@ def reshape_data(Y, marks): | |||
| class WABLBasicModel: | |||
| def __init__(self, base_model, pseudo_label_list): | |||
| self.cls_list = [] | |||
| self.cls_list.append(base_model) | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.mapping = dict(zip(pseudo_label_list, list(range(len(pseudo_label_list))))) | |||
| self.remapping = dict(zip(list(range(len(pseudo_label_list))), pseudo_label_list)) | |||
| self.remapping = dict( | |||
| zip(list(range(len(pseudo_label_list))), pseudo_label_list) | |||
| ) | |||
| def predict(self, X): | |||
| data_X, marks = merge_data(X) | |||
| prob = self.cls_list[0].predict_proba(X = data_X) | |||
| _cls = prob.argmax(axis = 1) | |||
| cls = list(map(lambda x : self.remapping[x], _cls)) | |||
| prob = self.cls_list[0].predict_proba(X=data_X) | |||
| _cls = prob.argmax(axis=1) | |||
| cls = list(map(lambda x: self.remapping[x], _cls)) | |||
| prob = reshape_data(prob, marks) | |||
| cls = reshape_data(cls, marks) | |||
| return {"cls" : cls, "prob" : prob} | |||
| return {"cls": cls, "prob": prob} | |||
| def valid(self, X, Y): | |||
| data_X, _ = merge_data(X) | |||
| _data_Y, _ = merge_data(Y) | |||
| data_Y = list(map(lambda y : self.mapping[y], _data_Y)) | |||
| score = self.cls_list[0].score(X = data_X, y = data_Y) | |||
| data_Y = list(map(lambda y: self.mapping[y], _data_Y)) | |||
| score = self.cls_list[0].score(X=data_X, y=data_Y) | |||
| return score, [score] | |||
| def train(self, X, Y): | |||
| #self.label_lists = [] | |||
| # self.label_lists = [] | |||
| data_X, _ = merge_data(X) | |||
| _data_Y, _ = merge_data(Y) | |||
| data_Y = list(map(lambda y : self.mapping[y], _data_Y)) | |||
| self.cls_list[0].fit(X = data_X, y = data_Y) | |||
| data_Y = list(map(lambda y: self.mapping[y], _data_Y)) | |||
| self.cls_list[0].fit(X=data_X, y=data_Y) | |||
| class DecisionTree(WABLBasicModel): | |||
| def __init__(self, code_len, label_lists, share = False): | |||
| def __init__(self, code_len, label_lists, share=False): | |||
| self.code_len = code_len | |||
| self._set_label_lists(label_lists) | |||
| @@ -91,14 +96,19 @@ class DecisionTree(WABLBasicModel): | |||
| self.share = share | |||
| if share: | |||
| # 本质上是同一个分类器 | |||
| self.cls_list.append(DecisionTreeClassifier(random_state = 0, min_samples_leaf = 3)) | |||
| self.cls_list.append( | |||
| DecisionTreeClassifier(random_state=0, min_samples_leaf=3) | |||
| ) | |||
| self.cls_list = self.cls_list * self.code_len | |||
| else: | |||
| for _ in range(code_len): | |||
| self.cls_list.append(DecisionTreeClassifier(random_state = 0, min_samples_leaf = 3)) | |||
| self.cls_list.append( | |||
| DecisionTreeClassifier(random_state=0, min_samples_leaf=3) | |||
| ) | |||
| class KNN(WABLBasicModel): | |||
| def __init__(self, code_len, label_lists, share = False, k = 3): | |||
| def __init__(self, code_len, label_lists, share=False, k=3): | |||
| self.code_len = code_len | |||
| self._set_label_lists(label_lists) | |||
| @@ -106,14 +116,15 @@ class KNN(WABLBasicModel): | |||
| self.share = share | |||
| if share: | |||
| # 本质上是同一个分类器 | |||
| self.cls_list.append(KNeighborsClassifier(n_neighbors = k)) | |||
| self.cls_list.append(KNeighborsClassifier(n_neighbors=k)) | |||
| self.cls_list = self.cls_list * self.code_len | |||
| else: | |||
| for _ in range(code_len): | |||
| self.cls_list.append(KNeighborsClassifier(n_neighbors = k)) | |||
| self.cls_list.append(KNeighborsClassifier(n_neighbors=k)) | |||
| class CNN(WABLBasicModel): | |||
| def __init__(self, base_model, code_len, label_lists, share = True): | |||
| def __init__(self, base_model, code_len, label_lists, share=True): | |||
| assert share == True, "Not implemented" | |||
| label_lists = [sorted(list(set(label_list))) for label_list in label_lists] | |||
| @@ -126,27 +137,28 @@ class CNN(WABLBasicModel): | |||
| if share: | |||
| self.cls_list.append(base_model) | |||
| def train(self, X, Y, n_epoch = 100): | |||
| #self.label_lists = [] | |||
| def train(self, X, Y, n_epoch=100): | |||
| # self.label_lists = [] | |||
| if self.share: | |||
| # 因为是同一个分类器,所以只需要把数据放在一起,然后训练其中任意一个即可 | |||
| data_X, _ = merge_data(X) | |||
| data_Y, _ = merge_data(Y) | |||
| self.cls_list[0].fit(X = data_X, y = data_Y, n_epoch = n_epoch) | |||
| #self.label_lists = [sorted(list(set(data_Y)))] * self.code_len | |||
| self.cls_list[0].fit(X=data_X, y=data_Y, n_epoch=n_epoch) | |||
| # self.label_lists = [sorted(list(set(data_Y)))] * self.code_len | |||
| else: | |||
| for i in range(self.code_len): | |||
| data_X = get_part_data(X, i) | |||
| data_Y = get_part_data(Y, i) | |||
| self.cls_list[i].fit(data_X, data_Y) | |||
| #self.label_lists.append(sorted(list(set(data_Y)))) | |||
| # self.label_lists.append(sorted(list(set(data_Y)))) | |||
| if __name__ == "__main__": | |||
| #data_path = "utils/hamming_data/generated_data/hamming_7_3_0.20.pk" | |||
| # data_path = "utils/hamming_data/generated_data/hamming_7_3_0.20.pk" | |||
| data_path = "datasets/generated_data/0_code_7_2_0.00.pk" | |||
| codes, data, labels = pk.load(open(data_path, "rb")) | |||
| cls = KNN(7, False, k = 3) | |||
| cls = KNN(7, False, k=3) | |||
| cls.train(data, labels) | |||
| print(cls.valid(data, labels)) | |||
| for res in cls.predict_proba(data): | |||
| @@ -157,4 +169,3 @@ if __name__ == "__main__": | |||
| print(res) | |||
| break | |||
| print("Trained") | |||
| @@ -1,14 +1,14 @@ | |||
| # coding: utf-8 | |||
| #================================================================# | |||
| # ================================================================# | |||
| # Copyright (C) 2020 Freecss All rights reserved. | |||
| # | |||
| # | |||
| # File Name :plog.py | |||
| # Author :freecss | |||
| # Email :karlfreecss@gmail.com | |||
| # Created Date :2020/10/23 | |||
| # Description : | |||
| # | |||
| #================================================================# | |||
| # ================================================================# | |||
| import time | |||
| import logging | |||
| @@ -19,13 +19,14 @@ import functools | |||
| global recorder | |||
| recorder = None | |||
| class ResultRecorder: | |||
| def __init__(self): | |||
| logging.basicConfig(level=logging.DEBUG, filemode='a') | |||
| logging.basicConfig(level=logging.DEBUG, filemode="a") | |||
| self.result = {} | |||
| self.set_savefile() | |||
| logging.info("===========================================================") | |||
| logging.info("============= Result Recorder Version: 0.03 ===============") | |||
| logging.info("===========================================================\n") | |||
| @@ -33,25 +34,25 @@ class ResultRecorder: | |||
| pass | |||
| def set_savefile(self): | |||
| local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) | |||
| local_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime()) | |||
| save_dir = os.path.join("results", local_time) | |||
| if not os.path.exists(save_dir): | |||
| os.makedirs(save_dir) | |||
| save_file_path = os.path.join(save_dir, "result.pk") | |||
| save_file = open(save_file_path, "wb") | |||
| self.save_dir = save_dir | |||
| self.save_file = save_file | |||
| filename = os.path.join(save_dir, "log.txt") | |||
| file_handler = logging.FileHandler(filename) | |||
| file_handler.setLevel(logging.DEBUG) | |||
| formatter = logging.Formatter('%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') | |||
| formatter = logging.Formatter("%(asctime)s - %(levelname)s: %(message)s") | |||
| file_handler.setFormatter(formatter) | |||
| logging.getLogger().addHandler(file_handler) | |||
| def print(self, *argv, screen = False): | |||
| def print(self, *argv, screen=False): | |||
| info = "" | |||
| for data in argv: | |||
| info += str(data) | |||
| @@ -62,9 +63,8 @@ class ResultRecorder: | |||
| def print_result(self, *argv): | |||
| for data in argv: | |||
| info = "#Result# %s" % str(data) | |||
| #print(info) | |||
| logging.info(info) | |||
| def store(self, *argv): | |||
| for data in argv: | |||
| if data.find(":") < 0: | |||
| @@ -81,11 +81,11 @@ class ResultRecorder: | |||
| self.result[label].append(data) | |||
| def write_kv(self, label, data): | |||
| self.print_result({label : data}) | |||
| #self.print_result(label + ":" + str(data)) | |||
| self.print_result({label: data}) | |||
| # self.print_result(label + ":" + str(data)) | |||
| self.store_kv(label, data) | |||
| def dump(self, save_file = None): | |||
| def dump(self, save_file=None): | |||
| if save_file is None: | |||
| save_file = self.save_file | |||
| pk.dump(self.result, save_file) | |||
| @@ -104,38 +104,43 @@ class ResultRecorder: | |||
| self.write_kv("func:", context) | |||
| return result | |||
| return clocked | |||
| def __del__(self): | |||
| self.dump() | |||
| def clocker(*argv): | |||
| global recorder | |||
| if recorder is None: | |||
| recorder = ResultRecorder() | |||
| return recorder.clock(*argv) | |||
| def INFO(*argv, screen = False): | |||
| def INFO(*argv, screen=False): | |||
| global recorder | |||
| if recorder is None: | |||
| recorder = ResultRecorder() | |||
| return recorder.print(*argv, screen = screen) | |||
| return recorder.print(*argv, screen=screen) | |||
| def DEBUG(*argv, screen = False): | |||
| def DEBUG(*argv, screen=False): | |||
| global recorder | |||
| if recorder is None: | |||
| recorder = ResultRecorder() | |||
| return recorder.print(*argv, screen = screen) | |||
| return recorder.print(*argv, screen=screen) | |||
| def logger(): | |||
| global recorder | |||
| if recorder is None: | |||
| recorder = ResultRecorder() | |||
| return recorder | |||
| if __name__ == "__main__": | |||
| recorder = ResultRecorder() | |||
| recorder.write_kv("test", 1) | |||
| recorder.set_savefile(pk_dir = "haha") | |||
| recorder.set_savefile(pk_dir="haha") | |||
| recorder.write_kv("test", 1) | |||
| @@ -1,9 +1,17 @@ | |||
| import numpy as np | |||
| from utils.plog import INFO | |||
| from collections import OrderedDict | |||
| # for multiple predictions, modify from `learn_add.py` | |||
| def flatten(l): | |||
| return [item for sublist in l for item in flatten(sublist)] if isinstance(l, list) else [l] | |||
| return ( | |||
| [item for sublist in l for item in flatten(sublist)] | |||
| if isinstance(l, list) | |||
| else [l] | |||
| ) | |||
| # for multiple predictions, modify from `learn_add.py` | |||
| def reform_idx(flatten_pred_res, save_pred_res): | |||
| re = [] | |||
| @@ -18,10 +26,25 @@ def reform_idx(flatten_pred_res, save_pred_res): | |||
| i = i + j | |||
| return re | |||
| def block_sample(X, Z, Y, sample_num, epoch_idx): | |||
| part_num = len(X) // sample_num | |||
| if part_num == 0: | |||
| part_num = 1 | |||
| seg_idx = epoch_idx % part_num | |||
| INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X)) | |||
| X = X[sample_num * seg_idx : sample_num * (seg_idx + 1)] | |||
| Z = Z[sample_num * seg_idx : sample_num * (seg_idx + 1)] | |||
| Y = Y[sample_num * seg_idx : sample_num * (seg_idx + 1)] | |||
| return X, Z, Y | |||
| def hamming_dist(A, B): | |||
| B = np.array(B) | |||
| A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) | |||
| return np.sum(A != B, axis = 1) | |||
| A = np.expand_dims(A, axis=0).repeat(axis=0, repeats=(len(B))) | |||
| return np.sum(A != B, axis=1) | |||
| def confidence_dist(A, B): | |||
| B = np.array(B) | |||
| @@ -29,7 +52,16 @@ def confidence_dist(A, B): | |||
| A = np.expand_dims(A, axis=0) | |||
| A = A.repeat(axis=0, repeats=(len(B))) | |||
| rows = np.array(range(len(B))) | |||
| rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) | |||
| rows = np.expand_dims(rows, axis=1).repeat(axis=1, repeats=len(B[0])) | |||
| cols = np.array(range(len(B[0]))) | |||
| cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) | |||
| return 1 - np.prod(A[rows, cols, B], axis = 1) | |||
| cols = np.expand_dims(cols, axis=0).repeat(axis=0, repeats=len(B)) | |||
| return 1 - np.prod(A[rows, cols, B], axis=1) | |||
| def copy_state_dict(state_dict): | |||
| new_state_dict = OrderedDict() | |||
| for k, v in state_dict.items(): | |||
| if k.startswith('base_model'): | |||
| name = ".".join(k.split(".")[1:]) | |||
| new_state_dict[name] = v | |||
| return new_state_dict | |||