Browse Source

[ENH] add log in examples

pull/1/head
troyyyyy 2 years ago
parent
commit
95c52ec966
4 changed files with 48 additions and 14 deletions
  1. +11
    -3
      examples/hed/main.py
  2. +11
    -3
      examples/hwf/main.py
  3. +10
    -3
      examples/mnist_add/main.py
  4. +16
    -5
      examples/zoo/main.py

+ 11
- 3
examples/hed/main.py View File

@@ -46,13 +46,20 @@ def main():
)

args = parser.parse_args()
# Build logger
print_log("Abductive Learning on the HED example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
total_train_data = get_dataset(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_dataset(train=False)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
cls = SymbolNet(num_classes=4)
loss_fn = nn.CrossEntropyLoss()
@@ -75,6 +82,8 @@ def main():
model = ABLModel(base_model)

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
kb = HedKB()

@@ -82,14 +91,13 @@ def main():
reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision)

### Building Evaluation Metrics
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [ConsistencyMetric(kb=kb)]

### Bridge Learning and Reasoning
print_log("Bridge Learning and Reasoning.", logger="current")
bridge = HedBridge(model, reasoner, metric_list)

# Build logger
print_log("Abductive Learning on the HED example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")


+ 11
- 3
examples/hwf/main.py View File

@@ -113,12 +113,19 @@ def main():
)

args = parser.parse_args()
# Build logger
print_log("Abductive Learning on the HWF example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
train_data = get_dataset(train=True, get_pseudo_label=True)
test_data = get_dataset(train=False, get_pseudo_label=True)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
cls = SymbolNet(num_classes=13, image_size=(45, 45, 1))
loss_fn = nn.CrossEntropyLoss()
@@ -140,6 +147,8 @@ def main():
model = ABLModel(base_model)

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
if args.ground:
kb = HwfGroundKB()
@@ -152,14 +161,13 @@ def main():
)

### Building Evaluation Metrics
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")]

### Bridge Learning and Reasoning
print_log("Bridge Learning and Reasoning.", logger="current")
bridge = SimpleBridge(model, reasoner, metric_list)

# Build logger
print_log("Abductive Learning on the HWF example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")


+ 10
- 3
examples/mnist_add/main.py View File

@@ -78,11 +78,17 @@ def main():

args = parser.parse_args()

# Build logger
print_log("Abductive Learning on the MNIST Addition example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
train_data = get_dataset(train=True, get_pseudo_label=True)
test_data = get_dataset(train=False, get_pseudo_label=True)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2)
@@ -112,6 +118,8 @@ def main():
model = ABLModel(base_model)

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
if args.prolog:
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")
@@ -126,14 +134,13 @@ def main():
)

### Building Evaluation Metrics
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]

### Bridge Learning and Reasoning
print_log("Bridge Learning and Reasoning.", logger="current")
bridge = SimpleBridge(model, reasoner, metric_list)

# Build logger
print_log("Abductive Learning on the MNIST Addition example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")


+ 16
- 5
examples/zoo/main.py View File

@@ -30,8 +30,13 @@ def main():
"--loops", type=int, default=3, help="number of loop iterations (default : 3)"
)
args = parser.parse_args()
# Build logger
print_log("Abductive Learning on the ZOO example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
X, y = load_and_preprocess_dataset(dataset_id=62)
X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)
label_data = transform_tab_data(X_label, y_label)
@@ -39,12 +44,17 @@ def main():
train_data = transform_tab_data(X_unlabel, y_unlabel)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build base model
base_model = RandomForestClassifier()

# Build ABLModel
model = ABLModel(base_model)

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
kb = ZooKB()
@@ -52,16 +62,17 @@ def main():
reasoner = Reasoner(kb, dist_func=consitency)

### Building Evaluation Metrics
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")]
# Build logger
print_log("Abductive Learning on the ZOO example.", logger="current")
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
### Bridging learning and reasoning
print_log("Bridge Learning and Reasoning.", logger="current")
bridge = SimpleBridge(model, reasoner, metric_list)
# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
# Performing training and testing
print_log("------- Use labeled data to pretrain the model -----------", logger="current")
base_model.fit(X_label, y_label)


Loading…
Cancel
Save