You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Permute.cs 1.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow.Keras.Engine;
  5. using Tensorflow.Keras.Utils;
  6. using static Tensorflow.Binding;
  7. using Tensorflow.Keras.ArgsDefinition;
  8. namespace Tensorflow.Keras.Layers {
  9. public class Permute : Layer
  10. {
  11. int[] dims, permute;
  12. public Permute(PermuteArgs args) : base(args)
  13. {
  14. this.dims = args.dims;
  15. }
  16. public override void build(Shape input_shape)
  17. {
  18. var rank = input_shape.rank;
  19. if (dims.Length != rank - 1)
  20. {
  21. throw new ValueError("Dimensions must match.");
  22. }
  23. permute = new int[input_shape.rank];
  24. dims.CopyTo(permute, 1);
  25. built = true;
  26. }
  27. protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
  28. {
  29. Tensor outputs = inputs;
  30. return tf.transpose(outputs, new Axis(permute));
  31. }
  32. public override Shape ComputeOutputShape(Shape input_shape)
  33. {
  34. Shape output_shape = new Shape(input_shape.dims);
  35. for (int i = 0; i < dims.Length; i += 1)
  36. {
  37. var d = dims[i];
  38. var target_dim = input_shape[d];
  39. output_shape[i + 1] = target_dim;
  40. }
  41. return output_shape;
  42. }
  43. }
  44. }