| @@ -12,7 +12,8 @@ namespace Tensorflow.Keras.Layers { | |||||
| axis = args.axis; | axis = args.axis; | ||||
| } | } | ||||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | ||||
| Tensor x = inputs; | |||||
| Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) | |||||
| : inputs; | |||||
| Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); | Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); | ||||
| Tensor s = tf.reduce_sum(e, axis: this.axis, keepdims: true); | Tensor s = tf.reduce_sum(e, axis: this.axis, keepdims: true); | ||||
| return tf.div(e, s); | return tf.div(e, s); | ||||
| @@ -120,7 +120,7 @@ namespace Tensorflow.Keras.Layers | |||||
| int count = inputs.Count(); | int count = inputs.Count(); | ||||
| if (count < 2 || count > 6) throw new ValueError( | if (count < 2 || count > 6) throw new ValueError( | ||||
| $"{ this.name } layer accepts inputs list of length from 2 to 5, " + | |||||
| $"{ this.name } layer accepts inputs list of length from 2 to 6, " + | |||||
| $"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." + | $"namely [query, value, (key), (query_mask), (value_mask), (return_attention_scores)]." + | ||||
| $"Received length: {count}."); | $"Received length: {count}."); | ||||
| @@ -0,0 +1,355 @@ | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.NumPy; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| using System; | |||||
| using System.Linq; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public class MultiHeadAttention : Layer | |||||
| { | |||||
| static readonly string _CHR_IDX = "abcdefghijklmnopqrstuvwxyz"; | |||||
| MultiHeadAttentionArgs args; | |||||
| Shape _query_shape = null; | |||||
| Shape _key_shape = null; | |||||
| Shape _value_shape = null; | |||||
| bool _built_from_signature = false; | |||||
| EinsumDense _query_dense = null; | |||||
| EinsumDense _key_dense = null; | |||||
| EinsumDense _value_dense = null; | |||||
| EinsumDense _output_dense = null; | |||||
| string _dot_product_equation = ""; | |||||
| string _combine_equation = ""; | |||||
| Softmax _softmax = null; | |||||
| Dropout _dropout_layer = null; | |||||
| /// <summary> | |||||
| /// Builds einsum equations for the attention computation. | |||||
| /// Query, key, value inputs after projection are expected to have the shape as: | |||||
| /// `(bs, [non-attention dims], [attention dims], num_heads, channels)`. | |||||
| /// `bs` and `[non-attention dims]` are treated as `[batch dims]`. | |||||
| /// | |||||
| /// <para> | |||||
| /// The attention operations can be generalized: | |||||
| /// </para> | |||||
| /// <para> | |||||
| /// (1) Query-key dot product: | |||||
| /// `([batch dims], [query attention dims], num_heads, channels), ([batch dims], | |||||
| /// [key attention dims], num_heads, channels) -> ([batch dim], | |||||
| /// num_heads, [query attention dims], [key attention dims])` | |||||
| /// </para><para> | |||||
| /// (2) Combination: | |||||
| /// `([batch dims], num_heads, [query attention dims], [key attention dims]), | |||||
| /// ([batch dims], [value attention dims], num_heads, channels) -> ([batch dims], | |||||
| /// [query attention dims], num_heads, channels)` | |||||
| /// </para> | |||||
| /// </summary> | |||||
| /// <param name="rank">Rank of query, key, value tensors.</param> | |||||
| /// <param name="attn_axes">List/tuple of axes, `[-1, rank)`, | |||||
| /// that attention will be applied to.</param> | |||||
| /// <returns></returns> | |||||
| public static (string, string, int) _build_attention_equation(int rank, Shape attn_axes) | |||||
| { | |||||
| var target_notation = _CHR_IDX.Substring(0, rank); | |||||
| // `batch_dims` includes the head dim. | |||||
| // batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,))) | |||||
| // Since range(rank) is an IEnumerable like (0, 1, 2 ...) whose index is equal to its value | |||||
| // use IEnumerable.Except instead of np.delete which is unavailable | |||||
| var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 })); | |||||
| var letter_offset = rank; | |||||
| var source_notation = ""; | |||||
| for (int i = 0; i < rank; i++) | |||||
| { | |||||
| if (batch_dims.Contains(i) || i == rank - 1) | |||||
| source_notation += target_notation[i]; | |||||
| else | |||||
| { | |||||
| source_notation += _CHR_IDX[letter_offset]; | |||||
| letter_offset += 1; | |||||
| } | |||||
| } | |||||
| var product_notation = new string((from i in batch_dims | |||||
| select target_notation[i]).Concat( | |||||
| from i in attn_axes.as_int_list() | |||||
| select target_notation[i]).Concat( | |||||
| from i in attn_axes.as_int_list() | |||||
| select source_notation[i]).ToArray()); | |||||
| var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}"; | |||||
| var attn_scores_rank = product_notation.Count(); | |||||
| var combine_equation = $"{product_notation},{source_notation}->{target_notation}"; | |||||
| return (dot_product_equation, combine_equation, attn_scores_rank); | |||||
| } | |||||
| /// <summary> | |||||
| /// Builds an einsum equation for projections inside multi-head attention. | |||||
| /// </summary> | |||||
| public static (string, string, int) _build_proj_equation(int free_dims, int bound_dims, int output_dims) | |||||
| { | |||||
| char _char; | |||||
| var input_str = ""; | |||||
| var kernel_str = ""; | |||||
| var output_str = ""; | |||||
| var bias_axes = ""; | |||||
| var letter_offset = 0; | |||||
| foreach (var i in range(free_dims)) | |||||
| { | |||||
| _char = _CHR_IDX[i + letter_offset]; | |||||
| input_str += _char; | |||||
| output_str += _char; | |||||
| } | |||||
| letter_offset += free_dims; | |||||
| foreach (var i in range(bound_dims)) | |||||
| { | |||||
| _char = _CHR_IDX[i + letter_offset]; | |||||
| input_str += _char; | |||||
| kernel_str += _char; | |||||
| } | |||||
| letter_offset += bound_dims; | |||||
| foreach (var i in range(output_dims)) | |||||
| { | |||||
| _char = _CHR_IDX[i + letter_offset]; | |||||
| kernel_str += _char; | |||||
| output_str += _char; | |||||
| bias_axes += _char; | |||||
| } | |||||
| var equation = $"{input_str},{kernel_str}->{output_str}"; | |||||
| return (equation, bias_axes, output_str.Count()); | |||||
| } | |||||
| static Shape _get_output_shape(int output_rank, Shape known_last_dims) | |||||
| => (from _ in range(output_rank - known_last_dims.rank) | |||||
| select -1).Concat(known_last_dims.as_int_list()).ToArray(); | |||||
| public MultiHeadAttention(MultiHeadAttentionArgs args) : base(args) | |||||
| { | |||||
| this.args = args; | |||||
| } | |||||
| public void _build_from_signature(Tensor query, Tensor value, Tensor key = null) | |||||
| => this._build_from_signature(query.shape, value.shape, key?.shape); | |||||
| public void _build_from_signature(Shape query, Shape value, Shape key = null) | |||||
| { | |||||
| this._built_from_signature = true; | |||||
| this._query_shape = query; | |||||
| this._value_shape = value; | |||||
| if (key == null) | |||||
| this._key_shape = this._value_shape; | |||||
| else | |||||
| this._key_shape = key; | |||||
| // Any setup work performed only once should happen in an `init_scope` | |||||
| // to avoid creating symbolic Tensors that will later pollute any eager | |||||
| // operations. | |||||
| tf_with(tf.init_scope(), _ => | |||||
| { | |||||
| var free_dims = this._query_shape.rank - 1; | |||||
| var (einsum_equation, bias_axes, output_rank) = _build_proj_equation( | |||||
| free_dims, bound_dims: 1, output_dims: 2); | |||||
| this._query_dense = _get_dense(einsum_equation, | |||||
| _get_output_shape(output_rank - 1, | |||||
| (this.args.NumHeads, this.args.KeyDim)), | |||||
| this.args.UseBias ? bias_axes : null, | |||||
| "query"); | |||||
| (einsum_equation, bias_axes, output_rank) = _build_proj_equation( | |||||
| this._key_shape.rank - 1, bound_dims: 1, output_dims: 2); | |||||
| this._key_dense = _get_dense(einsum_equation, | |||||
| _get_output_shape(output_rank - 1, | |||||
| (this.args.NumHeads, this.args.KeyDim)), | |||||
| this.args.UseBias ? bias_axes : null, | |||||
| "key"); | |||||
| (einsum_equation, bias_axes, output_rank) = _build_proj_equation( | |||||
| this._value_shape.rank - 1, bound_dims: 1, output_dims: 2); | |||||
| this._value_dense = _get_dense(einsum_equation, | |||||
| _get_output_shape(output_rank - 1, | |||||
| (this.args.NumHeads, this.args.ValueDim ?? this.args.KeyDim)), | |||||
| this.args.UseBias ? bias_axes : null, | |||||
| "value"); | |||||
| // Builds the attention computations for multi-head dot product attention. | |||||
| // These computations could be wrapped into the keras attention layer once | |||||
| // it support mult-head einsum computations. | |||||
| this._build_attention(output_rank); | |||||
| this._output_dense = _build_output_dense(free_dims, "attention_output"); | |||||
| }); | |||||
| this.StackLayers(_query_dense, _key_dense, _value_dense, _output_dense); | |||||
| } | |||||
| EinsumDense _get_dense(string equation, Shape output_shape, string bias_axes, string name) | |||||
| => new EinsumDense(new EinsumDenseArgs() | |||||
| { | |||||
| Equation = equation, | |||||
| OutputShape = output_shape, | |||||
| BiasAxes = bias_axes, | |||||
| Name = name, | |||||
| KernelInitializer = this.args.KernelInitializer, | |||||
| BiasInitializer = this.args.BiasInitializer, | |||||
| KernelRegularizer = this.args.KernelRegularizer, | |||||
| BiasRegularizer = this.args.BiasRegularizer, | |||||
| KernelConstraint = this.args.KernelConstraint, | |||||
| BiasConstraint = this.args.BiasConstraint | |||||
| }); | |||||
| EinsumDense _build_output_dense(int free_dims, string name) | |||||
| { | |||||
| if (this.args.OutputShape == null) this.args.OutputShape = new(this._query_shape[-1]); | |||||
| var (einsum_equation, bias_axes, output_rank) = _build_proj_equation( | |||||
| free_dims, bound_dims: 2, output_dims: len(this.args.OutputShape)); | |||||
| return _get_dense(einsum_equation, | |||||
| _get_output_shape(output_rank - 1, this.args.OutputShape), | |||||
| this.args.UseBias ? bias_axes : null, | |||||
| name); | |||||
| } | |||||
| void _build_attention(int rank) | |||||
| { | |||||
| if (this.args.AttentionAxis == null) | |||||
| this.args.AttentionAxis = new(range(1, rank - 2).ToArray()); | |||||
| int attn_scores_rank; | |||||
| (this._dot_product_equation, this._combine_equation, attn_scores_rank) | |||||
| = _build_attention_equation(rank, this.args.AttentionAxis); | |||||
| var norm_axes = range(attn_scores_rank - len(this.args.AttentionAxis), | |||||
| attn_scores_rank).ToArray(); | |||||
| this._softmax = new Softmax(new SoftmaxArgs { axis = norm_axes }); | |||||
| this._dropout_layer = new Dropout(new DropoutArgs { Rate = this.args.Dropout }); | |||||
| } | |||||
| Tensor _masked_softmax(Tensor attention_scores, Tensor attention_mask = null) | |||||
| { | |||||
| if(attention_mask != null) | |||||
| { | |||||
| var mask_expansion_axis = -len(this.args.AttentionAxis) * 2 - 1; | |||||
| for (int i = 0; i < len(attention_scores.shape) - len(attention_mask.shape); i++) | |||||
| attention_mask = tf.expand_dims(attention_mask, axis: mask_expansion_axis); | |||||
| } | |||||
| return this._softmax.Apply(attention_mask == null ? attention_scores : (attention_scores, attention_mask)); | |||||
| } | |||||
| public Tensors _compute_attention( | |||||
| Tensor query, | |||||
| Tensor key, | |||||
| Tensor value, | |||||
| Tensor attention_mask = null, | |||||
| bool training = false) | |||||
| { | |||||
| // Note: Applying scalar multiply at the smaller end of einsum improves | |||||
| // XLA performance, but may introduce slight numeric differences in | |||||
| // the Transformer attention head. | |||||
| query = tf.multiply(query, 1f / tf.sqrt(tf.convert_to_tensor((float)this.args.KeyDim))); | |||||
| // Take the dot product between "query" and "key" to get the raw | |||||
| // attention scores. | |||||
| var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query)); | |||||
| attention_scores = this._masked_softmax(attention_scores, attention_mask); | |||||
| // This is actually dropping out entire tokens to attend to, which might | |||||
| // seem a bit unusual, but is taken from the original Transformer paper. | |||||
| var attention_scores_dropout = this._dropout_layer.Apply(attention_scores, training: training); | |||||
| // `context_layer` = [B, T, N, H] | |||||
| var attention_output = tf.linalg.einsum(this._combine_equation, (attention_scores_dropout, value)); | |||||
| return (attention_output, attention_scores); | |||||
| } | |||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
| { | |||||
| Tensors _inp; | |||||
| Tensor _mask = null; | |||||
| int count = inputs.Count(); | |||||
| if (count < 2 || count > 5) throw new ValueError( | |||||
| $"{ this.name } layer accepts inputs list of length from 2 to 5, " + | |||||
| $"namely [query, value, (key), (attention_mask), (return_attention_scores)]." + | |||||
| $"Received length: {count}."); | |||||
| bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL; | |||||
| bool return_attention_scores = false; | |||||
| if (has_bool) | |||||
| { | |||||
| return_attention_scores = (bool)inputs[count - 1]; | |||||
| count--; | |||||
| } | |||||
| switch (count) | |||||
| { | |||||
| case 2: | |||||
| _inp = (inputs[0], inputs[1]); | |||||
| break; | |||||
| case 3: | |||||
| if (inputs[2].shape[-1] == inputs[1].shape[-1]) | |||||
| _inp = new[] { inputs[0], inputs[1], inputs[2] }; | |||||
| else | |||||
| { | |||||
| _inp = (inputs[0], inputs[1]); | |||||
| _mask = inputs[2]; | |||||
| } | |||||
| break; | |||||
| case 4: | |||||
| _inp = new[] { inputs[0], inputs[1], inputs[2] }; | |||||
| _mask = inputs[3]; | |||||
| break; | |||||
| default: | |||||
| throw new ValueError(); //TODO:Add discriptions for this err | |||||
| } | |||||
| return call(_inp, _mask, training, return_attention_scores); | |||||
| } | |||||
| protected Tensors call(Tensors inputs, | |||||
| Tensor attention_mask, | |||||
| bool? training = null, | |||||
| bool return_attention_scores = false) | |||||
| { | |||||
| var (query, value, key) = (inputs[0], inputs[1], inputs.Length == 3 ? inputs[2] : null); | |||||
| if (!this._built_from_signature) | |||||
| this._build_from_signature(query: query, value: value, key: key); | |||||
| if (key == null) | |||||
| key = value; | |||||
| // TODO: Add RaggedTensor support | |||||
| //var query_is_ragged = query is tf.RaggedTensor; | |||||
| //if (query_is_ragged) | |||||
| //{ | |||||
| // var query_lengths = query.nested_row_lengths(); | |||||
| // query = query.to_tensor(); | |||||
| //} | |||||
| //var key_is_ragged = key is tf.RaggedTensor; | |||||
| //var value_is_ragged = value is tf.RaggedTensor; | |||||
| //if (key_is_ragged && value_is_ragged) | |||||
| //{ | |||||
| // // Ensure they have the same shape. | |||||
| // var bounding_shape = tf.math.maximum(key.bounding_shape(), value.bounding_shape()); | |||||
| // key = key.to_tensor(shape: bounding_shape); | |||||
| // value = value.to_tensor(shape: bounding_shape); | |||||
| //} | |||||
| //else if (key_is_ragged) | |||||
| //{ | |||||
| // key = key.to_tensor(shape: tf.shape(value)); | |||||
| //} | |||||
| //else if (value_is_ragged) | |||||
| //{ | |||||
| // value = value.to_tensor(shape: tf.shape(key)); | |||||
| //} | |||||
| // N = `num_attention_heads` | |||||
| // H = `size_per_head` | |||||
| // `query` = [B, T, N ,H] | |||||
| query = this._query_dense.Apply(query); | |||||
| // `key` = [B, S, N, H] | |||||
| key = this._key_dense.Apply(key); | |||||
| // `value` = [B, S, N, H] | |||||
| value = this._value_dense.Apply(value); | |||||
| var (attention_output, attention_scores) = this._compute_attention(query, key, value, attention_mask, training ?? false); | |||||
| attention_output = this._output_dense.Apply(attention_output); | |||||
| //if (query_is_ragged) | |||||
| //{ | |||||
| // attention_output = tf.RaggedTensor.from_tensor(attention_output, lengths: query_lengths); | |||||
| //} | |||||
| if (return_attention_scores) | |||||
| return (attention_output, attention_scores); | |||||
| return attention_output; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -228,7 +228,7 @@ namespace Tensorflow.Keras.Layers | |||||
| Shape output_shape, | Shape output_shape, | ||||
| bool left_elided = false) | bool left_elided = false) | ||||
| { | { | ||||
| List<long> bias_shape; | |||||
| List<int> bias_shape; | |||||
| Dictionary<char, int> output_dim_map; | Dictionary<char, int> output_dim_map; | ||||
| Dictionary<char, int> input_dim_map; | Dictionary<char, int> input_dim_map; | ||||
| @@ -275,8 +275,8 @@ namespace Tensorflow.Keras.Layers | |||||
| var input_shape_at_dim = input_shape[input_dim_map[dim]]; | var input_shape_at_dim = input_shape[input_dim_map[dim]]; | ||||
| if (output_dim_map.TryGetValue(dim, out int index)) | if (output_dim_map.TryGetValue(dim, out int index)) | ||||
| { | { | ||||
| var output_shape_at_dim = output_shape[index]; | |||||
| if (output_shape_at_dim != input_shape_at_dim) | |||||
| var output_shape_at_dim = _output_shape[index]; | |||||
| if (output_shape_at_dim != -1 && output_shape_at_dim != input_shape_at_dim) | |||||
| throw new ValueError($"Input shape and output shape do not match at shared dimension '{dim}'. " + | throw new ValueError($"Input shape and output shape do not match at shared dimension '{dim}'. " + | ||||
| $"Input shape is {input_shape_at_dim}, " + | $"Input shape is {input_shape_at_dim}, " + | ||||
| $"and output shape is {output_shape[output_dim_map[dim]]}."); | $"and output shape is {output_shape[output_dim_map[dim]]}."); | ||||
| @@ -299,7 +299,7 @@ namespace Tensorflow.Keras.Layers | |||||
| if (input_dim_map.ContainsKey(dim)) | if (input_dim_map.ContainsKey(dim)) | ||||
| weight_shape.append(input_shape[input_dim_map[dim]]); | weight_shape.append(input_shape[input_dim_map[dim]]); | ||||
| else if (output_dim_map.ContainsKey(dim)) | else if (output_dim_map.ContainsKey(dim)) | ||||
| weight_shape.append(output_shape[output_dim_map[dim]]); | |||||
| weight_shape.append(_output_shape[output_dim_map[dim]]); | |||||
| else throw new ValueError($"Weight dimension '{dim}' did not have a match in " + | else throw new ValueError($"Weight dimension '{dim}' did not have a match in " + | ||||
| $"either the input spec '{input_spec}' " + | $"either the input spec '{input_spec}' " + | ||||
| $"or the output spec '{output_spec}'. " + | $"or the output spec '{output_spec}'. " + | ||||
| @@ -310,7 +310,7 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| var num_left_elided = left_elided ? elided : 0; | var num_left_elided = left_elided ? elided : 0; | ||||
| var idx_map = output_spec.Select((_char, i) => (i, _char)) | var idx_map = output_spec.Select((_char, i) => (i, _char)) | ||||
| .ToDictionary(_ => _._char, _ => output_shape[_.i + num_left_elided]); | |||||
| .ToDictionary(_ => _._char, _ => _output_shape[_.i + num_left_elided]); | |||||
| foreach (var _char in bias_axes) | foreach (var _char in bias_axes) | ||||
| if (!output_spec.Contains(_char)) | if (!output_spec.Contains(_char)) | ||||
| throw new ValueError($"Bias dimension '{_char}' was requested," + | throw new ValueError($"Bias dimension '{_char}' was requested," + | ||||
| @@ -327,7 +327,7 @@ namespace Tensorflow.Keras.Layers | |||||
| else bias_shape = null; | else bias_shape = null; | ||||
| return (weight_shape.ToArray(), | return (weight_shape.ToArray(), | ||||
| (bias_shape ?? new List<long>()).ToArray(), | |||||
| (bias_shape ?? new List<int>()).ToArray(), | |||||
| _output_shape.ToArray()); | _output_shape.ToArray()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -21,5 +21,36 @@ namespace Tensorflow.Keras.Layers | |||||
| causal = causal, | causal = causal, | ||||
| dropout = dropout | dropout = dropout | ||||
| }); | }); | ||||
| public MultiHeadAttention MultiHeadAttention(int num_heads, | |||||
| int key_dim, | |||||
| int? value_dim = null, | |||||
| float dropout = 0f, | |||||
| bool use_bias = true, | |||||
| Shape output_shape = null, | |||||
| Shape attention_axes = null, | |||||
| IInitializer kernel_initializer = null, | |||||
| IInitializer bias_initializer = null, | |||||
| IRegularizer kernel_regularizer = null, | |||||
| IRegularizer bias_regularizer = null, | |||||
| IRegularizer activity_regularizer = null, | |||||
| Action kernel_constraint = null, | |||||
| Action bias_constraint = null) => | |||||
| new MultiHeadAttention(new MultiHeadAttentionArgs | |||||
| { | |||||
| NumHeads = num_heads, | |||||
| KeyDim = key_dim, | |||||
| ValueDim = value_dim, | |||||
| Dropout = dropout, | |||||
| UseBias = use_bias, | |||||
| OutputShape = output_shape, | |||||
| AttentionAxis = attention_axes, | |||||
| KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, | |||||
| BiasInitializer = bias_initializer ?? tf.zeros_initializer, | |||||
| KernelRegularizer = kernel_regularizer, | |||||
| BiasRegularizer = bias_regularizer, | |||||
| ActivityRegularizer = activity_regularizer, | |||||
| KernelConstraint = kernel_constraint, | |||||
| BiasConstraint = bias_constraint, | |||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,45 +15,6 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
| public class AttentionTest : EagerModeTestBase | public class AttentionTest : EagerModeTestBase | ||||
| { | { | ||||
| #region BaseDenseAttention | #region BaseDenseAttention | ||||
| [TestMethod] | |||||
| public void test_one_dim_with_mask() | |||||
| { | |||||
| // Scores tensor of shape [1, 1, 1] | |||||
| var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32); | |||||
| // Value tensor of shape [1, 1, 1] | |||||
| var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||||
| // Scores mask tensor of shape [1, 1, 1] | |||||
| var scores_mask = np.array(new[, ,] { { { true } } }, dtype: np.@bool); | |||||
| var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v, scores_mask: scores_mask); | |||||
| var actual = _tup_1.Item1; | |||||
| var actual_scores = _tup_1.Item2; | |||||
| // Expected softmax_scores = [[[1]]] | |||||
| var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32); | |||||
| Assert.AreEqual(expected_scores, actual_scores.numpy()); | |||||
| // Expected tensor of shape [1, 1, 1]. | |||||
| // expected000 = softmax_scores[0, 0] * 1.6 = 1.6 | |||||
| var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||||
| Assert.AreEqual(expected, actual.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void test_one_dim_no_mask() | |||||
| { | |||||
| // Scores tensor of shape [1, 1, 1] | |||||
| var scores = np.array(new[, ,] { { { 1.1f } } }, dtype: np.float32); | |||||
| // Value tensor of shape [1, 1, 1] | |||||
| var v = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||||
| var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v); | |||||
| var actual = _tup_1.Item1; | |||||
| var actual_scores = _tup_1.Item2; | |||||
| // Expected softmax_scores = [[[1]]] | |||||
| var expected_scores = np.array(new[, ,] { { { 1f } } }, dtype: np.float32); | |||||
| Assert.AreEqual(expected_scores, actual_scores.numpy()); | |||||
| // Expected tensor of shape [1, 1, 1]. | |||||
| // expected000 = softmax_scores[0, 0] * 1.6 = 1.6 | |||||
| var expected = np.array(new[, ,] { { { 1.6f } } }, dtype: np.float32); | |||||
| Assert.AreEqual(expected, actual.numpy()); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void test_multi_dim_with_mask() | public void test_multi_dim_with_mask() | ||||
| @@ -81,35 +42,6 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
| var expected = np.array(new[, ,] { { { 1.3579528f } } }, dtype: np.float32); | var expected = np.array(new[, ,] { { { 1.3579528f } } }, dtype: np.float32); | ||||
| Assert.AreEqual(expected, actual.numpy()); | Assert.AreEqual(expected, actual.numpy()); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void test_multi_dim_no_mask() | |||||
| { | |||||
| // Scores tensor of shape [1, 1, 3] | |||||
| var scores = np.array(new[, ,] { { { 1f, 0f, 1f } } }, dtype: np.float32); | |||||
| // Value tensor of shape [1, 3, 1] | |||||
| var v = np.array(new[, ,] { { { 1.6f }, { 0.7f }, { -0.8f } } }, dtype: np.float32); | |||||
| var _tup_1 = new BaseDenseAttention(new())._apply_scores(scores: scores, value: v); | |||||
| var actual = _tup_1.Item1; | |||||
| var actual_scores = _tup_1.Item2; | |||||
| // Expected softmax_scores = softmax(scores). | |||||
| // => softmax_scores000 = exp(1)/(exp(1) + exp(0) + exp(1)) | |||||
| // = 0.42231879825 | |||||
| // softmax_scores001 = exp(0)/(exp(1) + exp(0) + exp(1)) | |||||
| // = 0.15536240349 | |||||
| // softmax_scores002 = exp(1)/(exp(1) + exp(0) + exp(1)) | |||||
| // = 0.42231879825 | |||||
| //Actually the output is 0.42231882, 0.15536241, 0.42231882 | |||||
| var expected_scores = np.array(new[, ,] { { { 0.42231882f, 0.15536241f, 0.42231882f } } }, dtype: np.float32); | |||||
| Assert.AreEqual(expected_scores, actual_scores.numpy()); | |||||
| // Expected tensor of shape [1, 1, 1]. | |||||
| // expected000 = 0.42231879825 * 1.6 + 0.15536240349 * 0.7 | |||||
| // - 0.42231879825 * 0.8 | |||||
| // = 0.44660872104 | |||||
| //Actually the output is 0.44660875 | |||||
| var expected = np.array(new[, ,] { { { 0.44660875f } } }, dtype: np.float32); | |||||
| Assert.AreEqual(expected, actual.numpy()); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void test_one_dim_batch_size_two() | public void test_one_dim_batch_size_two() | ||||
| @@ -132,101 +64,10 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
| var expected = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32); | var expected = np.array(new[, ,] { { { 1.6f } }, { { 2.6f } } }, dtype: np.float32); | ||||
| Assert.AreEqual(expected, actual.numpy()); | Assert.AreEqual(expected, actual.numpy()); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void test_shape_with_dropout() | |||||
| { | |||||
| // scores: Scores float tensor of shape `[batch_size, tq, tv]`. | |||||
| // value: Value tensor of shape `[batch_size, tv, dim]`. | |||||
| var batch_size = 4; | |||||
| var tq = 5; | |||||
| var tv = 6; | |||||
| var dim = 7; | |||||
| var scores = np.ones((batch_size, tq, tv)); | |||||
| var value = np.ones((batch_size, tv, dim)); | |||||
| var _tup_1 = new BaseDenseAttention(new BaseDenseAttentionArgs { dropout = 0.1f }) | |||||
| ._apply_scores(scores: scores, value: value, training: false); | |||||
| var actual = _tup_1.Item1; | |||||
| var actual_scores = _tup_1.Item2; | |||||
| // Expected Tensor of shape `[batch_size, tq, tv]`. | |||||
| var expected_scores_shape = new[] { | |||||
| batch_size, | |||||
| tq, | |||||
| tv | |||||
| }; | |||||
| Assert.AreEqual(expected_scores_shape, tf.shape(actual_scores).numpy()); | |||||
| // Expected Tensor of shape `[batch_size, tq, dim]`. | |||||
| var expected_shape = new[] { | |||||
| batch_size, | |||||
| tq, | |||||
| dim | |||||
| }; | |||||
| Assert.AreEqual(expected_shape, tf.shape(actual).numpy()); | |||||
| } | |||||
| #endregion | #endregion | ||||
| // ------------------------------------------------------------------ | // ------------------------------------------------------------------ | ||||
| #region Attention | #region Attention | ||||
| [TestMethod] | |||||
| public void test_example() | |||||
| { | |||||
| //Variable-length int sequences. | |||||
| var query_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); | |||||
| var value_input = keras.Input((1000), dtype: TF_DataType.TF_INT32); | |||||
| // Embedding lookup. | |||||
| var token_embedding = keras.layers.Embedding(input_dim: 1000, output_dim: 64); | |||||
| // Query embeddings of shape [batch_size, Tq, dimension]. | |||||
| var query_embeddings = token_embedding.Apply(query_input); | |||||
| // Value embeddings of shape [batch_size, Tv, dimension]. | |||||
| var value_embeddings = token_embedding.Apply(value_input); | |||||
| // CNN layer. | |||||
| var cnn_layer = keras.layers.Conv1D( | |||||
| filters: 100, | |||||
| kernel_size: 4, | |||||
| // Use 'same' padding so outputs have the same shape as inputs. | |||||
| padding: "same", | |||||
| activation: "relu"); | |||||
| var cnn_layer2 = keras.layers.Conv1D( | |||||
| filters: 100, | |||||
| kernel_size: 4, | |||||
| // Use 'same' padding so outputs have the same shape as inputs. | |||||
| padding: "same", | |||||
| activation: "relu"); | |||||
| // Query encoding of shape [batch_size, Tq, filters]. | |||||
| var query_seq_encoding = cnn_layer.Apply(query_embeddings); | |||||
| // Value encoding of shape [batch_size, Tv, filters]. | |||||
| var value_seq_encoding = cnn_layer2.Apply(value_embeddings); | |||||
| // Query-value attention of shape [batch_size, Tq, filters]. | |||||
| var query_value_attention_seq = keras.layers.Attention().Apply( | |||||
| (query_seq_encoding, value_seq_encoding)); | |||||
| // Reduce over the sequence axis to produce encodings of shape | |||||
| // [batch_size, filters]. | |||||
| var query_encoding = keras.layers.GlobalAveragePooling1D().Apply( | |||||
| query_seq_encoding); | |||||
| var query_value_attention = keras.layers.GlobalAveragePooling1D().Apply( | |||||
| query_value_attention_seq); | |||||
| // Concatenate query and document encodings to produce a DNN input layer. | |||||
| var input_layer = keras.layers.Concatenate().Apply( | |||||
| (query_encoding, query_value_attention)); | |||||
| // Add DNN layers, and create Model. | |||||
| // ... | |||||
| } | |||||
| [TestMethod] | |||||
| public void test_calculate_scores_one_dim() | |||||
| { | |||||
| // Query tensor of shape [1, 1, 1] | |||||
| var q = np.array(new[,,] { { { 1.1f } } }, dtype: np.float32); | |||||
| // Key tensor of shape [1, 1, 1] | |||||
| var k = np.array(new[,,] { { { 1.6f } } }, dtype: np.float32); | |||||
| var attention_layer = keras.layers.Attention(); | |||||
| //attention_layer.build((1)); | |||||
| var actual = attention_layer._calculate_scores(query: q, key: k); | |||||
| // Expected tensor of shape [1, 1, 1]. | |||||
| // expected000 = 1.1*1.6 = 1.76 | |||||
| // Actually the output is 1.7600001 | |||||
| var expected = np.array(new[,,] { { { 1.7600001f } } }, dtype: np.float32); | |||||
| Assert.AreEqual(expected, actual.numpy()); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void test_calculate_scores_multi_dim() | public void test_calculate_scores_multi_dim() | ||||
| @@ -305,6 +146,31 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
| Assert.AreEqual(expected, actual.numpy()); | Assert.AreEqual(expected, actual.numpy()); | ||||
| } | } | ||||
| #endregion | #endregion | ||||
| // ------------------------------------------------------------------ | |||||
| #region MultiHeadAttention | |||||
| [TestMethod] | |||||
| public void test_masked_attention() | |||||
| { | |||||
| var batch_size = 3; | |||||
| var query = keras.Input(shape: (4, 8)); | |||||
| var value = keras.Input(shape: (2, 8)); | |||||
| var mask_tensor = keras.Input(shape:(4, 2)); | |||||
| var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2); | |||||
| attention_layer.Apply(new[] { query, value, mask_tensor }); | |||||
| var from_data = 10 * np.random.randn(batch_size, 4, 8); | |||||
| var to_data = 10 * np.random.randn(batch_size, 2, 8); | |||||
| var mask_data = np.random.randint(2, size: (batch_size, 4, 2)); | |||||
| var masked_output_data = attention_layer.Apply(new[] { from_data, to_data, mask_data }); | |||||
| var null_mask_data = np.ones((batch_size, 4, 2)); | |||||
| var unmasked_output_data = attention_layer.Apply(new[] { from_data, to_data, null_mask_data }); | |||||
| Assert.AreNotEqual(masked_output_data, unmasked_output_data); | |||||
| } | |||||
| #endregion | |||||
| } | } | ||||
| } | } | ||||