|
|
|
@@ -17,11 +17,11 @@ from src.nets import net_factory |
|
|
|
|
|
|
|
def create_network(name, *args, **kwargs): |
|
|
|
freeze_bn = True |
|
|
|
num_classes = 21 |
|
|
|
num_classes = kwargs["num_classes"] |
|
|
|
if name == 'deeplab_v3_s16': |
|
|
|
deeplab_v3_s16_network = net_factory.nets_map["deeplab_v3_s16"]('eval', num_classes, 16, freeze_bn) |
|
|
|
return deeplab_v3_s16_network(*args, **kwargs) |
|
|
|
return deeplab_v3_s16_network |
|
|
|
if name == 'deeplab_v3_s8': |
|
|
|
deeplab_v3_s8_network = net_factory.nets_map["deeplab_v3_s8"]('eval', num_classes, 8, freeze_bn) |
|
|
|
return deeplab_v3_s8_network(*args, **kwargs) |
|
|
|
return deeplab_v3_s8_network |
|
|
|
raise NotImplementedError(f"{name} is not implemented in the repo") |