|
|
@@ -24,6 +24,7 @@ from mindspore.model_zoo.mobilenet import mobilenet_v2 |
|
|
from mindspore.train.model import Model |
|
|
from mindspore.train.model import Model |
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits |
|
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits |
|
|
|
|
|
from mindspore.common import dtype as mstype |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Image classification') |
|
|
parser = argparse.ArgumentParser(description='Image classification') |
|
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') |
|
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') |
|
|
@@ -39,7 +40,8 @@ context.set_context(enable_mem_reuse=True) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
if __name__ == '__main__': |
|
|
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') |
|
|
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') |
|
|
net = mobilenet_v2() |
|
|
|
|
|
|
|
|
net = mobilenet_v2(num_classes=config.num_classes) |
|
|
|
|
|
net.to_float(mstype.float16) |
|
|
|
|
|
|
|
|
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) |
|
|
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) |
|
|
step_size = dataset.get_dataset_size() |
|
|
step_size = dataset.get_dataset_size() |
|
|
|