You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RnnUtils.cs 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Text;
  5. using Tensorflow.Common.Types;
  6. using Tensorflow.Keras.Layers.Rnn;
  7. using Tensorflow.Common.Extensions;
  8. using System.Linq;
  9. namespace Tensorflow.Keras.Utils
  10. {
  11. internal static class RnnUtils
  12. {
  13. internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype)
  14. {
  15. Func<GeneralizedTensorShape, Tensor> create_zeros;
  16. create_zeros = (GeneralizedTensorShape unnested_state_size) =>
  17. {
  18. var flat_dims = unnested_state_size.ToSingleShape().dims;
  19. var init_state_size = new List<object> { batch_size_tensor};
  20. foreach(var dim in flat_dims)
  21. {
  22. init_state_size.add(dim);
  23. }
  24. var init_state_size_tensor = ops.convert_to_tensor(init_state_size.ToArray());
  25. return array_ops.zeros(init_state_size_tensor);
  26. };
  27. // TODO(Rinne): map structure with nested tensors.
  28. if(state_size.Shapes.Length > 1)
  29. {
  30. return new Tensors(state_size.ToShapeArray().Select(s => create_zeros(new GeneralizedTensorShape(s))));
  31. }
  32. else
  33. {
  34. return create_zeros(state_size);
  35. }
  36. }
  37. internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype)
  38. {
  39. Tensor batch_size_tensor = tf.convert_to_tensor(batch_size);
  40. if (inputs != null)
  41. {
  42. batch_size_tensor = tf.shape(inputs)[0];
  43. dtype = inputs.dtype;
  44. }
  45. return generate_zero_filled_state(batch_size_tensor, cell.StateSize, dtype);
  46. }
  47. /// <summary>
  48. /// Standardizes `__call__` to a single list of tensor inputs.
  49. ///
  50. /// When running a model loaded from a file, the input tensors
  51. /// `initial_state` and `constants` can be passed to `RNN.__call__()` as part
  52. /// of `inputs` instead of by the dedicated keyword arguments.This method
  53. /// makes sure the arguments are separated and that `initial_state` and
  54. /// `constants` are lists of tensors(or None).
  55. /// </summary>
  56. /// <param name="inputs">Tensor or list/tuple of tensors. which may include constants
  57. /// and initial states.In that case `num_constant` must be specified.</param>
  58. /// <param name="initial_state">Tensor or list of tensors or None, initial states.</param>
  59. /// <param name="constants">Tensor or list of tensors or None, constant tensors.</param>
  60. /// <param name="num_constants">Expected number of constants (if constants are passed as
  61. /// part of the `inputs` list.</param>
  62. /// <returns></returns>
  63. internal static (Tensors, Tensors, Tensors) standardize_args(Tensors inputs, Tensors initial_state, Tensors constants, int num_constants)
  64. {
  65. if(inputs.Length > 1)
  66. {
  67. // There are several situations here:
  68. // In the graph mode, __call__ will be only called once. The initial_state
  69. // and constants could be in inputs (from file loading).
  70. // In the eager mode, __call__ will be called twice, once during
  71. // rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
  72. // model.fit/train_on_batch/predict with real np data. In the second case,
  73. // the inputs will contain initial_state and constants as eager tensor.
  74. //
  75. // For either case, the real input is the first item in the list, which
  76. // could be a nested structure itself. Then followed by initial_states, which
  77. // could be a list of items, or list of list if the initial_state is complex
  78. // structure, and finally followed by constants which is a flat list.
  79. Debug.Assert(initial_state is null && constants is null);
  80. if(num_constants > 0)
  81. {
  82. constants = inputs.TakeLast(num_constants).ToTensors();
  83. inputs = inputs.SkipLast(num_constants).ToTensors();
  84. }
  85. if(inputs.Length > 1)
  86. {
  87. initial_state = inputs.Skip(1).ToTensors();
  88. inputs = inputs.Take(1).ToTensors();
  89. }
  90. }
  91. return (inputs, initial_state, constants);
  92. }
  93. }
  94. }