You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CSession.cs 3.4 kB

7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. using NumSharp.Core;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Runtime.InteropServices;
  5. using System.Text;
  6. using Tensorflow;
  7. namespace TensorFlowNET.UnitTest
  8. {
  9. /// <summary>
  10. /// tensorflow\c\c_test_util.cc
  11. /// </summary>
  12. public class CSession
  13. {
  14. private IntPtr session_;
  15. private List<TF_Output> inputs_ = new List<TF_Output>();
  16. private List<IntPtr> input_values_ = new List<IntPtr>();
  17. private List<TF_Output> outputs_ = new List<TF_Output>();
  18. private List<IntPtr> output_values_ = new List<IntPtr>();
  19. private List<IntPtr> targets_ = new List<IntPtr>();
  20. public CSession(Graph graph, Status s, bool user_XLA = false)
  21. {
  22. var opts = new SessionOptions();
  23. session_ = new Session(graph, opts, s);
  24. }
  25. public void SetInputs(Dictionary<IntPtr, IntPtr> inputs)
  26. {
  27. DeleteInputValues();
  28. inputs_.Clear();
  29. foreach (var input in inputs)
  30. {
  31. var i = new TF_Output(input.Key, 0);
  32. var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Output>());
  33. Marshal.StructureToPtr(i, handle, false);
  34. inputs_.Add(i);
  35. input_values_.Add(input.Value);
  36. }
  37. }
  38. private void DeleteInputValues()
  39. {
  40. for (var i = 0; i < input_values_.Count; ++i)
  41. {
  42. //input_values_[i].Dispose();
  43. }
  44. input_values_.Clear();
  45. }
  46. public void SetOutputs(List<IntPtr> outputs)
  47. {
  48. ResetOutputValues();
  49. outputs_.Clear();
  50. foreach (var output in outputs)
  51. {
  52. var handle = Marshal.AllocHGlobal(Marshal.SizeOf<TF_Output>());
  53. Marshal.StructureToPtr(new TF_Output(output, 0), handle, true);
  54. outputs_.Add(new TF_Output(output, 0));
  55. handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>());
  56. output_values_.Add(IntPtr.Zero);
  57. }
  58. }
  59. private void ResetOutputValues()
  60. {
  61. for (var i = 0; i < output_values_.Count; ++i)
  62. {
  63. //if (output_values_[i] != IntPtr.Zero)
  64. //output_values_[i].Dispose();
  65. }
  66. output_values_.Clear();
  67. }
  68. public unsafe void Run(Status s)
  69. {
  70. var inputs_ptr = inputs_.ToArray();// inputs_.Count == 0 ? IntPtr.Zero : inputs_[0];
  71. var input_values_ptr = input_values_.ToArray();// inputs_.Count == 0 ? IntPtr.Zero : input_values_[0];
  72. var outputs_ptr = outputs_.ToArray();// outputs_.Count == 0 ? IntPtr.Zero : outputs_[0];
  73. var output_values_ptr = output_values_.ToArray();// output_values_.Count == 0 ? IntPtr.Zero : output_values_[0];
  74. IntPtr targets_ptr = IntPtr.Zero;
  75. c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 1,
  76. outputs_ptr, output_values_ptr, outputs_.Count,
  77. targets_ptr, targets_.Count,
  78. IntPtr.Zero, s);
  79. s.Check();
  80. output_values_[0] = output_values_ptr[0];
  81. }
  82. public IntPtr output_tensor(int i)
  83. {
  84. return output_values_[i];
  85. }
  86. }
  87. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。

Contributors (1)