From be0eaa63e369c6aa2ab440b1e5fc8590d0536137 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Mon, 25 Dec 2023 20:15:10 +0800 Subject: [PATCH] [MNT] remove import from examples --- examples/hed/hed.ipynb | 17 ++-- examples/hed/main.py | 16 ++-- examples/hed/models/nn.py | 62 ++++++++++++ examples/hwf/hwf.ipynb | 146 ++++++----------------------- examples/hwf/main.py | 6 +- examples/hwf/models/nn.py | 46 +++++++++ examples/mnist_add/main.py | 5 +- examples/mnist_add/mnist_add.ipynb | 11 ++- examples/mnist_add/models/nn.py | 94 +++++++++++++++++++ 9 files changed, 261 insertions(+), 142 deletions(-) create mode 100644 examples/hed/models/nn.py create mode 100644 examples/hwf/models/nn.py create mode 100644 examples/mnist_add/models/nn.py diff --git a/examples/hed/hed.ipynb b/examples/hed/hed.ipynb index 27bf086..8f9d898 100644 --- a/examples/hed/hed.ipynb +++ b/examples/hed/hed.ipynb @@ -22,14 +22,15 @@ "import torch\n", "import torch.nn as nn\n", "import matplotlib.pyplot as plt\n", - "from examples.hed.datasets import get_dataset, split_equation\n", - "from examples.models.nn import SymbolNet\n", + "\n", "from abl.learning import ABLModel, BasicNN\n", - "from examples.hed.reasoning import HedKB, HedReasoner\n", - "from abl.data.evaluation import SymbolAccuracy\n", - "from examples.hed.consistency_metric import ConsistencyMetric\n", "from abl.utils import ABLLogger, print_log\n", - "from examples.hed.bridge import HedBridge" + "\n", + "from bridge import HedBridge\n", + "from consistency_metric import ConsistencyMetric\n", + "from datasets import get_dataset, split_equation\n", + "from models.nn import SymbolNet\n", + "from reasoning import HedKB, HedReasoner" ] }, { @@ -382,7 +383,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -415,7 +416,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.18" }, "orig_nbformat": 4, "vscode": { diff --git a/examples/hed/main.py b/examples/hed/main.py index f1287db..b3952bd 100644 --- a/examples/hed/main.py +++ b/examples/hed/main.py @@ -4,13 +4,15 @@ import argparse import torch import torch.nn as nn -from examples.hed.datasets import get_dataset, split_equation -from examples.models.nn import SymbolNet from abl.learning import ABLModel, BasicNN -from examples.hed.reasoning import HedKB, HedReasoner from abl.data.evaluation import ReasoningMetric, SymbolAccuracy from abl.utils import ABLLogger, print_log -from examples.hed.bridge import HedBridge + +from bridge import HedBridge +from datasets import get_dataset, split_equation +from models.nn import SymbolNet +from reasoning import HedKB, HedReasoner + def main(): parser = argparse.ArgumentParser(description="Handwritten Equation Decipherment example") @@ -54,7 +56,7 @@ def main(): # Build necessary components for BasicNN cls = SymbolNet(num_classes=4) loss_fn = nn.CrossEntropyLoss() - optimizer = torch.optim.RMSprop(cls.parameters(), lr=args.lr, weight_decay=args.weight_deccay) + optimizer = torch.optim.RMSprop(cls.parameters(), lr=args.lr, weight_decay=args.weight_decay) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") @@ -63,7 +65,7 @@ def main(): cls, loss_fn, optimizer, - device, + device=device, batch_size=args.batch_size, num_epochs=args.epochs, stop_loss=None, @@ -81,7 +83,7 @@ def main(): ### Building Evaluation Metrics metric_list = [SymbolAccuracy(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")] - + ### Bridge Learning and Reasoning bridge = HedBridge(model, reasoner, metric_list) diff --git a/examples/hed/models/nn.py b/examples/hed/models/nn.py new file mode 100644 index 0000000..7aa9994 --- /dev/null +++ b/examples/hed/models/nn.py @@ -0,0 +1,62 @@ +# coding: utf-8 +# ================================================================# +# Copyright (C) 2021 Freecss All rights reserved. +# +# File Name :lenet5.py +# Author :freecss +# Email :karlfreecss@gmail.com +# Created Date :2021/03/03 +# Description : +# +# ================================================================# + + +import torch +from torch import nn + + +class SymbolNet(nn.Module): + def __init__(self, num_classes=4, image_size=(28, 28, 1)): + super(SymbolNet, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(1, 32, 5, stride=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(32, momentum=0.99, eps=0.001), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(32, 64, 5, padding=2, stride=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(64, momentum=0.99, eps=0.001), + ) + + num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) + self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) + self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) + self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + return x + + +class SymbolNetAutoencoder(nn.Module): + def __init__(self, num_classes=4, image_size=(28, 28, 1)): + super(SymbolNetAutoencoder, self).__init__() + self.base_model = SymbolNet(num_classes, image_size) + self.softmax = nn.Softmax(dim=1) + self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) + self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) + + def forward(self, x): + x = self.base_model(x) + # x = self.softmax(x) + x = self.fc1(x) + x = self.fc2(x) + return x diff --git a/examples/hwf/hwf.ipynb b/examples/hwf/hwf.ipynb index 67f6a47..2be427a 100644 --- a/examples/hwf/hwf.ipynb +++ b/examples/hwf/hwf.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -23,13 +23,15 @@ "import torch\n", "import torch.nn as nn\n", "import matplotlib.pyplot as plt\n", - "from examples.hwf.datasets import get_dataset\n", - "from examples.models.nn import SymbolNet\n", + "\n", "from abl.learning import ABLModel, BasicNN\n", "from abl.reasoning import KBBase, Reasoner\n", "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", "from abl.utils import ABLLogger, print_log\n", - "from abl.bridge import SimpleBridge" + "from abl.bridge import SimpleBridge\n", + "\n", + "from datasets import get_dataset\n", + "from models.nn import SymbolNet" ] }, { @@ -43,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -62,24 +64,9 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\n", - "\n", - "Length of X, gt_pseudo_label, Y in train_data: 10000, 10000, 10000\n", - "Length of X, gt_pseudo_label, Y in test_data: 2000, 2000, 2000\n", - "\n", - "X is a list, with each element being a list of Tensor.\n", - "gt_pseudo_label is a list, with each element being a list of str.\n", - "Y is a list, with each element being a int.\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n", "print()\n", @@ -110,55 +97,9 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "X in the 1001st data example (a list of images):\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAClCAYAAADBAf6NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAANrElEQVR4nO3d2XIjtw4A0J5kZFf+/2fHS5L7NC6I10BgqtWSxXOeZEu9aWGzAIL88e+///67AQDL+uPWJwAA3JbOAAAsTmcAABanMwAAi9MZAIDF6QwAwOJ0BgBgcToDALC4nzMbxXmKfvz40drmn3/++Xj8xx+P1QcZ522K70k2p9PPn/lbH9+rzr7+6zmAvV3aplfb//33358+F9vD9/f3s23++uuv9Llrydrq8fhZe59t/9nf1/ZYd2UA4Mt0BgBgcT9m1ibYO01w6xRCN1STXWv3euL79vr6mu4j22b0559/fnpMgK+o2sDo6ND1tuXt6djuxtD8UfeR+H5kqY1R93539H1RZAAAFqczAACLm6omqEa4ZymEKvR9i9RAN8Qer6EbSouvi9cajzljPKYKAmAPWRvcbd+jt7e3s79Pp9OXtt+285D/8/Pzx+Osbf3s7yNk1xNTBtt2/v6Oz332mnF/RxAZAIDF6QwAwOJ0BgBgce3Swm6529nOm6V4WQ7lmrIykHH8QHbd2ViCaptofG+yMpKqvKQ7OyFAV8zXjzPnxZLop6eni45TjR/4Lm1b1j5X7Xt2X6xmsj2CyAAALE5nAAAW104TxDBHLBUZQ/yd3d1DaWFUheKzUFYV4pqZobGjet/uOZQG3LdL26xLt9+jbPoWM9lm51ndRzJVuvkIIgMAsDidAQBY3NRCRWc7mAjv3HrU5HgO3RkVs7BNt5ogvm6cjTCbBbEKFVVVEAAzYptVpSY7FVDjc92R9Nlzt7hXzBjPs7OI0a3viyIDALA4nQEAWNzUQkVx4okqtNFd3/medEMznfTBuL9qcaROCmJmhCrAf8lC8VV7uOfkaLcOke8hXsM4WVPHra/5e9yhAYCr0RkAgMXpDADA4i4uLRw3z/LiVU7pFjNHRXuXrnQWNOrm/6vz+Y4lN8D9mWlLuov0dMYJ7NHO3VN7WL0H3bb+6OsRGQCAxekMAMDiptIE3UV6YsoglhneOoQDwH6ydn/b8tlRZ9LDl6aUuyWMM2ntuK/39/ezbcb35LPzGc8l7mOmVPGrRAYAYHE6AwCwuKk0QTdU05l56iv7A+A43ZlSo5nF62IaeTzOTFo5u6eMx+nMGthNLczc77qzPR7BnRcAFqczAACLaw9RzBbmqcJI1euybQA41swEN9k9oZt5jiH7bLT9tm3b29vbx+MYyu8uohRVx4n3qKo6oqN6DzqL0m3b8SkEd2EAWJzOAAAsTmcAABZ31RkIo9Pp9PH4169fZ88dMbsSAJ+bWTwnPo7t+zj7XmbvPHhWTnjpTH7j9WTjFrKSwW3L75nV7ffoUkORAQBYnM4AACyunSbohjliqCSWZFTb3NM61AAr65b8RbHdrmYTzNr6cZv4XFaKV4Xvs3TGuK/sWrv3pG7IP0tV3NO9T2QAABanMwAAi2sPreyGQ7KwUgyTjK/pjj4F4Lpi+zyG77NQehb6njlmtY+Yhh4rA+K5ZjPhjrJZcrsL8I3vz2/j+XdSGGOa4OgF/EQGAGBxOgMAsLjdJx16fX39ePz09PT5QSfWuwbgOroh6WxUfBU6z0LkWVh/fF13lH016c9X91Vtk92vXl5ePh4/Pz+n+zs6/N91P2cCANyEzgAALE5nAAAWd9WFirp5G2MGAO7P2DZnJeZV7rtTJlht31mAaNt6ixDFfY37y+5d1f2quwhS51pvvWiRyAAALE5nAAAWN7VQ0dkOivBFts0421QMoXQXyeiEXarZs7LznAnHzISeKt1trlWissf13NMCHJnxOqvQ4G9V2PQoe763M9dTfT+y191TCRWPb2axpag74+4j8QsFgMXpDADA4qaqCaJuKL48iSQkE8OM42nG42Sv6y5+Uc2e1Vl4qTOKdTzP7oxb1axa2ftWhbiy5+J1d6/n7e3t4/HpdEpfd2nIbk+XzmzWXQ/9Frqf9d7pnZn34DukkXgM3bZ6xdRAJDIAAIvTGQCAxV1cTTD+P4ZRs0WLxtHI8bkY0umGobMwbnfUc7avbZsbYZ49l43Kr/Z9qT2Oc82KiKNduiDLvZmZuKU7oUr23e2m0u4phcJaut+9+Nz4vf7t1pMBHUVkAAAWpzMAAIvTGQCAxU2VFnbzlN1ZC+PfcZxAzNVWpxnPJys5HM815ofi8cfrya51Jid+6RiIvWU5su7Mj1G39PMWM9HtOaNjdv3VNnu7tIw0K13dtrlruNfxIbBt9fezW8a9ApEBAFiczgAALG73GQhjmDELtVShyZga+PXr16ev+ezvr57rTBj70nBoTA2Ma3Ffqhvi6lzDd13A56uuOdvkUS6d1bJKi2Wfe0x3bdt5aXDXd5zt7bucJ/1UWvzuxu91bBvG38h3aNtmiAwAwOJ0BgBgcVMzEF46IrvadzZivwpDXzpavbq2zr67CwhdOlq1CmtfGrqK1zaGgbNZHTvpg227r9Hml36Px/TOLdIGM2mt+F2pfvKd39+YFnh5efn0ddX7a3ZCbqGaMbdKDUS3ro66lse5EgBgis4AACxuqpqgCt93JmupwurZ/8fjZOHIKvw4E+LORptWE7dcuiZ81J1EKeqOMM/MTCQ1iqHj5+fn1jZHqD6bmc/t1mmPamGuKPvNzvx+ZyoQ7iG9wpqq726nqmWPibm+g8e8KgCgTWcAABanMwAAi5saM9AtC5opwci2qfK72flUi+d0L3vPxSuyMsPx3LJtujnuvUtfstLP7DXVc7fIsc8csztD3q1npbvme2txI76rS8f/VG21MQMAwEPSGQCAxV28UNHeugtMxFDNTHkVc7LPp0phANzKzGyX3bYtlsx209AzC/gdQWQAABanMwAAi7vracC6M+k96sIR9ygbVV4tJNVdAARgD7FtqlIDnbR0FfKPM2lW7WHHrVOrWmYAWJzOAAAs7u7SBN0JIjqpgXHxn2oSI65HagA4Umxz4oj/mQWyuhPeVdvERe5i2iK+bkxnHD0xl1YaABanMwAAi9MZAIDF3d2YgW6eJOaEshzOTH6IvvgZjLmzbIGnW5fPAI+vex+YKS3sjFEbt4nn8Pr6+ukxb01kAAAWpzMAAIu7u4WKou5iD93FjeLfyt3mZGU61QyE0gTArcTw/RiWz9qpS2e1HRcdykoIY3tandsR3BEBYHE6AwCwuLtOE1SyxW9iaCXO+rRt23Y6na5/YgupFiCSJgDuTZXOzF43viarlOq2c9k2/3Wu1yYyAACL0xkAgMXdXZoghp7jqMtxUptqjerM0Qs/PKLueyg1ANyDqjLg5eXl4/Hz8/N/bv/ZPn7rto3dSoWj21CRAQBYnM4AACxOZwAAFnd3YwaibnlHzLtU+RhjBo6TfSYA96Ize+1M/r9bwpidy7adj4s7og0VGQCAxekMAMDi8oWe70AVWslKMrqlGo+sCnFlpZvRTCnNHWeb4CF0frvblv8uLc72/7J7Qvde0d1+ZuGjrE2tvgfZNp3r8e0AgMXpDADA4u66moC+zuIZ23ZZGKlSHUcVB+zr/f397O8Yes7C0HvM4kquu3BbproVV5UKexEZAIDF6QwAwOLuupqAOd0FhDJVODGGJ5+enj4ej6GrmdGzwLls5PjPn+dNdxY6npm4jTlVejaKbWj8HGdSqNVkRF/9TH0DAGBxOgMAsDidAQBYnNLCB9SdoSrmqGbyh9nCHrP7A+bE31t8PI4tiJT8HqfT7nbHel1rETitNAAsTmcAABantPBBxDBUTA1019WuwvrxuWzf3cWNgDnVbzT+rrPUgN/o/rrp0EtTslG1QNwlx/FtAIDF6QwAwOKkCR7cmBaIlQZZaHEMJ2apgapqobtwEtBTVe90RqhLCxyn+ny6iw5l6YDT6fTpfrctn92ww7cDABanMwAAizPp0AOqJhN5e3v7eBzDTTHkP7MISpVaAC4Xf6NxkbBtq3+Lv8Xf/rad//65XNa2bludUv2tqvaI7XhMBezZzooMAMDidAYAYHE6AwCwOGMGHlC1+EVWglTNWpiVI3YXzFDSBHOyWeSq0rVuftlCRZfL3sPuzK9RbGe37fzzytra6jidcQqRVhoAFqczAACLkya4spkwUjd8NzPLX/Zcd/EL4Da+GvbdtrpdqH7zmW57lp1Dtq9qm+6CO51zq9qyS9Mm1eeTXWu3rT7iNi0yAACL0xkAgMVJE9y5cYRplIUKu6GwbtjRqGO4jUtD5NlrRnHfVZXQTJpxxp5tzp7h/23LK6+6+87a3SqNcwSRAQBYnM4AACzuawse82VZGGkMCcXQUQzTxTDSuIBQlkLoLFpSeX19Pft7XBQFOEZsCy5Zq37b9qkmyGQTmI3PZefTTUdU+80mZapG5WcLtFXX0znP8Tjx72wyoSolfASRAQBYnM4AACxOZwAAFqe08Mqy3NVYJlTlz36LOcNty3NZ1ZiBbJtuCRNwnG75b2dBo7Gpz37n3YXOZmZE7LaH2TZRtU3WhnbHM3THU3Rvn1np5tvbW2ubI2j1AWBxOgMAsDhpggNVpUExXJSVFs7uO5OVM1Zlj93zAfYVf4dV+H5mkZ89VWWCWSlet0ww2++2nbehp9Pp43HVfnXKHrulhdV7Hc+h254e3e6KDADA4nQGAGBx0gQAsDiRAQBYnM4AACxOZwAAFqczAACL0xkAgMXpDADA4nQGAGBxOgMAsDidAQBY3P8AoI/85JZr3lUAAAAASUVORK5CYII=", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "gt_pseudo_label in the 1001st data example (a list of ground truth pseudo-labels): ['5', '-', '3']\n", - "Y in the 1001st data example (the computed result): 2\n", - "\n", - "X in the 3001st data example (a list of images):\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgQAAABpCAYAAABF9zs7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA6v0lEQVR4nO2deXwURfr/Pz1H7vuE3ISEBAhH5Ei4dV0MaGCRU0TBgxVUEHVV9LsCIqCrrsciiqKrooIox6oo93qsKIgKCqJAwpmEQO6QOzM99fsjv25rOt0zPcmcmXq/XrzIdFdXV9f51FPPU8URQggYDAaDwWB4NRpXJ4DBYDAYDIbrYQIBg8FgMBgMJhAwGAwGg8FgAgGDwWAwGAwwgYDBYDAYDAaYQMBgMBgMBgNMIGAwGAwGgwEmEDAYDAaDwYCbCgQmkwmW9ksihIAQgvr6eowdOxZRUVH44YcfxGdNJpOzkspgOBWe58HzfLvrQpuQ4+6774Zer8dHH31kFpbtSeZ4XJnXhJAO9YWEEKxduxYcx4n/XnzxRQek0PGcOXMG6enpZt+i1Wqh1WrNrtH39Ho9NBoNOI6DRqMR/6b/KV2397/rrrsOjY2NTssvndPeZGfOnTuHy5cvIzY2Fv369UNgYKCrk8RgOByO41SHLS0txaVLl+Dv74+BAwciPDy8XRhCiE1xejuEEJw7dw61tbUAAJ1Oh9TUVAQEBLg4ZZ2jrq4OZ86cESdjTU1NGDhwoHg/OjradYlzAEK9lwpqtghvfn5+6NmzJ3Q6xw2jPXv2hEbjvHm7UwUCnudBCBEzUJBebf1gk8mEp556Cnv27MHq1asxcuRIhIaGdiguBsMTsNZW5Ab1d955By+99BKWLFmCxx57DCEhIWZhpe2RYR2j0YgVK1Zgx44dAICoqChs2rQJWVlZYhh6QHGEsEUIAc/z4ixVGMSEWaUSlurQkSNHMHv2bDQ1NYHjOEyfPh27du0S4wwKCrL7d9gbIR+stRGTySRqA6QaZSFv6d9KJCcn4/3330f37t3tWs5CejiOg4+PD/z8/OwWtzWc2hNIM00pE6XX6VnMqVOnUFxcDF9fX/Tp0wfx8fGIiIhgsxxGl6Yj9buhoQHl5eXQ6XSKMzzWbmynpqYGZWVlANr6JoPBIBvOkXmrNm46nNyM+MqVKzh69CgKCwuRkZEhfktqaipiYmI8sn7Q44W/vz9ycnKQmJjY7r5WqxUFACG8kgAgvc9xHFJSUhAfH4+oqCi7p59+pzNxikAgfKBUcrP2wXKF8+qrr2L9+vV45ZVX8I9//AM+Pj7geV6c5bgyMxkMR6FUn9XUd2EtWZjtCc9otVr7J9RLEPJQTsUsN6GRu96Z99JlZ00zQL9fGq6goACzZs1C37598d5778Hf3x8AoNfrO51WZyN8n7DsodFo0K1bN6xbt061LYXQVgT7ALn79Hhmj6Vqaf1wZP2xhlvqCuWktd9++w0nT56En58frr32WiQlJSEoKMiqASKD0dWR6+SPHTsGrVaLyZMno0ePHi5KWddDo9EgNzdX/O3n54fDhw+jrKwMI0eObKdad+UExdKst76+Hvv370dRURGGDRuG3r17IzQ0FL6+vs5Opt2hhV6O41Tbd9BCnpKQZU34cwTOtPPhnHH8sbVGoeb+E088gWeffRbPP/887rjjDvj4+MiuFTEjKYa3Isxe/vWvf2Hx4sVYunQpFi9eDK1Wa9G2pqO2PN4Kz/NinpWXl+Mvf/kLampqsGvXLvTs2VMM54h8VdNXqun/CgoKMG7cOERHR2Pbtm2IioqCj4+P3dLpCqzZbljKO1qrQAtSStptZ0Jr92gcMdY5RUMgt3ZlKzzPo6WlBRzHwdfXV7X9AYPhLRQUFOB///sf6uvrMXv2bAwcOBA6nc5qm2BtxjYEtzWgTbVuNBphNBqdNnvsTLz19fXYvXs3SktLkZeXh7i4OAQHB0Ov13v8cqswzqhNPz3Qyj0jp2FxZh65ojyctmRg6aPkMl5qDCPct7ZexmB0BWztDAghOHDgAO655x7cd999eOONN8xmOFKhXK59MeyLI/JVjXBnaaZ85coVLFu2DEajEbt370ZSUpL4jGBcp7R+7glYU/PLzf4tCQXSOAWtjyPsb6TCjDXthCPKyC1sCCxV4B9++AFffvkltFotHn74YTPfWAajq2LJR9pS5yW4UdEGhLSxkrUNv4RwDHX4+/vjtttuw+XLl7FlyxZ069YNU6dOdYmbnhrjRpPJJLqb0vWIFgS6WvnT9b6oqAhbt25FYmIiJk6caNXl1pqAZe90Wvot4MhlcacKBEodjiWpa//+/Vi8eDGefPJJPPXUU6olM9a5MTwdpbVOAO06MiUDMrWdDP0sazOWofM4KCgIixYtwrlz55CXlwd/f3+MGzfO4QKBXFkJA76wC5/cM1L1Nz24KM1EPcUuy9LALXgfnDlzBsuWLcOYMWMwYcIEVQO+VEMt905naYOseUF0FqdrCOQy1ZLBBp3hloQB1pkxuiL0ZivCLE7K77//jg0bNkCr1WLVqlUYMmSIrPqR2RLYB7l8Cg8Px4MPPojKykqsXr0aycnJmDNnjtmmMkrGYR1BqR9Vir+pqQlvv/02iouLceuttyIqKgrh4eGyKmpbhEh3wtpgLW0/tKZECUsum45GrjwcvaTjEg2BtBCsSbNqMsGaLyeD4YnQ1s+AvPBcWFiI5557DjNnzsRbb71ltnsdjbXOj7WZjhMaGoq77roLR48exbhx45CWloabbrpJFAik/Zk9UBow5MI0NTVhw4YNKCoqwu7du9G7d+924ZylGncWcmMCLRQI3iLWtM6W2o29hDu5uJQESEeWjUt2KrSkghE4cOAA1q9fj5iYGLz22msYMmSIqrgZjK6EMLgLSDsPqQZBwJoqtKt1/q5AzYBMX7fnzE46UCj1pQaDAa+99hp+//13TJ06FbGxsYiNjTUL78pZsCORq/NpaWl46aWXUFZWhgcffBDDhg3DbbfdpnrAd4QmWikua66OjljKcapzpVyFkzNyIoTg1KlTePPNN9HQ0IC5c+ciOzvb5rgZzsFkMsFgMJj9o92weJ43u85Qj9IyAfDHtrkmkwl6vV5xpqOmzTFsQy4PhT5Ip9NBq9WK7oj0fXuhJi6e59Hc3Iw9e/Zg69atGD58OG6++WaEhYUppr8rITcmxMbGYs6cOcjMzMS7776Lr7/+2mp7kApezmo/rhjTXOplYMlogs50V2+YwuwTLPPhhx9i69atZjMOHx8fPProo0hNTcXKlStx+vRpEEKQkJCApUuXIiIiwsWp9nxOnDiBp556ChEREfj3v/+N1NRUs7ai1tCQYTtKeZiUlIS1a9eiqKgICxYswIABA/DQQw/ZfICUrX2OnLp5zZo1+O6775CXl4d58+YhPT0dgO1byHs6Slq1jgy4to5FjjTI7BJGhWqQWr96Gp6c9o5QUFBgdjJaU1MT9Ho9Zs6cifDwcOzbtw9HjhwRD6RqaWlxdZI9FrpeVVZW4tNPP8WYMWPw3HPPefxOc12B0NBQ5Ofn45tvvsGjjz6KhoYG1fvoC3R2Btra2orm5mYcPnwY//3vfzF37lyMHTu2U3F6Gt7WB9sLlwgESjMW+lhK6dopw32ZNWsWhg8fDo1GA5PJhFWrVuGbb77B0qVLERQUhFOnTiEqKgrPPvssMjMzER4ezhqsjXRGS6X0DG1/wLAdR2kOlTSn9D1Lv9etW4dPP/0UEyZMwO23347+/fvbNX3ujqBhFpbQ5IzyBBsbqX2OtbKU5ju9mZMcthjCuwNupyFoaWlBfX09jEYjoqKi3OIcblsLzJ0K2BmkpqaiR48eaGhoQENDA8LCwsDzPH7++WdwHIfQ0FDExcVhxIgRSEtLY4JeJ+F5HrW1tWhoaEBERARCQkI6FA8rB8dhax9gy+CgVG4NDQ1oamrC6dOn8euvv2LevHm4+uqrbUpHV0DNOr+Pjw+ioqKg1+tRUVGBgIAAVScX0uXUEXsCd58IOe34Y3omYsmq8tChQ1i8eDEGDx6MHTt2IDY21hlJtAv29DP2RNasWYOtW7fizJkz4rWQkBCsXr0a2dnZ4pnk3po/HUFuRlJYWIhFixYhIiIC7733Hrp37w6dTqfY2UgPbhHiZJqBzkFrNoE/8lOwg5Jiy2BgyY1aLg6O47BhwwasW7cOM2fOxGeffYbk5GRxExtvgjaupcuCvp6bm4vt27fj22+/xfXXX4+bb74ZDzzwgNW4hbyk46S1NMIYoEZj4I79oNvVlNraWhw5cgR1dXXIzs5GQkKCq5OkCjbbAi5cuIDDhw+juroaHMchNjYWPXr0wIABA5CVldUljlZ1B5qamnDs2DGUlJSgf//+6Nmzp1t2Lt6EtP37+fmJR7SfP38eFRUVDrNQr6urw7lz53Dp0iVUVVUhKioK2dnZCA8Pt/u7ugqhoaHIzs5GcHAwjhw5guLiYqttiOd5lJaWori4GK2trRZ3LpTiTO+EzuAUgUBJYhL2XZeGFaQuT8hA2vWLniF4E7RWRPh2Pz8/PPPMM9i2bRvS09PFPdRtNbDydqQaJ8G+Rmg7Qp7S4ej96oG2WY10JiPXQcm1R4YytOaTnn1yHId+/frh448/xvXXX4+JEyfiueeeE+/JxSFFKE+e58HzvGIaTCYTtm3bJhoN7t27F5MmTQLQVu5Km1R1ReS+UxhP5Fxy1SzT0BqGuro63HvvvZg5c6YoQEg3nNLpdO3GOpPJZOaGLbRPd8Qtti6W3vc06M7YUwQZe8NxHOLj45GVlQUACAgIQEZGBnr06OE1HZKjMRgMOHPmDAoLC2E0GlFfX4/jx48jLi4OPXr0UF336HDS9sg0DfbB398fPXv2xPHjx3H27FlcvnzZ7u+orq7GhQsXUFZWBj8/P8TExDAbHRsJCwtDv3794OPjg2PHjiE2NhYxMTHtwhHSdhpkUVERiouL8euvv8JoNCI1NRV6vd7qe5QMRd2tvTl1yYDuhIA/JFgBYbYjSGTullkC1qRCd023IyGEYOHChdi7dy92796Njz/+WNxMStAQCbNbhnroulZaWorbbrsN8+bNQ01Njbj73PLly0UVJmCuZRPioGf+Sho7Tz721pnY4scu1dzQs0Rh0yIhjFxccrNb4f1ffPEFbrjhBpSWlmLv3r245ZZbZPtObylTW+23OI7D2LFj8fnnnyMyMhLjxo3Dxo0bzcpJaCt0GyovL8ddd92FefPmoaqqyuo7NRqN2Rgh1di5Ey7xMpDLiNraWhw7dgyXLl1Cbm6uuImGO+OOEp4rEBpPaGgoQkNDmTubnTEYDDh+/DgKCgpQVFSEpqYmDB48GD4+PjCZTPD19cV3331ntvkNx3FITk4WDTkt4Y7uT12FyMhIjBgxAmFhYfj222+RmJiI1NRUuxgfNzU14fLlyzAajaLxtZJ2gJWtPAEBAfDz84NWq0VpaSnq6+sBWD7gSa/XIykpCZGRkfjxxx+RmJiIrKws1SfxujMuOctAQFhH0Wq1+P333zFr1iwMHToUH330kegC4o6Drrulx5XIlY/crmkAs2rvKDU1Nbjvvvvw888/o7GxEb1798a7774rDgJ79+7F9OnTYTAYRPsCQgiWLFmChx56qF180vIQZqq27qbnrSi1f7o/ExgyZAi2bduGLVu24MYbb8T8+fPx1FNPiVocKVIfejpuYbYqLPnQGh3BjoROHxPM5ZEKwGq8b2itWlRUFF599VUAwK233oqkpCRs2bJFnAzRcdP9Iy2suesY4jZbFxuNRly5cgWtra0ICQkRd11zdcY1NzfjwIEDIIRg2LBh8Pf3d2l63BE1QgGjc9TX14uzF61Wi+DgYHH/gcTERIwaNUocNISBoLGxEdu3bxfLZ9CgQYiLi2sXNysr+yCXj3q9HjqdDikpKRgzZgw0Gg22b9+OzMxMUQtqyY5DTntTVlaGQ4cO4fLly7j++uuRlZXV7hlWppbpTP5oNBoEBwcjICAAw4cPh1arxd69e5GYmIjBgwdbdDl097JxyfHHtGRGI127cQeqqqrwwAMPgOd5fPbZZ0hMTLSo7rNlfbErIFRyS3swyFmzA0xjoAYhT+m8EuxshHo2ePBgbNq0qd0M5Mknn8TUqVMBtA1M77zzjvibpiuoOl2BNL+lXkb0zHD06NEYOXIk1qxZg2nTpmH58uVYvHixWJbCmrJcnwj8UUaEEBw+fBi33norJkyYgI0bN8LHx0fWm8Bb+iAllPoZNXmipF0RtAkajQZJSUl4/fXXcejQIUyZMgW5ubn48MMP4efnpziOuXt5OH3JQDo4VFdXY+/evbh06RKmTJmC/v37u8VA0drain379qG4uBjDhw9HcHAwAgMDrTYyqdrOG7D2zVL1mTfljb2Qc88V0Gg0okaNnoFkZ2dj5syZYvizZ89i/fr1IKTt8KnrrrsOUVFRTvoC70Nq5CksybS0tMBgMIj31PQnQJsxm9BX3njjjcjJyYGvr2+7LXpZ+2pDqqq3lC99+vTBnDlzYDKZsH79euTk5CAjI6NdOKlxro+PD7p3744pU6bAz88PH3zwAdLS0jBixAjPLAfiYo4ePUri4+PJ6NGjSWVlJeF5XrxnMpmIyWRySbqqq6vJ8OHDSUJCAjl27JhZugSE9LkqjZ4Gy6uOUVZWRgYOHEgAEACkf//+pLS0VDYsnccmk4nwPE94nictLS1kxowZRKPREI1GQ6Kjo8kPP/zgzM/oMqitx0Le0zz//PMEAFmxYoVNcZpMJvLdd9+RyMhIMn78eNLQ0CDGz9qUZYxGIzEaje2uS/Oc53ny97//nWg0GvLaa6+1C19RUUEGDx5MEhMTye+//24WD8/z5PPPPycBAQHklltuEd/naWXjFlZEgruhnFuas6Usnuexfft2nD59GqNHj4afnx+io6M7dHgFw74QN19/cwRERg1sDTmjMp1Oh+uvv160IdBoNPjvf/+Lb775BoQQREZGYsqUKZ06O8RbyseW/R6kDBo0SFyCfPHFF3HttdfadPgQrQYXlgm8Ic87gyWNs1TTBti+qyDtNurOmw6pwS0EAgFHr7dIC1kufoPBgHXr1uHgwYPYsWMHcnNzFePrig2RyKx/0r+l2LJXuiW7CzVhvLHzIzJ7CNCGg5bKSfB11+v10Gq1mD17NmbPng0AqKiowPjx4/Hjjz8CaFOZXnvttbICgVydUPIk8ZbysUUwoxkzZgzGjBmDxx9/HA8++CDWrl2L/v37W41PKuTJlTctADL+wNpSpoAzduq01p+6Go7YIgp1ELlMqK6uxvr161FVVSX6dd50001O3+9e6NyELUCPHj0KrVaLgIAAzJw502POUugsts5CO9IhKTUGJSFEDndtSI6CEILy8nLk5eWJp0f269cPu3btQvfu3S0+R5cHPavlOA5NTU3YuHEjSktLxXD09qq9evXCTTfdBJ1O1+5ZpfdZus/4A0IIvvrqK3z77beiy+fUqVPFXT6llJWVYf369airq4Ner0daWhqmTp1qdYc8Via28+WXX2L//v3iLJ8ul8rKSowbNw5lZWXYvXs3MjMzzZ4tLCzEli1b0NLSAqPRiJycHOTn53tWOdh1AUIBnueJ0Wg0W9s8ffo0SUtLI1lZWYrrofbA0tqcsPZjMpmI0Wgks2bNIgEBAeTzzz+XtRnoyghlpAY636RxGAwGxfyWW1OVKx8hHFsjbUNqQ6CmzSjZvChdP3bsGImPjycACMdxZPz48aSurq6dPYK3l4U9oMtg/vz5RKfTkU2bNsm2BbpscnNzSVVVlSobBtp+hKEeId/mzZsnlouAYEOQlJRETpw4YfYMXR4bNmwgWq2WLFy4ULxPj3/ujFOWDOg1Fvq3sIWjoKpxhKrLkppT+L1582Z89dVX6Nu3L0aPHo2+fft6hjRnR2hVtPDbUlil63IHPClZP9PvIsT60aHehJpyEMLJhVFy/6TLhVbzx8XF4amnnhL3Oqirq8Ojjz4qtssxY8Zg6tSpsmUopMHb2kxHUdr4RtoOqqursWbNGtTU1OCRRx5BQkICAgICVJWBGm0b4w+keWiLx4Y0zJAhQ7BmzRpUVFTgnnvuwfXXX48bbrihnTbVHduL0wQCuWvCdXr/AUeiJBTs378f69atw7vvvoubb77Z4elwR2ztQCyVqXTgUdO4hI6QCQNtKAkEgvua0WgUNyISUNvBCO1AKCedToeIiAjRvoAQgl27dmHVqlVobm4Gx7Wd13HjjTe67R7snorgkkgIEdXUQtnU1dXhgw8+gEajwd69e2U3lRKwRTBkyCMVCOhysaVfSk9PR3p6OtasWYPly5cjNjYW48eP9wih2aW9r3BEqzsceiPYEXibRC0MxAJKg7qANI+kf9MDmbXZPt1A5A4/8rayoKHX/YV8iI6OxvPPP4977rkHS5YswcqVK8UBW01HQx8mRkj7Y3tpBgwYgLfeegvvvvsu3nnnHURFReGWW27B1q1bzcJJZ1QM69D1+s4778Tbb7+N77//HnfccQeOHz8uthtaY6dk7Cbct6R1YFiHznOgfbkcO3ZMDKvUNwr3BPLy8vD+++8DaNvieO/evW7fXlw2CtMV3VUZZDQa0djYCI1Gg4CAAFXHWLojTU1NHf7X0NCAhoYGNDY2orm52eIgLDQEa2E6WunpBqnmXV0dqdYmMDAQN9xwA3Jzc7Fnzx588cUXolGa0rNy162p+YUlhKlTp2LGjBmYPn06wsLCsH37dhw/fhxNTU1m73XXzs0TGDRoEKZNm4aSkhJs2bIFly9fBsdxaGlpQUtLC3x9feHn5yc74Eu1cKwcOgedf3LlohahXNLT0zFjxgwAwIcffoiCggL7JtgBuMTtUDordSSWBI5Nmzbh/fffxzXXXIPNmzejX79+HtmoJk2a1Ok4OI5DWFgYVq1ahZ49eyoOFJb2iZDms1TtLd1K1NI6uVKZqV1b7ypYyge6DVlbk6TvKw0uSkKCRqPBhAkT0KtXLxw4cACTJk3CPffcg4kTJ7YTWrylXDqDpTwymUyorq7GsmXLcOnSJfzf//0f4uPjERER0a6M1a5Js7JRBz1W0O1LTb5Zy2NP0T67TCBQ89teFViusRBCcPr0aezZswc33HADxo4d67ENRvAl7wj0ABAREYHS0lKEh4crhrfFGEZOINBqtQgNDe3wElFdXR0MBoNTyioiIsLh71CDRqNBWFgYAMvnDiiVjS0dkdKgk5ycjOTkZHz33Xdim2HYj8DAQISHh6OpqQnl5eU4duwYampqMGjQIKSlpXnEYOLJWMvfzvQ3/v7+iIiIAM/zqKqqQlBQkLjVuLvhFIFAmtn0WqajbQfUqtro9TpPEwx27tzZqedpI6YXXngBJSUlNj+rdJ02FOQ4Dj169MDq1asRHR0tXlMDIQQtLS1YunQpvv32W6fYnHz//fcOf4c1CCEICQnByy+/DJ7nERsbi9raWjHPpXXWkjeNpU6PFgzl2oFQlkpLQp7WZlyJtIy0Wi3+/ve/44477sC6devwr3/9C3fffTfS09Pb7YOi1Eex/O8ccvWZ1mRKD96Ty3/y//fyEAwSBW6++WaMHj0a27Ztw4QJE7BkyRLk5eW5ZZk5XUPA8zwqKytRXl6OmJgYBAUFOeW0Nbow6+vrUV1dDaDt6Njg4GCHv9+RDB061C7xVFZWoqmpCaWlpVbD2qopEBqRn58fiouL0dzcrOpZWrvQ0tKCixcv2rSe1xXQ6XTi5ihqtkV1hEuTsKQQEhKCxMREGI1GFBcXIyIiAgEBAXZ9l7chLNfxPI+zZ8/i999/R0pKCgYMGGAmTMsZ87rjoOINWDIqlJZJfHw84uLisGPHDly8eBElJSUoLi5GdHQ0/Pz8nJFc1Tjd7fDKlSu4++67UVlZiRUrViAlJUVUzTqyktMF9cknn2DVqlWYMWMGdu7cidjYWIe915MICwvD2rVr0dLSYjUsbfRHGwPKQdsOFBUVYcGCBaiqqjKLSwk6Xp1Oh8ceewwrVqxQ8zldAkGQoo/BpQXojswUra09yy2vcVyb6+GcOXOQn5+PN954A3l5eXjmmWeQn59v83d5M9L8NRqNWLlyJXbs2IHi4mL4+/uLM1S67cjZ59BI7UCkRyszOo6ctkzqJs1xnEXD9HvvvRczZ87ESy+9hJdffhmrV6/GqFGj3KpsXKIhOHv2LGpqapCQkICUlBSnuxzW1NTg1KlTAICMjAy3KhBXotVqkZKSojq82k1paEFPp9OJ57fTZ7jTBnLSRibEodPpkJiYKHssqbdhrzprazxRUVGIiopCa2srTp48iStXrtglHd4MIQQXL14U+yR/f39Vz7B+y7l0Nr9jYmIQFRWFhoYGnDp1StwEzJ1wmg2BMJOkkVqtK81M7IHUv11u4xzGH6jxCJBa5FoSDGjjtI0bN6o+EUx4v5CeqKioDnyN5yNtO2qNzOg21NF6bsk+wVneQl0NqTBNtyW6bG2ZLEnbH22rxVAHXS4CSvYF1vJWqRyltjjuhEtPO1TrzmGvTKuoqEBhYSGam5uRk5ODhIQEhwohno6t+WBNGADa1P7CoTxqDOIY7gNdXikpKcjJyUFNTQ2+//57ZGRkiJ4QDPWcPXsWFy9eREREBIYNGwaTyYTAwECcPXsWWq0WvXv3VqUxYNgXYbDu1asXLl26hIsXL4LneXFWb0s/5Ul9mlPER6kmwJbnlKzYBSxtYCO999VXXyE/Px8lJSXYsWMHbrnlFqvrct6KdM3S2uy/M+VL/2aYI6xVdrRuOipP58+fj48//hg//fQTJk+ejJ9++snsvlybJKRtG9iuplWg+xlL/RCN0GbWrFmDyZMnY/jw4fjss8+wc+dOvPHGG3jllVcwZ84cFBcXy77Tmt0Ow3aEMhL6Mq1Wi8WLF2PDhg3YvHkzbrrpJhQWFpqFlWLNG8HdcamGoKOoMYQSrhFCUFFRgSNHjqC4uBhDhw5FRkaGRV97BsNdEOqwtY7FEQa50mUg+u+AgAD4+PjAYDCguroaBoPBanye2EE6AkIICgoKcObMGeh0OgwdOhRJSUlin8TzPBobG3HlyhXVS2sM+yCtn4GBgdDpdOjXrx+am5vxww8/dPod7izMecQCE722Jud/TWeuMKOinz1y5AhmzJiBn376CR999BHmzp2r+B65dTwGw9nQa5nW7AekftL2er9cW6IRLNmVZsBynZ7QGXrKzm1qkNoBAPKaS/r3Bx98gGnTpiEpKQkffPABhg8fbhaf3N8MxyLnSQAAfn5+eOKJJ/D2228jLS2tXVgaue3X5d7jrkKBR2kILBlIWVqP5nkezc3N4Hke/v7+ivsedJUOitH1oOt3eHi4ePLgxx9/jOTkZIwYMcKuBmTSAU6aBqCt8xs5ciR0Oh1OnjyJhoYGXH311VYNP92xI7QVa1obeqmT7pNOnDiBn376CSaTCZMmTUJmZiYCAgJEjxuO4+Dj44Nx48ahtLQUX3/9NQoKCnDttdcyWwInQ9d/Hx8fhIaGYty4caitrVW1d43Scrc7Tzg9SiAQkBMGaD9t4I+1HFs6SeEZ5rfLcDVK1vxarRaJiYlYs2YN9u/fj8mTJ2P06NEYNmwYdDqdXZcOaMGZ9rumhfK7774bs2fPxvTp03Ho0CF89tlnZgJBV9xZT+0e93LfvHv3bjz00EN44okn8NZbb4knfNIaoaCgICxbtgznz59Hfn4+dDoddu/ezQQCJ2Cp/QQHB2Pp0qUghKg+CE9prHJXDZnHCARSzYCcilL6+9KlS9i+fTuqq6tx1113YdCgQe3c6KRxe3JHxei6yNVvjuNgNBrbrTM7og4rqTk5jmsnhHd1bxFbv43OD8GwEoDFiYdWqxUFPOnSDcMxKM3mgT/KXKfTmd2zlyeWu+AxAgHQfo2HLiy5PQ7Onz+PxYsXY9iwYdi2bZu4TaTcRjgdtZRnMByJ1PJZgJ5l2HtPDUt7TsiFpcNLBQJL6XGEIaQzUMoHpXBSeyd6bw1BmJLzuPG0fPFkhDFBbkdIQdMsLTs1AoHc/ixCXO5Yvh4lENAorW8CbfsNvPfee6irq8P999+Pnj17mkl2gPxueNK43LHAGN6FGuMlwL72L7bUe45r26715ptvxtChQ7Fnzx4cOHAAs2fPFg+wUvMeT9QqqHUv4zgOv/zyCz799FMQQrBs2TKMHj1aVuNiTfPJcAzSwV7OUFQaXm28wB/1W6PRYPLkyUhPT8f333+PX375BbNmzWp3iJWr8AiBQI2boaCG02g0KC8vx/PPP4+kpCTs3LnTbMMUa4O9dMbDYLgjtApfyXjJHsi1BWl7FASC6upq5Ofno7i4GHl5eYiKilLdhjxdILCEIBAsX74cixYtwgsvvGA1TunfTHvpeATvF1pToFS2ctowNeE4jsPkyZMxfvx4TJ48GQcPHsSYMWOYQGAL1hocbTxozdaAfkauk/O0TonR9VHqdFJTU/GPf/wD1dXVePTRRzF48GDMmjXLrnXY0uyIThe9lSutYpVCG+RJ1aiehtLSAZ0/v/76K9avXw8/Pz+88MILuOqqq1TFKRX0mA2BcxBsYqwtj9niNig3rrjrWON5rVCCVG0nPYWK53nFzT3oPQ2kmgF3LCyG92Fp5h8XF4f58+cjKysLa9euxb59+5w2cCitmQtW83KGcI7YL8GVKC1b0v3JmTNnsGbNGhQXF2PBggUYNWpUu/BKcdKDk9TuieEYrPX9agw8Pbl+e4SGAFA+WY/+u7KyEs888wzq6uqwbNkyNDU14ZFHHkGfPn2waNEiM1cRWpUj/JYzAmIwXIlQB6WHTSmFc0SdVTIylBIYGIglS5aguLgYr7zyCoKDg7F48WJERkaKz9BGdEqGXJ6GktZDuCb3tzUIIQgPD8fTTz+NixcvYvny5UhMTMTDDz+MgIAA+ySc0Q65cYYe4JX2sKFRo9E2Go1uKTh4jIZAzWYODQ0N+PTTT/HVV1/hmmuuwcCBA7Fr1y58++23opZASaqXsyxlMNwFV9VJW96r1+uRl5eHvLw8/O9//8Mnn3yChoYGszBy7c/T25og2Ei/q7W1FSaTCb6+vu381i19s5AnAQEBmDhxIkaPHo1du3Zh165daG1tddh3MNrXRzkjz44Ir5bidCc8QkNAS2zWDDg4jkNxcTHmzp2LlJQUrFmzBomJifDx8VEMLz2O0hPXMxldG61W65KOhF7LVjriWmrjIDV2lK6r09oMT9cOADDbWAho+67jx49j+fLliI6Oxvvvv4/k5GSLnk000rwT8p31S45FTjNgzRbNWlxyaDQaVZoGV+ARAgEgb+VMCEFDQ4M4+6+trYXJZEJTUxN++uknaLVaXHPNNQgLC1OlxvH0jsldaGxsVHXYDY1er2eqUAnSOqlUP/V6PUJDQ6HT6VBbWwt/f39xzw1Hpke4JoXjOAQHB8NgMKCurg51dXUIDg5WFAo8vd3R6RcOJiotLcXBgwcxcuRI5OXlyU5IaJSWZSwtRzDsj5IA0FGtgNJSnruWp0cIBHIqfgBoaWnB448/Lp5A1dLSgqKiIiQmJuKll15CamqqWUekBJO8O4bcGrDRaMTTTz+NL774wibp+rrrrsOSJUsUy4LeIMQbsGVGkpOTg08++QSHDh3ChAkTMH36dNx3333twtHbedvaGdHr/9LrUiIjI7F27VpcuHABjzzyCOLi4vDCCy8gJCRE8RlXo6T9UIvw3Pnz53H//fcjNDQU69evR/fu3Ts1G6RtSGjjaDXu0+6Yz65EKc/kyr6jeecI+zOlOB1Rxm4lEKj5wNraWtTX10Oj0aC5uRnnz5/HmTNnUFlZCZPJhMjISKSkpGDo0KHo3r27k1LOAICamhpcuXIFv/zyCw4cOIDIyEj4+vqaDW5GoxGVlZXtPD969Ojh1mtr7kx4eDhyc3Nx8uRJHDhwAEOHDlWczTsjj318fDBw4EAEBATg+PHjKC8vh9FoBCA/U+oKGI1GVFVVoaSkBOfOnUNaWhqGDBmC4OBgVf7qSnBc22FHsbGxCA0NxeXLl8HzPMLDwy0KA95IdXU1WlpaEB4eDl9fX1XPeGteKeEWAoHUjcNSw1m9ejU+/PBDcFzbDmkPP/ww/va3v2H+/PmoqqrCunXr0Lt3b6u7pDE6jzBrFDQFQtkUFxfD19cXTz/9NHJzc8XwghvW/PnzcfnyZZt8eul1VeF3V4aeGRJCFGeZPM9bXV+mXXHpeOjO0NH56e4dr2AHIDdbFPLP0tkDly5dwl133QUAWLt2LeLj48UlMDnvJak1u6X8T0pKwsaNG/Hbb79h9uzZuOqqq/DSSy+1G/Q6owHydIxGI5588kkcOHAAL7/8MoYMGSLek/PQcZRthtyyGP1OermbDi/cl16T++3IPtDpAoFWq0VKSgoqKipQXFwMnU6HuLi4dlsLA20fXlNTg/LycvF3VVWVONvQ6XRISkpCZmYm/P39odfrkZ6ejp49ezr1m7wdoYLzPA+j0Yhu3brB398fmZmZyMrKAtDWYC9cuCB2qsIzgYGBSEhIQLdu3Vz8FZ5PaGgoevXqJR5HHB0djcjIyE7NUL0JuqO2BcGjoKCgAAEBAUhLS0O3bt0U7Stsxc/PDxkZGaiqqkJBQYFYptI0KKXNG8qdEIKioiL89ttv7TxbXIWj8t6RZep0gSAsLAyvvvoqzp8/jwceeAAhISF47733EB0dLfuRW7ZswcqVK8Xff/vb37B3714AbRJfVFSU21QAb4bjOCxatAh33nkngD/KRqCsrAxz587FiRMnUFZWJkrtQ4cOxWuvvYbIyEhx9iq17lU7k+pK0N9Nz2CkGgF6xv/nP/8ZgwYNwqZNm3DdddfhkUcewYIFC6DVattpBmh7DEcKDMIOhtLDYNxV0yM3W6R3YZSDXoqRHjhlybBM7poljZCSwELPdmmNndI7uypyh30Bf5QprSmgbWLsUfelmiWlspbeF8rKksGuM+1EnC4QaDQadOvWDU1NTaioqEB9fb2ZD29NTQ1Onz4tZlJlZSViYmLE5+Pj45GUlGQWJxMI3IOIiAhERESYXeN5HidPnsSZM2dw9uxZlJaWAgCCgoKQlpaG/v37Izk5WfWan7ehds2dEIKgoCAEBQVBp9OhqKgINTU1Tkhh53DVDNZe7zUajThx4gQuXLiA9PR0hIaGtttzQA5rbodyBAYGIjs7GzExMTh8+DC6deuG1NRUxecsqa+7GhzHIS0tDeXl5SgpKcGRI0eQkZHhtp5LHMchPT0dDQ0NKCoqQkBAADIzM1V5BzmyLF1iQyDMcOT8Mb///nvcfvvt4gYct99+O3bu3ClKV/7+/uB53mydTLoOQ0uJ3mKV7kro2RFgnudXrlzBokWL8OOPP6Kurk4sq8zMTGzatAnR0dHthAE1M6mujpI2RK4+C8s19EyWtkGQs6CWi9sRs3Zh/V0uDbRtg7u0U7l1fuG3AP13TU0NFi5ciOrqarz11lvo2bNnOxdLJTsPS+UrZ9/Ru3dvbN26FV999RWmTZuGCRMm4NVXX23Xj1qygeiq6HQ6LF26FNXV1Zg3bx5Onz6NLVu2oF+/fmIYa7t8dgZb6i8hBHq9HitWrEBlZSXuuusuXLp0CVu3bkWvXr2s2vZ0OYEAaFsXGz58OAwGAw4ePIigoCAAwNmzZzFgwADRCj0tLc1M9Uw3Up7n8csvv6CkpAR9+vRBRkYGAgMDxbDeNoi4CqHzpFWaJpMJP//8M86dO4cLFy6Is9WgoCBcddVVGDhwIGJiYsRyByyfaumuKmZHo/a76fvJyckYO3YsCCHYt28fevfujYSEBItCgD1paWnB4cOHUVRUhEGDBiEuLq6dHz79Xe5Wpmq0MjzP48iRIygpKUF8fDwSExMRExOD0NDQds925vvoZ3U6HcLDw+Hn54eqqipcuXLF7NwWQFkb4G557AgE7Vjfvn3h4+ODX375BbW1tbjqqqucpilQ4yIo/B8SEgKe51FfX4/q6up22xm7pMyIC+B5nhgMBtLQ0EAOHTpEUlJSSEBAAAkMDCQ33ngjqa6uJg0NDaShoYG0trYSQggxmUzEZDKZxdPQ0EDy8/NJfHw8+eabb0hjYyMxGo3twjGcT2NjI8nPzyf+/v5Eo9EQAAQAyczMJCdPniSNjY1m4YXyVSo7nue9smyNRiMxGAztrlvKr9bWVtLQ0EAef/xxEhQURN544412YU0mEzEajcRoNNo9zSUlJeSqq64iWVlZ5NSpU6Spqald2o1GI+F53u7vVos035TyUqm+1dXVkXHjxpGkpCRy8OBBVX2PtTpuKR00O3fuJP7+/mTmzJmkubnZLB/p59XE1RVpbm4m58+fJzk5OaRPnz6ksLCQmEwmwvO8Yl7ZA0v1Wu49PM+TiooKkpubS+Li4sjRo0dVtUdHlqlTNQSEkp60Wi38/f0RExOD6667TpxBDho0SJT0aOQMakwmE1paWtDc3AxfX1/4+/u7vXtTV4VIZiYajQY5OTmiZC6US0JCAiIjIxXXyqTLP5aMdLwBa+vActf0er3YfpqamkSvHOlzjspTk8mE5uZmGI1G+Pr6ypa1qzUDat8tZ/R38OBBnDt3Dj169EBCQgJiY2PN+h5iRaujVJbCM9b6sG7dumHKlCmIjY3Ftm3bkJGR0e5YZW/uB319fREWFoarr74aZWVl+Prrr1FYWIhRo0bB39/foe+Wa69EwWhXaSnPUv0QsFbHOorTBQIisaBNTk7Gq6++amb5rLTWZenjhecZzofOe6FC6/V6PPbYY7IW13JGV0oWtvR9el3XWwQE2kJaaB+AurYgzSfpM85Yt1dKhyWVqitQ825CCAwGA1588UXs27cPmzdvxpgxY0ThS4hDWO6krdjp+9b6KWttoX///nj77bexadMmzJ07F/PmzWsnEMg939XbDP2dISEhWLVqFc6ePYv8/Hz4+/tj586dZksH9swP4d3WdlqVtgO6PUs9cpTshaR/2/M7nCoQKK0H22rsYjKZsG/fPpw6dQoDBw7EgAEDzDwRGM5FbhYvZ0SldgYl3GMC3h/Y2ug5jsPQoUPx17/+FRUVFXj99dcxduxYsz06vGWgsIRcHtCdt5wgw/M8eJ6HVquFTqezOvOTIhVs6bpuTRgA/nCD1Gg0aG1tNdv101l2Iu6ItM/QarUICwvD9OnTUVdXhy1btiApKQnjx49X5QnSEaQTI+Ga2jZmLaxUA+HRGgLhA6Q+skoStFLlNhgMePPNN7Fjxw785z//wdixY8X73uh/6w5YmolKr9M7U0o7YvqatxpHAeryQqmdCL8nTpyI/Px83HvvvVi+fDk2bNhgJhAI5WBv63PprEdpwKOvuVO5EtJ2Xr10/wbA/JuU0q3kTUB/L+1hYW1ZiI6DRsnvXs2zXRVp/YqMjMTy5ctx7Ngx3HDDDUhLS8PVV19td4FAqONSjxo1uyEKdUCqGZBqXtX0sZ3FZW6Hln4rXROuS7fmZGvNnoW1snK1+tgdUPP9asOMHz8ekZGROHHihLjJV0BAAKZPn474+PhOp7WjuEMZK/U9Wq1W7Ji/+OILfPfdd+L9vn37YsCAAUhOTu7w++T6MVuEPpPJhKysLDz22GPQarVYuXIlxowZg1GjRtmcpq6MkG+xsbFYsGABGhsbsXr1avTq1QuTJ0+W3SG3M+9S0irR0GXq5+eHOXPmoKioCP/5z3/wzTffYNasWQgJCXFN/6jK9NDNMBgMZMaMGSQoKIjs3r3bpdbKDCJa8NorLjlrbW+1mBZQY6Gu9BwhbW1m+vTpordHZGQkOXTokCOSSoqKikifPn1IRkYGuXDhgkPe0VlsycdHHnlEzDe9Xk+2bt1qVh62lostYaWW8cLztDX76tWrCQCybNmydp4k3thmLHmMfP311yQkJITk5+eTpqYml+SPXH9ZVFRE+vbtSzIzM8n58+ctPutIbyu3ONzIFrZu3YoDBw4gKysLI0aMQEZGhlvMNLwZe6p85WZOwnVvpqPfLzyn0Whw6623YujQoQDaZpjbtm3D7t27MW/evA4dBkYkM5Wmpia8+eabKC4uxuzZsxETE4OwsDAxrKXvkMblaAix/TjtadOmYfjw4ejXr1+n0mrLs0oaDLrNjRw5Ev/85z9RV1eHhx9+GJMmTcKIESNULcF2RSxpl1NTU7Fy5UrU1tbi8ccfN9vHgV7amjJlCoYNG+aU9BErSz7SZx15eJXHCQT79u3Dm2++iffffx/Tp0/3mkruTbAytS/CYJufn4/8/HwAQFVVFcaPH4/Lly9jypQpZpt/qcl/uU6spaUFGzduRFFREfbs2YM+ffrY7yPsjNpOWBAcOI5DXl6eeFaHs1BjU5CdnY3s7GwsX74cL774IpKSkjBs2DCvPPXQGgkJCVi4cCE+++wzTJs2DS0tLQDa53PPnj2Rk5Njds+eedkZ+yhHlqlHCQRCwzSZTHjttdfw5Zdf4r777nPrjofBcAXEgl0NIQQBAQF4/PHHUVJSgpdffhlNTU0A2jrCBx980Mw9ix48LcUpN8jKbWctfcYVRoVqNAM7duzA5s2b8eOPP7rdwCqX/xMnTkRSUhJOnz6Nv/71r5g7dy5ycnK8TjCQ04gQQsyMaAcMGIDXX3/dzMCZDltQUIA777wTHMfBx8cHCxYsEE9utTe0AaIt2gJH4DECAc/zMBgMoovN/v378fPPP2PatGlMIGAwZJBTwwsdjq+vLyZMmICioiK88sorOHXqFACIrorCMdVy5wzIxWkwGNDS0gKdTge9Xm+2f4KlwchV6mw17zt27Bjeeecd6HQ6+Pr6uvVZAIQQDBw4ENnZ2XjwwQfx4Ycf4k9/+pO4RMT446RQrVaLxMREzJ49WzHswoULsXHjRnAch4CAANx4443o1asXgLa642i3RVfhMQLBBx98gM2bN+Po0aPgOA4PPPAARo8ejX79+jGrdAZDBjkXP+n1qKgovPDCC6ivr4dGo0F5eTkWLlwoHi42evRo3H///WauUPQgz3EcDAYDnnnmGfz666+YM2cOkpKSkJCQAKC9SyPtlkW7Y8kdgOQslPJG+HvWrFmYMmUK+vfv7/S0WUNuALn99ttxzTXXIDs7WxTsvAnaFoC+JrdnBCAvOM+dOxd//vOfAbQJEp988gnWrl0LjuMQHR2NpUuXinW8K+ExAsGJEyfw6aefws/PDyEhIRg2bBjy8/PN1CzeVvEZDCWUhAEp/v7++NOf/gSgbfD+4YcfsGLFCtTU1IAQgtDQUFy5ckUc2PV6PXx8fMQ4m5ubUV9fj/379+PHH3/EAw88gNzcXMV0SNuq0Hm7wqhQmj4l+vTpgwkTJjg6SZ2C/o5+/fqZnfLnjciVq6Wd/6T1b8CAARgwYACANmPZ9evX48svv4RGo0F8fDwqKirEg6y0Wq3bHrNsKx4hEAiFptFocPfdd+Mvf/kLevfuLV4TwjBNAYPRHqm1uRR6Fi8cS20wGAAAR48exaRJk8Rn58yZg9tuu00cyP/5z39iz549+O2338zil27fK6RDTu2udN1ZyC2BCBsHuTtyApclv3dvQs13K5U9bXfh6+uLlStXorq6GhzHoaamBk8++SQqKysBtO1L8eyzz5qd3NpRaC2ZNP3SsnXEeOdRAgHHcUhPT8eoUaO8rnIzGM4gODjYzN2qoqICBQUF4gB/4cIFlJWVica9Z8+eRUFBAfR6PUJDQx22tupslIwkGZ6PrQOpRqMx07gUFxfjmWeeQWFhIYA2geHSpUsIDg7ucJoqKythNBqtpktIu6PqpkcIBAJCZliyWPY2i1oGQwnpbEdNZwOYr8Fec8012Llzpxjm448/xvjx48Xf8+bNw6JFiwC0aQNSUlLEv+UGVekMh+d5WcNFR2NpFi2sNzs7TbZA21woeWt0ZK+FroSlmbZUKKCXsKR1V1pXYmNj8e9//1t0WSwsLMStt96KxsbGDqe1tbUV58+fR3JyssUtyWkbHkeMcx4jEFhzyWDSPINhjtCp2TojottSWFgYwsPDxd+ffPIJfv31V3EgDwoKEtdapdCDlfBbLn3uiidNLOTUye6ct+6AJWHVEnq9HhkZGeLvxsZGnDhxAjU1NfZMniKOXBrnCKs1DAaDwWB4Pd6pS2IwGAwGg2EGEwgYDAaDwWAwgYDBYDAYDAYTCBgMBoPBYIAJBAwGg8FgMMAEAgaDwWAwGGACAYPBYDAYDDCBgMFgMBgMBphAwGAwGAwGA8D/AzPZIQFbM+H0AAAAAElFTkSuQmCC", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "gt_pseudo_label in the 3001st data example (a list of ground truth pseudo-labels): ['4', '/', '6', '*', '5']\n", - "Y in the 3001st data example (the computed result): 3.333333333333333\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "X_1000, gt_pseudo_label_1000, Y_1000 = train_X[1000], train_gt_pseudo_label[1000], train_Y[1000]\n", "print(f\"X in the 1001st data example (a list of images):\")\n", @@ -207,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -236,18 +177,9 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Predicted class index for a batch of 32 instances: ndarray with shape (32,)\n", - "Predicted class probabilities for a batch of 32 instances: ndarray with shape (32, 13)\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "data_instances = [torch.randn(1, 45, 45).to(device) for _ in range(32)]\n", "pred_idx = base_model.predict(X=data_instances)\n", @@ -267,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -283,21 +215,9 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Predicted class labels for the 100 data examples: a list of length 2, \n", - "the first element is a ndarray of shape (3,), and the second element is a ndarray of shape (5,).\n", - "\n", - "Predicted class probabilities for the 100 data examples: a list of length 2, \n", - "the first element is a ndarray of shape (3, 13), and the second element is a ndarray of shape (5, 13).\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from abl.data.structures import ListData\n", "# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n", @@ -335,7 +255,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -371,17 +291,9 @@ }, { "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Reasoning result of pseudo-labels ['1', '-', '2', '*', '5'] is -9.\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "pseudo_labels = [\"1\", \"-\", \"2\", \"*\", \"5\"]\n", "reasoning_result = kb.logic_forward(pseudo_labels)\n", @@ -406,7 +318,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -439,7 +351,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -458,7 +370,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -504,7 +416,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.18" }, "orig_nbformat": 4, "vscode": { diff --git a/examples/hwf/main.py b/examples/hwf/main.py index b20c290..0f9dfe7 100644 --- a/examples/hwf/main.py +++ b/examples/hwf/main.py @@ -1,4 +1,3 @@ -import os import os.path as osp import argparse @@ -6,14 +5,15 @@ import numpy as np import torch from torch import nn -from examples.hwf.datasets import get_dataset -from examples.models.nn import SymbolNet from abl.learning import ABLModel, BasicNN from abl.reasoning import KBBase, GroundKB, Reasoner from abl.data.evaluation import ReasoningMetric, SymbolAccuracy from abl.utils import ABLLogger, print_log from abl.bridge import SimpleBridge +from datasets import get_dataset +from models.nn import SymbolNet + class HwfKB(KBBase): def __init__( diff --git a/examples/hwf/models/nn.py b/examples/hwf/models/nn.py new file mode 100644 index 0000000..2b6da47 --- /dev/null +++ b/examples/hwf/models/nn.py @@ -0,0 +1,46 @@ +# coding: utf-8 +# ================================================================# +# Copyright (C) 2021 Freecss All rights reserved. +# +# File Name :lenet5.py +# Author :freecss +# Email :karlfreecss@gmail.com +# Created Date :2021/03/03 +# Description : +# +# ================================================================# + + +import torch +from torch import nn + + +class SymbolNet(nn.Module): + def __init__(self, num_classes=4, image_size=(28, 28, 1)): + super(SymbolNet, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(1, 32, 5, stride=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(32, momentum=0.99, eps=0.001), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(32, 64, 5, padding=2, stride=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(64, momentum=0.99, eps=0.001), + ) + + num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) + self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) + self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) + self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + return x diff --git a/examples/mnist_add/main.py b/examples/mnist_add/main.py index 53f4f05..4322254 100644 --- a/examples/mnist_add/main.py +++ b/examples/mnist_add/main.py @@ -10,8 +10,9 @@ from abl.data.evaluation import ReasoningMetric, SymbolAccuracy from abl.learning import ABLModel, BasicNN from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner from abl.utils import ABLLogger, print_log -from examples.mnist_add.datasets import get_dataset -from examples.models.nn import LeNet5 + +from datasets import get_dataset +from models.nn import LeNet5 class AddKB(KBBase): diff --git a/examples/mnist_add/mnist_add.ipynb b/examples/mnist_add/mnist_add.ipynb index 27ec738..c0af396 100644 --- a/examples/mnist_add/mnist_add.ipynb +++ b/examples/mnist_add/mnist_add.ipynb @@ -25,13 +25,14 @@ "\n", "from torch.optim import RMSprop, lr_scheduler\n", "\n", - "from examples.mnist_add.datasets import get_dataset\n", - "from examples.models.nn import LeNet5\n", + "from abl.bridge import SimpleBridge\n", + "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", "from abl.learning import ABLModel, BasicNN\n", "from abl.reasoning import KBBase, Reasoner\n", - "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", "from abl.utils import ABLLogger, print_log\n", - "from abl.bridge import SimpleBridge" + "\n", + "from datasets import get_dataset\n", + "from models.nn import LeNet5" ] }, { @@ -425,7 +426,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ diff --git a/examples/mnist_add/models/nn.py b/examples/mnist_add/models/nn.py new file mode 100644 index 0000000..5fb7f3e --- /dev/null +++ b/examples/mnist_add/models/nn.py @@ -0,0 +1,94 @@ +# coding: utf-8 +# ================================================================# +# Copyright (C) 2021 Freecss All rights reserved. +# +# File Name :lenet5.py +# Author :freecss +# Email :karlfreecss@gmail.com +# Created Date :2021/03/03 +# Description : +# +# ================================================================# + + +import numpy as np +import torch +from torch import nn + + +class LeNet5(nn.Module): + def __init__(self, num_classes=10, image_size=(28, 28)): + super(LeNet5, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(1, 6, 3, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(6, 16, 3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) + ) + self.conv3 = nn.Sequential(nn.Conv2d(16, 16, 3), nn.ReLU()) + + feature_map_size = (np.array(image_size) // 2 - 2) // 2 - 2 + num_features = 16 * feature_map_size[0] * feature_map_size[1] + + self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) + self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) + self.fc3 = nn.Linear(84, num_classes) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + return x + + +class SymbolNet(nn.Module): + def __init__(self, num_classes=4, image_size=(28, 28, 1)): + super(SymbolNet, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(1, 32, 5, stride=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(32, momentum=0.99, eps=0.001), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(32, 64, 5, padding=2, stride=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.BatchNorm2d(64, momentum=0.99, eps=0.001), + ) + + num_features = 64 * (image_size[0] // 4 - 1) * (image_size[1] // 4 - 1) + self.fc1 = nn.Sequential(nn.Linear(num_features, 120), nn.ReLU()) + self.fc2 = nn.Sequential(nn.Linear(120, 84), nn.ReLU()) + self.fc3 = nn.Sequential(nn.Linear(84, num_classes)) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + return x + + +class SymbolNetAutoencoder(nn.Module): + def __init__(self, num_classes=4, image_size=(28, 28, 1)): + super(SymbolNetAutoencoder, self).__init__() + self.base_model = SymbolNet(num_classes, image_size) + self.softmax = nn.Softmax(dim=1) + self.fc1 = nn.Sequential(nn.Linear(num_classes, 100), nn.ReLU()) + self.fc2 = nn.Sequential(nn.Linear(100, image_size[0] * image_size[1]), nn.ReLU()) + + def forward(self, x): + x = self.base_model(x) + # x = self.softmax(x) + x = self.fc1(x) + x = self.fc2(x) + return x