You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTest.cs 8.0 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. namespace TensorFlowNET.UnitTest
  7. {
  8. [TestClass]
  9. public class GraphTest
  10. {
  11. /// <summary>
  12. /// Port from c_api_test.cc
  13. /// `TEST(CAPI, Graph)`
  14. /// </summary>
  15. [TestMethod]
  16. public void c_api_Graph()
  17. {
  18. var s = new Status();
  19. var graph = new Graph();
  20. // Make a placeholder operation.
  21. var feed = c_test_util.Placeholder(graph, s);
  22. Assert.AreEqual("feed", feed.Name);
  23. Assert.AreEqual("Placeholder", feed.OpType);
  24. Assert.AreEqual("", feed.Device);
  25. Assert.AreEqual(1, feed.NumOutputs);
  26. Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0));
  27. Assert.AreEqual(1, feed.OutputListLength("output"));
  28. Assert.AreEqual(0, feed.NumInputs);
  29. Assert.AreEqual(0, feed.OutputNumConsumers(0));
  30. Assert.AreEqual(0, feed.NumControlInputs);
  31. Assert.AreEqual(0, feed.NumControlOutputs);
  32. AttrValue attr_value = null;
  33. Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s));
  34. Assert.AreEqual(attr_value.Type, DataType.DtInt32);
  35. // Test not found errors in TF_Operation*() query functions.
  36. Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
  37. Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code);
  38. Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
  39. Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message);
  40. // Make a constant oper with the scalar "3".
  41. var three = c_test_util.ScalarConst(3, graph, s);
  42. Assert.AreEqual(TF_Code.TF_OK, s.Code);
  43. // Add oper.
  44. var add = c_test_util.Add(feed, three, graph, s);
  45. Assert.AreEqual(TF_Code.TF_OK, s.Code);
  46. // Test TF_Operation*() query functions.
  47. Assert.AreEqual("add", add.Name);
  48. Assert.AreEqual("AddN", add.OpType);
  49. Assert.AreEqual("", add.Device);
  50. Assert.AreEqual(1, add.NumOutputs);
  51. Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0));
  52. Assert.AreEqual(1, add.OutputListLength("sum"));
  53. Assert.AreEqual(TF_Code.TF_OK, s.Code);
  54. Assert.AreEqual(2, add.InputListLength("inputs"));
  55. Assert.AreEqual(TF_Code.TF_OK, s.Code);
  56. Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0));
  57. Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1));
  58. var add_in_0 = add.Input(0);
  59. Assert.AreEqual(feed, add_in_0.oper);
  60. Assert.AreEqual(0, add_in_0.index);
  61. var add_in_1 = add.Input(1);
  62. Assert.AreEqual(three, add_in_1.oper);
  63. Assert.AreEqual(0, add_in_1.index);
  64. Assert.AreEqual(0, add.OutputNumConsumers(0));
  65. Assert.AreEqual(0, add.NumControlInputs);
  66. Assert.AreEqual(0, add.NumControlOutputs);
  67. Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s));
  68. Assert.AreEqual(DataType.DtInt32, attr_value.Type);
  69. Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s));
  70. Assert.AreEqual(2, attr_value.I);
  71. // Placeholder oper now has a consumer.
  72. Assert.AreEqual(1, feed.OutputNumConsumers(0));
  73. TF_Input[] feed_port = feed.OutputConsumers(0, 1);
  74. Assert.AreEqual(1, feed_port.Length);
  75. Assert.AreEqual(add, feed_port[0].oper);
  76. Assert.AreEqual(0, feed_port[0].index);
  77. // The scalar const oper also has a consumer.
  78. Assert.AreEqual(1, three.OutputNumConsumers(0));
  79. TF_Input[] three_port = three.OutputConsumers(0, 1);
  80. Assert.AreEqual(add, three_port[0].oper);
  81. Assert.AreEqual(1, three_port[0].index);
  82. // Serialize to GraphDef.
  83. var graph_def = c_test_util.GetGraphDef(graph);
  84. // Validate GraphDef is what we expect.
  85. bool found_placeholder = false;
  86. bool found_scalar_const = false;
  87. bool found_add = false;
  88. foreach (var n in graph_def.Node)
  89. {
  90. if (c_test_util.IsPlaceholder(n))
  91. {
  92. Assert.IsFalse(found_placeholder);
  93. found_placeholder = true;
  94. }
  95. else if (c_test_util.IsScalarConst(n, 3))
  96. {
  97. Assert.IsFalse(found_scalar_const);
  98. found_scalar_const = true;
  99. }
  100. else if (c_test_util.IsAddN(n, 2))
  101. {
  102. Assert.IsFalse(found_add);
  103. found_add = true;
  104. }
  105. else
  106. {
  107. Assert.Fail($"Unexpected NodeDef: {n}");
  108. }
  109. }
  110. Assert.IsTrue(found_placeholder);
  111. Assert.IsTrue(found_scalar_const);
  112. Assert.IsTrue(found_add);
  113. // Add another oper to the graph.
  114. var neg = c_test_util.Neg(add, graph, s);
  115. Assert.AreEqual(TF_Code.TF_OK, s.Code);
  116. // Serialize to NodeDef.
  117. var node_def = c_test_util.GetNodeDef(neg);
  118. // Validate NodeDef is what we expect.
  119. Assert.IsTrue(c_test_util.IsNeg(node_def, "add"));
  120. // Serialize to GraphDef.
  121. var graph_def2 = c_test_util.GetGraphDef(graph);
  122. // Compare with first GraphDef + added NodeDef.
  123. graph_def.Node.Add(node_def);
  124. Assert.AreEqual(graph_def.ToString(), graph_def2.ToString());
  125. // Look up some nodes by name.
  126. Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg");
  127. Assert.AreEqual(neg, neg2);
  128. var node_def2 = c_test_util.GetNodeDef(neg2);
  129. Assert.AreEqual(node_def.ToString(), node_def2.ToString());
  130. Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed");
  131. Assert.AreEqual(feed, feed2);
  132. node_def = c_test_util.GetNodeDef(feed);
  133. node_def2 = c_test_util.GetNodeDef(feed2);
  134. Assert.AreEqual(node_def.ToString(), node_def2.ToString());
  135. // Test iterating through the nodes of a graph.
  136. found_placeholder = false;
  137. found_scalar_const = false;
  138. found_add = false;
  139. bool found_neg = false;
  140. uint pos = 0;
  141. Operation oper;
  142. while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
  143. {
  144. if (oper.Equals(feed))
  145. {
  146. Assert.IsFalse(found_placeholder);
  147. found_placeholder = true;
  148. }
  149. else if (oper.Equals(three))
  150. {
  151. Assert.IsFalse(found_scalar_const);
  152. found_scalar_const = true;
  153. }
  154. else if (oper.Equals(add))
  155. {
  156. Assert.IsFalse(found_add);
  157. found_add = true;
  158. }
  159. else if (oper.Equals(neg))
  160. {
  161. Assert.IsFalse(found_neg);
  162. found_neg = true;
  163. }
  164. else
  165. {
  166. node_def = c_test_util.GetNodeDef(oper);
  167. Assert.Fail($"Unexpected Node: {node_def.ToString()}");
  168. }
  169. }
  170. Assert.IsTrue(found_placeholder);
  171. Assert.IsTrue(found_scalar_const);
  172. Assert.IsTrue(found_add);
  173. Assert.IsTrue(found_neg);
  174. graph.Dispose();
  175. s.Dispose();
  176. }
  177. }
  178. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。

Contributors (1)