You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stack.py 3.1 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorlayer as tl
  4. from tensorlayer import logging
  5. from tensorlayer.layers.core import Module
  6. __all__ = [
  7. 'Stack',
  8. 'UnStack',
  9. ]
  10. class Stack(Module):
  11. """
  12. The :class:`Stack` class is a layer for stacking a list of rank-R tensors into one rank-(R+1) tensor, see `tf.stack() <https://www.tensorflow.org/api_docs/python/tf/stack>`__.
  13. Parameters
  14. ----------
  15. axis : int
  16. New dimension along which to stack.
  17. name : str
  18. A unique layer name.
  19. Examples
  20. ---------
  21. >>> import tensorflow as tf
  22. >>> import tensorlayer as tl
  23. >>> ni = tl.layers.Input([None, 784], name='input')
  24. >>> net1 = tl.layers.Dense(10, name='dense1')(ni)
  25. >>> net2 = tl.layers.Dense(10, name='dense2')(ni)
  26. >>> net3 = tl.layers.Dense(10, name='dense3')(ni)
  27. >>> net = tl.layers.Stack(axis=1, name='stack')([net1, net2, net3])
  28. (?, 3, 10)
  29. """
  30. def __init__(
  31. self,
  32. axis=1,
  33. name=None, #'stack',
  34. ):
  35. super().__init__(name)
  36. self.axis = axis
  37. self.build(None)
  38. self._built = True
  39. logging.info("Stack %s: axis: %d" % (self.name, self.axis))
  40. def __repr__(self):
  41. s = '{classname}(axis={axis}'
  42. if self.name is not None:
  43. s += ', name=\'{name}\''
  44. s += ')'
  45. return s.format(classname=self.__class__.__name__, **self.__dict__)
  46. def build(self, inputs_shape):
  47. self.stack = tl.ops.Stack(axis=self.axis)
  48. def forward(self, inputs):
  49. outputs = self.stack(inputs)
  50. return outputs
  51. class UnStack(Module):
  52. """
  53. The :class:`UnStack` class is a layer for unstacking the given dimension of a rank-R tensor into rank-(R-1) tensors., see `tf.unstack() <https://www.tensorflow.org/api_docs/python/tf/unstack>`__.
  54. Parameters
  55. ----------
  56. num : int or None
  57. The length of the dimension axis. Automatically inferred if None (the default).
  58. axis : int
  59. Dimension along which axis to concatenate.
  60. name : str
  61. A unique layer name.
  62. Returns
  63. -------
  64. list of :class:`Layer`
  65. The list of layer objects unstacked from the input.
  66. Examples
  67. --------
  68. >>> ni = Input([4, 10], name='input')
  69. >>> nn = Dense(n_units=5)(ni)
  70. >>> nn = UnStack(axis=1)(nn) # unstack in channel axis
  71. >>> len(nn) # 5
  72. >>> nn[0].shape # (4,)
  73. """
  74. def __init__(self, num=None, axis=0, name=None): #'unstack'):
  75. super().__init__(name)
  76. self.num = num
  77. self.axis = axis
  78. self.build(None)
  79. self._built = True
  80. logging.info("UnStack %s: num: %s axis: %d" % (self.name, self.num, self.axis))
  81. def __repr__(self):
  82. s = '{classname}(num={num}, axis={axis}'
  83. if self.name is not None:
  84. s += ', name=\'{name}\''
  85. s += ')'
  86. return s.format(classname=self.__class__.__name__, **self.__dict__)
  87. def build(self, inputs_shape):
  88. self.unstack = tl.ops.Unstack(num=self.num, axis=self.axis)
  89. def forward(self, inputs):
  90. outputs = self.unstack(inputs)
  91. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.