diff --git a/abl/utils/cache.py b/abl/utils/cache.py index 418c93f..ff8ad1a 100644 --- a/abl/utils/cache.py +++ b/abl/utils/cache.py @@ -1,9 +1,9 @@ import pickle -from os import PathLike -from typing import Callable, Generic, Hashable, TypeVar, Union +import os +import os.path as osp +from typing import Callable, Generic, TypeVar -from .logger import print_log -from .utils import to_hashable +from .logger import print_log, ABLLogger K = TypeVar("K") T = TypeVar("T") @@ -98,6 +98,14 @@ class Cache(Generic[K, T]): last[NEXT] = self.root[PREV] = self.cache_dict[cache_key] = link if isinstance(self.max_size, int): self.full = len(self.cache_dict) >= self.max_size + if self.full: + log_dir = ABLLogger.get_current_instance().log_dir + cache_dir = osp.join(log_dir, "cache") + os.makedirs(cache_dir, exist_ok=True) + cache_path = osp.join(cache_dir, "cache.pth") + with open(cache_path, "wb") as file: + pickle.dump(self.cache_dict, file, protocol=pickle.HIGHEST_PROTOCOL) + print_log(f"Cache will be saved to {cache_path}", logger="current") return result diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index fc06d34..e94a347 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -6,6 +6,8 @@ "metadata": {}, "outputs": [], "source": [ + "import os.path as osp\n", + "\n", "import torch.nn as nn\n", "import torch\n", "\n", @@ -14,7 +16,7 @@ "from abl.learning import BasicNN, ABLModel\n", "from abl.bridge import SimpleBridge\n", "from abl.evaluation import SymbolMetric\n", - "from abl.utils import ABLLogger\n", + "from abl.utils import ABLLogger, print_log\n", "\n", "from examples.models.nn import LeNet5\n", "from examples.mnist_add.datasets.get_mnist_add import get_mnist_add" @@ -24,10 +26,22 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11/15 21:35:55 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Abductive Learning on the MNIST Add example.\n" + ] + } + ], "source": [ "# Initialize logger\n", - "logger = ABLLogger.get_instance(\"abl\")" + "print_log(\"Abductive Learning on the MNIST Add example.\", logger=\"current\")\n", + "\n", + "# Retrieve the directory of the Log file and define the directory for saving the model weights.\n", + "log_dir = ABLLogger.get_current_instance().log_dir\n", + "weights_dir = osp.join(log_dir, \"weights\")" ] }, { @@ -46,13 +60,10 @@ "source": [ "# Initialize knowledge base and abducer\n", "class add_KB(KBBase):\n", - " def __init__(self, pseudo_label_list=list(range(10)), max_err=0, use_cache=True):\n", - " super().__init__(pseudo_label_list, max_err, use_cache)\n", - "\n", " def logic_forward(self, nums):\n", " return sum(nums)\n", "\n", - "kb = add_KB()\n", + "kb = add_KB(pseudo_label_list=list(range(10)))\n", "\n", "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", "abducer = ReasonerBase(kb, dist_func=\"confidence\")" @@ -92,7 +103,6 @@ " criterion,\n", " optimizer,\n", " device,\n", - " save_interval=1,\n", " batch_size=32,\n", " num_epochs=1,\n", ")" @@ -182,45 +192,57 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "11/15 13:36:00 - abl - WARNING - Transform used in the training phase will be used in prediction.\n" - ] - }, - { - "ename": "TypeError", - "evalue": "Input must be of type list.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/home/huwc/ABL-Package/examples/mnist_add/mnist_add_example.ipynb 单元格 17\u001b[0m line \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m bridge\u001b[39m.\u001b[39;49mtrain(train_data, loops\u001b[39m=\u001b[39;49m\u001b[39m5\u001b[39;49m, segment_size\u001b[39m=\u001b[39;49m\u001b[39m10000\u001b[39;49m)\n\u001b[1;32m 2\u001b[0m bridge\u001b[39m.\u001b[39mtest(test_data)\n", - "File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:92\u001b[0m, in \u001b[0;36mSimpleBridge.train\u001b[0;34m(self, train_data, loops, segment_size, eval_interval, save_interval, save_dir)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpredict(sub_data_samples)\n\u001b[1;32m 91\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39midx_to_pseudo_label(sub_data_samples)\n\u001b[0;32m---> 92\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce_pseudo_label(sub_data_samples)\n\u001b[1;32m 93\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpseudo_label_to_idx(sub_data_samples)\n\u001b[1;32m 94\u001b[0m loss \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mtrain(sub_data_samples)\n", - "File \u001b[0;32m~/ABL-Package/abl/bridge/simple_bridge.py:36\u001b[0m, in \u001b[0;36mSimpleBridge.abduce_pseudo_label\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce_pseudo_label\u001b[39m(\n\u001b[1;32m 31\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 32\u001b[0m data_samples: ListData,\n\u001b[1;32m 33\u001b[0m max_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m,\n\u001b[1;32m 34\u001b[0m require_more_revision: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m,\n\u001b[1;32m 35\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[List[Any]]:\n\u001b[0;32m---> 36\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabducer\u001b[39m.\u001b[39;49mbatch_abduce(data_samples, max_revision, require_more_revision)\n\u001b[1;32m 37\u001b[0m \u001b[39mreturn\u001b[39;00m data_samples[\u001b[39m\"\u001b[39m\u001b[39mabduced_pseudo_label\u001b[39m\u001b[39m\"\u001b[39m]\n", - "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:246\u001b[0m, in \u001b[0;36mReasonerBase.batch_abduce\u001b[0;34m(self, data_samples, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 247\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n", - "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:247\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_abduce\u001b[39m(\n\u001b[1;32m 240\u001b[0m \u001b[39mself\u001b[39m, data_samples, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 241\u001b[0m ):\n\u001b[1;32m 242\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 243\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data in batches.\u001b[39;00m\n\u001b[1;32m 244\u001b[0m \u001b[39m For detailed information, refer to `abduce`.\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[1;32m 246\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[0;32m--> 247\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mabduce(data_sample, max_revision, require_more_revision)\n\u001b[1;32m 248\u001b[0m \u001b[39mfor\u001b[39;00m data_sample \u001b[39min\u001b[39;00m data_samples\n\u001b[1;32m 249\u001b[0m ]\n", - "File \u001b[0;32m~/ABL-Package/abl/reasoning/reasoner.py:222\u001b[0m, in \u001b[0;36mReasonerBase.abduce\u001b[0;34m(self, pred_prob, pred_pseudo_label, y, max_revision, require_more_revision)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mabduce\u001b[39m(\n\u001b[1;32m 194\u001b[0m \u001b[39mself\u001b[39m, pred_prob, pred_pseudo_label, y, max_revision\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, require_more_revision\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m\n\u001b[1;32m 195\u001b[0m ):\n\u001b[1;32m 196\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 197\u001b[0m \u001b[39m Perform abductive reasoning on the given prediction data.\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[39m knowledge base.\u001b[39;00m\n\u001b[1;32m 221\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 222\u001b[0m symbol_num \u001b[39m=\u001b[39m \u001b[39mlen\u001b[39m(flatten(pred_pseudo_label))\n\u001b[1;32m 223\u001b[0m max_revision_num \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_max_revision_num(max_revision, symbol_num)\n\u001b[1;32m 225\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39muse_zoopt:\n", - "File \u001b[0;32m~/ABL-Package/abl/utils/utils.py:26\u001b[0m, in \u001b[0;36mflatten\u001b[0;34m(nested_list)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[39mFlattens a nested list.\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[39m If the input object is not a list.\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list, \u001b[39mlist\u001b[39m):\n\u001b[0;32m---> 26\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInput must be of type list.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 28\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m nested_list \u001b[39mor\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(nested_list[\u001b[39m0\u001b[39m], (\u001b[39mlist\u001b[39m, \u001b[39mtuple\u001b[39m)):\n\u001b[1;32m 29\u001b[0m \u001b[39mreturn\u001b[39;00m nested_list\n", - "\u001b[0;31mTypeError\u001b[0m: Input must be of type list." + "11/15 21:36:18 - abl - \u001b[5m\u001b[4m\u001b[33mWARNING\u001b[0m - Transform used in the training phase will be used in prediction.\n", + "11/15 21:36:21 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [1/3] model loss is 1.80390\n", + "11/15 21:36:24 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [2/3] model loss is 1.41898\n", + "11/15 21:36:26 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [1/5] segment(train) [3/3] model loss is 1.08221\n", + "11/15 21:36:26 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [1]\n", + "11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.590 \n", + "11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [1]\n", + "11/15 21:36:27 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_1.pth\n", + "11/15 21:36:29 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [1/3] model loss is 0.65210\n", + "11/15 21:36:31 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [2/3] model loss is 0.13546\n", + "11/15 21:36:32 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [2/5] segment(train) [3/3] model loss is 0.08060\n", + "11/15 21:36:32 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [2]\n", + "11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.982 \n", + "11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [2]\n", + "11/15 21:36:34 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_2.pth\n", + "11/15 21:36:35 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [1/3] model loss is 0.06446\n", + "11/15 21:36:37 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [2/3] model loss is 0.05224\n", + "11/15 21:36:39 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [3/5] segment(train) [3/3] model loss is 0.05119\n", + "11/15 21:36:39 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [3]\n", + "11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.989 \n", + "11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [3]\n", + "11/15 21:36:40 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_3.pth\n", + "11/15 21:36:42 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [1/3] model loss is 0.04667\n", + "11/15 21:36:44 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [2/3] model loss is 0.04027\n", + "11/15 21:36:45 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [4/5] segment(train) [3/3] model loss is 0.03672\n", + "11/15 21:36:45 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [4]\n", + "11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.990 \n", + "11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [4]\n", + "11/15 21:36:46 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_4.pth\n", + "11/15 21:36:48 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [1/3] model loss is 0.03381\n", + "11/15 21:36:50 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [2/3] model loss is 0.03333\n", + "11/15 21:36:52 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - loop(train) [5/5] segment(train) [3/3] model loss is 0.03195\n", + "11/15 21:36:52 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation start: loop(val) [5]\n", + "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.992 \n", + "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Saving model: loop(save) [5]\n", + "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_5.pth\n", + "11/15 21:36:53 - abl - \u001b[4m\u001b[37mINFO\u001b[0m - Evaluation ended, mnist_add/character_accuracy: 0.988 \n" ] } ], "source": [ - "bridge.train(train_data, loops=5, segment_size=10000)\n", + "bridge.train(train_data, loops=5, segment_size=10000, save_interval=1, save_dir=weights_dir)\n", "bridge.test(test_data)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -239,7 +261,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" }, "orig_nbformat": 4, "vscode": {