You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Graph.cs 16 kB

7 years ago
7 years ago
7 years ago
7 years ago
6 years ago
7 years ago
7 years ago
6 years ago
7 years ago
6 years ago
6 years ago
7 years ago
7 years ago
7 years ago
6 years ago
7 years ago
7 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
6 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
6 years ago
6 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Runtime.InteropServices;
  5. using System.Text;
  6. namespace Tensorflow
  7. {
  8. /// <summary>
  9. /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations.
  10. /// This leads to a low-level programming model in which you first define the dataflow graph,
  11. /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
  12. /// https://www.tensorflow.org/guide/graphs
  13. /// </summary>
  14. /*
  15. A TensorFlow computation, represented as a dataflow graph.
  16. A `Graph` contains a set of
  17. `tf.Operation` objects,
  18. which represent units of computation; and
  19. `tf.Tensor` objects, which represent
  20. the units of data that flow between operations.
  21. A default `Graph` is always registered, and accessible by calling
  22. `tf.get_default_graph`.
  23. To add an operation to the default graph, simply call one of the functions
  24. that defines a new `Operation`:
  25. ```python
  26. c = tf.constant(4.0)
  27. assert c.graph is tf.get_default_graph()
  28. ```
  29. Another typical usage involves the
  30. `tf.Graph.as_default`
  31. context manager, which overrides the current default graph for the
  32. lifetime of the context:
  33. ```python
  34. g = tf.Graph()
  35. with g.as_default():
  36. # Define operations and tensors in `g`.
  37. c = tf.constant(30.0)
  38. assert c.graph is g
  39. ```
  40. Important note: This class *is not* thread-safe for graph construction. All
  41. operations should be created from a single thread, or external
  42. synchronization must be provided. Unless otherwise specified, all methods
  43. are not thread-safe.
  44. A `Graph` instance supports an arbitrary number of "collections"
  45. that are identified by name. For convenience when building a large
  46. graph, collections can store groups of related objects: for
  47. example, the `tf.Variable` uses a collection (named
  48. `tf.GraphKeys.GLOBAL_VARIABLES`) for
  49. all variables that are created during the construction of a graph. The caller
  50. may define additional collections by specifying a new name.
  51. */
  52. public partial class Graph : IPython, IDisposable
  53. {
  54. private IntPtr _handle;
  55. private Dictionary<int, ITensorOrOperation> _nodes_by_id;
  56. public Dictionary<string, ITensorOrOperation> _nodes_by_name;
  57. private Dictionary<string, int> _names_in_use;
  58. public int _version;
  59. private int _next_id_counter;
  60. private List<Operation> _unfetchable_ops = new List<Operation>();
  61. private List<Tensor> _unfeedable_tensors = new List<Tensor>();
  62. public string _name_stack = "";
  63. public string _graph_key;
  64. public Status Status { get; }
  65. /// <summary>
  66. /// True if the graph is considered "finalized". In that case no
  67. /// new operations can be added.
  68. /// </summary>
  69. private bool _finalized = false;
  70. /// <summary>
  71. /// Arbitrary collections of objects.
  72. /// </summary>
  73. private Dictionary<string, object> _collections = new Dictionary<string, object>();
  74. public bool building_function;
  75. public Graph()
  76. {
  77. _handle = c_api.TF_NewGraph();
  78. Status = new Status();
  79. _nodes_by_id = new Dictionary<int, ITensorOrOperation>();
  80. _nodes_by_name = new Dictionary<string, ITensorOrOperation>();
  81. _names_in_use = new Dictionary<string, int>();
  82. _graph_key = $"grap-key-{ops.uid()}/";
  83. }
  84. public Graph(IntPtr handle)
  85. {
  86. _handle = handle;
  87. Status = new Status();
  88. _nodes_by_id = new Dictionary<int, ITensorOrOperation>();
  89. _nodes_by_name = new Dictionary<string, ITensorOrOperation>();
  90. _names_in_use = new Dictionary<string, int>();
  91. _graph_key = $"grap-key-{ops.uid()}/";
  92. }
  93. public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
  94. {
  95. return _as_graph_element_locked(obj, allow_tensor, allow_operation);
  96. }
  97. public Graph as_default() => ops.set_default_graph(this);
  98. private Tensor _as_graph_element(object obj)
  99. {
  100. if (obj is RefVariable var)
  101. return var._as_graph_element();
  102. return null;
  103. }
  104. private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
  105. {
  106. string types_str = "";
  107. if (allow_tensor && allow_operation)
  108. {
  109. types_str = "Tensor or Operation";
  110. }
  111. else if (allow_tensor)
  112. {
  113. types_str = "Tensor";
  114. }
  115. else if (allow_operation)
  116. {
  117. types_str = "Operation";
  118. }
  119. var temp_obj = _as_graph_element(obj);
  120. if (temp_obj != null)
  121. obj = temp_obj;
  122. // If obj appears to be a name...
  123. if (obj is string name)
  124. {
  125. if (name.Contains(":") && allow_tensor)
  126. {
  127. string op_name = name.Split(':')[0];
  128. int out_n = int.Parse(name.Split(':')[1]);
  129. if (_nodes_by_name.ContainsKey(op_name))
  130. return _nodes_by_name[op_name].outputs[out_n];
  131. }
  132. else if (!name.Contains(":") & allow_operation)
  133. {
  134. if (!_nodes_by_name.ContainsKey(name))
  135. throw new KeyError($"The name {name} refers to an Operation not in the graph.");
  136. return _nodes_by_name[name];
  137. }
  138. else if (!name.Contains(":") & !allow_operation)
  139. {
  140. throw new NotImplementedException("_as_graph_element_locked");
  141. }
  142. }
  143. if (obj is Tensor tensor && allow_tensor)
  144. {
  145. if (tensor.graph.Equals(this))
  146. {
  147. return tensor;
  148. }
  149. else
  150. {
  151. throw new Exception($"Tensor {obj} is not an element of this graph.");
  152. }
  153. }
  154. else if (obj is Operation op && allow_operation)
  155. {
  156. if (op.graph.Equals(this))
  157. {
  158. return op;
  159. }
  160. else
  161. {
  162. throw new Exception($"Operation {obj} is not an element of this graph.");
  163. }
  164. }
  165. throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}.");
  166. }
  167. public void add_to_collection<T>(string name, T value)
  168. {
  169. _check_not_finalized();
  170. if (_collections.ContainsKey(name))
  171. (_collections[name] as List<T>).Add(value);
  172. else
  173. _collections[name] = new List<T> { value };
  174. }
  175. public void add_to_collections<T>(List<string> names, T value)
  176. {
  177. foreach (string name in names)
  178. add_to_collection(name, value);
  179. }
  180. private void _check_not_finalized()
  181. {
  182. if (_finalized)
  183. throw new RuntimeError("Graph is finalized and cannot be modified.");
  184. }
  185. public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
  186. TF_DataType[] input_types = null, string name = null,
  187. Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
  188. {
  189. if (inputs == null)
  190. inputs = new Tensor[0];
  191. foreach ((int idx, Tensor a) in Python.enumerate(inputs))
  192. {
  193. }
  194. if (String.IsNullOrEmpty(name))
  195. name = op_type;
  196. // If a names ends with a '/' it is a "name scope" and we use it as-is,
  197. // after removing the trailing '/'.
  198. name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name);
  199. var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
  200. var input_ops = inputs.Select(x => x.op).ToArray();
  201. var control_inputs = _control_dependencies_for_inputs(input_ops);
  202. var op = new Operation(node_def,
  203. this,
  204. inputs: inputs,
  205. output_types: dtypes,
  206. control_inputs: control_inputs,
  207. input_types: input_types,
  208. original_op: null,
  209. op_def: op_def);
  210. _create_op_helper(op, true);
  211. /*Console.Write($"create_op: {op_type} '{node_def.Name}'");
  212. Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");
  213. Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}");
  214. Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}");
  215. Console.WriteLine();*/
  216. return op;
  217. }
  218. private void _create_op_helper(Operation op, bool compute_device = true)
  219. {
  220. _record_op_seen_by_control_dependencies(op);
  221. }
  222. public void _add_op(Operation op)
  223. {
  224. op._id_value = _next_id();
  225. _nodes_by_id[op._id] = op;
  226. _nodes_by_name[op.name] = op;
  227. _version = Math.Max(_version, op._id);
  228. }
  229. public int _next_id()
  230. {
  231. return ++_next_id_counter;
  232. }
  233. public bool is_fetchable<T>(T tensor_or_op)
  234. {
  235. if (tensor_or_op is Tensor tensor)
  236. {
  237. return !_unfetchable_ops.Contains(tensor); ;
  238. }
  239. else if (tensor_or_op is Operation op)
  240. {
  241. return !_unfetchable_ops.Contains(op);
  242. }
  243. return false;
  244. }
  245. public string get_name_scope()
  246. {
  247. return _name_stack;
  248. }
  249. public string name_scope(string name)
  250. {
  251. string new_stack = "";
  252. if (string.IsNullOrEmpty(name))
  253. new_stack = "";
  254. else if (name.EndsWith("/"))
  255. new_stack = ops._name_from_scope_name(name);
  256. else
  257. new_stack = unique_name(name);
  258. _name_stack = new_stack;
  259. return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/";
  260. }
  261. /// <summary>
  262. /// Return a unique operation name for `name`.
  263. ///
  264. /// Note: You rarely need to call `unique_name()` directly.Most of
  265. /// the time you just need to create `with g.name_scope()` blocks to
  266. /// generate structured names.
  267. ///
  268. /// `unique_name` is used to generate structured names, separated by
  269. /// `"/"`, to help identify operations when debugging a graph.
  270. /// Operation names are displayed in error messages reported by the
  271. /// TensorFlow runtime, and in various visualization tools such as
  272. /// TensorBoard.
  273. ///
  274. /// If `mark_as_used` is set to `True`, which is the default, a new
  275. /// unique name is created and marked as in use.If it's set to `False`,
  276. /// the unique name is returned without actually being marked as used.
  277. /// This is useful when the caller simply wants to know what the name
  278. /// to be created will be.
  279. /// </summary>
  280. /// <param name="name">The name for an operation.</param>
  281. /// <param name="mark_as_used"> Whether to mark this name as being used.</param>
  282. /// <returns>A string to be passed to `create_op()` that will be used
  283. /// to name the operation being created.</returns>
  284. public string unique_name(string name, bool mark_as_used = true)
  285. {
  286. if (!String.IsNullOrEmpty(_name_stack))
  287. name = _name_stack + "/" + name;
  288. // For the sake of checking for names in use, we treat names as case
  289. // insensitive (e.g. foo = Foo).
  290. var name_key = name.ToLower();
  291. int i = 0;
  292. if (_names_in_use.ContainsKey(name_key))
  293. i = _names_in_use[name_key];
  294. // Increment the number for "name_key".
  295. if (mark_as_used)
  296. _names_in_use[name_key] = i + 1;
  297. if (i > 0)
  298. {
  299. // Make sure the composed name key is not already used.
  300. var base_name_key = name_key;
  301. while (_names_in_use.ContainsKey(name_key))
  302. {
  303. name_key = $"{base_name_key}_{i}";
  304. i += 1;
  305. }
  306. // Mark the composed name_key as used in case someone wants
  307. // to call unique_name("name_1").
  308. if (mark_as_used)
  309. _names_in_use[name_key] = 1;
  310. // Return the new name with the original capitalization of the given name.
  311. name = $"{name}_{i-1}";
  312. }
  313. return name;
  314. }
  315. public TF_Output[] ReturnOutputs(IntPtr results)
  316. {
  317. IntPtr return_output_handle = IntPtr.Zero;
  318. int num_return_outputs = 0;
  319. c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle);
  320. TF_Output[] return_outputs = new TF_Output[num_return_outputs];
  321. for (int i = 0; i < num_return_outputs; i++)
  322. {
  323. var handle = return_output_handle + (Marshal.SizeOf<TF_Output>() * i);
  324. return_outputs[i] = Marshal.PtrToStructure<TF_Output>(handle);
  325. }
  326. return return_outputs;
  327. }
  328. public unsafe Operation[] ReturnOperations(IntPtr results)
  329. {
  330. TF_Operation return_oper_handle = new TF_Operation();
  331. int num_return_opers = 0;
  332. c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle);
  333. Operation[] return_opers = new Operation[num_return_opers];
  334. for (int i = 0; i < num_return_opers; i++)
  335. {
  336. var handle = return_oper_handle.node + Marshal.SizeOf<TF_Operation>() * i;
  337. return_opers[i] = new Operation(*(IntPtr*)handle);
  338. }
  339. return return_opers;
  340. }
  341. public Operation OperationByName(string operName)
  342. {
  343. return c_api.TF_GraphOperationByName(_handle, operName);
  344. }
  345. public ITensorOrOperation[] get_operations()
  346. {
  347. return _nodes_by_name.Values.Select(x => x).ToArray();
  348. }
  349. public string[] get_all_collection_keys()
  350. {
  351. return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray();
  352. }
  353. public object get_collection(string name, string scope = null)
  354. {
  355. return _collections.ContainsKey(name) ? _collections[name] : null;
  356. }
  357. public object get_collection_ref(string name)
  358. {
  359. if (!_collections.ContainsKey(name))
  360. _collections[name] = new List<object>();
  361. return _collections[name];
  362. }
  363. public void prevent_feeding(Tensor tensor)
  364. {
  365. _unfeedable_tensors.Add(tensor);
  366. }
  367. public void prevent_fetching(Operation op)
  368. {
  369. _unfetchable_ops.Add(op);
  370. }
  371. public void Dispose()
  372. {
  373. c_api.TF_DeleteGraph(_handle);
  374. }
  375. public void __enter__()
  376. {
  377. }
  378. public void __exit__()
  379. {
  380. }
  381. public static implicit operator IntPtr(Graph graph)
  382. {
  383. return graph._handle;
  384. }
  385. }
  386. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。