diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 69408446..7ea8085d 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -171,6 +171,34 @@ namespace Tensorflow return $"{name}_{_names_in_use[name_key]}"; } + public TF_Output[] ReturnOutputs(IntPtr results) + { + IntPtr return_output_handle = IntPtr.Zero; + int num_return_outputs = 0; + c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); + TF_Output[] return_outputs = new TF_Output[num_return_outputs]; + for (int i = 0; i < num_return_outputs; i++) + { + return_outputs[i] = Marshal.PtrToStructure(return_output_handle + (Marshal.SizeOf() * i)); + } + + return return_outputs; + } + + public Operation[] ReturnOperations(IntPtr results) + { + IntPtr return_oper_handle = IntPtr.Zero; + int num_return_opers = 0; + c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_opers, ref return_oper_handle); + Operation[] return_opers = new Operation[num_return_opers]; + for (int i = 0; i < num_return_opers; i++) + { + // return_opers[i] = Marshal.PtrToStructure(return_oper_handle + (Marshal.SizeOf() * i)); + } + + return return_opers; + } + public Operation[] get_operations() { return _nodes_by_name.Values.Select(x => x).ToArray(); diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 9a836d10..4c68f68b 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -15,6 +15,13 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_DeleteGraph(IntPtr graph); + /// + /// + /// + /// TF_ImportGraphDefOptions* + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteImportGraphDefOptions(IntPtr opts); + [DllImport(TensorFlowLibName)] public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); @@ -31,6 +38,29 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); + /// + /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and + /// a bad status on error. Otherwise, returns a populated + /// TF_ImportGraphDefResults instance. The returned instance must be deleted via + /// TF_DeleteImportGraphDefResults(). + /// + /// TF_Graph* + /// const TF_Buffer* + /// const TF_ImportGraphDefOptions* + /// TF_Status* + /// TF_ImportGraphDefResults* + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); + + /// + /// Import the graph serialized in `graph_def` into `graph`. + /// + /// TF_Graph* + /// TF_Buffer* + /// TF_ImportGraphDefOptions* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); /// /// Iterate through the operations of a graph. /// @@ -80,7 +110,96 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status); + /// + /// Set any imported nodes with input `src_name:src_index` to have that input + /// replaced with `dst`. `src_name` refers to a node in the graph to be imported, + /// `dst` references a node already existing in the graph being imported into. + /// `src_name` is copied and has no lifetime requirements. + /// + /// TF_ImportGraphDefOptions* + /// const char* + /// int + /// TF_Output + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsAddInputMapping(IntPtr opts, string src_name, int src_index, TF_Output dst); + + /// + /// Add an operation in `graph_def` to be returned via the `return_opers` output + /// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no + // lifetime requirements. + /// + /// TF_ImportGraphDefOptions* opts + /// const char* + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsAddReturnOperation(IntPtr opts, string oper_name); + + /// + /// Add an output in `graph_def` to be returned via the `return_outputs` output + /// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input + /// mapping, the corresponding existing tensor in `graph` will be returned. + /// `oper_name` is copied and has no lifetime requirements. + /// + /// TF_ImportGraphDefOptions* + /// const char* + /// int + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsAddReturnOutput(IntPtr opts, string oper_name, int index); + + /// + /// Returns the number of return operations added via + /// TF_ImportGraphDefOptionsAddReturnOperation(). + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_ImportGraphDefOptionsNumReturnOperations(IntPtr opts); + + /// + /// Returns the number of return outputs added via + /// TF_ImportGraphDefOptionsAddReturnOutput(). + /// + /// const TF_ImportGraphDefOptions* + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts); + + /// + /// Set the prefix to be prepended to the names of nodes in `graph_def` that will + /// be imported into `graph`. `prefix` is copied and has no lifetime + /// requirements. + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix); + + /// + /// Fetches the return operations requested via + /// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched + /// operations is returned in `num_opers`. The array of return operations is + /// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. + /// + /// TF_ImportGraphDefResults* + /// int* + /// TF_Operation*** + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref IntPtr opers); + + /// + /// Fetches the return outputs requested via + /// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is + /// returned in `num_outputs`. The array of return outputs is returned in + /// `outputs`. `*outputs` is owned by and has the lifetime of `results`. + /// + /// TF_ImportGraphDefResults* results + /// int* + /// TF_Output** + [DllImport(TensorFlowLibName)] + public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_NewGraph(); + [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_NewGraph(); + public static extern IntPtr TF_NewImportGraphDefOptions(); } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 1bd440b3..a41974f2 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -41,8 +41,25 @@ namespace Tensorflow } public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); + + public Operation[] ControlInputs(int max_control_inputs) + { + var control_inputs = new Operation[NumControlInputs]; + var control_input_handles = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlInputs); + c_api.TF_OperationGetControlInputs(_handle, control_input_handles, max_control_inputs); + return control_inputs; + } + public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); + public Operation[] ControlOutputs(int max_control_outputs) + { + var control_outputs = new Operation[NumControlOutputs]; + var control_output_handles = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlOutputs); + c_api.TF_OperationGetControlInputs(_handle, control_output_handles, max_control_outputs); + return control_outputs; + } + private Tensor[] _outputs; public Tensor[] outputs => _outputs; public Tensor[] inputs; diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 0a090cbc..0fb55730 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -49,6 +49,35 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); + /// + /// Get list of all control inputs to an operation. `control_inputs` must + /// point to an array of length `max_control_inputs` (ideally set to + /// TF_OperationNumControlInputs(oper)). Returns the number of control + /// inputs (should match TF_OperationNumControlInputs(oper)). + /// + /// TF_Operation* + /// TF_Operation** + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationGetControlInputs(IntPtr oper, IntPtr control_inputs, int max_control_inputs); + + /// + /// Get the list of operations that have `*oper` as a control input. + /// `control_outputs` must point to an array of length at least + /// `max_control_outputs` (ideally set to + /// TF_OperationNumControlOutputs(oper)). Beware that a concurrent + /// modification of the graph can increase the number of control + /// outputs. Returns the number of control outputs (should match + /// TF_OperationNumControlOutputs(oper)). + /// + /// TF_Operation* + /// TF_Operation** + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationGetControlOutputs(IntPtr oper, IntPtr control_outputs, int max_control_outputs); + /// /// TF_Output producer = TF_OperationInput(consumer); /// There is an edge from producer.oper's output (given by @@ -105,14 +134,19 @@ namespace Tensorflow /// /// Get list of all current consumers of a specific output of an - /// operation. + /// operation. `consumers` must point to an array of length at least + /// `max_consumers` (ideally set to + /// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent + /// modification of the graph can increase the number of consumers of + /// an operation. Returns the number of output consumers (should match + /// TF_OperationOutputNumConsumers(oper_out)). /// /// /// /// /// [DllImport(TensorFlowLibName)] - public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input * consumers, int max_consumers); + public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, int max_consumers); [DllImport(TensorFlowLibName)] public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); diff --git a/test/TensorFlowNET.UnitTest/CApiTest.cs b/test/TensorFlowNET.UnitTest/CApiTest.cs new file mode 100644 index 00000000..fea2c6e4 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/CApiTest.cs @@ -0,0 +1,25 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; + +namespace TensorFlowNET.UnitTest +{ + public class CApiTest + { + public void EXPECT_EQ(object expected, object actual) + { + Assert.AreEqual(expected, actual); + } + + public void ASSERT_EQ(object expected, object actual) + { + Assert.AreEqual(expected, actual); + } + + public void ASSERT_TRUE(bool condition) + { + Assert.IsTrue(condition); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index e97644df..515b503c 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -1,13 +1,15 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Collections.Generic; +using System.Runtime.InteropServices; using System.Text; using Tensorflow; +using Buffer = Tensorflow.Buffer; namespace TensorFlowNET.UnitTest { [TestClass] - public class GraphTest + public class GraphTest : CApiTest { /// /// Port from c_api_test.cc @@ -21,74 +23,74 @@ namespace TensorFlowNET.UnitTest // Make a placeholder operation. var feed = c_test_util.Placeholder(graph, s); - Assert.AreEqual("feed", feed.Name); - Assert.AreEqual("Placeholder", feed.OpType); - Assert.AreEqual("", feed.Device); - Assert.AreEqual(1, feed.NumOutputs); - Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0)); - Assert.AreEqual(1, feed.OutputListLength("output")); - Assert.AreEqual(0, feed.NumInputs); - Assert.AreEqual(0, feed.OutputNumConsumers(0)); - Assert.AreEqual(0, feed.NumControlInputs); - Assert.AreEqual(0, feed.NumControlOutputs); + EXPECT_EQ("feed", feed.Name); + EXPECT_EQ("Placeholder", feed.OpType); + EXPECT_EQ("", feed.Device); + EXPECT_EQ(1, feed.NumOutputs); + EXPECT_EQ(TF_DataType.TF_INT32, feed.OutputType(0)); + EXPECT_EQ(1, feed.OutputListLength("output")); + EXPECT_EQ(0, feed.NumInputs); + EXPECT_EQ(0, feed.OutputNumConsumers(0)); + EXPECT_EQ(0, feed.NumControlInputs); + EXPECT_EQ(0, feed.NumControlOutputs); AttrValue attr_value = null; - Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); - Assert.AreEqual(attr_value.Type, DataType.DtInt32); + ASSERT_TRUE(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); + EXPECT_EQ(attr_value.Type, DataType.DtInt32); // Test not found errors in TF_Operation*() query functions. - Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); - Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); + EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); + EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code); Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); - Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); + EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message); // Make a constant oper with the scalar "3". var three = c_test_util.ScalarConst(3, graph, s); - Assert.AreEqual(TF_Code.TF_OK, s.Code); + EXPECT_EQ(TF_Code.TF_OK, s.Code); // Add oper. var add = c_test_util.Add(feed, three, graph, s); - Assert.AreEqual(TF_Code.TF_OK, s.Code); + EXPECT_EQ(TF_Code.TF_OK, s.Code); // Test TF_Operation*() query functions. - Assert.AreEqual("add", add.Name); - Assert.AreEqual("AddN", add.OpType); - Assert.AreEqual("", add.Device); - Assert.AreEqual(1, add.NumOutputs); - Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0)); - Assert.AreEqual(1, add.OutputListLength("sum")); - Assert.AreEqual(TF_Code.TF_OK, s.Code); - Assert.AreEqual(2, add.InputListLength("inputs")); - Assert.AreEqual(TF_Code.TF_OK, s.Code); - Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0)); - Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1)); + EXPECT_EQ("add", add.Name); + EXPECT_EQ("AddN", add.OpType); + EXPECT_EQ("", add.Device); + EXPECT_EQ(1, add.NumOutputs); + EXPECT_EQ(TF_DataType.TF_INT32, add.OutputType(0)); + EXPECT_EQ(1, add.OutputListLength("sum")); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + EXPECT_EQ(2, add.InputListLength("inputs")); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(0)); + EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(1)); var add_in_0 = add.Input(0); - Assert.AreEqual(feed, add_in_0.oper); - Assert.AreEqual(0, add_in_0.index); + EXPECT_EQ(feed, add_in_0.oper); + EXPECT_EQ(0, add_in_0.index); var add_in_1 = add.Input(1); - Assert.AreEqual(three, add_in_1.oper); - Assert.AreEqual(0, add_in_1.index); - Assert.AreEqual(0, add.OutputNumConsumers(0)); - Assert.AreEqual(0, add.NumControlInputs); - Assert.AreEqual(0, add.NumControlOutputs); + EXPECT_EQ(three, add_in_1.oper); + EXPECT_EQ(0, add_in_1.index); + EXPECT_EQ(0, add.OutputNumConsumers(0)); + EXPECT_EQ(0, add.NumControlInputs); + EXPECT_EQ(0, add.NumControlOutputs); - Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); - Assert.AreEqual(DataType.DtInt32, attr_value.Type); - Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); - Assert.AreEqual(2, attr_value.I); + ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); + EXPECT_EQ(DataType.DtInt32, attr_value.Type); + ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); + EXPECT_EQ(2, attr_value.I); // Placeholder oper now has a consumer. - Assert.AreEqual(1, feed.OutputNumConsumers(0)); + EXPECT_EQ(1, feed.OutputNumConsumers(0)); TF_Input[] feed_port = feed.OutputConsumers(0, 1); - Assert.AreEqual(1, feed_port.Length); - Assert.AreEqual(add, feed_port[0].oper); - Assert.AreEqual(0, feed_port[0].index); + EXPECT_EQ(1, feed_port.Length); + EXPECT_EQ(add, feed_port[0].oper); + EXPECT_EQ(0, feed_port[0].index); // The scalar const oper also has a consumer. - Assert.AreEqual(1, three.OutputNumConsumers(0)); + EXPECT_EQ(1, three.OutputNumConsumers(0)); TF_Input[] three_port = three.OutputConsumers(0, 1); - Assert.AreEqual(add, three_port[0].oper); - Assert.AreEqual(1, three_port[0].index); + EXPECT_EQ(add, three_port[0].oper); + EXPECT_EQ(1, three_port[0].index); // Serialize to GraphDef. var graph_def = c_test_util.GetGraphDef(graph); @@ -119,38 +121,38 @@ namespace TensorFlowNET.UnitTest Assert.Fail($"Unexpected NodeDef: {n}"); } } - Assert.IsTrue(found_placeholder); - Assert.IsTrue(found_scalar_const); - Assert.IsTrue(found_add); + ASSERT_TRUE(found_placeholder); + ASSERT_TRUE(found_scalar_const); + ASSERT_TRUE(found_add); // Add another oper to the graph. var neg = c_test_util.Neg(add, graph, s); - Assert.AreEqual(TF_Code.TF_OK, s.Code); + EXPECT_EQ(TF_Code.TF_OK, s.Code); // Serialize to NodeDef. var node_def = c_test_util.GetNodeDef(neg); // Validate NodeDef is what we expect. - Assert.IsTrue(c_test_util.IsNeg(node_def, "add")); + ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); // Serialize to GraphDef. var graph_def2 = c_test_util.GetGraphDef(graph); // Compare with first GraphDef + added NodeDef. graph_def.Node.Add(node_def); - Assert.AreEqual(graph_def.ToString(), graph_def2.ToString()); + EXPECT_EQ(graph_def.ToString(), graph_def2.ToString()); // Look up some nodes by name. Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); - Assert.AreEqual(neg, neg2); + EXPECT_EQ(neg, neg2); var node_def2 = c_test_util.GetNodeDef(neg2); - Assert.AreEqual(node_def.ToString(), node_def2.ToString()); + EXPECT_EQ(node_def.ToString(), node_def2.ToString()); Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); - Assert.AreEqual(feed, feed2); + EXPECT_EQ(feed, feed2); node_def = c_test_util.GetNodeDef(feed); node_def2 = c_test_util.GetNodeDef(feed2); - Assert.AreEqual(node_def.ToString(), node_def2.ToString()); + EXPECT_EQ(node_def.ToString(), node_def2.ToString()); // Test iterating through the nodes of a graph. found_placeholder = false; @@ -189,13 +191,106 @@ namespace TensorFlowNET.UnitTest } } - Assert.IsTrue(found_placeholder); - Assert.IsTrue(found_scalar_const); - Assert.IsTrue(found_add); - Assert.IsTrue(found_neg); + ASSERT_TRUE(found_placeholder); + ASSERT_TRUE(found_scalar_const); + ASSERT_TRUE(found_add); + ASSERT_TRUE(found_neg); graph.Dispose(); s.Dispose(); } + + /// + /// Port from c_api_test.cc + /// `TEST(CAPI, ImportGraphDef)` + /// + [TestMethod] + public void c_api_ImportGraphDef() + { + var s = new Status(); + var graph = new Graph(); + + // Create a simple graph. + c_test_util.Placeholder(graph, s); + var oper = c_test_util.ScalarConst(3, graph, s); + c_test_util.Neg(oper, graph, s); + + // Export to a GraphDef. + var graph_def = new Buffer(); + c_api.TF_GraphToGraphDef(graph, graph_def, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + + // Import it, with a prefix, in a fresh graph. + graph.Dispose(); + graph = new Graph(); + var opts = c_api.TF_NewImportGraphDefOptions(); + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); + c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + + Operation scalar = c_api.TF_GraphOperationByName(graph, "imported/scalar"); + Operation feed = c_api.TF_GraphOperationByName(graph, "imported/feed"); + Operation neg = c_api.TF_GraphOperationByName(graph, "imported/neg"); + + // Test basic structure of the imported graph. + EXPECT_EQ(0, scalar.NumInputs); + EXPECT_EQ(0, feed.NumInputs); + EXPECT_EQ(1, neg.NumInputs); + + var neg_input = neg.Input(0); + EXPECT_EQ(scalar, neg_input.oper); + EXPECT_EQ(0, neg_input.index); + + // Test that we can't see control edges involving the source and sink nodes. + EXPECT_EQ(0, scalar.NumControlInputs); + EXPECT_EQ(0, scalar.ControlInputs(100).Length); + EXPECT_EQ(0, scalar.NumControlOutputs); + EXPECT_EQ(0, scalar.ControlOutputs(100).Length); + + EXPECT_EQ(0, feed.NumControlInputs); + EXPECT_EQ(0, feed.ControlInputs(100).Length); + EXPECT_EQ(0, feed.NumControlOutputs); + EXPECT_EQ(0, feed.ControlOutputs(100).Length); + + EXPECT_EQ(0, neg.NumControlInputs); + EXPECT_EQ(0, neg.ControlInputs(100).Length); + EXPECT_EQ(0, neg.NumControlOutputs); + EXPECT_EQ(0, neg.ControlOutputs(100).Length); + + // Import it again, with an input mapping, return outputs, and a return + // operation, into the same graph. + c_api.TF_DeleteImportGraphDefOptions(opts); + opts = c_api.TF_NewImportGraphDefOptions(); + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); + c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); + c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); + c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); + EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); + c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); + EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); + var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); + EXPECT_EQ(TF_Code.TF_OK, s.Code); + + Operation scalar2 = c_api.TF_GraphOperationByName(graph, "imported2/scalar"); + Operation feed2 = c_api.TF_GraphOperationByName(graph, "imported2/feed"); + Operation neg2 = c_api.TF_GraphOperationByName(graph, "imported2/neg"); + + // Check input mapping + neg_input = neg.Input(0); + EXPECT_EQ(scalar, neg_input.oper); + EXPECT_EQ(0, neg_input.index); + + // Check return outputs + var return_outputs = graph.ReturnOutputs(results); + ASSERT_EQ(2, return_outputs.Length); + EXPECT_EQ(feed2, return_outputs[0].oper); + EXPECT_EQ(0, return_outputs[0].index); + EXPECT_EQ(scalar, return_outputs[1].oper); // remapped + EXPECT_EQ(0, return_outputs[1].index); + + // Check return operation + var num_return_opers = graph.ReturnOperations(results); + ASSERT_EQ(1, num_return_opers); + } } } diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 201e1411..65bab728 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -10,7 +10,7 @@ using Tensorflow; namespace TensorFlowNET.UnitTest { [TestClass] - public class TensorTest + public class TensorTest : CApiTest { /// /// Port from c_api_test.cc @@ -22,10 +22,10 @@ namespace TensorFlowNET.UnitTest ulong num_bytes = 6 * sizeof(float); long[] dims = { 2, 3 }; Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); - Assert.AreEqual(TF_DataType.TF_FLOAT, t.dtype); - Assert.AreEqual(2, t.NDims); + EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); + EXPECT_EQ(2, t.NDims); Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape)); - Assert.AreEqual(num_bytes, t.bytesize); + EXPECT_EQ(num_bytes, t.bytesize); t.Dispose(); } @@ -41,11 +41,11 @@ namespace TensorFlowNET.UnitTest var tensor = new Tensor(nd); var array = tensor.Data(); - Assert.AreEqual(tensor.dtype, TF_DataType.TF_FLOAT); - Assert.AreEqual(tensor.rank, nd.ndim); - Assert.AreEqual(tensor.shape[0], nd.shape[0]); - Assert.AreEqual(tensor.shape[1], nd.shape[1]); - Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float)); + EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); + EXPECT_EQ(tensor.rank, nd.ndim); + EXPECT_EQ(tensor.shape[0], nd.shape[0]); + EXPECT_EQ(tensor.shape[1], nd.shape[1]); + EXPECT_EQ(tensor.bytesize, (uint)nd.size * sizeof(float)); Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), array)); } @@ -66,20 +66,20 @@ namespace TensorFlowNET.UnitTest int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); - Assert.AreEqual(-1, num_dims); + EXPECT_EQ(-1, num_dims); // Set the shape to be unknown, expect no change. c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); - Assert.AreEqual(-1, num_dims); + EXPECT_EQ(-1, num_dims); // Set the shape to be 2 x Unknown long[] dims = { 2, -1 }; c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); - Assert.AreEqual(2, num_dims); + EXPECT_EQ(2, num_dims); // Get the dimension vector appropriately. var returned_dims = new long[dims.Length]; @@ -103,9 +103,9 @@ namespace TensorFlowNET.UnitTest Assert.IsTrue(s.Code == TF_Code.TF_OK); c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); - Assert.AreEqual(2, num_dims); - Assert.AreEqual(2, returned_dims[0]); - Assert.AreEqual(3, returned_dims[1]); + EXPECT_EQ(2, num_dims); + EXPECT_EQ(2, returned_dims[0]); + EXPECT_EQ(3, returned_dims[1]); // Try to set 'unknown' with same rank on the shape and see that // it doesn't change. @@ -115,9 +115,9 @@ namespace TensorFlowNET.UnitTest Assert.IsTrue(s.Code == TF_Code.TF_OK); c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); - Assert.AreEqual(2, num_dims); - Assert.AreEqual(2, returned_dims[0]); - Assert.AreEqual(3, returned_dims[1]); + EXPECT_EQ(2, num_dims); + EXPECT_EQ(2, returned_dims[0]); + EXPECT_EQ(3, returned_dims[1]); // Try to fetch a shape with the wrong num_dims c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); @@ -135,7 +135,7 @@ namespace TensorFlowNET.UnitTest num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); - Assert.AreEqual(0, num_dims); + EXPECT_EQ(0, num_dims); c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); //Assert.IsTrue(s.Code == TF_Code.TF_OK);