You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Flatten.cs 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. using System;
  2. using System.Linq;
  3. using Tensorflow.Common.Types;
  4. using Tensorflow.Framework;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Engine;
  7. using Tensorflow.Keras.Utils;
  8. using static Tensorflow.Binding;
  9. namespace Tensorflow.Keras.Layers
  10. {
  11. public class Flatten : Layer
  12. {
  13. FlattenArgs args;
  14. InputSpec input_spec;
  15. bool _channels_first;
  16. public Flatten(FlattenArgs args)
  17. : base(args)
  18. {
  19. this.args = args;
  20. args.DataFormat = conv_utils.normalize_data_format(args.DataFormat);
  21. input_spec = new InputSpec(min_ndim: 1);
  22. _channels_first = args.DataFormat == "channels_first";
  23. }
  24. protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
  25. {
  26. if (_channels_first)
  27. {
  28. throw new NotImplementedException("");
  29. }
  30. if (tf.executing_eagerly())
  31. {
  32. return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 });
  33. }
  34. else
  35. {
  36. var input_shape = inputs.shape;
  37. var rank = inputs.shape.ndim;
  38. if (rank == 1)
  39. return array_ops.expand_dims(inputs, axis: 1);
  40. var batch_dim = tensor_shape.dimension_value(input_shape[0]);
  41. if (batch_dim != -1)
  42. {
  43. return array_ops.reshape(inputs, new[] { batch_dim, -1 });
  44. }
  45. var non_batch_dims = ((int[])input_shape).Skip(1).ToArray();
  46. var num = 1;
  47. if (non_batch_dims.Length > 0)
  48. {
  49. for (var i = 0; i < non_batch_dims.Length; i++)
  50. {
  51. num *= non_batch_dims[i];
  52. }
  53. }
  54. return array_ops.reshape(inputs, new[] { inputs.shape[0], num });
  55. }
  56. }
  57. }
  58. }