You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

builder.py 1.5 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. from mmcv.utils import Registry, build_from_cfg
  4. TRANSFORMER = Registry('Transformer')
  5. LINEAR_LAYERS = Registry('linear layers')
  6. def build_transformer(cfg, default_args=None):
  7. """Builder for Transformer."""
  8. return build_from_cfg(cfg, TRANSFORMER, default_args)
  9. LINEAR_LAYERS.register_module('Linear', module=nn.Linear)
  10. def build_linear_layer(cfg, *args, **kwargs):
  11. """Build linear layer.
  12. Args:
  13. cfg (None or dict): The linear layer config, which should contain:
  14. - type (str): Layer type.
  15. - layer args: Args needed to instantiate an linear layer.
  16. args (argument list): Arguments passed to the `__init__`
  17. method of the corresponding linear layer.
  18. kwargs (keyword arguments): Keyword arguments passed to the `__init__`
  19. method of the corresponding linear layer.
  20. Returns:
  21. nn.Module: Created linear layer.
  22. """
  23. if cfg is None:
  24. cfg_ = dict(type='Linear')
  25. else:
  26. if not isinstance(cfg, dict):
  27. raise TypeError('cfg must be a dict')
  28. if 'type' not in cfg:
  29. raise KeyError('the cfg dict must contain the key "type"')
  30. cfg_ = cfg.copy()
  31. layer_type = cfg_.pop('type')
  32. if layer_type not in LINEAR_LAYERS:
  33. raise KeyError(f'Unrecognized linear type {layer_type}')
  34. else:
  35. linear_layer = LINEAR_LAYERS.get(layer_type)
  36. layer = linear_layer(*args, **kwargs, **cfg_)
  37. return layer

No Description

Contributors (2)