From dab27d746cfa6b16a70db13e4afcdf45ce2a0b79 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Fri, 31 Mar 2023 16:36:13 +0800 Subject: [PATCH] [FIX] run examples/hwf/hwf_example.ipynb after reformat examples --- examples/hwf/datasets/get_hwf.py | 6 +++--- examples/hwf/hwf_example.ipynb | 10 +++++----- examples/mnist_add/datasets/get_mnist_add.py | 8 +++----- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/examples/hwf/datasets/get_hwf.py b/examples/hwf/datasets/get_hwf.py index 87da5cf..299a283 100644 --- a/examples/hwf/datasets/get_hwf.py +++ b/examples/hwf/datasets/get_hwf.py @@ -12,7 +12,7 @@ def get_data(file, get_pseudo_label): if get_pseudo_label: Z = [] Y = [] - img_dir = './datasets/hwf/data/Handwritten_Math_Symbols/' + img_dir = './datasets/data/Handwritten_Math_Symbols/' with open(file) as f: data = json.load(f) for idx in range(len(data)): @@ -36,9 +36,9 @@ def get_data(file, get_pseudo_label): def get_hwf(train = True, get_pseudo_label = False): if(train): - file = './datasets/hwf/data/expr_train.json' + file = './datasets/data/expr_train.json' else: - file = './datasets/hwf/data/expr_test.json' + file = './datasets/data/expr_test.json' return get_data(file, get_pseudo_label) diff --git a/examples/hwf/hwf_example.ipynb b/examples/hwf/hwf_example.ipynb index 5881de9..98f46a6 100644 --- a/examples/hwf/hwf_example.ipynb +++ b/examples/hwf/hwf_example.ipynb @@ -8,7 +8,7 @@ "source": [ "import sys\n", "\n", - "sys.path.append(\"../\")\n", + "sys.path.append(\"../../\")\n", "\n", "import torch.nn as nn\n", "import torch\n", @@ -21,8 +21,8 @@ "from abl.models.wabl_models import WABLBasicModel\n", "\n", "from models.nn import SymbolNet\n", - "from datasets.hwf.get_hwf import get_hwf\n", - "from abl import framework_hed" + "from datasets.get_hwf import get_hwf\n", + "from abl import framework" ] }, { @@ -150,7 +150,7 @@ "outputs": [], "source": [ "# Train model\n", - "framework_hed.train(\n", + "framework.train(\n", " model, abducer, train_data, test_data, loop_num=15, sample_num=5000, verbose=1\n", ")\n", "\n", @@ -175,7 +175,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.16" }, "orig_nbformat": 4 }, diff --git a/examples/mnist_add/datasets/get_mnist_add.py b/examples/mnist_add/datasets/get_mnist_add.py index d00c6ce..46b5f12 100644 --- a/examples/mnist_add/datasets/get_mnist_add.py +++ b/examples/mnist_add/datasets/get_mnist_add.py @@ -1,6 +1,4 @@ -import torch import torchvision -from torch.utils.data import Dataset from torchvision.transforms import transforms def get_data(file, img_dataset, get_pseudo_label): @@ -23,12 +21,12 @@ def get_data(file, img_dataset, get_pseudo_label): def get_mnist_add(train = True, get_pseudo_label = False): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) - img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=train, download=True, transform=transform) + img_dataset = torchvision.datasets.MNIST(root='./datasets/', train=train, download=True, transform=transform) if train: - file = './datasets/mnist_add/train_data.txt' + file = './datasets/train_data.txt' else: - file = './datasets/mnist_add/test_data.txt' + file = './datasets/test_data.txt' return get_data(file, img_dataset, get_pseudo_label)