diff --git a/model_zoo/official/cv/vgg16/src/vgg.py b/model_zoo/official/cv/vgg16/src/vgg.py index a2e4f13f3a..848a9dbcc8 100644 --- a/model_zoo/official/cv/vgg16/src/vgg.py +++ b/model_zoo/official/cv/vgg16/src/vgg.py @@ -127,7 +127,7 @@ cfg = { } -def vgg16(num_classes=1000, args=None, phase="train"): +def vgg16(num_classes=1000, args=None, phase="train", **kwargs): """ Get Vgg16 neural network with batch normalization. @@ -140,11 +140,11 @@ def vgg16(num_classes=1000, args=None, phase="train"): Cell, cell instance of Vgg16 neural network with batch normalization. Examples: - >>> vgg16(num_classes=1000, args=args) + >>> vgg16(num_classes=1000, args=args, **kwargs) """ if args is None: from .config import cifar_cfg args = cifar_cfg - net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase) + net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase, **kwargs) return net