You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Permute.cs 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow.Keras.Engine;
  5. using Tensorflow.Keras.Utils;
  6. using static Tensorflow.Binding;
  7. using Tensorflow.Keras.ArgsDefinition;
  8. namespace Tensorflow.Keras.Layers {
  9. public class Permute : Layer
  10. {
  11. int[] dims, permute;
  12. public Permute(PermuteArgs args) : base(args)
  13. {
  14. this.dims = args.dims;
  15. }
  16. public override void build(Shape input_shape)
  17. {
  18. var rank = input_shape.rank;
  19. if (dims.Length != rank - 1)
  20. {
  21. throw new ValueError("Dimensions must match.");
  22. }
  23. permute = new int[input_shape.rank];
  24. dims.CopyTo(permute, 1);
  25. built = true;
  26. _buildInputShape = input_shape;
  27. }
  28. protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
  29. {
  30. Tensor outputs = inputs;
  31. return tf.transpose(outputs, new Axis(permute));
  32. }
  33. public override Shape ComputeOutputShape(Shape input_shape)
  34. {
  35. Shape output_shape = new Shape(input_shape.dims);
  36. for (int i = 0; i < dims.Length; i += 1)
  37. {
  38. var d = dims[i];
  39. var target_dim = input_shape[d];
  40. output_shape[i + 1] = target_dim;
  41. }
  42. return output_shape;
  43. }
  44. }
  45. }