using Tensorflow.Keras.Engine; using Tensorflow.Keras.ArgsDefinition; using static Tensorflow.Binding; using static Tensorflow.KerasApi; using System; using System.Collections.Generic; using System.Linq; /// /// Base class for attention layers that can be used in sequence DNN/CNN models. ///This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2. ///Attention is formed by three tensors: Query, Key and Value. /// namespace Tensorflow.Keras.Layers { /// /// Base Attention class for Dense networks. /// This class is suitable for Dense or CNN networks, and not for RNN networks. /// Implementations of attention mechanisms should inherit from this class, and /// reuse the `apply_attention_scores()` method. /// public class BaseDenseAttention : Layer { BaseDenseAttentionArgs args; bool causal { get => args.causal; } float dropout { get => args.dropout; } protected bool supports_masking; public BaseDenseAttention(BaseDenseAttentionArgs args) : base(args) { this.args = args; this.supports_masking = true; } /// /// Calculates attention scores. /// /// query: Query tensor of shape `[batch_size, Tq, dim]`. /// key: Key tensor of shape `[batch_size, Tv, dim]`. /// Tensor of shape `[batch_size, Tq, Tv]`. public virtual Tensor _calculate_scores(Tensor query, Tensor key) => throw new NotImplementedException(""); /// /// Applies attention scores to the given value tensor. /// To use this method in your attention layer, follow the steps: /// /// * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape /// `[batch_size, Tv]` to calculate the attention `scores`. /// /// /// * Pass `scores` and `value` tensors to this method. The method applies /// `scores_mask`, calculates `attention_distribution = softmax(scores)`, then /// returns `matmul(attention_distribution, value). /// /// /// * Apply `query_mask` and return the result. /// /// /// Scores float tensor of shape `[batch_size, Tq, Tv]`. /// Value tensor of shape `[batch_size, Tv, dim]`. /// /// A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or /// [batch_size, Tq, Tv]`. If given, scores at positions where /// `scores_mask==False` do not contribute to the result. It must contain /// at least one `True` value in each line along the last dimension. /// /// /// Boolean indicating whether the layer should behave in /// training mode (adding dropout) or in inference mode (no dropout). /// /// /// /// Tensor of shape `[batch_size, Tq, dim]`. /// /// /// Attention scores after masking and softmax with shape /// [batch_size, Tq, Tv]`. /// /// public (Tensor, Tensor) _apply_scores(Tensor scores, Tensor value, Tensor scores_mask = null, bool? training = null) { if (scores_mask != null) { var padding_mask = tf.logical_not(scores_mask); // Bias so padding positions do not contribute to attention distribution. // Note 65504. is the max float16 value. if (scores.dtype == tf.float16) scores -= 65504f * tf.cast(padding_mask, dtype: scores.dtype); else scores -= 1000000000f * tf.cast(padding_mask, dtype: scores.dtype); } bool _training; training ??= false; // TODO: Delete this line when backend.learning_phase is available if (training == null) _training = keras.backend.learning_phase() == Tensorflow.Keras.GraphLearningPhase.train_mode ? true : false; else _training = training.Value; var weights = tf.nn.softmax(scores); Func dropped_weights = () => tf.nn.dropout(weights, rate: this.dropout); weights = Tensorflow.Framework.smart_module.smart_cond(_training, dropped_weights, () => tf.identity(weights)); //return (tf.matmul(weights, value), weights); return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights); } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { Tensors _inp; Tensors _mask = null; int count = inputs.Count(); if (count < 2 || count > 6) throw new ValueError( $"{ this.name } layer accepts inputs list of length from 2 to 5, " + $"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." + $"Received length: {count}."); bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL; bool return_attention_scores = false; if (has_bool) { return_attention_scores = (bool)inputs[count - 1]; count--; } switch (count) { case 2: _inp = (inputs[0], inputs[1]); break; case 3: _inp = new[] { inputs[0], inputs[1], inputs[2] }; break; case 4: if (inputs[0].shape == inputs[2].shape) if (inputs[1].shape == inputs[3].shape) { _inp = new[] { inputs[0], inputs[1] }; _mask = new[] { inputs[2], inputs[3] }; break; } throw new ValueError(); //TODO:Add discriptions for this err case 5: _inp = new[] { inputs[0], inputs[1], inputs[2] }; _mask = (inputs[3], inputs[4]); break; default: throw new ValueError(); //TODO:Add discriptions for this err } return call(_inp, _mask, training, return_attention_scores); } protected Tensors call(Tensors inputs, Tensors mask = null, bool? training = null, bool return_attention_scores = false) { Tensor causal_mask; //this._validate_call_args(inputs: inputs, mask: mask); var q = inputs[0]; var v = inputs[1]; var k = inputs.Count() > 2 ? inputs[2] : v; var q_mask = mask != null ? mask[0] : null; var v_mask = mask != null ? mask[1] : null; var scores = this._calculate_scores(query: q, key: k); if (v_mask != null) // Mask of shape [batch_size, 1, Tv]. v_mask = tf.expand_dims(v_mask, axis: -2); if (this.causal) { // Creates a lower triangular mask, so position i cannot attend to // positions j>i. This prevents the flow of information from the future // into the past. var scores_shape = tf.shape(scores); // causal_mask_shape = [1, Tq, Tv]. var causal_mask_shape = tf.concat(new List { tf.ones_like(tf.slice(scores_shape, new[]{0}, new[]{-2})), tf.concat(new[]{scores_shape[-2], scores_shape[-1]}, 0) }, axis: 0); var _causal_mask_shape = new Shape(causal_mask_shape.ToArray()); causal_mask = _lower_triangular_mask(_causal_mask_shape); } else causal_mask = null; var scores_mask = _merge_masks(v_mask, causal_mask); var (result, attention_scores) = this._apply_scores(scores: scores, value: v, scores_mask: scores_mask, training: training); if (q_mask != null) { // Mask of shape [batch_size, Tq, 1]. q_mask = tf.expand_dims(q_mask, axis: -1); result *= tf.cast(q_mask, dtype: result.dtype); } if (return_attention_scores) return new Tensors(result, attention_scores); return result; } public Tensor compute_mask(Tensors inputs, Tensors mask = null) { this._validate_call_args(inputs: inputs, mask: mask); if (mask != null) { var q_mask = mask[0]; if (q_mask == null) return null; return tf.convert_to_tensor(q_mask); } return null; } //public Shape compute_output_shape(Shape input_shape) { // // return_attention_scores argument of BaseDenseAttention.call method // // is ignored. Output shape of attention_scores cannot be returned. // return input_shape[0]; //} /// /// Validates arguments of the call method. /// public void _validate_call_args(Tensors inputs, Tensors mask) { if (inputs.Count() < 2 || inputs.Count() > 3) throw new ValueError( $"{this.name} layer accepts inputs list of length 2 or 3, " + $"namely [query, value] or [query, value, key]. Received length: {len(inputs)}."); if (mask != null) if (mask.Count() < 2 || mask.Count() > inputs.Count()) throw new ValueError($"{this.name} layer mask must be a list of length 2, " + $"namely [query_mask, value_mask]. Received length: {len(mask)}."); } public static Tensor _lower_triangular_mask(Shape shape) { var row_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -2); var col_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -1); return tf.greater_equal(row_index, col_index); } public static Tensor _merge_masks(Tensor x, Tensor y) { if (x == null) return y; if (y == null) return x; return tf.logical_and(x, y); } public override LayerArgs get_config() => this.args; } }