You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTest.cs 5.1 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. namespace TensorFlowNET.UnitTest
  7. {
  8. [TestClass]
  9. public class GraphTest
  10. {
  11. /// <summary>
  12. /// Port from c_api_test.cc
  13. /// `TEST(CAPI, Graph)`
  14. /// </summary>
  15. [TestMethod]
  16. public void c_api_Graph()
  17. {
  18. var s = new Status();
  19. var graph = new Graph();
  20. // Make a placeholder operation.
  21. var feed = c_test_util.Placeholder(graph, s);
  22. Assert.AreEqual("feed", feed.Name);
  23. Assert.AreEqual("Placeholder", feed.OpType);
  24. Assert.AreEqual("", feed.Device);
  25. Assert.AreEqual(1, feed.NumOutputs);
  26. Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0));
  27. Assert.AreEqual(1, feed.OutputListLength("output"));
  28. Assert.AreEqual(0, feed.NumInputs);
  29. Assert.AreEqual(0, feed.OutputNumConsumers(0));
  30. Assert.AreEqual(0, feed.NumControlInputs);
  31. Assert.AreEqual(0, feed.NumControlOutputs);
  32. AttrValue attr_value = null;
  33. Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s));
  34. Assert.AreEqual(attr_value.Type, DataType.DtInt32);
  35. // Test not found errors in TF_Operation*() query functions.
  36. Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
  37. Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code);
  38. Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
  39. Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message);
  40. // Make a constant oper with the scalar "3".
  41. var three = c_test_util.ScalarConst(3, graph, s);
  42. Assert.AreEqual(TF_Code.TF_OK, s.Code);
  43. // Add oper.
  44. var add = c_test_util.Add(feed, three, graph, s);
  45. Assert.AreEqual(TF_Code.TF_OK, s.Code);
  46. // Test TF_Operation*() query functions.
  47. Assert.AreEqual("add", add.Name);
  48. Assert.AreEqual("AddN", add.OpType);
  49. Assert.AreEqual("", add.Device);
  50. Assert.AreEqual(1, add.NumOutputs);
  51. Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0));
  52. Assert.AreEqual(1, add.OutputListLength("sum"));
  53. Assert.AreEqual(TF_Code.TF_OK, s.Code);
  54. Assert.AreEqual(2, add.InputListLength("inputs"));
  55. Assert.AreEqual(TF_Code.TF_OK, s.Code);
  56. Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0));
  57. Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1));
  58. var add_in_0 = add.Input(0);
  59. Assert.AreEqual(feed, add_in_0.oper);
  60. Assert.AreEqual(0, add_in_0.index);
  61. var add_in_1 = add.Input(1);
  62. Assert.AreEqual(three, add_in_1.oper);
  63. Assert.AreEqual(0, add_in_1.index);
  64. Assert.AreEqual(0, add.OutputNumConsumers(0));
  65. Assert.AreEqual(0, add.NumControlInputs);
  66. Assert.AreEqual(0, add.NumControlOutputs);
  67. Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s));
  68. Assert.AreEqual(DataType.DtInt32, attr_value.Type);
  69. Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s));
  70. Assert.AreEqual(2, attr_value.I);
  71. // Placeholder oper now has a consumer.
  72. Assert.AreEqual(1, feed.OutputNumConsumers(0));
  73. TF_Input[] feed_port = feed.OutputConsumers(0, 1);
  74. Assert.AreEqual(1, feed_port.Length);
  75. Assert.AreEqual(add, feed_port[0].oper);
  76. Assert.AreEqual(0, feed_port[0].index);
  77. // The scalar const oper also has a consumer.
  78. Assert.AreEqual(1, three.OutputNumConsumers(0));
  79. TF_Input[] three_port = three.OutputConsumers(0, 1);
  80. Assert.AreEqual(add, three_port[0].oper);
  81. Assert.AreEqual(1, three_port[0].index);
  82. // Serialize to GraphDef.
  83. var graph_def = c_test_util.GetGraphDef(graph);
  84. // Validate GraphDef is what we expect.
  85. bool found_placeholder = false;
  86. bool found_scalar_const = false;
  87. bool found_add = false;
  88. foreach (var n in graph_def.Node)
  89. {
  90. if (c_test_util.IsPlaceholder(n))
  91. {
  92. Assert.IsFalse(found_placeholder);
  93. found_placeholder = true;
  94. }
  95. /*else if (IsScalarConst(n, 3))
  96. {
  97. Assert.IsFalse(found_scalar_const);
  98. found_scalar_const = true;
  99. }
  100. else if (IsAddN(n, 2))
  101. {
  102. Assert.IsFalse(found_add);
  103. found_add = true;
  104. }
  105. else
  106. {
  107. ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n);
  108. }*/
  109. }
  110. }
  111. }
  112. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。

Contributors (1)