You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LayersApi.Attention.cs 2.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. using System;
  2. using Tensorflow.NumPy;
  3. using System.Collections.Generic;
  4. using Tensorflow.Keras.ArgsDefinition;
  5. using Tensorflow.Keras.Engine;
  6. using static Tensorflow.Binding;
  7. using static Tensorflow.KerasApi;
  8. namespace Tensorflow.Keras.Layers
  9. {
  10. public partial class LayersApi
  11. {
  12. public ILayer Attention(bool use_scale = false,
  13. string score_mode = "dot",
  14. bool causal = false,
  15. float dropout = 0f) =>
  16. new Attention(new AttentionArgs
  17. {
  18. use_scale = use_scale,
  19. score_mode = score_mode,
  20. causal = causal,
  21. dropout = dropout
  22. });
  23. public ILayer MultiHeadAttention(int num_heads,
  24. int key_dim,
  25. int? value_dim = null,
  26. float dropout = 0f,
  27. bool use_bias = true,
  28. Shape output_shape = null,
  29. Shape attention_axes = null,
  30. IInitializer kernel_initializer = null,
  31. IInitializer bias_initializer = null,
  32. IRegularizer kernel_regularizer = null,
  33. IRegularizer bias_regularizer = null,
  34. IRegularizer activity_regularizer = null,
  35. Action kernel_constraint = null,
  36. Action bias_constraint = null) =>
  37. new MultiHeadAttention(new MultiHeadAttentionArgs
  38. {
  39. NumHeads = num_heads,
  40. KeyDim = key_dim,
  41. ValueDim = value_dim,
  42. Dropout = dropout,
  43. UseBias = use_bias,
  44. OutputShape = output_shape,
  45. AttentionAxis = attention_axes,
  46. KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer,
  47. BiasInitializer = bias_initializer ?? tf.zeros_initializer,
  48. KernelRegularizer = kernel_regularizer,
  49. BiasRegularizer = bias_regularizer,
  50. ActivityRegularizer = activity_regularizer,
  51. KernelConstraint = kernel_constraint,
  52. BiasConstraint = bias_constraint,
  53. });
  54. }
  55. }