You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RnnUtils.cs 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Text;
  5. using Tensorflow.Common.Types;
  6. using Tensorflow.Keras.Layers.Rnn;
  7. using Tensorflow.Common.Extensions;
  8. namespace Tensorflow.Keras.Utils
  9. {
  10. internal static class RnnUtils
  11. {
  12. internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype)
  13. {
  14. Func<GeneralizedTensorShape, Tensor> create_zeros;
  15. create_zeros = (GeneralizedTensorShape unnested_state_size) =>
  16. {
  17. var flat_dims = unnested_state_size.ToSingleShape().dims;
  18. var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray();
  19. return array_ops.zeros(new Shape(init_state_size), dtype: dtype);
  20. };
  21. // TODO(Rinne): map structure with nested tensors.
  22. if(state_size.Shapes.Length > 1)
  23. {
  24. return new Tensors(state_size.ToShapeArray().Select(s => create_zeros(new GeneralizedTensorShape(s))));
  25. }
  26. else
  27. {
  28. return create_zeros(state_size);
  29. }
  30. }
  31. internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype)
  32. {
  33. if (inputs != null)
  34. {
  35. batch_size = inputs.shape[0];
  36. dtype = inputs.dtype;
  37. }
  38. return generate_zero_filled_state(batch_size, cell.StateSize, dtype);
  39. }
  40. /// <summary>
  41. /// Standardizes `__call__` to a single list of tensor inputs.
  42. ///
  43. /// When running a model loaded from a file, the input tensors
  44. /// `initial_state` and `constants` can be passed to `RNN.__call__()` as part
  45. /// of `inputs` instead of by the dedicated keyword arguments.This method
  46. /// makes sure the arguments are separated and that `initial_state` and
  47. /// `constants` are lists of tensors(or None).
  48. /// </summary>
  49. /// <param name="inputs">Tensor or list/tuple of tensors. which may include constants
  50. /// and initial states.In that case `num_constant` must be specified.</param>
  51. /// <param name="initial_state">Tensor or list of tensors or None, initial states.</param>
  52. /// <param name="constants">Tensor or list of tensors or None, constant tensors.</param>
  53. /// <param name="num_constants">Expected number of constants (if constants are passed as
  54. /// part of the `inputs` list.</param>
  55. /// <returns></returns>
  56. internal static (Tensors, Tensors, Tensors) standardize_args(Tensors inputs, Tensors initial_state, Tensors constants, int num_constants)
  57. {
  58. if(inputs.Length > 1)
  59. {
  60. // There are several situations here:
  61. // In the graph mode, __call__ will be only called once. The initial_state
  62. // and constants could be in inputs (from file loading).
  63. // In the eager mode, __call__ will be called twice, once during
  64. // rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
  65. // model.fit/train_on_batch/predict with real np data. In the second case,
  66. // the inputs will contain initial_state and constants as eager tensor.
  67. //
  68. // For either case, the real input is the first item in the list, which
  69. // could be a nested structure itself. Then followed by initial_states, which
  70. // could be a list of items, or list of list if the initial_state is complex
  71. // structure, and finally followed by constants which is a flat list.
  72. Debug.Assert(initial_state is null && constants is null);
  73. if(num_constants > 0)
  74. {
  75. constants = inputs.TakeLast(num_constants).ToTensors();
  76. inputs = inputs.SkipLast(num_constants).ToTensors();
  77. }
  78. if(inputs.Length > 1)
  79. {
  80. initial_state = inputs.Skip(1).ToTensors();
  81. inputs = inputs.Take(1).ToTensors();
  82. }
  83. }
  84. return (inputs, initial_state, constants);
  85. }
  86. }
  87. }