|
|
|
@@ -16,19 +16,21 @@ |
|
|
|
network config setting, will be used in train.py |
|
|
|
""" |
|
|
|
import os |
|
|
|
os.system('pip install easydict') |
|
|
|
from easydict import EasyDict as edict |
|
|
|
class Config: |
|
|
|
def __init__(self, **entries): |
|
|
|
self.__dict__.update(entries) |
|
|
|
|
|
|
|
mnist_cfg = edict({ |
|
|
|
'num_classes': 10, |
|
|
|
'lr': 0.01, |
|
|
|
'momentum': 0.9, |
|
|
|
'epoch_size': 10, |
|
|
|
'batch_size': 32, |
|
|
|
'buffer_size': 1000, |
|
|
|
'image_height': 32, |
|
|
|
'image_width': 32, |
|
|
|
'save_checkpoint_steps': 1875, |
|
|
|
'keep_checkpoint_max': 150, |
|
|
|
'air_name': "lenet", |
|
|
|
}) |
|
|
|
# 定义配置信息 |
|
|
|
mnist_cfg = Config( |
|
|
|
num_classes=10, |
|
|
|
lr=0.01, |
|
|
|
momentum=0.9, |
|
|
|
epoch_size=10, |
|
|
|
batch_size=32, |
|
|
|
buffer_size=1000, |
|
|
|
image_height=32, |
|
|
|
image_width=32, |
|
|
|
save_checkpoint_steps=1875, |
|
|
|
keep_checkpoint_max=150, |
|
|
|
air_name="lenet" |
|
|
|
) |