using Tensorflow.Keras.Engine;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System;
using System.Collections.Generic;
using System.Linq;
///
/// Base class for attention layers that can be used in sequence DNN/CNN models.
///This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
///Attention is formed by three tensors: Query, Key and Value.
///
namespace Tensorflow.Keras.Layers
{
///
/// Base Attention class for Dense networks.
/// This class is suitable for Dense or CNN networks, and not for RNN networks.
/// Implementations of attention mechanisms should inherit from this class, and
/// reuse the `apply_attention_scores()` method.
///
public class BaseDenseAttention : Layer
{
BaseDenseAttentionArgs args;
bool causal { get => args.causal; }
float dropout { get => args.dropout; }
protected bool supports_masking;
public BaseDenseAttention(BaseDenseAttentionArgs args) : base(args)
{
this.args = args;
this.supports_masking = true;
}
///
/// Calculates attention scores.
///
/// query: Query tensor of shape `[batch_size, Tq, dim]`.
/// key: Key tensor of shape `[batch_size, Tv, dim]`.
/// Tensor of shape `[batch_size, Tq, Tv]`.
public virtual Tensor _calculate_scores(Tensor query, Tensor key) =>
throw new NotImplementedException("");
///
/// Applies attention scores to the given value tensor.
/// To use this method in your attention layer, follow the steps:
///
/// * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape
/// `[batch_size, Tv]` to calculate the attention `scores`.
///
///
/// * Pass `scores` and `value` tensors to this method. The method applies
/// `scores_mask`, calculates `attention_distribution = softmax(scores)`, then
/// returns `matmul(attention_distribution, value).
///
///
/// * Apply `query_mask` and return the result.
///
///
/// Scores float tensor of shape `[batch_size, Tq, Tv]`.
/// Value tensor of shape `[batch_size, Tv, dim]`.
///
/// A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or
/// [batch_size, Tq, Tv]`. If given, scores at positions where
/// `scores_mask==False` do not contribute to the result. It must contain
/// at least one `True` value in each line along the last dimension.
///
///
/// Boolean indicating whether the layer should behave in
/// training mode (adding dropout) or in inference mode (no dropout).
///
///
///
/// Tensor of shape `[batch_size, Tq, dim]`.
///
///
/// Attention scores after masking and softmax with shape
/// [batch_size, Tq, Tv]`.
///
///
public (Tensor, Tensor) _apply_scores(Tensor scores,
Tensor value,
Tensor scores_mask = null,
bool? training = null)
{
if (scores_mask != null)
{
var padding_mask = tf.logical_not(scores_mask);
// Bias so padding positions do not contribute to attention distribution.
// Note 65504. is the max float16 value.
if (scores.dtype == tf.float16)
scores -= 65504f * tf.cast(padding_mask, dtype: scores.dtype);
else
scores -= 1000000000f * tf.cast(padding_mask, dtype: scores.dtype);
}
bool _training;
training ??= false; // TODO: Delete this line when backend.learning_phase is available
if (training == null)
_training = keras.backend.learning_phase() ==
Tensorflow.Keras.GraphLearningPhase.train_mode ?
true : false;
else _training = training.Value;
var weights = tf.nn.softmax(scores);
Func dropped_weights = () => tf.nn.dropout(weights, rate: this.dropout);
weights = Tensorflow.Framework.smart_module.smart_cond(_training, dropped_weights, () => tf.identity(weights));
//return (tf.matmul(weights, value), weights);
return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights);
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
Tensors _inp;
Tensors _mask = null;
int count = inputs.Count();
if (count < 2 || count > 6) throw new ValueError(
$"{ this.name } layer accepts inputs list of length from 2 to 5, " +
$"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." +
$"Received length: {count}.");
bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL;
bool return_attention_scores = false;
if (has_bool)
{
return_attention_scores = (bool)inputs[count - 1];
count--;
}
switch (count)
{
case 2:
_inp = (inputs[0], inputs[1]);
break;
case 3:
_inp = new[] { inputs[0], inputs[1], inputs[2] };
break;
case 4:
if (inputs[0].shape == inputs[2].shape)
if (inputs[1].shape == inputs[3].shape)
{
_inp = new[] { inputs[0], inputs[1] };
_mask = new[] { inputs[2], inputs[3] };
break;
}
throw new ValueError(); //TODO:Add discriptions for this err
case 5:
_inp = new[] { inputs[0], inputs[1], inputs[2] };
_mask = (inputs[3], inputs[4]);
break;
default:
throw new ValueError(); //TODO:Add discriptions for this err
}
return call(_inp, _mask, training, return_attention_scores);
}
protected Tensors call(Tensors inputs, Tensors mask = null, bool? training = null, bool return_attention_scores = false)
{
Tensor causal_mask;
//this._validate_call_args(inputs: inputs, mask: mask);
var q = inputs[0];
var v = inputs[1];
var k = inputs.Count() > 2 ? inputs[2] : v;
var q_mask = mask != null ? mask[0] : null;
var v_mask = mask != null ? mask[1] : null;
var scores = this._calculate_scores(query: q, key: k);
if (v_mask != null)
// Mask of shape [batch_size, 1, Tv].
v_mask = tf.expand_dims(v_mask, axis: -2);
if (this.causal)
{
// Creates a lower triangular mask, so position i cannot attend to
// positions j>i. This prevents the flow of information from the future
// into the past.
var scores_shape = tf.shape(scores);
// causal_mask_shape = [1, Tq, Tv].
var causal_mask_shape = tf.concat(new List {
tf.ones_like(tf.slice(scores_shape, new[]{0}, new[]{-2})),
tf.concat(new[]{scores_shape[-2], scores_shape[-1]}, 0)
}, axis: 0);
var _causal_mask_shape = new Shape(causal_mask_shape.ToArray());
causal_mask = _lower_triangular_mask(_causal_mask_shape);
}
else
causal_mask = null;
var scores_mask = _merge_masks(v_mask, causal_mask);
var (result, attention_scores) = this._apply_scores(scores: scores, value: v, scores_mask: scores_mask, training: training);
if (q_mask != null)
{
// Mask of shape [batch_size, Tq, 1].
q_mask = tf.expand_dims(q_mask, axis: -1);
result *= tf.cast(q_mask, dtype: result.dtype);
}
if (return_attention_scores)
return new Tensors(result, attention_scores);
return result;
}
public Tensor compute_mask(Tensors inputs, Tensors mask = null)
{
this._validate_call_args(inputs: inputs, mask: mask);
if (mask != null)
{
var q_mask = mask[0];
if (q_mask == null)
return null;
return tf.convert_to_tensor(q_mask);
}
return null;
}
//public Shape compute_output_shape(Shape input_shape) {
// // return_attention_scores argument of BaseDenseAttention.call method
// // is ignored. Output shape of attention_scores cannot be returned.
// return input_shape[0];
//}
///
/// Validates arguments of the call method.
///
public void _validate_call_args(Tensors inputs, Tensors mask)
{
if (inputs.Count() < 2 || inputs.Count() > 3)
throw new ValueError(
$"{this.name} layer accepts inputs list of length 2 or 3, " +
$"namely [query, value] or [query, value, key]. Received length: {len(inputs)}.");
if (mask != null)
if (mask.Count() < 2 || mask.Count() > inputs.Count())
throw new ValueError($"{this.name} layer mask must be a list of length 2, " +
$"namely [query_mask, value_mask]. Received length: {len(mask)}.");
}
public static Tensor _lower_triangular_mask(Shape shape)
{
var row_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -2);
var col_index = tf.cumsum(tf.ones(shape: shape, dtype: tf.int32), axis: -1);
return tf.greater_equal(row_index, col_index);
}
public static Tensor _merge_masks(Tensor x, Tensor y)
{
if (x == null)
return y;
if (y == null)
return x;
return tf.logical_and(x, y);
}
public override LayerArgs get_config() => this.args;
}
}