Browse Source

amend deeplabv3 hub config

tags/v1.1.0
jzg 5 years ago
parent
commit
a32c5fbc92
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      model_zoo/official/cv/deeplabv3/mindspore_hub_conf.py

+ 3
- 3
model_zoo/official/cv/deeplabv3/mindspore_hub_conf.py View File

@@ -17,11 +17,11 @@ from src.nets import net_factory

def create_network(name, *args, **kwargs):
freeze_bn = True
num_classes = 21
num_classes = kwargs["num_classes"]
if name == 'deeplab_v3_s16':
deeplab_v3_s16_network = net_factory.nets_map["deeplab_v3_s16"]('eval', num_classes, 16, freeze_bn)
return deeplab_v3_s16_network(*args, **kwargs)
return deeplab_v3_s16_network
if name == 'deeplab_v3_s8':
deeplab_v3_s8_network = net_factory.nets_map["deeplab_v3_s8"]('eval', num_classes, 8, freeze_bn)
return deeplab_v3_s8_network(*args, **kwargs)
return deeplab_v3_s8_network
raise NotImplementedError(f"{name} is not implemented in the repo")

Loading…
Cancel
Save