You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MultiHeadAttention.cs 17 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. using Tensorflow.Keras.ArgsDefinition;
  2. using Tensorflow.Keras.ArgsDefinition.Core;
  3. using Tensorflow.Keras.Engine;
  4. using Tensorflow.NumPy;
  5. using static Tensorflow.Binding;
  6. using static Tensorflow.KerasApi;
  7. using System;
  8. using System.Linq;
  9. using Tensorflow.Common.Types;
  10. namespace Tensorflow.Keras.Layers
  11. {
  12. public class MultiHeadAttention : Layer
  13. {
  14. static readonly string _CHR_IDX = "abcdefghijklmnopqrstuvwxyz";
  15. MultiHeadAttentionArgs args;
  16. Shape _query_shape = null;
  17. Shape _key_shape = null;
  18. Shape _value_shape = null;
  19. bool _built_from_signature = false;
  20. EinsumDense _query_dense = null;
  21. EinsumDense _key_dense = null;
  22. EinsumDense _value_dense = null;
  23. EinsumDense _output_dense = null;
  24. string _dot_product_equation = "";
  25. string _combine_equation = "";
  26. Softmax _softmax = null;
  27. Dropout _dropout_layer = null;
  28. /// <summary>
  29. /// Builds einsum equations for the attention computation.
  30. /// Query, key, value inputs after projection are expected to have the shape as:
  31. /// `(bs, [non-attention dims], [attention dims], num_heads, channels)`.
  32. /// `bs` and `[non-attention dims]` are treated as `[batch dims]`.
  33. ///
  34. /// <para>
  35. /// The attention operations can be generalized:
  36. /// </para>
  37. /// <para>
  38. /// (1) Query-key dot product:
  39. /// `([batch dims], [query attention dims], num_heads, channels), ([batch dims],
  40. /// [key attention dims], num_heads, channels) -> ([batch dim],
  41. /// num_heads, [query attention dims], [key attention dims])`
  42. /// </para><para>
  43. /// (2) Combination:
  44. /// `([batch dims], num_heads, [query attention dims], [key attention dims]),
  45. /// ([batch dims], [value attention dims], num_heads, channels) -> ([batch dims],
  46. /// [query attention dims], num_heads, channels)`
  47. /// </para>
  48. /// </summary>
  49. /// <param name="rank">Rank of query, key, value tensors.</param>
  50. /// <param name="attn_axes">List/tuple of axes, `[-1, rank)`,
  51. /// that attention will be applied to.</param>
  52. /// <returns></returns>
  53. public static (string, string, int) _build_attention_equation(int rank, Shape attn_axes)
  54. {
  55. var target_notation = _CHR_IDX.Substring(0, rank);
  56. // `batch_dims` includes the head dim.
  57. // batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
  58. // Since range(rank) is an IEnumerable like (0, 1, 2 ...) whose index is equal to its value
  59. // use IEnumerable.Except instead of np.delete which is unavailable
  60. var batch_dims = range(rank).Except(attn_axes.as_int_list().concat(new[] { rank - 1 }));
  61. var letter_offset = rank;
  62. var source_notation = "";
  63. for (int i = 0; i < rank; i++)
  64. {
  65. if (batch_dims.Contains(i) || i == rank - 1)
  66. source_notation += target_notation[i];
  67. else
  68. {
  69. source_notation += _CHR_IDX[letter_offset];
  70. letter_offset += 1;
  71. }
  72. }
  73. var product_notation = new string((from i in batch_dims
  74. select target_notation[i]).Concat(
  75. from i in attn_axes.as_int_list()
  76. select target_notation[i]).Concat(
  77. from i in attn_axes.as_int_list()
  78. select source_notation[i]).ToArray());
  79. var dot_product_equation = $"{source_notation},{target_notation}->{product_notation}";
  80. var attn_scores_rank = product_notation.Count();
  81. var combine_equation = $"{product_notation},{source_notation}->{target_notation}";
  82. return (dot_product_equation, combine_equation, attn_scores_rank);
  83. }
  84. /// <summary>
  85. /// Builds an einsum equation for projections inside multi-head attention.
  86. /// </summary>
  87. public static (string, string, int) _build_proj_equation(int free_dims, int bound_dims, int output_dims)
  88. {
  89. char _char;
  90. var input_str = "";
  91. var kernel_str = "";
  92. var output_str = "";
  93. var bias_axes = "";
  94. var letter_offset = 0;
  95. foreach (var i in range(free_dims))
  96. {
  97. _char = _CHR_IDX[i + letter_offset];
  98. input_str += _char;
  99. output_str += _char;
  100. }
  101. letter_offset += free_dims;
  102. foreach (var i in range(bound_dims))
  103. {
  104. _char = _CHR_IDX[i + letter_offset];
  105. input_str += _char;
  106. kernel_str += _char;
  107. }
  108. letter_offset += bound_dims;
  109. foreach (var i in range(output_dims))
  110. {
  111. _char = _CHR_IDX[i + letter_offset];
  112. kernel_str += _char;
  113. output_str += _char;
  114. bias_axes += _char;
  115. }
  116. var equation = $"{input_str},{kernel_str}->{output_str}";
  117. return (equation, bias_axes, output_str.Count());
  118. }
  119. static Shape _get_output_shape(int output_rank, Shape known_last_dims)
  120. => (from _ in range(output_rank - known_last_dims.rank)
  121. select -1).Concat(known_last_dims.as_int_list()).ToArray();
  122. public MultiHeadAttention(MultiHeadAttentionArgs args) : base(args)
  123. {
  124. this.args = args;
  125. }
  126. public void _build_from_signature(Tensor query, Tensor value, Tensor key = null)
  127. => this._build_from_signature(query.shape, value.shape, key?.shape);
  128. public void _build_from_signature(Shape query, Shape value, Shape key = null)
  129. {
  130. this._built_from_signature = true;
  131. this._query_shape = query;
  132. this._value_shape = value;
  133. if (key == null)
  134. this._key_shape = this._value_shape;
  135. else
  136. this._key_shape = key;
  137. // Any setup work performed only once should happen in an `init_scope`
  138. // to avoid creating symbolic Tensors that will later pollute any eager
  139. // operations.
  140. tf_with(tf.init_scope(), _ =>
  141. {
  142. var free_dims = this._query_shape.rank - 1;
  143. var (einsum_equation, bias_axes, output_rank) = _build_proj_equation(
  144. free_dims, bound_dims: 1, output_dims: 2);
  145. this._query_dense = _get_dense(einsum_equation,
  146. _get_output_shape(output_rank - 1,
  147. (this.args.NumHeads, this.args.KeyDim)),
  148. this.args.UseBias ? bias_axes : null,
  149. "query");
  150. (einsum_equation, bias_axes, output_rank) = _build_proj_equation(
  151. this._key_shape.rank - 1, bound_dims: 1, output_dims: 2);
  152. this._key_dense = _get_dense(einsum_equation,
  153. _get_output_shape(output_rank - 1,
  154. (this.args.NumHeads, this.args.KeyDim)),
  155. this.args.UseBias ? bias_axes : null,
  156. "key");
  157. (einsum_equation, bias_axes, output_rank) = _build_proj_equation(
  158. this._value_shape.rank - 1, bound_dims: 1, output_dims: 2);
  159. this._value_dense = _get_dense(einsum_equation,
  160. _get_output_shape(output_rank - 1,
  161. (this.args.NumHeads, this.args.ValueDim ?? this.args.KeyDim)),
  162. this.args.UseBias ? bias_axes : null,
  163. "value");
  164. // Builds the attention computations for multi-head dot product attention.
  165. // These computations could be wrapped into the keras attention layer once
  166. // it support mult-head einsum computations.
  167. this._build_attention(output_rank);
  168. this._output_dense = _build_output_dense(free_dims, "attention_output");
  169. });
  170. this.StackLayers(_query_dense, _key_dense, _value_dense, _output_dense);
  171. }
  172. EinsumDense _get_dense(string equation, Shape output_shape, string bias_axes, string name)
  173. => new EinsumDense(new EinsumDenseArgs()
  174. {
  175. Equation = equation,
  176. OutputShape = output_shape,
  177. BiasAxes = bias_axes,
  178. Name = name,
  179. KernelInitializer = this.args.KernelInitializer,
  180. BiasInitializer = this.args.BiasInitializer,
  181. KernelRegularizer = this.args.KernelRegularizer,
  182. BiasRegularizer = this.args.BiasRegularizer,
  183. KernelConstraint = this.args.KernelConstraint,
  184. BiasConstraint = this.args.BiasConstraint
  185. });
  186. EinsumDense _build_output_dense(int free_dims, string name)
  187. {
  188. if (this.args.OutputShape == null) this.args.OutputShape = new(this._query_shape[-1]);
  189. var (einsum_equation, bias_axes, output_rank) = _build_proj_equation(
  190. free_dims, bound_dims: 2, output_dims: len(this.args.OutputShape));
  191. return _get_dense(einsum_equation,
  192. _get_output_shape(output_rank - 1, this.args.OutputShape),
  193. this.args.UseBias ? bias_axes : null,
  194. name);
  195. }
  196. void _build_attention(int rank)
  197. {
  198. if (this.args.AttentionAxis == null)
  199. this.args.AttentionAxis = new(range(1, rank - 2).ToArray());
  200. int attn_scores_rank;
  201. (this._dot_product_equation, this._combine_equation, attn_scores_rank)
  202. = _build_attention_equation(rank, this.args.AttentionAxis);
  203. var norm_axes = range(attn_scores_rank - len(this.args.AttentionAxis),
  204. attn_scores_rank).ToArray();
  205. this._softmax = new Softmax(new SoftmaxArgs { axis = norm_axes });
  206. this._dropout_layer = new Dropout(new DropoutArgs { Rate = this.args.Dropout });
  207. }
  208. Tensor _masked_softmax(Tensor attention_scores, Tensor attention_mask = null)
  209. {
  210. if(attention_mask != null)
  211. {
  212. var mask_expansion_axis = -len(this.args.AttentionAxis) * 2 - 1;
  213. for (int i = 0; i < len(attention_scores.shape) - len(attention_mask.shape); i++)
  214. attention_mask = tf.expand_dims(attention_mask, axis: mask_expansion_axis);
  215. }
  216. return this._softmax.Apply(attention_mask == null ? attention_scores : (attention_scores, attention_mask));
  217. }
  218. public Tensors _compute_attention(
  219. Tensor query,
  220. Tensor key,
  221. Tensor value,
  222. Tensor attention_mask = null,
  223. bool training = false)
  224. {
  225. // Note: Applying scalar multiply at the smaller end of einsum improves
  226. // XLA performance, but may introduce slight numeric differences in
  227. // the Transformer attention head.
  228. query = tf.multiply(query, 1f / tf.sqrt(tf.convert_to_tensor((float)this.args.KeyDim)));
  229. // Take the dot product between "query" and "key" to get the raw
  230. // attention scores.
  231. var attention_scores = tf.linalg.einsum(this._dot_product_equation, (key, query));
  232. attention_scores = this._masked_softmax(attention_scores, attention_mask);
  233. // This is actually dropping out entire tokens to attend to, which might
  234. // seem a bit unusual, but is taken from the original Transformer paper.
  235. var attention_scores_dropout = this._dropout_layer.Apply(attention_scores, training: training);
  236. // `context_layer` = [B, T, N, H]
  237. var attention_output = tf.linalg.einsum(this._combine_equation, (attention_scores_dropout, value));
  238. return (attention_output, attention_scores);
  239. }
  240. <<<<<<< HEAD
  241. protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
  242. =======
  243. protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
  244. >>>>>>> master
  245. {
  246. Tensors _inp;
  247. Tensor _mask = null;
  248. int count = inputs.Count();
  249. if (count < 2 || count > 5) throw new ValueError(
  250. $"{ this.name } layer accepts inputs list of length from 2 to 5, " +
  251. $"namely [query, value, (key), (attention_mask), (return_attention_scores)]." +
  252. $"Received length: {count}.");
  253. bool has_bool = inputs[count - 1].dtype == TF_DataType.TF_BOOL;
  254. bool return_attention_scores = false;
  255. if (has_bool)
  256. {
  257. return_attention_scores = (bool)inputs[count - 1];
  258. count--;
  259. }
  260. switch (count)
  261. {
  262. case 2:
  263. _inp = (inputs[0], inputs[1]);
  264. break;
  265. case 3:
  266. if (inputs[2].shape[-1] == inputs[1].shape[-1])
  267. _inp = new[] { inputs[0], inputs[1], inputs[2] };
  268. else
  269. {
  270. _inp = (inputs[0], inputs[1]);
  271. _mask = inputs[2];
  272. }
  273. break;
  274. case 4:
  275. _inp = new[] { inputs[0], inputs[1], inputs[2] };
  276. _mask = inputs[3];
  277. break;
  278. default:
  279. throw new ValueError(); //TODO:Add discriptions for this err
  280. }
  281. return call(_inp, _mask, training, return_attention_scores);
  282. }
  283. protected Tensors call(Tensors inputs,
  284. Tensor attention_mask,
  285. bool? training = null,
  286. bool return_attention_scores = false)
  287. {
  288. var (query, value, key) = (inputs[0], inputs[1], inputs.Length == 3 ? inputs[2] : null);
  289. if (!this._built_from_signature)
  290. this._build_from_signature(query: query, value: value, key: key);
  291. if (key == null)
  292. key = value;
  293. // TODO: Add RaggedTensor support
  294. //var query_is_ragged = query is tf.RaggedTensor;
  295. //if (query_is_ragged)
  296. //{
  297. // var query_lengths = query.nested_row_lengths();
  298. // query = query.to_tensor();
  299. //}
  300. //var key_is_ragged = key is tf.RaggedTensor;
  301. //var value_is_ragged = value is tf.RaggedTensor;
  302. //if (key_is_ragged && value_is_ragged)
  303. //{
  304. // // Ensure they have the same shape.
  305. // var bounding_shape = tf.math.maximum(key.bounding_shape(), value.bounding_shape());
  306. // key = key.to_tensor(shape: bounding_shape);
  307. // value = value.to_tensor(shape: bounding_shape);
  308. //}
  309. //else if (key_is_ragged)
  310. //{
  311. // key = key.to_tensor(shape: tf.shape(value));
  312. //}
  313. //else if (value_is_ragged)
  314. //{
  315. // value = value.to_tensor(shape: tf.shape(key));
  316. //}
  317. // N = `num_attention_heads`
  318. // H = `size_per_head`
  319. // `query` = [B, T, N ,H]
  320. query = this._query_dense.Apply(query);
  321. // `key` = [B, S, N, H]
  322. key = this._key_dense.Apply(key);
  323. // `value` = [B, S, N, H]
  324. value = this._value_dense.Apply(value);
  325. var (attention_output, attention_scores) = this._compute_attention(query, key, value, attention_mask, training ?? false);
  326. attention_output = this._output_dense.Apply(attention_output);
  327. //if (query_is_ragged)
  328. //{
  329. // attention_output = tf.RaggedTensor.from_tensor(attention_output, lengths: query_lengths);
  330. //}
  331. if (return_attention_scores)
  332. return (attention_output, attention_scores.Single);
  333. return attention_output;
  334. }
  335. }
  336. }