| @@ -16,19 +16,21 @@ | |||||
| network config setting, will be used in train.py | network config setting, will be used in train.py | ||||
| """ | """ | ||||
| import os | import os | ||||
| os.system('pip install easydict') | |||||
| from easydict import EasyDict as edict | |||||
| class Config: | |||||
| def __init__(self, **entries): | |||||
| self.__dict__.update(entries) | |||||
| mnist_cfg = edict({ | |||||
| 'num_classes': 10, | |||||
| 'lr': 0.01, | |||||
| 'momentum': 0.9, | |||||
| 'epoch_size': 10, | |||||
| 'batch_size': 32, | |||||
| 'buffer_size': 1000, | |||||
| 'image_height': 32, | |||||
| 'image_width': 32, | |||||
| 'save_checkpoint_steps': 1875, | |||||
| 'keep_checkpoint_max': 150, | |||||
| 'air_name': "lenet", | |||||
| }) | |||||
| # 定义配置信息 | |||||
| mnist_cfg = Config( | |||||
| num_classes=10, | |||||
| lr=0.01, | |||||
| momentum=0.9, | |||||
| epoch_size=10, | |||||
| batch_size=32, | |||||
| buffer_size=1000, | |||||
| image_height=32, | |||||
| image_width=32, | |||||
| save_checkpoint_steps=1875, | |||||
| keep_checkpoint_max=150, | |||||
| air_name="lenet" | |||||
| ) | |||||