| @@ -171,6 +171,34 @@ namespace Tensorflow | |||
| return $"{name}_{_names_in_use[name_key]}"; | |||
| } | |||
| public TF_Output[] ReturnOutputs(IntPtr results) | |||
| { | |||
| IntPtr return_output_handle = IntPtr.Zero; | |||
| int num_return_outputs = 0; | |||
| c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); | |||
| TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | |||
| for (int i = 0; i < num_return_outputs; i++) | |||
| { | |||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(return_output_handle + (Marshal.SizeOf<TF_Output>() * i)); | |||
| } | |||
| return return_outputs; | |||
| } | |||
| public Operation[] ReturnOperations(IntPtr results) | |||
| { | |||
| IntPtr return_oper_handle = IntPtr.Zero; | |||
| int num_return_opers = 0; | |||
| c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_opers, ref return_oper_handle); | |||
| Operation[] return_opers = new Operation[num_return_opers]; | |||
| for (int i = 0; i < num_return_opers; i++) | |||
| { | |||
| // return_opers[i] = Marshal.PtrToStructure<TF_Output>(return_oper_handle + (Marshal.SizeOf<TF_Output>() * i)); | |||
| } | |||
| return return_opers; | |||
| } | |||
| public Operation[] get_operations() | |||
| { | |||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
| @@ -15,6 +15,13 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_DeleteGraph(IntPtr graph); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="opts">TF_ImportGraphDefOptions*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_DeleteImportGraphDefOptions(IntPtr opts); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | |||
| @@ -31,6 +38,29 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||
| /// <summary> | |||
| /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | |||
| /// a bad status on error. Otherwise, returns a populated | |||
| /// TF_ImportGraphDefResults instance. The returned instance must be deleted via | |||
| /// TF_DeleteImportGraphDefResults(). | |||
| /// </summary> | |||
| /// <param name="graph">TF_Graph*</param> | |||
| /// <param name="graph_def">const TF_Buffer*</param> | |||
| /// <param name="options">const TF_ImportGraphDefOptions*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns>TF_ImportGraphDefResults*</returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); | |||
| /// <summary> | |||
| /// Import the graph serialized in `graph_def` into `graph`. | |||
| /// </summary> | |||
| /// <param name="graph">TF_Graph*</param> | |||
| /// <param name="graph_def">TF_Buffer*</param> | |||
| /// <param name="options">TF_ImportGraphDefOptions*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); | |||
| /// <summary> | |||
| /// Iterate through the operations of a graph. | |||
| /// </summary> | |||
| @@ -80,7 +110,96 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status); | |||
| /// <summary> | |||
| /// Set any imported nodes with input `src_name:src_index` to have that input | |||
| /// replaced with `dst`. `src_name` refers to a node in the graph to be imported, | |||
| /// `dst` references a node already existing in the graph being imported into. | |||
| /// `src_name` is copied and has no lifetime requirements. | |||
| /// </summary> | |||
| /// <param name="opts">TF_ImportGraphDefOptions*</param> | |||
| /// <param name="src_name">const char*</param> | |||
| /// <param name="src_index">int</param> | |||
| /// <param name="dst">TF_Output</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsAddInputMapping(IntPtr opts, string src_name, int src_index, TF_Output dst); | |||
| /// <summary> | |||
| /// Add an operation in `graph_def` to be returned via the `return_opers` output | |||
| /// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no | |||
| // lifetime requirements. | |||
| /// </summary> | |||
| /// <param name="opts">TF_ImportGraphDefOptions* opts</param> | |||
| /// <param name="oper_name">const char*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsAddReturnOperation(IntPtr opts, string oper_name); | |||
| /// <summary> | |||
| /// Add an output in `graph_def` to be returned via the `return_outputs` output | |||
| /// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input | |||
| /// mapping, the corresponding existing tensor in `graph` will be returned. | |||
| /// `oper_name` is copied and has no lifetime requirements. | |||
| /// </summary> | |||
| /// <param name="opts">TF_ImportGraphDefOptions*</param> | |||
| /// <param name="oper_name">const char*</param> | |||
| /// <param name="index">int</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsAddReturnOutput(IntPtr opts, string oper_name, int index); | |||
| /// <summary> | |||
| /// Returns the number of return operations added via | |||
| /// TF_ImportGraphDefOptionsAddReturnOperation(). | |||
| /// </summary> | |||
| /// <param name="opts"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_ImportGraphDefOptionsNumReturnOperations(IntPtr opts); | |||
| /// <summary> | |||
| /// Returns the number of return outputs added via | |||
| /// TF_ImportGraphDefOptionsAddReturnOutput(). | |||
| /// </summary> | |||
| /// <param name="opts">const TF_ImportGraphDefOptions*</param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts); | |||
| /// <summary> | |||
| /// Set the prefix to be prepended to the names of nodes in `graph_def` that will | |||
| /// be imported into `graph`. `prefix` is copied and has no lifetime | |||
| /// requirements. | |||
| /// </summary> | |||
| /// <param name="ops"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix); | |||
| /// <summary> | |||
| /// Fetches the return operations requested via | |||
| /// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched | |||
| /// operations is returned in `num_opers`. The array of return operations is | |||
| /// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. | |||
| /// </summary> | |||
| /// <param name="results">TF_ImportGraphDefResults*</param> | |||
| /// <param name="num_opers">int*</param> | |||
| /// <param name="opers">TF_Operation***</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref IntPtr opers); | |||
| /// <summary> | |||
| /// Fetches the return outputs requested via | |||
| /// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is | |||
| /// returned in `num_outputs`. The array of return outputs is returned in | |||
| /// `outputs`. `*outputs` is owned by and has the lifetime of `results`. | |||
| /// </summary> | |||
| /// <param name="results">TF_ImportGraphDefResults* results</param> | |||
| /// <param name="num_outputs">int*</param> | |||
| /// <param name="outputs">TF_Output**</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_NewGraph(); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static unsafe extern IntPtr TF_NewGraph(); | |||
| public static extern IntPtr TF_NewImportGraphDefOptions(); | |||
| } | |||
| } | |||
| @@ -41,8 +41,25 @@ namespace Tensorflow | |||
| } | |||
| public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||
| public Operation[] ControlInputs(int max_control_inputs) | |||
| { | |||
| var control_inputs = new Operation[NumControlInputs]; | |||
| var control_input_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs); | |||
| c_api.TF_OperationGetControlInputs(_handle, control_input_handles, max_control_inputs); | |||
| return control_inputs; | |||
| } | |||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||
| public Operation[] ControlOutputs(int max_control_outputs) | |||
| { | |||
| var control_outputs = new Operation[NumControlOutputs]; | |||
| var control_output_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | |||
| c_api.TF_OperationGetControlInputs(_handle, control_output_handles, max_control_outputs); | |||
| return control_outputs; | |||
| } | |||
| private Tensor[] _outputs; | |||
| public Tensor[] outputs => _outputs; | |||
| public Tensor[] inputs; | |||
| @@ -49,6 +49,35 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); | |||
| /// <summary> | |||
| /// Get list of all control inputs to an operation. `control_inputs` must | |||
| /// point to an array of length `max_control_inputs` (ideally set to | |||
| /// TF_OperationNumControlInputs(oper)). Returns the number of control | |||
| /// inputs (should match TF_OperationNumControlInputs(oper)). | |||
| /// </summary> | |||
| /// <param name="oper">TF_Operation*</param> | |||
| /// <param name="control_inputs">TF_Operation**</param> | |||
| /// <param name="max_control_inputs"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationGetControlInputs(IntPtr oper, IntPtr control_inputs, int max_control_inputs); | |||
| /// <summary> | |||
| /// Get the list of operations that have `*oper` as a control input. | |||
| /// `control_outputs` must point to an array of length at least | |||
| /// `max_control_outputs` (ideally set to | |||
| /// TF_OperationNumControlOutputs(oper)). Beware that a concurrent | |||
| /// modification of the graph can increase the number of control | |||
| /// outputs. Returns the number of control outputs (should match | |||
| /// TF_OperationNumControlOutputs(oper)). | |||
| /// </summary> | |||
| /// <param name="oper">TF_Operation*</param> | |||
| /// <param name="control_outputs">TF_Operation**</param> | |||
| /// <param name="max_control_outputs"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationGetControlOutputs(IntPtr oper, IntPtr control_outputs, int max_control_outputs); | |||
| /// <summary> | |||
| /// TF_Output producer = TF_OperationInput(consumer); | |||
| /// There is an edge from producer.oper's output (given by | |||
| @@ -105,14 +134,19 @@ namespace Tensorflow | |||
| /// <summary> | |||
| /// Get list of all current consumers of a specific output of an | |||
| /// operation. | |||
| /// operation. `consumers` must point to an array of length at least | |||
| /// `max_consumers` (ideally set to | |||
| /// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent | |||
| /// modification of the graph can increase the number of consumers of | |||
| /// an operation. Returns the number of output consumers (should match | |||
| /// TF_OperationOutputNumConsumers(oper_out)). | |||
| /// </summary> | |||
| /// <param name="oper_out"></param> | |||
| /// <param name="consumers"></param> | |||
| /// <param name="max_consumers"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input * consumers, int max_consumers); | |||
| public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, int max_consumers); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | |||
| @@ -0,0 +1,25 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| public class CApiTest | |||
| { | |||
| public void EXPECT_EQ(object expected, object actual) | |||
| { | |||
| Assert.AreEqual(expected, actual); | |||
| } | |||
| public void ASSERT_EQ(object expected, object actual) | |||
| { | |||
| Assert.AreEqual(expected, actual); | |||
| } | |||
| public void ASSERT_TRUE(bool condition) | |||
| { | |||
| Assert.IsTrue(condition); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,13 +1,15 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using Buffer = Tensorflow.Buffer; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class GraphTest | |||
| public class GraphTest : CApiTest | |||
| { | |||
| /// <summary> | |||
| /// Port from c_api_test.cc | |||
| @@ -21,74 +23,74 @@ namespace TensorFlowNET.UnitTest | |||
| // Make a placeholder operation. | |||
| var feed = c_test_util.Placeholder(graph, s); | |||
| Assert.AreEqual("feed", feed.Name); | |||
| Assert.AreEqual("Placeholder", feed.OpType); | |||
| Assert.AreEqual("", feed.Device); | |||
| Assert.AreEqual(1, feed.NumOutputs); | |||
| Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0)); | |||
| Assert.AreEqual(1, feed.OutputListLength("output")); | |||
| Assert.AreEqual(0, feed.NumInputs); | |||
| Assert.AreEqual(0, feed.OutputNumConsumers(0)); | |||
| Assert.AreEqual(0, feed.NumControlInputs); | |||
| Assert.AreEqual(0, feed.NumControlOutputs); | |||
| EXPECT_EQ("feed", feed.Name); | |||
| EXPECT_EQ("Placeholder", feed.OpType); | |||
| EXPECT_EQ("", feed.Device); | |||
| EXPECT_EQ(1, feed.NumOutputs); | |||
| EXPECT_EQ(TF_DataType.TF_INT32, feed.OutputType(0)); | |||
| EXPECT_EQ(1, feed.OutputListLength("output")); | |||
| EXPECT_EQ(0, feed.NumInputs); | |||
| EXPECT_EQ(0, feed.OutputNumConsumers(0)); | |||
| EXPECT_EQ(0, feed.NumControlInputs); | |||
| EXPECT_EQ(0, feed.NumControlOutputs); | |||
| AttrValue attr_value = null; | |||
| Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); | |||
| Assert.AreEqual(attr_value.Type, DataType.DtInt32); | |||
| ASSERT_TRUE(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); | |||
| EXPECT_EQ(attr_value.Type, DataType.DtInt32); | |||
| // Test not found errors in TF_Operation*() query functions. | |||
| Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||
| Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||
| EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||
| EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||
| Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||
| Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); | |||
| EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message); | |||
| // Make a constant oper with the scalar "3". | |||
| var three = c_test_util.ScalarConst(3, graph, s); | |||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Add oper. | |||
| var add = c_test_util.Add(feed, three, graph, s); | |||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Test TF_Operation*() query functions. | |||
| Assert.AreEqual("add", add.Name); | |||
| Assert.AreEqual("AddN", add.OpType); | |||
| Assert.AreEqual("", add.Device); | |||
| Assert.AreEqual(1, add.NumOutputs); | |||
| Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0)); | |||
| Assert.AreEqual(1, add.OutputListLength("sum")); | |||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
| Assert.AreEqual(2, add.InputListLength("inputs")); | |||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
| Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0)); | |||
| Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1)); | |||
| EXPECT_EQ("add", add.Name); | |||
| EXPECT_EQ("AddN", add.OpType); | |||
| EXPECT_EQ("", add.Device); | |||
| EXPECT_EQ(1, add.NumOutputs); | |||
| EXPECT_EQ(TF_DataType.TF_INT32, add.OutputType(0)); | |||
| EXPECT_EQ(1, add.OutputListLength("sum")); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| EXPECT_EQ(2, add.InputListLength("inputs")); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(0)); | |||
| EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(1)); | |||
| var add_in_0 = add.Input(0); | |||
| Assert.AreEqual(feed, add_in_0.oper); | |||
| Assert.AreEqual(0, add_in_0.index); | |||
| EXPECT_EQ(feed, add_in_0.oper); | |||
| EXPECT_EQ(0, add_in_0.index); | |||
| var add_in_1 = add.Input(1); | |||
| Assert.AreEqual(three, add_in_1.oper); | |||
| Assert.AreEqual(0, add_in_1.index); | |||
| Assert.AreEqual(0, add.OutputNumConsumers(0)); | |||
| Assert.AreEqual(0, add.NumControlInputs); | |||
| Assert.AreEqual(0, add.NumControlOutputs); | |||
| EXPECT_EQ(three, add_in_1.oper); | |||
| EXPECT_EQ(0, add_in_1.index); | |||
| EXPECT_EQ(0, add.OutputNumConsumers(0)); | |||
| EXPECT_EQ(0, add.NumControlInputs); | |||
| EXPECT_EQ(0, add.NumControlOutputs); | |||
| Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); | |||
| Assert.AreEqual(DataType.DtInt32, attr_value.Type); | |||
| Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); | |||
| Assert.AreEqual(2, attr_value.I); | |||
| ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); | |||
| EXPECT_EQ(DataType.DtInt32, attr_value.Type); | |||
| ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); | |||
| EXPECT_EQ(2, attr_value.I); | |||
| // Placeholder oper now has a consumer. | |||
| Assert.AreEqual(1, feed.OutputNumConsumers(0)); | |||
| EXPECT_EQ(1, feed.OutputNumConsumers(0)); | |||
| TF_Input[] feed_port = feed.OutputConsumers(0, 1); | |||
| Assert.AreEqual(1, feed_port.Length); | |||
| Assert.AreEqual(add, feed_port[0].oper); | |||
| Assert.AreEqual(0, feed_port[0].index); | |||
| EXPECT_EQ(1, feed_port.Length); | |||
| EXPECT_EQ(add, feed_port[0].oper); | |||
| EXPECT_EQ(0, feed_port[0].index); | |||
| // The scalar const oper also has a consumer. | |||
| Assert.AreEqual(1, three.OutputNumConsumers(0)); | |||
| EXPECT_EQ(1, three.OutputNumConsumers(0)); | |||
| TF_Input[] three_port = three.OutputConsumers(0, 1); | |||
| Assert.AreEqual(add, three_port[0].oper); | |||
| Assert.AreEqual(1, three_port[0].index); | |||
| EXPECT_EQ(add, three_port[0].oper); | |||
| EXPECT_EQ(1, three_port[0].index); | |||
| // Serialize to GraphDef. | |||
| var graph_def = c_test_util.GetGraphDef(graph); | |||
| @@ -119,38 +121,38 @@ namespace TensorFlowNET.UnitTest | |||
| Assert.Fail($"Unexpected NodeDef: {n}"); | |||
| } | |||
| } | |||
| Assert.IsTrue(found_placeholder); | |||
| Assert.IsTrue(found_scalar_const); | |||
| Assert.IsTrue(found_add); | |||
| ASSERT_TRUE(found_placeholder); | |||
| ASSERT_TRUE(found_scalar_const); | |||
| ASSERT_TRUE(found_add); | |||
| // Add another oper to the graph. | |||
| var neg = c_test_util.Neg(add, graph, s); | |||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Serialize to NodeDef. | |||
| var node_def = c_test_util.GetNodeDef(neg); | |||
| // Validate NodeDef is what we expect. | |||
| Assert.IsTrue(c_test_util.IsNeg(node_def, "add")); | |||
| ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); | |||
| // Serialize to GraphDef. | |||
| var graph_def2 = c_test_util.GetGraphDef(graph); | |||
| // Compare with first GraphDef + added NodeDef. | |||
| graph_def.Node.Add(node_def); | |||
| Assert.AreEqual(graph_def.ToString(), graph_def2.ToString()); | |||
| EXPECT_EQ(graph_def.ToString(), graph_def2.ToString()); | |||
| // Look up some nodes by name. | |||
| Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); | |||
| Assert.AreEqual(neg, neg2); | |||
| EXPECT_EQ(neg, neg2); | |||
| var node_def2 = c_test_util.GetNodeDef(neg2); | |||
| Assert.AreEqual(node_def.ToString(), node_def2.ToString()); | |||
| EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||
| Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); | |||
| Assert.AreEqual(feed, feed2); | |||
| EXPECT_EQ(feed, feed2); | |||
| node_def = c_test_util.GetNodeDef(feed); | |||
| node_def2 = c_test_util.GetNodeDef(feed2); | |||
| Assert.AreEqual(node_def.ToString(), node_def2.ToString()); | |||
| EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||
| // Test iterating through the nodes of a graph. | |||
| found_placeholder = false; | |||
| @@ -189,13 +191,106 @@ namespace TensorFlowNET.UnitTest | |||
| } | |||
| } | |||
| Assert.IsTrue(found_placeholder); | |||
| Assert.IsTrue(found_scalar_const); | |||
| Assert.IsTrue(found_add); | |||
| Assert.IsTrue(found_neg); | |||
| ASSERT_TRUE(found_placeholder); | |||
| ASSERT_TRUE(found_scalar_const); | |||
| ASSERT_TRUE(found_add); | |||
| ASSERT_TRUE(found_neg); | |||
| graph.Dispose(); | |||
| s.Dispose(); | |||
| } | |||
| /// <summary> | |||
| /// Port from c_api_test.cc | |||
| /// `TEST(CAPI, ImportGraphDef)` | |||
| /// </summary> | |||
| [TestMethod] | |||
| public void c_api_ImportGraphDef() | |||
| { | |||
| var s = new Status(); | |||
| var graph = new Graph(); | |||
| // Create a simple graph. | |||
| c_test_util.Placeholder(graph, s); | |||
| var oper = c_test_util.ScalarConst(3, graph, s); | |||
| c_test_util.Neg(oper, graph, s); | |||
| // Export to a GraphDef. | |||
| var graph_def = new Buffer(); | |||
| c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| // Import it, with a prefix, in a fresh graph. | |||
| graph.Dispose(); | |||
| graph = new Graph(); | |||
| var opts = c_api.TF_NewImportGraphDefOptions(); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | |||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| Operation scalar = c_api.TF_GraphOperationByName(graph, "imported/scalar"); | |||
| Operation feed = c_api.TF_GraphOperationByName(graph, "imported/feed"); | |||
| Operation neg = c_api.TF_GraphOperationByName(graph, "imported/neg"); | |||
| // Test basic structure of the imported graph. | |||
| EXPECT_EQ(0, scalar.NumInputs); | |||
| EXPECT_EQ(0, feed.NumInputs); | |||
| EXPECT_EQ(1, neg.NumInputs); | |||
| var neg_input = neg.Input(0); | |||
| EXPECT_EQ(scalar, neg_input.oper); | |||
| EXPECT_EQ(0, neg_input.index); | |||
| // Test that we can't see control edges involving the source and sink nodes. | |||
| EXPECT_EQ(0, scalar.NumControlInputs); | |||
| EXPECT_EQ(0, scalar.ControlInputs(100).Length); | |||
| EXPECT_EQ(0, scalar.NumControlOutputs); | |||
| EXPECT_EQ(0, scalar.ControlOutputs(100).Length); | |||
| EXPECT_EQ(0, feed.NumControlInputs); | |||
| EXPECT_EQ(0, feed.ControlInputs(100).Length); | |||
| EXPECT_EQ(0, feed.NumControlOutputs); | |||
| EXPECT_EQ(0, feed.ControlOutputs(100).Length); | |||
| EXPECT_EQ(0, neg.NumControlInputs); | |||
| EXPECT_EQ(0, neg.ControlInputs(100).Length); | |||
| EXPECT_EQ(0, neg.NumControlOutputs); | |||
| EXPECT_EQ(0, neg.ControlOutputs(100).Length); | |||
| // Import it again, with an input mapping, return outputs, and a return | |||
| // operation, into the same graph. | |||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||
| opts = c_api.TF_NewImportGraphDefOptions(); | |||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); | |||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); | |||
| EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | |||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | |||
| EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||
| var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); | |||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
| Operation scalar2 = c_api.TF_GraphOperationByName(graph, "imported2/scalar"); | |||
| Operation feed2 = c_api.TF_GraphOperationByName(graph, "imported2/feed"); | |||
| Operation neg2 = c_api.TF_GraphOperationByName(graph, "imported2/neg"); | |||
| // Check input mapping | |||
| neg_input = neg.Input(0); | |||
| EXPECT_EQ(scalar, neg_input.oper); | |||
| EXPECT_EQ(0, neg_input.index); | |||
| // Check return outputs | |||
| var return_outputs = graph.ReturnOutputs(results); | |||
| ASSERT_EQ(2, return_outputs.Length); | |||
| EXPECT_EQ(feed2, return_outputs[0].oper); | |||
| EXPECT_EQ(0, return_outputs[0].index); | |||
| EXPECT_EQ(scalar, return_outputs[1].oper); // remapped | |||
| EXPECT_EQ(0, return_outputs[1].index); | |||
| // Check return operation | |||
| var num_return_opers = graph.ReturnOperations(results); | |||
| ASSERT_EQ(1, num_return_opers); | |||
| } | |||
| } | |||
| } | |||
| @@ -10,7 +10,7 @@ using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class TensorTest | |||
| public class TensorTest : CApiTest | |||
| { | |||
| /// <summary> | |||
| /// Port from c_api_test.cc | |||
| @@ -22,10 +22,10 @@ namespace TensorFlowNET.UnitTest | |||
| ulong num_bytes = 6 * sizeof(float); | |||
| long[] dims = { 2, 3 }; | |||
| Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | |||
| Assert.AreEqual(TF_DataType.TF_FLOAT, t.dtype); | |||
| Assert.AreEqual(2, t.NDims); | |||
| EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); | |||
| EXPECT_EQ(2, t.NDims); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape)); | |||
| Assert.AreEqual(num_bytes, t.bytesize); | |||
| EXPECT_EQ(num_bytes, t.bytesize); | |||
| t.Dispose(); | |||
| } | |||
| @@ -41,11 +41,11 @@ namespace TensorFlowNET.UnitTest | |||
| var tensor = new Tensor(nd); | |||
| var array = tensor.Data<float>(); | |||
| Assert.AreEqual(tensor.dtype, TF_DataType.TF_FLOAT); | |||
| Assert.AreEqual(tensor.rank, nd.ndim); | |||
| Assert.AreEqual(tensor.shape[0], nd.shape[0]); | |||
| Assert.AreEqual(tensor.shape[1], nd.shape[1]); | |||
| Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float)); | |||
| EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | |||
| EXPECT_EQ(tensor.rank, nd.ndim); | |||
| EXPECT_EQ(tensor.shape[0], nd.shape[0]); | |||
| EXPECT_EQ(tensor.shape[1], nd.shape[1]); | |||
| EXPECT_EQ(tensor.bytesize, (uint)nd.size * sizeof(float)); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | |||
| } | |||
| @@ -66,20 +66,20 @@ namespace TensorFlowNET.UnitTest | |||
| int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| Assert.AreEqual(-1, num_dims); | |||
| EXPECT_EQ(-1, num_dims); | |||
| // Set the shape to be unknown, expect no change. | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
| Assert.AreEqual(-1, num_dims); | |||
| EXPECT_EQ(-1, num_dims); | |||
| // Set the shape to be 2 x Unknown | |||
| long[] dims = { 2, -1 }; | |||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
| Assert.AreEqual(2, num_dims); | |||
| EXPECT_EQ(2, num_dims); | |||
| // Get the dimension vector appropriately. | |||
| var returned_dims = new long[dims.Length]; | |||
| @@ -103,9 +103,9 @@ namespace TensorFlowNET.UnitTest | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| Assert.AreEqual(2, num_dims); | |||
| Assert.AreEqual(2, returned_dims[0]); | |||
| Assert.AreEqual(3, returned_dims[1]); | |||
| EXPECT_EQ(2, num_dims); | |||
| EXPECT_EQ(2, returned_dims[0]); | |||
| EXPECT_EQ(3, returned_dims[1]); | |||
| // Try to set 'unknown' with same rank on the shape and see that | |||
| // it doesn't change. | |||
| @@ -115,9 +115,9 @@ namespace TensorFlowNET.UnitTest | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| Assert.AreEqual(2, num_dims); | |||
| Assert.AreEqual(2, returned_dims[0]); | |||
| Assert.AreEqual(3, returned_dims[1]); | |||
| EXPECT_EQ(2, num_dims); | |||
| EXPECT_EQ(2, returned_dims[0]); | |||
| EXPECT_EQ(3, returned_dims[1]); | |||
| // Try to fetch a shape with the wrong num_dims | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | |||
| @@ -135,7 +135,7 @@ namespace TensorFlowNET.UnitTest | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| Assert.AreEqual(0, num_dims); | |||
| EXPECT_EQ(0, num_dims); | |||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); | |||
| //Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||