| @@ -55,6 +55,9 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| var target_notation = _CHR_IDX.Substring(0, rank); | var target_notation = _CHR_IDX.Substring(0, rank); | ||||
| // `batch_dims` includes the head dim. | // `batch_dims` includes the head dim. | ||||
| // batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,))) | |||||
| // Since range(rank) is an IEnumerable like (0, 1, 2 ...) whose index is equal to its value | |||||
| // use IEnumerable.Except instead of np.delete which is unavailable | |||||
| var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 })); | var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 })); | ||||
| var letter_offset = rank; | var letter_offset = rank; | ||||
| var source_notation = ""; | var source_notation = ""; | ||||
| @@ -68,14 +71,14 @@ namespace Tensorflow.Keras.Layers | |||||
| letter_offset += 1; | letter_offset += 1; | ||||
| } | } | ||||
| } | } | ||||
| var product_notation = "".Insert(0, new string((from i in batch_dims | |||||
| select (char)(int)target_notation[i]).Concat( | |||||
| from i in attn_axes.as_int_list() | |||||
| select (char)(int)target_notation[i]).Concat( | |||||
| from i in attn_axes.as_int_list() | |||||
| select source_notation[i]).ToArray())); | |||||
| var product_notation = new string((from i in batch_dims | |||||
| select target_notation[i]).Concat( | |||||
| from i in attn_axes.as_int_list() | |||||
| select target_notation[i]).Concat( | |||||
| from i in attn_axes.as_int_list() | |||||
| select source_notation[i]).ToArray()); | |||||
| var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}"; | var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}"; | ||||
| var attn_scores_rank = product_notation.Count(); | var attn_scores_rank = product_notation.Count(); | ||||
| var combine_equation = $"{product_notation},{source_notation}->{target_notation}"; | var combine_equation = $"{product_notation},{source_notation}->{target_notation}"; | ||||
| @@ -163,7 +166,7 @@ namespace Tensorflow.Keras.Layers | |||||
| this._value_shape.rank - 1, bound_dims: 1, output_dims: 2); | this._value_shape.rank - 1, bound_dims: 1, output_dims: 2); | ||||
| this._value_dense = _get_dense(einsum_equation, | this._value_dense = _get_dense(einsum_equation, | ||||
| _get_output_shape(output_rank - 1, | _get_output_shape(output_rank - 1, | ||||
| (this.args.NumHeads, this.args.ValueDim ?? -1)), | |||||
| (this.args.NumHeads, this.args.ValueDim ?? this.args.KeyDim)), | |||||
| this.args.UseBias ? bias_axes : null, | this.args.UseBias ? bias_axes : null, | ||||
| "value"); | "value"); | ||||
| // Builds the attention computations for multi-head dot product attention. | // Builds the attention computations for multi-head dot product attention. | ||||
| @@ -235,7 +238,7 @@ namespace Tensorflow.Keras.Layers | |||||
| // Note: Applying scalar multiply at the smaller end of einsum improves | // Note: Applying scalar multiply at the smaller end of einsum improves | ||||
| // XLA performance, but may introduce slight numeric differences in | // XLA performance, but may introduce slight numeric differences in | ||||
| // the Transformer attention head. | // the Transformer attention head. | ||||
| query = tf.multiply(query, 1d / Math.Sqrt(this.args.KeyDim)); | |||||
| query = tf.multiply(query, 1f / tf.sqrt(tf.convert_to_tensor((float)this.args.KeyDim))); | |||||
| // Take the dot product between "query" and "key" to get the raw | // Take the dot product between "query" and "key" to get the raw | ||||
| // attention scores. | // attention scores. | ||||
| var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query)); | var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query)); | ||||
| @@ -273,7 +276,7 @@ namespace Tensorflow.Keras.Layers | |||||
| _inp = (inputs[0], inputs[1]); | _inp = (inputs[0], inputs[1]); | ||||
| break; | break; | ||||
| case 3: | case 3: | ||||
| if (inputs[2].shape[-1] != inputs[0].shape[-1]) | |||||
| if (inputs[2].shape[-1] == inputs[1].shape[-1]) | |||||
| _inp = new[] { inputs[0], inputs[1], inputs[2] }; | _inp = new[] { inputs[0], inputs[1], inputs[2] }; | ||||
| else | else | ||||
| { | { | ||||
| @@ -228,7 +228,7 @@ namespace Tensorflow.Keras.Layers | |||||
| Shape output_shape, | Shape output_shape, | ||||
| bool left_elided = false) | bool left_elided = false) | ||||
| { | { | ||||
| List<long> bias_shape; | |||||
| List<int> bias_shape; | |||||
| Dictionary<char, int> output_dim_map; | Dictionary<char, int> output_dim_map; | ||||
| Dictionary<char, int> input_dim_map; | Dictionary<char, int> input_dim_map; | ||||
| @@ -275,8 +275,8 @@ namespace Tensorflow.Keras.Layers | |||||
| var input_shape_at_dim = input_shape[input_dim_map[dim]]; | var input_shape_at_dim = input_shape[input_dim_map[dim]]; | ||||
| if (output_dim_map.TryGetValue(dim, out int index)) | if (output_dim_map.TryGetValue(dim, out int index)) | ||||
| { | { | ||||
| var output_shape_at_dim = output_shape[index]; | |||||
| if (output_shape_at_dim != input_shape_at_dim) | |||||
| var output_shape_at_dim = _output_shape[index]; | |||||
| if (output_shape_at_dim != -1 && output_shape_at_dim != input_shape_at_dim) | |||||
| throw new ValueError($"Input shape and output shape do not match at shared dimension '{dim}'. " + | throw new ValueError($"Input shape and output shape do not match at shared dimension '{dim}'. " + | ||||
| $"Input shape is {input_shape_at_dim}, " + | $"Input shape is {input_shape_at_dim}, " + | ||||
| $"and output shape is {output_shape[output_dim_map[dim]]}."); | $"and output shape is {output_shape[output_dim_map[dim]]}."); | ||||
| @@ -299,7 +299,7 @@ namespace Tensorflow.Keras.Layers | |||||
| if (input_dim_map.ContainsKey(dim)) | if (input_dim_map.ContainsKey(dim)) | ||||
| weight_shape.append(input_shape[input_dim_map[dim]]); | weight_shape.append(input_shape[input_dim_map[dim]]); | ||||
| else if (output_dim_map.ContainsKey(dim)) | else if (output_dim_map.ContainsKey(dim)) | ||||
| weight_shape.append(output_shape[output_dim_map[dim]]); | |||||
| weight_shape.append(_output_shape[output_dim_map[dim]]); | |||||
| else throw new ValueError($"Weight dimension '{dim}' did not have a match in " + | else throw new ValueError($"Weight dimension '{dim}' did not have a match in " + | ||||
| $"either the input spec '{input_spec}' " + | $"either the input spec '{input_spec}' " + | ||||
| $"or the output spec '{output_spec}'. " + | $"or the output spec '{output_spec}'. " + | ||||
| @@ -310,7 +310,7 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| var num_left_elided = left_elided ? elided : 0; | var num_left_elided = left_elided ? elided : 0; | ||||
| var idx_map = output_spec.Select((_char, i) => (i, _char)) | var idx_map = output_spec.Select((_char, i) => (i, _char)) | ||||
| .ToDictionary(_ => _._char, _ => output_shape[_.i + num_left_elided]); | |||||
| .ToDictionary(_ => _._char, _ => _output_shape[_.i + num_left_elided]); | |||||
| foreach (var _char in bias_axes) | foreach (var _char in bias_axes) | ||||
| if (!output_spec.Contains(_char)) | if (!output_spec.Contains(_char)) | ||||
| throw new ValueError($"Bias dimension '{_char}' was requested," + | throw new ValueError($"Bias dimension '{_char}' was requested," + | ||||
| @@ -327,7 +327,7 @@ namespace Tensorflow.Keras.Layers | |||||
| else bias_shape = null; | else bias_shape = null; | ||||
| return (weight_shape.ToArray(), | return (weight_shape.ToArray(), | ||||
| (bias_shape ?? new List<long>()).ToArray(), | |||||
| (bias_shape ?? new List<int>()).ToArray(), | |||||
| _output_shape.ToArray()); | _output_shape.ToArray()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -151,19 +151,21 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void test_masked_attention() | public void test_masked_attention() | ||||
| { | { | ||||
| var batch_size = 3; | |||||
| var query = keras.Input(shape: (4, 8)); | var query = keras.Input(shape: (4, 8)); | ||||
| var value = keras.Input(shape: (2, 8)); | var value = keras.Input(shape: (2, 8)); | ||||
| var mask_tensor = keras.Input(shape:(4, 2)); | var mask_tensor = keras.Input(shape:(4, 2)); | ||||
| var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2); | var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2); | ||||
| attention_layer.Apply(new[] { query, value, mask_tensor }); | attention_layer.Apply(new[] { query, value, mask_tensor }); | ||||
| var from_data = 10 * np.random.randn(3, 4, 8); | |||||
| var to_data = 10 * np.random.randn(3, 2, 8); | |||||
| var from_data = 10 * np.random.randn(batch_size, 4, 8); | |||||
| var to_data = 10 * np.random.randn(batch_size, 2, 8); | |||||
| var mask_data = np.random.randint(2, size: (3, 4, 2)); | |||||
| var mask_data = np.random.randint(2, size: (batch_size, 4, 2)); | |||||
| var masked_output_data = attention_layer.Apply(new[] { from_data, to_data, mask_data }); | var masked_output_data = attention_layer.Apply(new[] { from_data, to_data, mask_data }); | ||||
| var null_mask_data = np.ones((3, 4, 2)); | |||||
| var null_mask_data = np.ones((batch_size, 4, 2)); | |||||
| var unmasked_output_data = attention_layer.Apply(new[] { from_data, to_data, null_mask_data }); | var unmasked_output_data = attention_layer.Apply(new[] { from_data, to_data, null_mask_data }); | ||||
| Assert.AreNotEqual(masked_output_data, unmasked_output_data); | Assert.AreNotEqual(masked_output_data, unmasked_output_data); | ||||