using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Core;
using Tensorflow.Keras.Engine;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System;
using System.Linq;
using Tensorflow.Common.Types;
namespace Tensorflow.Keras.Layers
{
public class MultiHeadAttention : Layer
{
static readonly string _CHR_IDX = "abcdefghijklmnopqrstuvwxyz";
MultiHeadAttentionArgs args;
Shape _query_shape = null;
Shape _key_shape = null;
Shape _value_shape = null;
bool _built_from_signature = false;
EinsumDense _query_dense = null;
EinsumDense _key_dense = null;
EinsumDense _value_dense = null;
EinsumDense _output_dense = null;
string _dot_product_equation = "";
string _combine_equation = "";
Softmax _softmax = null;
Dropout _dropout_layer = null;
///
/// Builds einsum equations for the attention computation.
/// Query, key, value inputs after projection are expected to have the shape as:
/// `(bs, [non-attention dims], [attention dims], num_heads, channels)`.
/// `bs` and `[non-attention dims]` are treated as `[batch dims]`.
///
///
/// The attention operations can be generalized:
///
///
/// (1) Query-key dot product:
/// `([batch dims], [query attention dims], num_heads, channels), ([batch dims],
/// [key attention dims], num_heads, channels) -> ([batch dim],
/// num_heads, [query attention dims], [key attention dims])`
///
/// (2) Combination:
/// `([batch dims], num_heads, [query attention dims], [key attention dims]),
/// ([batch dims], [value attention dims], num_heads, channels) -> ([batch dims],
/// [query attention dims], num_heads, channels)`
///
///
/// Rank of query, key, value tensors.
/// List/tuple of axes, `[-1, rank)`,
/// that attention will be applied to.
///
public static (string, string, int) _build_attention_equation(int rank, Shape attn_axes)
{
var target_notation = _CHR_IDX.Substring(0, rank);
// `batch_dims` includes the head dim.
// batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
// Since range(rank) is an IEnumerable like (0, 1, 2 ...) whose index is equal to its value
// use IEnumerable.Except instead of np.delete which is unavailable
var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 }));
var letter_offset = rank;
var source_notation = "";
for (int i = 0; i < rank; i++)
{
if (batch_dims.Contains(i) || i == rank - 1)
source_notation += target_notation[i];
else
{
source_notation += _CHR_IDX[letter_offset];
letter_offset += 1;
}
}
var product_notation = new string((from i in batch_dims
select target_notation[i]).Concat(
from i in attn_axes.as_int_list()
select target_notation[i]).Concat(
from i in attn_axes.as_int_list()
select source_notation[i]).ToArray());
var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}";
var attn_scores_rank = product_notation.Count();
var combine_equation = $"{product_notation},{source_notation}->{target_notation}";
return (dot_product_equation, combine_equation, attn_scores_rank);
}
///
/// Builds an einsum equation for projections inside multi-head attention.
///
public static (string, string, int) _build_proj_equation(int free_dims, int bound_dims, int output_dims)
{
char _char;
var input_str = "";
var kernel_str = "";
var output_str = "";
var bias_axes = "";
var letter_offset = 0;
foreach (var i in range(free_dims))
{
_char = _CHR_IDX[i + letter_offset];
input_str += _char;
output_str += _char;
}
letter_offset += free_dims;
foreach (var i in range(bound_dims))
{
_char = _CHR_IDX[i + letter_offset];
input_str += _char;
kernel_str += _char;
}
letter_offset += bound_dims;
foreach (var i in range(output_dims))
{
_char = _CHR_IDX[i + letter_offset];
kernel_str += _char;
output_str += _char;
bias_axes += _char;
}
var equation = $"{input_str},{kernel_str}->{output_str}";
return (equation, bias_axes, output_str.Count());
}
static Shape _get_output_shape(int output_rank, Shape known_last_dims)
=> (from _ in range(output_rank - known_last_dims.rank)
select -1).Concat(known_last_dims.as_int_list()).ToArray();
public MultiHeadAttention(MultiHeadAttentionArgs args) : base(args)
{
this.args = args;
}
public void _build_from_signature(Tensor query, Tensor value, Tensor key = null)
=> this._build_from_signature(query.shape, value.shape, key?.shape);
public void _build_from_signature(Shape query, Shape value, Shape key = null)
{
this._built_from_signature = true;
this._query_shape = query;
this._value_shape = value;
if (key == null)
this._key_shape = this._value_shape;
else
this._key_shape = key;
// Any setup work performed only once should happen in an `init_scope`
// to avoid creating symbolic Tensors that will later pollute any eager
// operations.
tf_with(tf.init_scope(), _ =>
{
var free_dims = this._query_shape.rank - 1;
var (einsum_equation, bias_axes, output_rank) = _build_proj_equation(
free_dims, bound_dims: 1, output_dims: 2);
this._query_dense = _get_dense(einsum_equation,
_get_output_shape(output_rank - 1,
(this.args.NumHeads, this.args.KeyDim)),
this.args.UseBias ? bias_axes : null,
"query");
(einsum_equation, bias_axes, output_rank) = _build_proj_equation(
this._key_shape.rank - 1, bound_dims: 1, output_dims: 2);
this._key_dense = _get_dense(einsum_equation,
_get_output_shape(output_rank - 1,
(this.args.NumHeads, this.args.KeyDim)),
this.args.UseBias ? bias_axes : null,
"key");
(einsum_equation, bias_axes, output_rank) = _build_proj_equation(
this._value_shape.rank - 1, bound_dims: 1, output_dims: 2);
this._value_dense = _get_dense(einsum_equation,
_get_output_shape(output_rank - 1,
(this.args.NumHeads, this.args.ValueDim ?? this.args.KeyDim)),
this.args.UseBias ? bias_axes : null,
"value");
// Builds the attention computations for multi-head dot product attention.
// These computations could be wrapped into the keras attention layer once
// it support mult-head einsum computations.
this._build_attention(output_rank);
this._output_dense = _build_output_dense(free_dims, "attention_output");
});
this.StackLayers(_query_dense, _key_dense, _value_dense, _output_dense);
}
EinsumDense _get_dense(string equation, Shape output_shape, string bias_axes, string name)
=> new EinsumDense(new EinsumDenseArgs()
{
Equation = equation,
OutputShape = output_shape,
BiasAxes = bias_axes,
Name = name,
KernelInitializer = this.args.KernelInitializer,
BiasInitializer = this.args.BiasInitializer,
KernelRegularizer = this.args.KernelRegularizer,
BiasRegularizer = this.args.BiasRegularizer,
KernelConstraint = this.args.KernelConstraint,
BiasConstraint = this.args.BiasConstraint
});
EinsumDense _build_output_dense(int free_dims, string name)
{
if (this.args.OutputShape == null) this.args.OutputShape = new(this._query_shape[-1]);
var (einsum_equation, bias_axes, output_rank) = _build_proj_equation(
free_dims, bound_dims: 2, output_dims: len(this.args.OutputShape));
return _get_dense(einsum_equation,
_get_output_shape(output_rank - 1, this.args.OutputShape),
this.args.UseBias ? bias_axes : null,
name);
}
void _build_attention(int rank)
{
if (this.args.AttentionAxis == null)
this.args.AttentionAxis = new(range(1, rank - 2).ToArray());
int attn_scores_rank;
(this._dot_product_equation, this._combine_equation, attn_scores_rank)
= _build_attention_equation(rank, this.args.AttentionAxis);
var norm_axes = range(attn_scores_rank - len(this.args.AttentionAxis),
attn_scores_rank).ToArray();
this._softmax = new Softmax(new SoftmaxArgs { axis = norm_axes });
this._dropout_layer = new Dropout(new DropoutArgs { Rate = this.args.Dropout });
}
Tensor _masked_softmax(Tensor attention_scores, Tensor attention_mask = null)
{
if(attention_mask != null)
{
var mask_expansion_axis = -len(this.args.AttentionAxis) * 2 - 1;
for (int i = 0; i < len(attention_scores.shape) - len(attention_mask.shape); i++)
attention_mask = tf.expand_dims(attention_mask, axis: mask_expansion_axis);
}
return this._softmax.Apply(attention_mask == null ? attention_scores : (attention_scores, attention_mask));
}
public Tensors _compute_attention(
Tensor query,
Tensor key,
Tensor value,
Tensor attention_mask = null,
bool training = false)
{
// Note: Applying scalar multiply at the smaller end of einsum improves
// XLA performance, but may introduce slight numeric differences in
// the Transformer attention head.
query = tf.multiply(query, 1f / tf.sqrt(tf.convert_to_tensor((float)this.args.KeyDim)));
// Take the dot product between "query" and "key" to get the raw
// attention scores.
var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query));
attention_scores = this._masked_softmax(attention_scores, attention_mask);
// This is actually dropping out entire tokens to attend to, which might
// seem a bit unusual, but is taken from the original Transformer paper.
var attention_scores_dropout = this._dropout_layer.Apply(attention_scores, training: training);
// `context_layer` = [B, T, N, H]
var attention_output = tf.linalg.einsum(this._combine_equation, (attention_scores_dropout, value));
return (attention_output, attention_scores);
}
<<<<<<< HEAD
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
=======
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
>>>>>>> master
{
Tensors _inp;
Tensor _mask = null;
int count = inputs.Count();
if (count < 2 || count > 5) throw new ValueError(
$"{ this.name } layer accepts inputs list of length from 2 to 5, " +
$"namely [query, value, (key), (attention_mask), (return_attention_scores)]." +
$"Received length: {count}.");
bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL;
bool return_attention_scores = false;
if (has_bool)
{
return_attention_scores = (bool)inputs[count - 1];
count--;
}
switch (count)
{
case 2:
_inp = (inputs[0], inputs[1]);
break;
case 3:
if (inputs[2].shape[-1] == inputs[1].shape[-1])
_inp = new[] { inputs[0], inputs[1], inputs[2] };
else
{
_inp = (inputs[0], inputs[1]);
_mask = inputs[2];
}
break;
case 4:
_inp = new[] { inputs[0], inputs[1], inputs[2] };
_mask = inputs[3];
break;
default:
throw new ValueError(); //TODO:Add discriptions for this err
}
return call(_inp, _mask, training, return_attention_scores);
}
protected Tensors call(Tensors inputs,
Tensor attention_mask,
bool? training = null,
bool return_attention_scores = false)
{
var (query, value, key) = (inputs[0], inputs[1], inputs.Length == 3 ? inputs[2] : null);
if (!this._built_from_signature)
this._build_from_signature(query: query, value: value, key: key);
if (key == null)
key = value;
// TODO: Add RaggedTensor support
//var query_is_ragged = query is tf.RaggedTensor;
//if (query_is_ragged)
//{
// var query_lengths = query.nested_row_lengths();
// query = query.to_tensor();
//}
//var key_is_ragged = key is tf.RaggedTensor;
//var value_is_ragged = value is tf.RaggedTensor;
//if (key_is_ragged && value_is_ragged)
//{
// // Ensure they have the same shape.
// var bounding_shape = tf.math.maximum(key.bounding_shape(), value.bounding_shape());
// key = key.to_tensor(shape: bounding_shape);
// value = value.to_tensor(shape: bounding_shape);
//}
//else if (key_is_ragged)
//{
// key = key.to_tensor(shape: tf.shape(value));
//}
//else if (value_is_ragged)
//{
// value = value.to_tensor(shape: tf.shape(key));
//}
// N = `num_attention_heads`
// H = `size_per_head`
// `query` = [B, T, N ,H]
query = this._query_dense.Apply(query);
// `key` = [B, S, N, H]
key = this._key_dense.Apply(key);
// `value` = [B, S, N, H]
value = this._value_dense.Apply(value);
var (attention_output, attention_scores) = this._compute_attention(query, key, value, attention_mask, training ?? false);
attention_output = this._output_dense.Apply(attention_output);
//if (query_is_ragged)
//{
// attention_output = tf.RaggedTensor.from_tensor(attention_output, lengths: query_lengths);
//}
if (return_attention_scores)
return (attention_output, attention_scores.Single);
return attention_output;
}
}
}