You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BaseSession.cs 6.3 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. using NumSharp.Core;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. namespace Tensorflow
  8. {
  9. public class BaseSession : IDisposable
  10. {
  11. protected Graph _graph;
  12. protected bool _opened;
  13. protected bool _closed;
  14. protected int _current_version;
  15. protected byte[] _target;
  16. protected IntPtr _session;
  17. public BaseSession(string target = "", Graph graph = null)
  18. {
  19. if(graph is null)
  20. {
  21. _graph = ops.get_default_graph();
  22. }
  23. else
  24. {
  25. _graph = graph;
  26. }
  27. _target = UTF8Encoding.UTF8.GetBytes(target);
  28. var opts = c_api.TF_NewSessionOptions();
  29. var status = new Status();
  30. _session = c_api.TF_NewSession(_graph, opts, status);
  31. c_api.TF_DeleteSessionOptions(opts);
  32. }
  33. public void Dispose()
  34. {
  35. }
  36. public virtual object run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null)
  37. {
  38. var result = _run(fetches, feed_dict);
  39. return result;
  40. }
  41. private unsafe object _run(Tensor fetches, Dictionary<Tensor, NDArray> feed_dict = null)
  42. {
  43. var feed_dict_tensor = new Dictionary<Tensor, NDArray>();
  44. if (feed_dict != null)
  45. {
  46. foreach (var feed in feed_dict)
  47. {
  48. feed_dict_tensor[feed.Key] = feed.Value;
  49. }
  50. }
  51. // Create a fetch handler to take care of the structure of fetches.
  52. var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
  53. // Run request and get response.
  54. // We need to keep the returned movers alive for the following _do_run().
  55. // These movers are no longer needed when _do_run() completes, and
  56. // are deleted when `movers` goes out of scope when this _run() ends.
  57. var _ = _update_with_movers();
  58. var final_fetches = fetch_handler.fetches();
  59. var final_targets = fetch_handler.targets();
  60. // We only want to really perform the run if fetches or targets are provided,
  61. // or if the call is a partial run that specifies feeds.
  62. var results = _do_run(final_fetches, feed_dict_tensor);
  63. return fetch_handler.build_results(null, results);
  64. }
  65. private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict)
  66. {
  67. var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray();
  68. var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
  69. return _call_tf_sessionrun(feeds, fetches);
  70. }
  71. private unsafe object[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list)
  72. {
  73. // Ensure any changes to the graph are reflected in the runtime.
  74. _extend_graph();
  75. var status = new Status();
  76. var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
  77. c_api.TF_SessionRun(_session,
  78. run_options: null,
  79. inputs: feed_dict.Select(f => f.Key).ToArray(),
  80. input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
  81. ninputs: feed_dict.Length,
  82. outputs: fetch_list,
  83. output_values: output_values,
  84. noutputs: fetch_list.Length,
  85. target_opers: IntPtr.Zero,
  86. ntargets: 0,
  87. run_metadata: IntPtr.Zero,
  88. status: status);
  89. status.Check(true);
  90. object[] result = new object[fetch_list.Length];
  91. for (int i = 0; i < fetch_list.Length; i++)
  92. {
  93. var tensor = new Tensor(output_values[i]);
  94. Type type = tensor.dtype.as_numpy_datatype();
  95. var ndims = tensor.shape.Select(x => (int)x).ToArray();
  96. switch (tensor.dtype)
  97. {
  98. case TF_DataType.TF_STRING:
  99. {
  100. // wired, don't know why we have to start from offset 9.
  101. var bytes = tensor.Data();
  102. var output = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9);
  103. result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims);
  104. }
  105. break;
  106. case TF_DataType.TF_FLOAT:
  107. {
  108. var output = *(float*)c_api.TF_TensorData(output_values[i]);
  109. result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims);
  110. }
  111. break;
  112. case TF_DataType.TF_INT16:
  113. {
  114. var output = *(short*)c_api.TF_TensorData(output_values[i]);
  115. result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims);
  116. }
  117. break;
  118. case TF_DataType.TF_INT32:
  119. {
  120. var output = *(int*)c_api.TF_TensorData(output_values[i]);
  121. result[i] = tensor.NDims == 0 ? output : np.array(output).reshape(ndims);
  122. }
  123. break;
  124. default:
  125. throw new NotImplementedException("can't get output");
  126. }
  127. }
  128. return result;
  129. }
  130. /// <summary>
  131. /// If a tensor handle that is fed to a device incompatible placeholder,
  132. /// we move the tensor to the right device, generate a new tensor handle,
  133. /// and update feed_dict to use the new handle.
  134. /// </summary>
  135. private List<object> _update_with_movers()
  136. {
  137. return new List<object> { };
  138. }
  139. private void _extend_graph()
  140. {
  141. }
  142. }
  143. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。