You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

normalization.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. """
  2. Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
  3. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
  4. """
  5. import re
  6. import jittor as jt
  7. from jittor import init
  8. from jittor import nn
  9. from models.networks.spectral_norm import spectral_norm
  10. # Returns a function that creates a normalization function
  11. # that does not condition on semantic map
  12. def get_nonspade_norm_layer(opt, norm_type='instance'):
  13. # helper function to get # output channels of the previous layer
  14. def get_out_channel(layer):
  15. if hasattr(layer, 'out_channels'):
  16. return getattr(layer, 'out_channels')
  17. return layer.weight.size(0)
  18. # this function will be returned
  19. def add_norm_layer(layer):
  20. nonlocal norm_type
  21. if norm_type.startswith('spectral'):
  22. layer = spectral_norm(layer)
  23. subnorm_type = norm_type[len('spectral'):]
  24. if subnorm_type == 'none' or len(subnorm_type) == 0:
  25. return layer
  26. # remove bias in the previous layer, which is meaningless
  27. # since it has no effect after normalization
  28. if getattr(layer, 'bias', None) is not None:
  29. delattr(layer, 'bias')
  30. setattr(layer, 'bias', None)
  31. # layer.load_parameters({'bias': None})
  32. if subnorm_type == 'batch':
  33. norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
  34. elif subnorm_type == 'sync_batch':
  35. norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
  36. elif subnorm_type == 'instance':
  37. norm_layer = nn.InstanceNorm2d(
  38. get_out_channel(layer), affine=False)
  39. else:
  40. raise ValueError(
  41. 'normalization layer %s is not recognized' % subnorm_type)
  42. return nn.Sequential(layer, norm_layer)
  43. return add_norm_layer
  44. # Creates SPADE normalization layer based on the given configuration
  45. # SPADE consists of two steps. First, it normalizes the activations using
  46. # your favorite normalization method, such as Batch Norm or Instance Norm.
  47. # Second, it applies scale and bias to the normalized output, conditioned on
  48. # the segmentation map.
  49. # The format of |config_text| is spade(norm)(ks), where
  50. # (norm) specifies the type of parameter-free normalization.
  51. # (e.g. syncbatch, batch, instance)
  52. # (ks) specifies the size of kernel in the SPADE module (e.g. 3x3)
  53. # Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5.
  54. # Also, the other arguments are
  55. # |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE
  56. # |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE
  57. class SPADE(nn.Module):
  58. def __init__(self, config_text, norm_nc, label_nc):
  59. super().__init__()
  60. assert config_text.startswith('spade')
  61. parsed = re.search('spade(\D+)(\d)x\d', config_text)
  62. param_free_norm_type = str(parsed.group(1))
  63. ks = int(parsed.group(2))
  64. if param_free_norm_type == 'instance':
  65. self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
  66. elif param_free_norm_type == 'syncbatch':
  67. self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
  68. elif param_free_norm_type == 'batch':
  69. self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
  70. else:
  71. raise ValueError('%s is not a recognized param-free norm type in SPADE'
  72. % param_free_norm_type)
  73. # The dimension of the intermediate embedding space. Yes, hardcoded.
  74. nhidden = 128
  75. pw = ks // 2
  76. self.mlp_shared = nn.Sequential(
  77. nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
  78. nn.ReLU()
  79. )
  80. self.mlp_gamma = nn.Conv2d(
  81. nhidden, norm_nc, kernel_size=ks, padding=pw)
  82. self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
  83. def execute(self, x, segmap):
  84. # Part 1. generate parameter-free normalized activations
  85. normalized = self.param_free_norm(x)
  86. # Part 2. produce scaling and bias conditioned on semantic map
  87. segmap = nn.interpolate(segmap, size=x.size()[2:], mode='nearest')
  88. actv = self.mlp_shared(segmap)
  89. gamma = self.mlp_gamma(actv)
  90. beta = self.mlp_beta(actv)
  91. # apply scale and bias
  92. out = normalized * (1 + gamma) + beta
  93. return out

第三届计图人工智能挑战赛——风格及语义引导的风景图片生成赛道项目,由jittor计图框架实现