Browse Source

[FIX] run examples/hwf/hwf_example.ipynb after reformat examples

pull/3/head
Gao Enhao 2 years ago
parent
commit
dab27d746c
3 changed files with 11 additions and 13 deletions
  1. +3
    -3
      examples/hwf/datasets/get_hwf.py
  2. +5
    -5
      examples/hwf/hwf_example.ipynb
  3. +3
    -5
      examples/mnist_add/datasets/get_mnist_add.py

+ 3
- 3
examples/hwf/datasets/get_hwf.py View File

@@ -12,7 +12,7 @@ def get_data(file, get_pseudo_label):
if get_pseudo_label:
Z = []
Y = []
img_dir = './datasets/hwf/data/Handwritten_Math_Symbols/'
img_dir = './datasets/data/Handwritten_Math_Symbols/'
with open(file) as f:
data = json.load(f)
for idx in range(len(data)):
@@ -36,9 +36,9 @@ def get_data(file, get_pseudo_label):

def get_hwf(train = True, get_pseudo_label = False):
if(train):
file = './datasets/hwf/data/expr_train.json'
file = './datasets/data/expr_train.json'
else:
file = './datasets/hwf/data/expr_test.json'
file = './datasets/data/expr_test.json'
return get_data(file, get_pseudo_label)



+ 5
- 5
examples/hwf/hwf_example.ipynb View File

@@ -8,7 +8,7 @@
"source": [
"import sys\n",
"\n",
"sys.path.append(\"../\")\n",
"sys.path.append(\"../../\")\n",
"\n",
"import torch.nn as nn\n",
"import torch\n",
@@ -21,8 +21,8 @@
"from abl.models.wabl_models import WABLBasicModel\n",
"\n",
"from models.nn import SymbolNet\n",
"from datasets.hwf.get_hwf import get_hwf\n",
"from abl import framework_hed"
"from datasets.get_hwf import get_hwf\n",
"from abl import framework"
]
},
{
@@ -150,7 +150,7 @@
"outputs": [],
"source": [
"# Train model\n",
"framework_hed.train(\n",
"framework.train(\n",
" model, abducer, train_data, test_data, loop_num=15, sample_num=5000, verbose=1\n",
")\n",
"\n",
@@ -175,7 +175,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.16"
},
"orig_nbformat": 4
},


+ 3
- 5
examples/mnist_add/datasets/get_mnist_add.py View File

@@ -1,6 +1,4 @@
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision.transforms import transforms

def get_data(file, img_dataset, get_pseudo_label):
@@ -23,12 +21,12 @@ def get_data(file, img_dataset, get_pseudo_label):

def get_mnist_add(train = True, get_pseudo_label = False):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))])
img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=train, download=True, transform=transform)
img_dataset = torchvision.datasets.MNIST(root='./datasets/', train=train, download=True, transform=transform)
if train:
file = './datasets/mnist_add/train_data.txt'
file = './datasets/train_data.txt'
else:
file = './datasets/mnist_add/test_data.txt'
file = './datasets/test_data.txt'
return get_data(file, img_dataset, get_pseudo_label)


Loading…
Cancel
Save