| @@ -171,6 +171,34 @@ namespace Tensorflow | |||||
| return $"{name}_{_names_in_use[name_key]}"; | return $"{name}_{_names_in_use[name_key]}"; | ||||
| } | } | ||||
| public TF_Output[] ReturnOutputs(IntPtr results) | |||||
| { | |||||
| IntPtr return_output_handle = IntPtr.Zero; | |||||
| int num_return_outputs = 0; | |||||
| c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); | |||||
| TF_Output[] return_outputs = new TF_Output[num_return_outputs]; | |||||
| for (int i = 0; i < num_return_outputs; i++) | |||||
| { | |||||
| return_outputs[i] = Marshal.PtrToStructure<TF_Output>(return_output_handle + (Marshal.SizeOf<TF_Output>() * i)); | |||||
| } | |||||
| return return_outputs; | |||||
| } | |||||
| public Operation[] ReturnOperations(IntPtr results) | |||||
| { | |||||
| IntPtr return_oper_handle = IntPtr.Zero; | |||||
| int num_return_opers = 0; | |||||
| c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_opers, ref return_oper_handle); | |||||
| Operation[] return_opers = new Operation[num_return_opers]; | |||||
| for (int i = 0; i < num_return_opers; i++) | |||||
| { | |||||
| // return_opers[i] = Marshal.PtrToStructure<TF_Output>(return_oper_handle + (Marshal.SizeOf<TF_Output>() * i)); | |||||
| } | |||||
| return return_opers; | |||||
| } | |||||
| public Operation[] get_operations() | public Operation[] get_operations() | ||||
| { | { | ||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
| @@ -15,6 +15,13 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_DeleteGraph(IntPtr graph); | public static extern void TF_DeleteGraph(IntPtr graph); | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="opts">TF_ImportGraphDefOptions*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_DeleteImportGraphDefOptions(IntPtr opts); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | ||||
| @@ -31,6 +38,29 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | ||||
| /// <summary> | |||||
| /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | |||||
| /// a bad status on error. Otherwise, returns a populated | |||||
| /// TF_ImportGraphDefResults instance. The returned instance must be deleted via | |||||
| /// TF_DeleteImportGraphDefResults(). | |||||
| /// </summary> | |||||
| /// <param name="graph">TF_Graph*</param> | |||||
| /// <param name="graph_def">const TF_Buffer*</param> | |||||
| /// <param name="options">const TF_ImportGraphDefOptions*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <returns>TF_ImportGraphDefResults*</returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TF_GraphImportGraphDefWithResults(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); | |||||
| /// <summary> | |||||
| /// Import the graph serialized in `graph_def` into `graph`. | |||||
| /// </summary> | |||||
| /// <param name="graph">TF_Graph*</param> | |||||
| /// <param name="graph_def">TF_Buffer*</param> | |||||
| /// <param name="options">TF_ImportGraphDefOptions*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Iterate through the operations of a graph. | /// Iterate through the operations of a graph. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -80,7 +110,96 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status); | public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status); | ||||
| /// <summary> | |||||
| /// Set any imported nodes with input `src_name:src_index` to have that input | |||||
| /// replaced with `dst`. `src_name` refers to a node in the graph to be imported, | |||||
| /// `dst` references a node already existing in the graph being imported into. | |||||
| /// `src_name` is copied and has no lifetime requirements. | |||||
| /// </summary> | |||||
| /// <param name="opts">TF_ImportGraphDefOptions*</param> | |||||
| /// <param name="src_name">const char*</param> | |||||
| /// <param name="src_index">int</param> | |||||
| /// <param name="dst">TF_Output</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_ImportGraphDefOptionsAddInputMapping(IntPtr opts, string src_name, int src_index, TF_Output dst); | |||||
| /// <summary> | |||||
| /// Add an operation in `graph_def` to be returned via the `return_opers` output | |||||
| /// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no | |||||
| // lifetime requirements. | |||||
| /// </summary> | |||||
| /// <param name="opts">TF_ImportGraphDefOptions* opts</param> | |||||
| /// <param name="oper_name">const char*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_ImportGraphDefOptionsAddReturnOperation(IntPtr opts, string oper_name); | |||||
| /// <summary> | |||||
| /// Add an output in `graph_def` to be returned via the `return_outputs` output | |||||
| /// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input | |||||
| /// mapping, the corresponding existing tensor in `graph` will be returned. | |||||
| /// `oper_name` is copied and has no lifetime requirements. | |||||
| /// </summary> | |||||
| /// <param name="opts">TF_ImportGraphDefOptions*</param> | |||||
| /// <param name="oper_name">const char*</param> | |||||
| /// <param name="index">int</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_ImportGraphDefOptionsAddReturnOutput(IntPtr opts, string oper_name, int index); | |||||
| /// <summary> | |||||
| /// Returns the number of return operations added via | |||||
| /// TF_ImportGraphDefOptionsAddReturnOperation(). | |||||
| /// </summary> | |||||
| /// <param name="opts"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TF_ImportGraphDefOptionsNumReturnOperations(IntPtr opts); | |||||
| /// <summary> | |||||
| /// Returns the number of return outputs added via | |||||
| /// TF_ImportGraphDefOptionsAddReturnOutput(). | |||||
| /// </summary> | |||||
| /// <param name="opts">const TF_ImportGraphDefOptions*</param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TF_ImportGraphDefOptionsNumReturnOutputs(IntPtr opts); | |||||
| /// <summary> | |||||
| /// Set the prefix to be prepended to the names of nodes in `graph_def` that will | |||||
| /// be imported into `graph`. `prefix` is copied and has no lifetime | |||||
| /// requirements. | |||||
| /// </summary> | |||||
| /// <param name="ops"></param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_ImportGraphDefOptionsSetPrefix(IntPtr ops, string prefix); | |||||
| /// <summary> | |||||
| /// Fetches the return operations requested via | |||||
| /// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched | |||||
| /// operations is returned in `num_opers`. The array of return operations is | |||||
| /// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. | |||||
| /// </summary> | |||||
| /// <param name="results">TF_ImportGraphDefResults*</param> | |||||
| /// <param name="num_opers">int*</param> | |||||
| /// <param name="opers">TF_Operation***</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_ImportGraphDefResultsReturnOperations(IntPtr results, ref int num_opers, ref IntPtr opers); | |||||
| /// <summary> | |||||
| /// Fetches the return outputs requested via | |||||
| /// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is | |||||
| /// returned in `num_outputs`. The array of return outputs is returned in | |||||
| /// `outputs`. `*outputs` is owned by and has the lifetime of `results`. | |||||
| /// </summary> | |||||
| /// <param name="results">TF_ImportGraphDefResults* results</param> | |||||
| /// <param name="num_outputs">int*</param> | |||||
| /// <param name="outputs">TF_Output**</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TF_NewGraph(); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_NewGraph(); | |||||
| public static extern IntPtr TF_NewImportGraphDefOptions(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -41,8 +41,25 @@ namespace Tensorflow | |||||
| } | } | ||||
| public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | ||||
| public Operation[] ControlInputs(int max_control_inputs) | |||||
| { | |||||
| var control_inputs = new Operation[NumControlInputs]; | |||||
| var control_input_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs); | |||||
| c_api.TF_OperationGetControlInputs(_handle, control_input_handles, max_control_inputs); | |||||
| return control_inputs; | |||||
| } | |||||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | ||||
| public Operation[] ControlOutputs(int max_control_outputs) | |||||
| { | |||||
| var control_outputs = new Operation[NumControlOutputs]; | |||||
| var control_output_handles = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | |||||
| c_api.TF_OperationGetControlInputs(_handle, control_output_handles, max_control_outputs); | |||||
| return control_outputs; | |||||
| } | |||||
| private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
| public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
| public Tensor[] inputs; | public Tensor[] inputs; | ||||
| @@ -49,6 +49,35 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); | public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, IntPtr output_attr_value, IntPtr status); | ||||
| /// <summary> | |||||
| /// Get list of all control inputs to an operation. `control_inputs` must | |||||
| /// point to an array of length `max_control_inputs` (ideally set to | |||||
| /// TF_OperationNumControlInputs(oper)). Returns the number of control | |||||
| /// inputs (should match TF_OperationNumControlInputs(oper)). | |||||
| /// </summary> | |||||
| /// <param name="oper">TF_Operation*</param> | |||||
| /// <param name="control_inputs">TF_Operation**</param> | |||||
| /// <param name="max_control_inputs"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TF_OperationGetControlInputs(IntPtr oper, IntPtr control_inputs, int max_control_inputs); | |||||
| /// <summary> | |||||
| /// Get the list of operations that have `*oper` as a control input. | |||||
| /// `control_outputs` must point to an array of length at least | |||||
| /// `max_control_outputs` (ideally set to | |||||
| /// TF_OperationNumControlOutputs(oper)). Beware that a concurrent | |||||
| /// modification of the graph can increase the number of control | |||||
| /// outputs. Returns the number of control outputs (should match | |||||
| /// TF_OperationNumControlOutputs(oper)). | |||||
| /// </summary> | |||||
| /// <param name="oper">TF_Operation*</param> | |||||
| /// <param name="control_outputs">TF_Operation**</param> | |||||
| /// <param name="max_control_outputs"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TF_OperationGetControlOutputs(IntPtr oper, IntPtr control_outputs, int max_control_outputs); | |||||
| /// <summary> | /// <summary> | ||||
| /// TF_Output producer = TF_OperationInput(consumer); | /// TF_Output producer = TF_OperationInput(consumer); | ||||
| /// There is an edge from producer.oper's output (given by | /// There is an edge from producer.oper's output (given by | ||||
| @@ -105,14 +134,19 @@ namespace Tensorflow | |||||
| /// <summary> | /// <summary> | ||||
| /// Get list of all current consumers of a specific output of an | /// Get list of all current consumers of a specific output of an | ||||
| /// operation. | |||||
| /// operation. `consumers` must point to an array of length at least | |||||
| /// `max_consumers` (ideally set to | |||||
| /// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent | |||||
| /// modification of the graph can increase the number of consumers of | |||||
| /// an operation. Returns the number of output consumers (should match | |||||
| /// TF_OperationOutputNumConsumers(oper_out)). | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="oper_out"></param> | /// <param name="oper_out"></param> | ||||
| /// <param name="consumers"></param> | /// <param name="consumers"></param> | ||||
| /// <param name="max_consumers"></param> | /// <param name="max_consumers"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input * consumers, int max_consumers); | |||||
| public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, int max_consumers); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | ||||
| @@ -0,0 +1,25 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | |||||
| public class CApiTest | |||||
| { | |||||
| public void EXPECT_EQ(object expected, object actual) | |||||
| { | |||||
| Assert.AreEqual(expected, actual); | |||||
| } | |||||
| public void ASSERT_EQ(object expected, object actual) | |||||
| { | |||||
| Assert.AreEqual(expected, actual); | |||||
| } | |||||
| public void ASSERT_TRUE(bool condition) | |||||
| { | |||||
| Assert.IsTrue(condition); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,13 +1,15 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Buffer = Tensorflow.Buffer; | |||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class GraphTest | |||||
| public class GraphTest : CApiTest | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Port from c_api_test.cc | /// Port from c_api_test.cc | ||||
| @@ -21,74 +23,74 @@ namespace TensorFlowNET.UnitTest | |||||
| // Make a placeholder operation. | // Make a placeholder operation. | ||||
| var feed = c_test_util.Placeholder(graph, s); | var feed = c_test_util.Placeholder(graph, s); | ||||
| Assert.AreEqual("feed", feed.Name); | |||||
| Assert.AreEqual("Placeholder", feed.OpType); | |||||
| Assert.AreEqual("", feed.Device); | |||||
| Assert.AreEqual(1, feed.NumOutputs); | |||||
| Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType(0)); | |||||
| Assert.AreEqual(1, feed.OutputListLength("output")); | |||||
| Assert.AreEqual(0, feed.NumInputs); | |||||
| Assert.AreEqual(0, feed.OutputNumConsumers(0)); | |||||
| Assert.AreEqual(0, feed.NumControlInputs); | |||||
| Assert.AreEqual(0, feed.NumControlOutputs); | |||||
| EXPECT_EQ("feed", feed.Name); | |||||
| EXPECT_EQ("Placeholder", feed.OpType); | |||||
| EXPECT_EQ("", feed.Device); | |||||
| EXPECT_EQ(1, feed.NumOutputs); | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, feed.OutputType(0)); | |||||
| EXPECT_EQ(1, feed.OutputListLength("output")); | |||||
| EXPECT_EQ(0, feed.NumInputs); | |||||
| EXPECT_EQ(0, feed.OutputNumConsumers(0)); | |||||
| EXPECT_EQ(0, feed.NumControlInputs); | |||||
| EXPECT_EQ(0, feed.NumControlOutputs); | |||||
| AttrValue attr_value = null; | AttrValue attr_value = null; | ||||
| Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); | |||||
| Assert.AreEqual(attr_value.Type, DataType.DtInt32); | |||||
| ASSERT_TRUE(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); | |||||
| EXPECT_EQ(attr_value.Type, DataType.DtInt32); | |||||
| // Test not found errors in TF_Operation*() query functions. | // Test not found errors in TF_Operation*() query functions. | ||||
| Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||||
| Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||||
| EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||||
| EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||||
| Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | ||||
| Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); | |||||
| EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message); | |||||
| // Make a constant oper with the scalar "3". | // Make a constant oper with the scalar "3". | ||||
| var three = c_test_util.ScalarConst(3, graph, s); | var three = c_test_util.ScalarConst(3, graph, s); | ||||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| // Add oper. | // Add oper. | ||||
| var add = c_test_util.Add(feed, three, graph, s); | var add = c_test_util.Add(feed, three, graph, s); | ||||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| // Test TF_Operation*() query functions. | // Test TF_Operation*() query functions. | ||||
| Assert.AreEqual("add", add.Name); | |||||
| Assert.AreEqual("AddN", add.OpType); | |||||
| Assert.AreEqual("", add.Device); | |||||
| Assert.AreEqual(1, add.NumOutputs); | |||||
| Assert.AreEqual(TF_DataType.TF_INT32, add.OutputType(0)); | |||||
| Assert.AreEqual(1, add.OutputListLength("sum")); | |||||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||||
| Assert.AreEqual(2, add.InputListLength("inputs")); | |||||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||||
| Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(0)); | |||||
| Assert.AreEqual(TF_DataType.TF_INT32, add.InputType(1)); | |||||
| EXPECT_EQ("add", add.Name); | |||||
| EXPECT_EQ("AddN", add.OpType); | |||||
| EXPECT_EQ("", add.Device); | |||||
| EXPECT_EQ(1, add.NumOutputs); | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, add.OutputType(0)); | |||||
| EXPECT_EQ(1, add.OutputListLength("sum")); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| EXPECT_EQ(2, add.InputListLength("inputs")); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(0)); | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(1)); | |||||
| var add_in_0 = add.Input(0); | var add_in_0 = add.Input(0); | ||||
| Assert.AreEqual(feed, add_in_0.oper); | |||||
| Assert.AreEqual(0, add_in_0.index); | |||||
| EXPECT_EQ(feed, add_in_0.oper); | |||||
| EXPECT_EQ(0, add_in_0.index); | |||||
| var add_in_1 = add.Input(1); | var add_in_1 = add.Input(1); | ||||
| Assert.AreEqual(three, add_in_1.oper); | |||||
| Assert.AreEqual(0, add_in_1.index); | |||||
| Assert.AreEqual(0, add.OutputNumConsumers(0)); | |||||
| Assert.AreEqual(0, add.NumControlInputs); | |||||
| Assert.AreEqual(0, add.NumControlOutputs); | |||||
| EXPECT_EQ(three, add_in_1.oper); | |||||
| EXPECT_EQ(0, add_in_1.index); | |||||
| EXPECT_EQ(0, add.OutputNumConsumers(0)); | |||||
| EXPECT_EQ(0, add.NumControlInputs); | |||||
| EXPECT_EQ(0, add.NumControlOutputs); | |||||
| Assert.IsTrue(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); | |||||
| Assert.AreEqual(DataType.DtInt32, attr_value.Type); | |||||
| Assert.IsTrue(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); | |||||
| Assert.AreEqual(2, attr_value.I); | |||||
| ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s)); | |||||
| EXPECT_EQ(DataType.DtInt32, attr_value.Type); | |||||
| ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s)); | |||||
| EXPECT_EQ(2, attr_value.I); | |||||
| // Placeholder oper now has a consumer. | // Placeholder oper now has a consumer. | ||||
| Assert.AreEqual(1, feed.OutputNumConsumers(0)); | |||||
| EXPECT_EQ(1, feed.OutputNumConsumers(0)); | |||||
| TF_Input[] feed_port = feed.OutputConsumers(0, 1); | TF_Input[] feed_port = feed.OutputConsumers(0, 1); | ||||
| Assert.AreEqual(1, feed_port.Length); | |||||
| Assert.AreEqual(add, feed_port[0].oper); | |||||
| Assert.AreEqual(0, feed_port[0].index); | |||||
| EXPECT_EQ(1, feed_port.Length); | |||||
| EXPECT_EQ(add, feed_port[0].oper); | |||||
| EXPECT_EQ(0, feed_port[0].index); | |||||
| // The scalar const oper also has a consumer. | // The scalar const oper also has a consumer. | ||||
| Assert.AreEqual(1, three.OutputNumConsumers(0)); | |||||
| EXPECT_EQ(1, three.OutputNumConsumers(0)); | |||||
| TF_Input[] three_port = three.OutputConsumers(0, 1); | TF_Input[] three_port = three.OutputConsumers(0, 1); | ||||
| Assert.AreEqual(add, three_port[0].oper); | |||||
| Assert.AreEqual(1, three_port[0].index); | |||||
| EXPECT_EQ(add, three_port[0].oper); | |||||
| EXPECT_EQ(1, three_port[0].index); | |||||
| // Serialize to GraphDef. | // Serialize to GraphDef. | ||||
| var graph_def = c_test_util.GetGraphDef(graph); | var graph_def = c_test_util.GetGraphDef(graph); | ||||
| @@ -119,38 +121,38 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.Fail($"Unexpected NodeDef: {n}"); | Assert.Fail($"Unexpected NodeDef: {n}"); | ||||
| } | } | ||||
| } | } | ||||
| Assert.IsTrue(found_placeholder); | |||||
| Assert.IsTrue(found_scalar_const); | |||||
| Assert.IsTrue(found_add); | |||||
| ASSERT_TRUE(found_placeholder); | |||||
| ASSERT_TRUE(found_scalar_const); | |||||
| ASSERT_TRUE(found_add); | |||||
| // Add another oper to the graph. | // Add another oper to the graph. | ||||
| var neg = c_test_util.Neg(add, graph, s); | var neg = c_test_util.Neg(add, graph, s); | ||||
| Assert.AreEqual(TF_Code.TF_OK, s.Code); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| // Serialize to NodeDef. | // Serialize to NodeDef. | ||||
| var node_def = c_test_util.GetNodeDef(neg); | var node_def = c_test_util.GetNodeDef(neg); | ||||
| // Validate NodeDef is what we expect. | // Validate NodeDef is what we expect. | ||||
| Assert.IsTrue(c_test_util.IsNeg(node_def, "add")); | |||||
| ASSERT_TRUE(c_test_util.IsNeg(node_def, "add")); | |||||
| // Serialize to GraphDef. | // Serialize to GraphDef. | ||||
| var graph_def2 = c_test_util.GetGraphDef(graph); | var graph_def2 = c_test_util.GetGraphDef(graph); | ||||
| // Compare with first GraphDef + added NodeDef. | // Compare with first GraphDef + added NodeDef. | ||||
| graph_def.Node.Add(node_def); | graph_def.Node.Add(node_def); | ||||
| Assert.AreEqual(graph_def.ToString(), graph_def2.ToString()); | |||||
| EXPECT_EQ(graph_def.ToString(), graph_def2.ToString()); | |||||
| // Look up some nodes by name. | // Look up some nodes by name. | ||||
| Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); | Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg"); | ||||
| Assert.AreEqual(neg, neg2); | |||||
| EXPECT_EQ(neg, neg2); | |||||
| var node_def2 = c_test_util.GetNodeDef(neg2); | var node_def2 = c_test_util.GetNodeDef(neg2); | ||||
| Assert.AreEqual(node_def.ToString(), node_def2.ToString()); | |||||
| EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||||
| Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); | Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed"); | ||||
| Assert.AreEqual(feed, feed2); | |||||
| EXPECT_EQ(feed, feed2); | |||||
| node_def = c_test_util.GetNodeDef(feed); | node_def = c_test_util.GetNodeDef(feed); | ||||
| node_def2 = c_test_util.GetNodeDef(feed2); | node_def2 = c_test_util.GetNodeDef(feed2); | ||||
| Assert.AreEqual(node_def.ToString(), node_def2.ToString()); | |||||
| EXPECT_EQ(node_def.ToString(), node_def2.ToString()); | |||||
| // Test iterating through the nodes of a graph. | // Test iterating through the nodes of a graph. | ||||
| found_placeholder = false; | found_placeholder = false; | ||||
| @@ -189,13 +191,106 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| } | } | ||||
| Assert.IsTrue(found_placeholder); | |||||
| Assert.IsTrue(found_scalar_const); | |||||
| Assert.IsTrue(found_add); | |||||
| Assert.IsTrue(found_neg); | |||||
| ASSERT_TRUE(found_placeholder); | |||||
| ASSERT_TRUE(found_scalar_const); | |||||
| ASSERT_TRUE(found_add); | |||||
| ASSERT_TRUE(found_neg); | |||||
| graph.Dispose(); | graph.Dispose(); | ||||
| s.Dispose(); | s.Dispose(); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Port from c_api_test.cc | |||||
| /// `TEST(CAPI, ImportGraphDef)` | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void c_api_ImportGraphDef() | |||||
| { | |||||
| var s = new Status(); | |||||
| var graph = new Graph(); | |||||
| // Create a simple graph. | |||||
| c_test_util.Placeholder(graph, s); | |||||
| var oper = c_test_util.ScalarConst(3, graph, s); | |||||
| c_test_util.Neg(oper, graph, s); | |||||
| // Export to a GraphDef. | |||||
| var graph_def = new Buffer(); | |||||
| c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| // Import it, with a prefix, in a fresh graph. | |||||
| graph.Dispose(); | |||||
| graph = new Graph(); | |||||
| var opts = c_api.TF_NewImportGraphDefOptions(); | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | |||||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| Operation scalar = c_api.TF_GraphOperationByName(graph, "imported/scalar"); | |||||
| Operation feed = c_api.TF_GraphOperationByName(graph, "imported/feed"); | |||||
| Operation neg = c_api.TF_GraphOperationByName(graph, "imported/neg"); | |||||
| // Test basic structure of the imported graph. | |||||
| EXPECT_EQ(0, scalar.NumInputs); | |||||
| EXPECT_EQ(0, feed.NumInputs); | |||||
| EXPECT_EQ(1, neg.NumInputs); | |||||
| var neg_input = neg.Input(0); | |||||
| EXPECT_EQ(scalar, neg_input.oper); | |||||
| EXPECT_EQ(0, neg_input.index); | |||||
| // Test that we can't see control edges involving the source and sink nodes. | |||||
| EXPECT_EQ(0, scalar.NumControlInputs); | |||||
| EXPECT_EQ(0, scalar.ControlInputs(100).Length); | |||||
| EXPECT_EQ(0, scalar.NumControlOutputs); | |||||
| EXPECT_EQ(0, scalar.ControlOutputs(100).Length); | |||||
| EXPECT_EQ(0, feed.NumControlInputs); | |||||
| EXPECT_EQ(0, feed.ControlInputs(100).Length); | |||||
| EXPECT_EQ(0, feed.NumControlOutputs); | |||||
| EXPECT_EQ(0, feed.ControlOutputs(100).Length); | |||||
| EXPECT_EQ(0, neg.NumControlInputs); | |||||
| EXPECT_EQ(0, neg.ControlInputs(100).Length); | |||||
| EXPECT_EQ(0, neg.NumControlOutputs); | |||||
| EXPECT_EQ(0, neg.ControlOutputs(100).Length); | |||||
| // Import it again, with an input mapping, return outputs, and a return | |||||
| // operation, into the same graph. | |||||
| c_api.TF_DeleteImportGraphDefOptions(opts); | |||||
| opts = c_api.TF_NewImportGraphDefOptions(); | |||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); | |||||
| c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0)); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); | |||||
| EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | |||||
| EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||||
| var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); | |||||
| EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||||
| Operation scalar2 = c_api.TF_GraphOperationByName(graph, "imported2/scalar"); | |||||
| Operation feed2 = c_api.TF_GraphOperationByName(graph, "imported2/feed"); | |||||
| Operation neg2 = c_api.TF_GraphOperationByName(graph, "imported2/neg"); | |||||
| // Check input mapping | |||||
| neg_input = neg.Input(0); | |||||
| EXPECT_EQ(scalar, neg_input.oper); | |||||
| EXPECT_EQ(0, neg_input.index); | |||||
| // Check return outputs | |||||
| var return_outputs = graph.ReturnOutputs(results); | |||||
| ASSERT_EQ(2, return_outputs.Length); | |||||
| EXPECT_EQ(feed2, return_outputs[0].oper); | |||||
| EXPECT_EQ(0, return_outputs[0].index); | |||||
| EXPECT_EQ(scalar, return_outputs[1].oper); // remapped | |||||
| EXPECT_EQ(0, return_outputs[1].index); | |||||
| // Check return operation | |||||
| var num_return_opers = graph.ReturnOperations(results); | |||||
| ASSERT_EQ(1, num_return_opers); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -10,7 +10,7 @@ using Tensorflow; | |||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class TensorTest | |||||
| public class TensorTest : CApiTest | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Port from c_api_test.cc | /// Port from c_api_test.cc | ||||
| @@ -22,10 +22,10 @@ namespace TensorFlowNET.UnitTest | |||||
| ulong num_bytes = 6 * sizeof(float); | ulong num_bytes = 6 * sizeof(float); | ||||
| long[] dims = { 2, 3 }; | long[] dims = { 2, 3 }; | ||||
| Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | ||||
| Assert.AreEqual(TF_DataType.TF_FLOAT, t.dtype); | |||||
| Assert.AreEqual(2, t.NDims); | |||||
| EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); | |||||
| EXPECT_EQ(2, t.NDims); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape)); | Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape)); | ||||
| Assert.AreEqual(num_bytes, t.bytesize); | |||||
| EXPECT_EQ(num_bytes, t.bytesize); | |||||
| t.Dispose(); | t.Dispose(); | ||||
| } | } | ||||
| @@ -41,11 +41,11 @@ namespace TensorFlowNET.UnitTest | |||||
| var tensor = new Tensor(nd); | var tensor = new Tensor(nd); | ||||
| var array = tensor.Data<float>(); | var array = tensor.Data<float>(); | ||||
| Assert.AreEqual(tensor.dtype, TF_DataType.TF_FLOAT); | |||||
| Assert.AreEqual(tensor.rank, nd.ndim); | |||||
| Assert.AreEqual(tensor.shape[0], nd.shape[0]); | |||||
| Assert.AreEqual(tensor.shape[1], nd.shape[1]); | |||||
| Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float)); | |||||
| EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | |||||
| EXPECT_EQ(tensor.rank, nd.ndim); | |||||
| EXPECT_EQ(tensor.shape[0], nd.shape[0]); | |||||
| EXPECT_EQ(tensor.shape[1], nd.shape[1]); | |||||
| EXPECT_EQ(tensor.bytesize, (uint)nd.size * sizeof(float)); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | ||||
| } | } | ||||
| @@ -66,20 +66,20 @@ namespace TensorFlowNET.UnitTest | |||||
| int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| Assert.AreEqual(-1, num_dims); | |||||
| EXPECT_EQ(-1, num_dims); | |||||
| // Set the shape to be unknown, expect no change. | // Set the shape to be unknown, expect no change. | ||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | ||||
| Assert.AreEqual(-1, num_dims); | |||||
| EXPECT_EQ(-1, num_dims); | |||||
| // Set the shape to be 2 x Unknown | // Set the shape to be 2 x Unknown | ||||
| long[] dims = { 2, -1 }; | long[] dims = { 2, -1 }; | ||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | ||||
| Assert.AreEqual(2, num_dims); | |||||
| EXPECT_EQ(2, num_dims); | |||||
| // Get the dimension vector appropriately. | // Get the dimension vector appropriately. | ||||
| var returned_dims = new long[dims.Length]; | var returned_dims = new long[dims.Length]; | ||||
| @@ -103,9 +103,9 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| Assert.AreEqual(2, num_dims); | |||||
| Assert.AreEqual(2, returned_dims[0]); | |||||
| Assert.AreEqual(3, returned_dims[1]); | |||||
| EXPECT_EQ(2, num_dims); | |||||
| EXPECT_EQ(2, returned_dims[0]); | |||||
| EXPECT_EQ(3, returned_dims[1]); | |||||
| // Try to set 'unknown' with same rank on the shape and see that | // Try to set 'unknown' with same rank on the shape and see that | ||||
| // it doesn't change. | // it doesn't change. | ||||
| @@ -115,9 +115,9 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| Assert.AreEqual(2, num_dims); | |||||
| Assert.AreEqual(2, returned_dims[0]); | |||||
| Assert.AreEqual(3, returned_dims[1]); | |||||
| EXPECT_EQ(2, num_dims); | |||||
| EXPECT_EQ(2, returned_dims[0]); | |||||
| EXPECT_EQ(3, returned_dims[1]); | |||||
| // Try to fetch a shape with the wrong num_dims | // Try to fetch a shape with the wrong num_dims | ||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | ||||
| @@ -135,7 +135,7 @@ namespace TensorFlowNET.UnitTest | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| Assert.AreEqual(0, num_dims); | |||||
| EXPECT_EQ(0, num_dims); | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); | ||||
| //Assert.IsTrue(s.Code == TF_Code.TF_OK); | //Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||