You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_layers.py 4.9 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorflow as tf
  4. import tensorlayer as tl
  5. __all__ = [
  6. 'activation_module',
  7. 'conv_module',
  8. 'dense_module',
  9. ]
  10. def activation_module(layer, activation_fn, leaky_relu_alpha=0.2, name=None):
  11. act_name = name + "/activation" if name is not None else "activation"
  12. if activation_fn not in ["ReLU", "ReLU6", "Leaky_ReLU", "PReLU", "PReLU6", "PTReLU6", "CReLU", "ELU", "SELU",
  13. "tanh", "sigmoid", "softmax", None]:
  14. raise Exception("Unknown 'activation_fn': %s" % activation_fn)
  15. elif activation_fn == "ReLU":
  16. layer = tl.layers.LambdaLayer(prev_layer=layer, fn=tf.nn.relu, name=act_name)
  17. elif activation_fn == "ReLU6":
  18. layer = tl.layers.LambdaLayer(prev_layer=layer, fn=tf.nn.relu6, name=act_name)
  19. elif activation_fn == "Leaky_ReLU":
  20. layer = tl.layers.LambdaLayer(
  21. prev_layer=layer, fn=tf.nn.leaky_relu, fn_args={'alpha': leaky_relu_alpha}, name=act_name
  22. )
  23. elif activation_fn == "PReLU":
  24. layer = tl.layers.PReluLayer(prev_layer=layer, channel_shared=False, name=act_name)
  25. elif activation_fn == "PReLU6":
  26. layer = tl.layers.PRelu6Layer(prev_layer=layer, channel_shared=False, name=act_name)
  27. elif activation_fn == "PTReLU6":
  28. layer = tl.layers.PTRelu6Layer(prev_layer=layer, channel_shared=False, name=act_name)
  29. elif activation_fn == "CReLU":
  30. layer = tl.layers.LambdaLayer(prev_layer=layer, fn=tf.nn.crelu, name=act_name)
  31. elif activation_fn == "ELU":
  32. layer = tl.layers.LambdaLayer(prev_layer=layer, fn=tf.nn.elu, name=act_name)
  33. elif activation_fn == "SELU":
  34. layer = tl.layers.LambdaLayer(prev_layer=layer, fn=tf.nn.selu, name=act_name)
  35. elif activation_fn == "tanh":
  36. layer = tl.layers.LambdaLayer(prev_layer=layer, fn=tf.nn.tanh, name=act_name)
  37. elif activation_fn == "sigmoid":
  38. layer = tl.layers.LambdaLayer(prev_layer=layer, fn=tf.nn.sigmoid, name=act_name)
  39. elif activation_fn == "softmax":
  40. layer = tl.layers.LambdaLayer(prev_layer=layer, fn=tf.nn.softmax, name=act_name)
  41. return layer
  42. def conv_module(
  43. prev_layer, n_out_channel, filter_size, strides, padding, is_train=True, use_batchnorm=True, activation_fn=None,
  44. conv_init=tl.initializers.random_uniform(),
  45. batch_norm_init=tl.initializers.truncated_normal(mean=1.,
  46. stddev=0.02), bias_init=tf.zeros_initializer(), name=None
  47. ):
  48. if activation_fn not in ["ReLU", "ReLU6", "Leaky_ReLU", "PReLU", "PReLU6", "PTReLU6", "CReLU", "ELU", "SELU",
  49. "tanh", "sigmoid", "softmax", None]:
  50. raise Exception("Unknown 'activation_fn': %s" % activation_fn)
  51. conv_name = 'conv2d' if name is None else name
  52. bn_name = 'batch_norm' if name is None else name + '/BatchNorm'
  53. layer = tl.layers.Conv2d(
  54. prev_layer,
  55. n_filter=n_out_channel,
  56. filter_size=filter_size,
  57. strides=strides,
  58. padding=padding,
  59. act=None,
  60. W_init=conv_init,
  61. b_init=None if use_batchnorm else bias_init, # Not useful as the convolutions are batch normalized
  62. name=conv_name
  63. )
  64. if use_batchnorm:
  65. layer = tl.layers.BatchNormLayer(layer, act=None, is_train=is_train, gamma_init=batch_norm_init, name=bn_name)
  66. logits = layer.outputs
  67. layer = activation_module(layer, activation_fn, name=conv_name)
  68. return layer, logits
  69. def dense_module(
  70. prev_layer, n_units, is_train, use_batchnorm=True, activation_fn=None,
  71. dense_init=tl.initializers.random_uniform(),
  72. batch_norm_init=tl.initializers.truncated_normal(mean=1.,
  73. stddev=0.02), bias_init=tf.zeros_initializer(), name=None
  74. ):
  75. if activation_fn not in ["ReLU", "ReLU6", "Leaky_ReLU", "PReLU", "PReLU6", "PTReLU6", "CReLU", "ELU", "SELU",
  76. "tanh", "sigmoid", "softmax", None]:
  77. raise Exception("Unknown 'activation_fn': %s" % activation_fn)
  78. # Flatten: Conv to FC
  79. if prev_layer.outputs.get_shape().__len__() != 2: # The input dimension must be rank 2
  80. layer = tl.layers.FlattenLayer(prev_layer, name='flatten')
  81. else:
  82. layer = prev_layer
  83. layer = tl.layers.DenseLayer(
  84. layer,
  85. n_units=n_units,
  86. act=None,
  87. W_init=dense_init,
  88. b_init=None if use_batchnorm else bias_init, # Not useful as the convolutions are batch normalized
  89. name='dense' if name is None else name
  90. )
  91. if use_batchnorm:
  92. layer = tl.layers.BatchNormLayer(
  93. layer, act=None, is_train=is_train, gamma_init=batch_norm_init, name='batch_norm'
  94. )
  95. logits = layer.outputs
  96. layer = activation_module(layer, activation_fn)
  97. return layer, logits

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.