You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseSession.cs 4.6 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. using NumSharp.Core;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. namespace Tensorflow
  8. {
  9. public class BaseSession : IDisposable
  10. {
  11. private Graph _graph;
  12. private bool _opened;
  13. private bool _closed;
  14. private int _current_version;
  15. private byte[] _target;
  16. private IntPtr _session;
  17. public BaseSession(string target = "", Graph graph = null)
  18. {
  19. if(graph is null)
  20. {
  21. _graph = ops.get_default_graph();
  22. }
  23. else
  24. {
  25. _graph = graph;
  26. }
  27. _target = UTF8Encoding.UTF8.GetBytes(target);
  28. var opts = c_api.TF_NewSessionOptions();
  29. var status = new Status();
  30. _session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle);
  31. c_api.TF_DeleteSessionOptions(opts);
  32. }
  33. public void Dispose()
  34. {
  35. }
  36. public virtual object run(Tensor fetches, FeedDict feed_dict = null)
  37. {
  38. var result = _run(fetches, feed_dict);
  39. return result;
  40. }
  41. private unsafe object _run(Tensor fetches, FeedDict feed_dict = null)
  42. {
  43. var feed_dict_tensor = new FeedDict();
  44. if (feed_dict != null)
  45. {
  46. NDArray np_val = null;
  47. foreach (FeedValue feed in feed_dict)
  48. {
  49. switch (feed.feed_val)
  50. {
  51. case float value:
  52. np_val = np.asarray(value);
  53. break;
  54. }
  55. feed_dict_tensor[feed.feed] = np_val;
  56. }
  57. }
  58. // Create a fetch handler to take care of the structure of fetches.
  59. var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
  60. // Run request and get response.
  61. // We need to keep the returned movers alive for the following _do_run().
  62. // These movers are no longer needed when _do_run() completes, and
  63. // are deleted when `movers` goes out of scope when this _run() ends.
  64. var _ = _update_with_movers();
  65. var final_fetches = fetch_handler.fetches();
  66. var final_targets = fetch_handler.targets();
  67. // We only want to really perform the run if fetches or targets are provided,
  68. // or if the call is a partial run that specifies feeds.
  69. var results = _do_run(final_fetches, feed_dict_tensor);
  70. return fetch_handler.build_results(null, results);
  71. }
  72. private object[] _do_run(List<Tensor> fetch_list, FeedDict feed_dict)
  73. {
  74. var feeds = feed_dict.items().Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray();
  75. var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
  76. return _call_tf_sessionrun(feeds, fetches);
  77. }
  78. private unsafe object[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list)
  79. {
  80. // Ensure any changes to the graph are reflected in the runtime.
  81. _extend_graph();
  82. var status = new Status();
  83. var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
  84. c_api.TF_SessionRun(_session,
  85. run_options: IntPtr.Zero,
  86. inputs: feed_dict.Select(f => f.Key).ToArray(),
  87. input_values: new IntPtr[] { },
  88. ninputs: 0,
  89. outputs: fetch_list,
  90. output_values: output_values,
  91. noutputs: fetch_list.Length,
  92. target_opers: new IntPtr[] { },
  93. ntargets: 0,
  94. run_metadata: IntPtr.Zero,
  95. status: status.Handle);
  96. var result = output_values.Select(x => c_api.TF_TensorData(x))
  97. .Select(x => (object)*(float*)x)
  98. .ToArray();
  99. return result;
  100. }
  101. /// <summary>
  102. /// If a tensor handle that is fed to a device incompatible placeholder,
  103. /// we move the tensor to the right device, generate a new tensor handle,
  104. /// and update feed_dict to use the new handle.
  105. /// </summary>
  106. private List<object> _update_with_movers()
  107. {
  108. return new List<object> { };
  109. }
  110. private void _extend_graph()
  111. {
  112. }
  113. }
  114. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。

Contributors (1)