diff --git a/example/vgg16_cifar10/eval.py b/example/vgg16_cifar10/eval.py index b034183373..ca2bbd12eb 100644 --- a/example/vgg16_cifar10/eval.py +++ b/example/vgg16_cifar10/eval.py @@ -39,7 +39,7 @@ if __name__ == '__main__': context.set_context(device_id=args_opt.device_id) context.set_context(enable_mem_reuse=True, enable_hccl=False) - net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes) + net = vgg16(num_classes=cfg.num_classes) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, weight_decay=cfg.weight_decay) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) diff --git a/example/vgg16_cifar10/train.py b/example/vgg16_cifar10/train.py index 32cd344d50..a4aa587c3d 100644 --- a/example/vgg16_cifar10/train.py +++ b/example/vgg16_cifar10/train.py @@ -64,7 +64,7 @@ if __name__ == '__main__': context.set_context(device_id=args_opt.device_id) context.set_context(enable_mem_reuse=True, enable_hccl=False) - net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes) + net = vgg16(num_classes=cfg.num_classes) lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=50000 // cfg.batch_size) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) diff --git a/mindspore/model_zoo/vgg.py b/mindspore/model_zoo/vgg.py index 6fcd075cc8..f3532fab13 100644 --- a/mindspore/model_zoo/vgg.py +++ b/mindspore/model_zoo/vgg.py @@ -14,7 +14,6 @@ # ============================================================================ """VGG.""" import mindspore.nn as nn -from mindspore.ops import operations as P from mindspore.common.initializer import initializer import mindspore.common.dtype as mstype @@ -63,8 +62,7 @@ class Vgg(nn.Cell): def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1): super(Vgg, self).__init__() self.layers = _make_layer(base, batch_norm=batch_norm) - self.reshape = P.Reshape() - self.shp = (batch_size, -1) + self.flatten = nn.Flatten() self.classifier = nn.SequentialCell([ nn.Dense(512 * 7 * 7, 4096), nn.ReLU(), @@ -74,7 +72,7 @@ class Vgg(nn.Cell): def construct(self, x): x = self.layers(x) - x = self.reshape(x, self.shp) + x = self.flatten(x) x = self.classifier(x) return x @@ -87,20 +85,19 @@ cfg = { } -def vgg16(batch_size=1, num_classes=1000): +def vgg16(num_classes=1000): """ Get Vgg16 neural network with batch normalization. Args: - batch_size (int): Batch size. Default: 1. num_classes (int): Class numbers. Default: 1000. Returns: Cell, cell instance of Vgg16 neural network with batch normalization. Examples: - >>> vgg16(batch_size=1, num_classes=1000) + >>> vgg16(num_classes=1000) """ - net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True, batch_size=batch_size) + net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True) return net