You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseDenseAttention.cs 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. using Tensorflow.Keras.Engine;
  2. using Tensorflow.Keras.ArgsDefinition;
  3. using static Tensorflow.Binding;
  4. using static Tensorflow.KerasApi;
  5. using System;
  6. using System.Collections.Generic;
  7. using System.Linq;
  8. /// <summary>
  9. /// Base class for attention layers that can be used in sequence DNN/CNN models.
  10. ///This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
  11. ///Attention is formed by three tensors: Query, Key and Value.
  12. /// </summary>
  13. namespace Tensorflow.Keras.Layers
  14. {
  15. /// <summary>
  16. /// Base Attention class for Dense networks.
  17. /// This class is suitable for Dense or CNN networks, and not for RNN networks.
  18. /// Implementations of attention mechanisms should inherit from this class, and
  19. /// reuse the `apply_attention_scores()` method.
  20. /// </summary>
  21. public class BaseDenseAttention : Layer
  22. {
  23. BaseDenseAttentionArgs args;
  24. bool causal { get => args.causal; }
  25. float dropout { get => args.dropout; }
  26. protected bool supports_masking;
  27. public BaseDenseAttention(BaseDenseAttentionArgs args) : base(args)
  28. {
  29. this.args = args;
  30. this.supports_masking = true;
  31. }
  32. /// <summary>
  33. /// Calculates attention scores.
  34. /// </summary>
  35. /// <param name="query">query: Query tensor of shape `[batch_size, Tq, dim]`.</param>
  36. /// <param name="key">key: Key tensor of shape `[batch_size, Tv, dim]`.</param>
  37. /// <returns>Tensor of shape `[batch_size, Tq, Tv]`.</returns>
  38. public virtual Tensor _calculate_scores(Tensor query, Tensor key) =>
  39. throw new NotImplementedException("");
  40. /// <summary>
  41. /// Applies attention scores to the given value tensor.
  42. /// To use this method in your attention layer, follow the steps:
  43. /// <para>
  44. /// * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape
  45. /// `[batch_size, Tv]` to calculate the attention `scores`.
  46. /// </para>
  47. /// <para>
  48. /// * Pass `scores` and `value` tensors to this method. The method applies
  49. /// `scores_mask`, calculates `attention_distribution = softmax(scores)`, then
  50. /// returns `matmul(attention_distribution, value).
  51. /// </para>
  52. /// <para>
  53. /// * Apply `query_mask` and return the result.
  54. /// </para>
  55. /// </summary>
  56. /// <param name="scores">Scores float tensor of shape `[batch_size, Tq, Tv]`.</param>
  57. /// <param name="value">Value tensor of shape `[batch_size, Tv, dim]`.</param>
  58. /// <param name="scores_mask">
  59. /// A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or
  60. /// [batch_size, Tq, Tv]`. If given, scores at positions where
  61. /// `scores_mask==False` do not contribute to the result. It must contain
  62. /// at least one `True` value in each line along the last dimension.
  63. /// </param>
  64. /// <param name="training">
  65. /// Boolean indicating whether the layer should behave in
  66. /// training mode (adding dropout) or in inference mode (no dropout).
  67. /// </param>
  68. /// <returns>
  69. /// <para>
  70. /// Tensor of shape `[batch_size, Tq, dim]`.
  71. /// </para>
  72. /// <para>
  73. /// Attention scores after masking and softmax with shape
  74. /// [batch_size, Tq, Tv]`.
  75. /// </para>
  76. /// </returns>
  77. public (Tensor, Tensor) _apply_scores(Tensor scores,
  78. Tensor value,
  79. Tensor scores_mask = null,
  80. bool? training = null)
  81. {
  82. if (scores_mask != null)
  83. {
  84. var padding_mask = tf.logical_not(scores_mask);
  85. // Bias so padding positions do not contribute to attention distribution.
  86. // Note 65504. is the max float16 value.
  87. if (scores.dtype == tf.float16)
  88. scores -= 65504f * tf.cast(padding_mask, dtype: scores.dtype);
  89. else
  90. scores -= 1000000000f * tf.cast(padding_mask, dtype: scores.dtype);
  91. }
  92. bool _training;
  93. training ??= false; // TODO: Delete this line when backend.learning_phase is available
  94. if (training == null)
  95. _training = keras.backend.learning_phase() ==
  96. Tensorflow.Keras.GraphLearningPhase.train_mode ?
  97. true : false;
  98. else _training = training.Value;
  99. var weights = tf.nn.softmax(scores);
  100. Func<Tensor> dropped_weights = () => tf.nn.dropout(weights, rate: this.dropout);
  101. weights = Tensorflow.Framework.smart_module.smart_cond(_training, dropped_weights, () => tf.identity(weights));
  102. //return (tf.matmul(weights, value), weights);
  103. return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights);
  104. }
  105. protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
  106. {
  107. Tensors _inp;
  108. Tensors _mask = null;
  109. int count = inputs.Count();
  110. if (count < 2 || count > 6) throw new ValueError(
  111. $"{ this.name } layer accepts inputs list of length from 2 to 5, " +
  112. $"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." +
  113. $"Received length: {count}.");
  114. bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL;
  115. bool return_attention_scores = false;
  116. if (has_bool)
  117. {
  118. return_attention_scores = (bool)inputs[count - 1];
  119. count--;
  120. }
  121. switch (count)
  122. {
  123. case 2:
  124. _inp = (inputs[0], inputs[1]);
  125. break;
  126. case 3:
  127. _inp = new[] { inputs[0], inputs[1], inputs[2] };
  128. break;
  129. case 4:
  130. if (inputs[0].shape == inputs[2].shape)
  131. if (inputs[1].shape == inputs[3].shape)
  132. {
  133. _inp = new[] { inputs[0], inputs[1] };
  134. _mask = new[] { inputs[2], inputs[3] };
  135. break;
  136. }
  137. throw new ValueError(); //TODO:Add discriptions for this err
  138. case 5:
  139. _inp = new[] { inputs[0], inputs[1], inputs[2] };
  140. _mask = (inputs[3], inputs[4]);
  141. break;
  142. default:
  143. throw new ValueError(); //TODO:Add discriptions for this err
  144. }
  145. return call(_inp, _mask, training, return_attention_scores);
  146. }
  147. protected Tensors call(Tensors inputs, Tensors mask = null, bool? training = null, bool return_attention_scores = false)
  148. {
  149. Tensor causal_mask;
  150. //this._validate_call_args(inputs: inputs, mask: mask);
  151. var q = inputs[0];
  152. var v = inputs[1];
  153. var k = inputs.Count() > 2 ? inputs[2] : v;
  154. var q_mask = mask != null ? mask[0] : null;
  155. var v_mask = mask != null ? mask[1] : null;
  156. var scores = this._calculate_scores(query: q, key: k);
  157. if (v_mask != null)
  158. // Mask of shape [batch_size, 1, Tv].
  159. v_mask = tf.expand_dims(v_mask, axis: -2);
  160. if (this.causal)
  161. {
  162. // Creates a lower triangular mask, so position i cannot attend to
  163. // positions j>i. This prevents the flow of information from the future
  164. // into the past.
  165. var scores_shape = tf.shape(scores);
  166. // causal_mask_shape = [1, Tq, Tv].
  167. var causal_mask_shape = tf.concat(new List<Tensor> {
  168. tf.ones_like(tf.slice(scores_shape, new[]{0}, new[]{-2})),
  169. tf.concat(new[]{scores_shape[-2], scores_shape[-1]}, 0)
  170. }, axis: 0);
  171. var _causal_mask_shape = new Shape(causal_mask_shape.ToArray<int>());
  172. causal_mask = _lower_triangular_mask(_causal_mask_shape);
  173. }
  174. else
  175. causal_mask = null;
  176. var scores_mask = _merge_masks(v_mask, causal_mask);
  177. var (result, attention_scores) = this._apply_scores(scores: scores, value: v, scores_mask: scores_mask, training: training);
  178. if (q_mask != null)
  179. {
  180. // Mask of shape [batch_size, Tq, 1].
  181. q_mask = tf.expand_dims(q_mask, axis: -1);
  182. result *= tf.cast(q_mask, dtype: result.dtype);
  183. }
  184. if (return_attention_scores)
  185. return new Tensors(result, attention_scores);
  186. return result;
  187. }
  188. public Tensor compute_mask(Tensors inputs, Tensors mask = null)
  189. {
  190. this._validate_call_args(inputs: inputs, mask: mask);
  191. if (mask != null)
  192. {
  193. var q_mask = mask[0];
  194. if (q_mask == null)
  195. return null;
  196. return tf.convert_to_tensor(q_mask);
  197. }
  198. return null;
  199. }
  200. //public Shape compute_output_shape(Shape input_shape) {
  201. // // return_attention_scores argument of BaseDenseAttention.call method
  202. // // is ignored. Output shape of attention_scores cannot be returned.
  203. // return input_shape[0];
  204. //}
  205. /// <summary>
  206. /// Validates arguments of the call method.
  207. /// </summary>
  208. public void _validate_call_args(Tensors inputs, Tensors mask)
  209. {
  210. if (inputs.Count() < 2 || inputs.Count() > 3)
  211. throw new ValueError(
  212. $"{this.name} layer accepts inputs list of length 2 or 3, " +
  213. $"namely [query, value] or [query, value, key]. Received length: {len(inputs)}.");
  214. if (mask != null)
  215. if (mask.Count() < 2 || mask.Count() > inputs.Count())
  216. throw new ValueError($"{this.name} layer mask must be a list of length 2, " +
  217. $"namely [query_mask, value_mask]. Received length: {len(mask)}.");
  218. }
  219. public static Tensor _lower_triangular_mask(Shape shape)
  220. {
  221. var row_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -2);
  222. var col_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -1);
  223. return tf.greater_equal(row_index, col_index);
  224. }
  225. public static Tensor _merge_masks(Tensor x, Tensor y)
  226. {
  227. if (x == null)
  228. return y;
  229. if (y == null)
  230. return x;
  231. return tf.logical_and(x, y);
  232. }
  233. public override LayerArgs get_config() => this.args;
  234. }
  235. }