Browse Source

remove the parameter batch_size of VGG16, for we can use flatten instead of reshape.

tags/v0.2.0-alpha
caojian05 5 years ago
parent
commit
b36094e327
3 changed files with 7 additions and 10 deletions
  1. +1
    -1
      example/vgg16_cifar10/eval.py
  2. +1
    -1
      example/vgg16_cifar10/train.py
  3. +5
    -8
      mindspore/model_zoo/vgg.py

+ 1
- 1
example/vgg16_cifar10/eval.py View File

@@ -39,7 +39,7 @@ if __name__ == '__main__':
context.set_context(device_id=args_opt.device_id) context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True, enable_hccl=False) context.set_context(enable_mem_reuse=True, enable_hccl=False)


net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
net = vgg16(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
weight_decay=cfg.weight_decay) weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)


+ 1
- 1
example/vgg16_cifar10/train.py View File

@@ -64,7 +64,7 @@ if __name__ == '__main__':
context.set_context(device_id=args_opt.device_id) context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True, enable_hccl=False) context.set_context(enable_mem_reuse=True, enable_hccl=False)


net = vgg16(batch_size=cfg.batch_size, num_classes=cfg.num_classes)
net = vgg16(num_classes=cfg.num_classes)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=50000 // cfg.batch_size) lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=50000 // cfg.batch_size)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)


+ 5
- 8
mindspore/model_zoo/vgg.py View File

@@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""VGG.""" """VGG."""
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype


@@ -63,8 +62,7 @@ class Vgg(nn.Cell):
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1): def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1):
super(Vgg, self).__init__() super(Vgg, self).__init__()
self.layers = _make_layer(base, batch_norm=batch_norm) self.layers = _make_layer(base, batch_norm=batch_norm)
self.reshape = P.Reshape()
self.shp = (batch_size, -1)
self.flatten = nn.Flatten()
self.classifier = nn.SequentialCell([ self.classifier = nn.SequentialCell([
nn.Dense(512 * 7 * 7, 4096), nn.Dense(512 * 7 * 7, 4096),
nn.ReLU(), nn.ReLU(),
@@ -74,7 +72,7 @@ class Vgg(nn.Cell):


def construct(self, x): def construct(self, x):
x = self.layers(x) x = self.layers(x)
x = self.reshape(x, self.shp)
x = self.flatten(x)
x = self.classifier(x) x = self.classifier(x)
return x return x


@@ -87,20 +85,19 @@ cfg = {
} }




def vgg16(batch_size=1, num_classes=1000):
def vgg16(num_classes=1000):
""" """
Get Vgg16 neural network with batch normalization. Get Vgg16 neural network with batch normalization.


Args: Args:
batch_size (int): Batch size. Default: 1.
num_classes (int): Class numbers. Default: 1000. num_classes (int): Class numbers. Default: 1000.


Returns: Returns:
Cell, cell instance of Vgg16 neural network with batch normalization. Cell, cell instance of Vgg16 neural network with batch normalization.


Examples: Examples:
>>> vgg16(batch_size=1, num_classes=1000)
>>> vgg16(num_classes=1000)
""" """


net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True, batch_size=batch_size)
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True)
return net return net

Loading…
Cancel
Save