|
|
|
@@ -20,9 +20,9 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt |
|
|
|
|
|
|
|
import os |
|
|
|
import argparse |
|
|
|
from dataset import create_dataset |
|
|
|
from config import mnist_cfg as cfg |
|
|
|
from lenet import LeNet5 |
|
|
|
from src.dataset import create_dataset |
|
|
|
from src.config import mnist_cfg as cfg |
|
|
|
from src.lenet import LeNet5 |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore import context |
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
|
@@ -32,10 +32,10 @@ from mindspore.nn.metrics import Accuracy |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') |
|
|
|
parser = argparse.ArgumentParser(description='MindSpore Lenet Example') |
|
|
|
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], |
|
|
|
help='device where the code will be implemented (default: Ascend)') |
|
|
|
parser.add_argument('--data_path', type=str, default="./MNIST_Data", |
|
|
|
parser.add_argument('--data_path', type=str, default="./Data", |
|
|
|
help='path where the dataset is saved') |
|
|
|
parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\ |
|
|
|
path where the trained ckpt file') |
|
|
|
|