You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

builder.py 1.4 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. from mmcv.cnn import MODELS as MMCV_MODELS
  4. from mmcv.utils import Registry
  5. MODELS = Registry('models', parent=MMCV_MODELS)
  6. BACKBONES = MODELS
  7. NECKS = MODELS
  8. ROI_EXTRACTORS = MODELS
  9. SHARED_HEADS = MODELS
  10. HEADS = MODELS
  11. LOSSES = MODELS
  12. DETECTORS = MODELS
  13. def build_backbone(cfg):
  14. """Build backbone."""
  15. return BACKBONES.build(cfg)
  16. def build_neck(cfg):
  17. """Build neck."""
  18. return NECKS.build(cfg)
  19. def build_roi_extractor(cfg):
  20. """Build roi extractor."""
  21. return ROI_EXTRACTORS.build(cfg)
  22. def build_shared_head(cfg):
  23. """Build shared head."""
  24. return SHARED_HEADS.build(cfg)
  25. def build_head(cfg):
  26. """Build head."""
  27. return HEADS.build(cfg)
  28. def build_loss(cfg):
  29. """Build loss."""
  30. return LOSSES.build(cfg)
  31. def build_detector(cfg, train_cfg=None, test_cfg=None):
  32. """Build detector."""
  33. if train_cfg is not None or test_cfg is not None:
  34. warnings.warn(
  35. 'train_cfg and test_cfg is deprecated, '
  36. 'please specify them in model', UserWarning)
  37. assert cfg.get('train_cfg') is None or train_cfg is None, \
  38. 'train_cfg specified in both outer field and model field '
  39. assert cfg.get('test_cfg') is None or test_cfg is None, \
  40. 'test_cfg specified in both outer field and model field '
  41. return DETECTORS.build(
  42. cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))

No Description

Contributors (3)