Browse Source

Merge branch 'parrial_ab_data' of https://github.com/AbductiveLearning/ABL-Package into parrial_ab_data

pull/4/head
Gao Enhao 2 years ago
parent
commit
f1378c7e82
2 changed files with 3 additions and 31 deletions
  1. +3
    -3
      abl/evaluation/semantics_metric.py
  2. +0
    -28
      abl/utils/utils.py

+ 3
- 3
abl/evaluation/semantics_metric.py View File

@@ -10,10 +10,10 @@ class SemanticsMetric(BaseMetric):
self.kb = kb

def process(self, data_samples: Sequence[dict]) -> None:
pred_psedudo_label_list = data_samples.pred_pseudo_label
pred_pseudo_label_list = data_samples.pred_pseudo_label
y_list = data_samples.Y
for pred_psedudo_label, y in zip(pred_psedudo_label_list, y_list):
if self.kb._check_equal(self.kb.logic_forward(pred_psedudo_label), y):
for pred_pseudo_label, y in zip(pred_pseudo_label_list, y_list):
if self.kb._check_equal(self.kb.logic_forward(pred_pseudo_label), y):
self.results.append(1)
else:
self.results.append(0)


+ 0
- 28
abl/utils/utils.py View File

@@ -144,34 +144,6 @@ def block_sample(X, Z, Y, sample_num, seg_idx):
return (data[start_idx:end_idx] for data in (X, Z, Y))


def check_equal(a, b, max_err=0):
"""
Check whether two numbers a and b are equal within a maximum allowable error.

Parameters
----------
a, b : int or float
The numbers to compare.
max_err : int or float, optional
The maximum allowable absolute difference between a and b for them to be considered equal.
Default is 0, meaning the numbers must be exactly equal.

Returns
-------
bool
True if a and b are equal within the allowable error, False otherwise.

Raises
------
TypeError
If a or b are not of type int or float.
"""
if not (isinstance(a, (int, float)) and isinstance(b, (int, float))):
raise TypeError("Input values must be int or float.")

return abs(a - b) <= max_err


def to_hashable(x):
"""
Convert a nested list to a nested tuple so it is hashable.


Loading…
Cancel
Save