Browse Source

[MNT] enable save of cache_dict for abl_cache

pull/4/head
Gao Enhao 2 years ago
parent
commit
aae657a338
2 changed files with 70 additions and 40 deletions
  1. +12
    -4
      abl/utils/cache.py
  2. +58
    -36
      examples/mnist_add/mnist_add_example.ipynb

+ 12
- 4
abl/utils/cache.py View File

@@ -1,9 +1,9 @@
import pickle
from os import PathLike
from typing import Callable, Generic, Hashable, TypeVar, Union
import os
import os.path as osp
from typing import Callable, Generic, TypeVar

from .logger import print_log
from .utils import to_hashable
from .logger import print_log, ABLLogger

K = TypeVar("K")
T = TypeVar("T")
@@ -98,6 +98,14 @@ class Cache(Generic[K, T]):
last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link
if isinstance(self.max_size, int):
self.full = len(self.cache_dict) >= self.max_size
if self.full:
log_dir = ABLLogger.get_current_instance().log_dir
cache_dir = osp.join(log_dir, "cache")
os.makedirs(cache_dir, exist_ok=True)
cache_path = osp.join(cache_dir, "cache.pth")
with open(cache_path, "wb") as file:
pickle.dump(self.cache_dict, file, protocol=pickle.HIGHEST_PROTOCOL)
print_log(f"Cache will be saved to {cache_path}", logger="current")
return result




+ 58
- 36
examples/mnist_add/mnist_add_example.ipynb View File

@@ -6,6 +6,8 @@
"metadata": {},
"outputs": [],
"source": [
"import os.path as osp\n",
"\n",
"import torch.nn as nn\n",
"import torch\n",
"\n",
@@ -14,7 +16,7 @@
"from abl.learning import BasicNN, ABLModel\n",
"from abl.bridge import SimpleBridge\n",
"from abl.evaluation import SymbolMetric\n",
"from abl.utils import ABLLogger\n",
"from abl.utils import ABLLogger, print_log\n",
"\n",
"from examples.models.nn import LeNet5\n",
"from examples.mnist_add.datasets.get_mnist_add import get_mnist_add"
@@ -24,10 +26,22 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11/15 21:35:55 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Abductive Learning on the MNIST Add example.\n"
]
}
],
"source": [
"# Initialize logger\n",
"logger = ABLLogger.get_instance(\"abl\")"
"print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n",
"\n",
"# Retrieve the directory of the Log file and define the directory for saving the model weights.\n",
"log_dir = ABLLogger.get_current_instance().log_dir\n",
"weights_dir = osp.join(log_dir, \"weights\")"
]
},
{
@@ -46,13 +60,10 @@
"source": [
"# Initialize knowledge base and abducer\n",
"class add_KB(KBBase):\n",
" def __init__(self, pseudo_label_list=list(range(10)), max_err=0, use_cache=True):\n",
" super().__init__(pseudo_label_list, max_err, use_cache)\n",
"\n",
" def logic_forward(self, nums):\n",
" return sum(nums)\n",
"\n",
"kb = add_KB()\n",
"kb = add_KB(pseudo_label_list=list(range(10)))\n",
"\n",
"# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n",
"abducer = ReasonerBase(kb, dist_func=\"confidence\")"
@@ -92,7 +103,6 @@
" criterion,\n",
" optimizer,\n",
" device,\n",
" save_interval=1,\n",
" batch_size=32,\n",
" num_epochs=1,\n",
")"
@@ -182,45 +192,57 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11/15 13:36:00 - abl - WARNING - Transform used in the training phase will be used in prediction.\n"
]
},
{
"ename": "TypeError",
"evalue": "Input must be of type list.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/huwc/ABL-Package/examples/mnist_add/mnist_add_example.ipynb 单元格 17\u001b[0m line \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2B210.28.135.93/home/huwc/ABL-Package/examples/mnist_add/mnist_add_example.ipynb#X22sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m bridge\u001b[39m.\u001b[39;49mtrain(train_data, loops\u001b[39m=\u001b[39;49m\u001b[39m5\u001b[39;49m, segment_size\u001b[39m=\u001b[39;49m\u001b[39m10000\u001b[39;49m)\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2B210.28.135.93/home/huwc/ABL-Package/examples/mnist_add/mnist_add_example.ipynb#X22sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a>\u001b[0m bridge\u001b[39m.\u001b[39mtest(test_data)\n",
"File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:92\u001b[0m, in \u001b[0;36mSimpleBridge.train\u001b[0;34m(self, train_data, loops, segment_size, eval_interval, save_interval, save_dir)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpredict(sub_data_samples)\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39midx_to_pseudo_label(sub_data_samples)\n\u001b[0;32m---> 92\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce_pseudo_label(sub_data_samples)\n\u001b[1;32m 93\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpseudo_label_to_idx(sub_data_samples)\n\u001b[1;32m 94\u001b[0m loss \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mtrain(sub_data_samples)\n",
"File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:36\u001b[0m, in \u001b[0;36mSimpleBridge.abduce_pseudo_label\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce_pseudo_label\u001b[39m(\n\u001b[1;32m 31\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 32\u001b[0m data_samples: ListData,\n\u001b[1;32m 33\u001b[0m max_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m,\n\u001b[1;32m 34\u001b[0m require_more_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m,\n\u001b[1;32m 35\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[List[Any]]:\n\u001b[0;32m---> 36\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabducer\u001b[39m.\u001b[39;49mbatch_abduce(data_samples, max_revision, require_more_revision)\n\u001b[1;32m 37\u001b[0m \u001b[39mreturn\u001b[39;00m data_samples[\u001b[39m\"\u001b[39m\u001b[39mabduced_pseudo_label\u001b[39m\u001b[39m\"\u001b[39m]\n",
"File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:246\u001b[0m, in \u001b[0;36mReasonerBase.batch_abduce\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 247\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n",
"File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:247\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[0;32m--> 247\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n",
"File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:222\u001b[0m, in \u001b[0;36mReasonerBase.abduce\u001b[0;34m(self, pred_prob, pred_pseudo_label, y, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce\u001b[39m(\n\u001b[1;32m 194\u001b[0m \u001b[39mself\u001b[39m, pred_prob, pred_pseudo_label, y, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 195\u001b[0m ):\n\u001b[1;32m 196\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 197\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data.\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[39m knowledge base.\u001b[39;00m\n\u001b[1;32m 221\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 222\u001b[0m symbol_num \u001b[39m=\u001b[39m \u001b[39mlen\u001b[39m(flatten(pred_pseudo_label))\n\u001b[1;32m 223\u001b[0m max_revision_num \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_max_revision_num(max_revision, symbol_num)\n\u001b[1;32m 225\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39muse_zoopt:\n",
"File \u001b[0;32m~/ABL-Package/abl/utils/utils.py:26\u001b[0m, in \u001b[0;36mflatten\u001b[0;34m(nested_list)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[39mFlattens a nested list.\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[39m If the input object is not a list.\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list, \u001b[39mlist\u001b[39m):\n\u001b[0;32m---> 26\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput must be of type list.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 28\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m nested_list \u001b[39mor\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list[\u001b[39m0\u001b[39m], (\u001b[39mlist\u001b[39m, \u001b[39mtuple\u001b[39m)):\n\u001b[1;32m 29\u001b[0m \u001b[39mreturn\u001b[39;00m nested_list\n",
"\u001b[0;31mTypeError\u001b[0m: Input must be of type list."
"11/15 21:36:18 - abl - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - Transform used in the training phase will be used in prediction.\n",
"11/15 21:36:21 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [1/3] model loss is 1.80390\n",
"11/15 21:36:24 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [2/3] model loss is 1.41898\n",
"11/15 21:36:26 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [3/3] model loss is 1.08221\n",
"11/15 21:36:26 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [1]\n",
"11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.590 \n",
"11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [1]\n",
"11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_1.pth\n",
"11/15 21:36:29 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [1/3] model loss is 0.65210\n",
"11/15 21:36:31 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [2/3] model loss is 0.13546\n",
"11/15 21:36:32 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [3/3] model loss is 0.08060\n",
"11/15 21:36:32 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [2]\n",
"11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.982 \n",
"11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [2]\n",
"11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_2.pth\n",
"11/15 21:36:35 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [1/3] model loss is 0.06446\n",
"11/15 21:36:37 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [2/3] model loss is 0.05224\n",
"11/15 21:36:39 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [3/3] model loss is 0.05119\n",
"11/15 21:36:39 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [3]\n",
"11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.989 \n",
"11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [3]\n",
"11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_3.pth\n",
"11/15 21:36:42 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [1/3] model loss is 0.04667\n",
"11/15 21:36:44 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [2/3] model loss is 0.04027\n",
"11/15 21:36:45 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [3/3] model loss is 0.03672\n",
"11/15 21:36:45 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [4]\n",
"11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.990 \n",
"11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [4]\n",
"11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_4.pth\n",
"11/15 21:36:48 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [1/3] model loss is 0.03381\n",
"11/15 21:36:50 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [2/3] model loss is 0.03333\n",
"11/15 21:36:52 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [3/3] model loss is 0.03195\n",
"11/15 21:36:52 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [5]\n",
"11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.992 \n",
"11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [5]\n",
"11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_5.pth\n",
"11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.988 \n"
]
}
],
"source": [
"bridge.train(train_data, loops=5, segment_size=10000)\n",
"bridge.train(train_data, loops=5, segment_size=10000, save_interval=1, save_dir=weights_dir)\n",
"bridge.test(test_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -239,7 +261,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.16"
},
"orig_nbformat": 4,
"vscode": {


Loading…
Cancel
Save