using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Collections.Generic; using System.Text; using Tensorflow; namespace TensorFlowNET.UnitTest { [TestClass] public class GraphTest { /// /// Port from c_api_test.cc /// `TEST(CAPI, Graph)` /// [TestMethod] public void c_api_Graph() { var s = new Status(); var graph = new Graph(); // Make a placeholder operation. var feed = c_test_util.Placeholder(graph, s); Assert.AreEqual("feed", feed.Name); Assert.AreEqual("Placeholder", feed.OpType); Assert.AreEqual("", feed.Device); Assert.AreEqual(1, feed.NumOutputs); Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0)); Assert.AreEqual(1, feed.OutputListLength("output")); Assert.AreEqual(0, feed.NumInputs); Assert.AreEqual(0, feed.OutputNumConsumers(0)); Assert.AreEqual(0, feed.NumControlInputs); Assert.AreEqual(0, feed.NumControlOutputs); AttrValue attr_value = null; Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); Assert.AreEqual(attr_value.Type, DataType.DtInt32); // Test not found errors in TF_Operation*() query functions. Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); // Make a constant oper with the scalar "3". var three = c_test_util.ScalarConst(3, graph, s); Assert.AreEqual(TF_Code.TF_OK, s.Code); // Add oper. var add = c_test_util.Add(feed, three, graph, s); Assert.AreEqual(TF_Code.TF_OK, s.Code); // Test TF_Operation*() query functions. Assert.AreEqual("add", add.Name); Assert.AreEqual("AddN", add.OpType); Assert.AreEqual("", add.Device); Assert.AreEqual(1, add.NumOutputs); Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0)); Assert.AreEqual(1, add.OutputListLength("sum")); Assert.AreEqual(TF_Code.TF_OK, s.Code); Assert.AreEqual(2, add.InputListLength("inputs")); Assert.AreEqual(TF_Code.TF_OK, s.Code); Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0)); Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1)); var add_in_0 = add.Input(0); Assert.AreEqual(feed, add_in_0.oper); Assert.AreEqual(0, add_in_0.index); var add_in_1 = add.Input(1); Assert.AreEqual(three, add_in_1.oper); Assert.AreEqual(0, add_in_1.index); Assert.AreEqual(0, add.OutputNumConsumers(0)); Assert.AreEqual(0, add.NumControlInputs); Assert.AreEqual(0, add.NumControlOutputs); Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); Assert.AreEqual(DataType.DtInt32, attr_value.Type); Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); Assert.AreEqual(2, attr_value.I); // Placeholder oper now has a consumer. Assert.AreEqual(1, feed.OutputNumConsumers(0)); TF_Input[] feed_port = feed.OutputConsumers(0, 1); Assert.AreEqual(1, feed_port.Length); Assert.AreEqual(add, feed_port[0].oper); Assert.AreEqual(0, feed_port[0].index); // The scalar const oper also has a consumer. Assert.AreEqual(1, three.OutputNumConsumers(0)); TF_Input[] three_port = three.OutputConsumers(0, 1); Assert.AreEqual(add, three_port[0].oper); Assert.AreEqual(1, three_port[0].index); // Serialize to GraphDef. var graph_def = c_test_util.GetGraphDef(graph); // Validate GraphDef is what we expect. bool found_placeholder = false; bool found_scalar_const = false; bool found_add = false; foreach (var n in graph_def.Node) { if (c_test_util.IsPlaceholder(n)) { Assert.IsFalse(found_placeholder); found_placeholder = true; } else if (c_test_util.IsScalarConst(n, 3)) { Assert.IsFalse(found_scalar_const); found_scalar_const = true; } else if (c_test_util.IsAddN(n, 2)) { Assert.IsFalse(found_add); found_add = true; } else { Assert.Fail($"Unexpected NodeDef: {n}"); } } Assert.IsTrue(found_placeholder); Assert.IsTrue(found_scalar_const); Assert.IsTrue(found_add); // Add another oper to the graph. var neg = c_test_util.Neg(add, graph, s); Assert.AreEqual(TF_Code.TF_OK, s.Code); // Serialize to NodeDef. var node_def = c_test_util.GetNodeDef(neg); // Validate NodeDef is what we expect. Assert.IsTrue(c_test_util.IsNeg(node_def, "add")); // Serialize to GraphDef. var graph_def2 = c_test_util.GetGraphDef(graph); // Compare with first GraphDef + added NodeDef. graph_def.Node.Add(node_def); Assert.AreEqual(graph_def.ToString(), graph_def2.ToString()); // Look up some nodes by name. Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); Assert.AreEqual(neg, neg2); var node_def2 = c_test_util.GetNodeDef(neg2); Assert.AreEqual(node_def.ToString(), node_def2.ToString()); Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); Assert.AreEqual(feed, feed2); node_def = c_test_util.GetNodeDef(feed); node_def2 = c_test_util.GetNodeDef(feed2); Assert.AreEqual(node_def.ToString(), node_def2.ToString()); // Test iterating through the nodes of a graph. found_placeholder = false; found_scalar_const = false; found_add = false; bool found_neg = false; uint pos = 0; Operation oper; while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) { if (oper.Equals(feed)) { Assert.IsFalse(found_placeholder); found_placeholder = true; } else if (oper.Equals(three)) { Assert.IsFalse(found_scalar_const); found_scalar_const = true; } else if (oper.Equals(add)) { Assert.IsFalse(found_add); found_add = true; } else if (oper.Equals(neg)) { Assert.IsFalse(found_neg); found_neg = true; } else { node_def = c_test_util.GetNodeDef(oper); Assert.Fail($"Unexpected Node: {node_def.ToString()}"); } } Assert.IsTrue(found_placeholder); Assert.IsTrue(found_scalar_const); Assert.IsTrue(found_add); Assert.IsTrue(found_neg); graph.Dispose(); s.Dispose(); } } }