|
|
|
@@ -127,7 +127,7 @@ cfg = { |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def vgg16(num_classes=1000, args=None, phase="train"): |
|
|
|
def vgg16(num_classes=1000, args=None, phase="train", **kwargs): |
|
|
|
""" |
|
|
|
Get Vgg16 neural network with batch normalization. |
|
|
|
|
|
|
|
@@ -140,11 +140,11 @@ def vgg16(num_classes=1000, args=None, phase="train"): |
|
|
|
Cell, cell instance of Vgg16 neural network with batch normalization. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> vgg16(num_classes=1000, args=args) |
|
|
|
>>> vgg16(num_classes=1000, args=args, **kwargs) |
|
|
|
""" |
|
|
|
|
|
|
|
if args is None: |
|
|
|
from .config import cifar_cfg |
|
|
|
args = cifar_cfg |
|
|
|
net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase) |
|
|
|
net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase, **kwargs) |
|
|
|
return net |