You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Execute.cs 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. using System.Collections.Generic;
  2. using System;
  3. using System.Linq;
  4. using static Tensorflow.Binding;
  5. namespace Tensorflow.Eager
  6. {
  7. public class Execute
  8. {
  9. /// <summary>
  10. /// Execute a TensorFlow operation.
  11. /// </summary>
  12. /// <param name="op_name">
  13. /// Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
  14. /// execute.
  15. /// </param>
  16. /// <param name="num_outputs">
  17. /// The number of outputs of the operation to fetch.
  18. /// </param>
  19. /// <param name="inputs">
  20. /// A list of inputs to the operation. Each entry should be a Tensor, or
  21. /// a value which can be passed to the Tensor constructor to create one.
  22. /// </param>
  23. /// <param name="attrs">
  24. /// A tuple with alternating string attr names and attr values for this
  25. /// operation.
  26. /// </param>
  27. /// <param name="ctx">The value of context.context().</param>
  28. /// <param name="name">Customized name for the operation.</param>
  29. /// <returns>List of output Tensor objects. The list is empty if there are no outputs</returns>
  30. public Tensor[] execute(Context ctx, string op_name, int num_outputs,
  31. Tensor[] inputs, object[] attrs,
  32. string name = null)
  33. {
  34. ctx.ensure_initialized();
  35. var results = tf.Runner.TFE_Execute(ctx,
  36. ctx.device_name,
  37. op_name,
  38. inputs,
  39. attrs,
  40. num_outputs);
  41. return results;
  42. }
  43. public (TF_DataType, Tensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null)
  44. {
  45. if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid)
  46. return (default_dtype, null);
  47. if (args.Count(x => x is Tensor) == args.Length)
  48. return ((args[0] as Tensor).dtype, args.Select(x => x as Tensor).ToArray());
  49. var dtype = TF_DataType.DtInvalid;
  50. foreach (var x in args)
  51. {
  52. if (x is Tensor et)
  53. dtype = et.dtype;
  54. }
  55. if (dtype == TF_DataType.DtInvalid)
  56. {
  57. var ret = new List<Tensor>();
  58. foreach (var t in args)
  59. {
  60. ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as Tensor);
  61. if (dtype == TF_DataType.DtInvalid)
  62. dtype = ret.Last().dtype;
  63. }
  64. return (dtype, ret.ToArray());
  65. }
  66. else
  67. throw new NotImplementedException("");
  68. }
  69. }
  70. }