You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

execute.cs 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using System.Xml.Linq;
  6. using Tensorflow.Contexts;
  7. using static Tensorflow.ApiDef.Types;
  8. using static Tensorflow.CostGraphDef.Types;
  9. using static Tensorflow.Binding;
  10. using Tensorflow.Gradients;
  11. namespace Tensorflow.Eager
  12. {
  13. internal static class _execute
  14. {
  15. public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx)
  16. {
  17. var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx));
  18. var types = v.Select(t => t.dtype.as_datatype_enum());
  19. return (types.ToArray(), v.ToArray());
  20. }
  21. public static Tensor[] execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
  22. {
  23. return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name);
  24. }
  25. public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
  26. {
  27. string device_name = ctx.DeviceName;
  28. ctx.ensure_initialized();
  29. var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs);
  30. return tensors;
  31. }
  32. public static bool must_record_gradient()
  33. {
  34. return tf.GetTapeSet().Count != 0;
  35. }
  36. public static bool record_gradient(string op_name, Tensor[] inputs, object[] attrs, Tensor[] results)
  37. {
  38. return tf.Runner.RecordGradient(op_name, inputs, attrs, results);
  39. }
  40. }
  41. }