diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 16456ecd..83610c94 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -20,6 +20,38 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TFE_DeleteContextOptions(IntPtr options); + /// + /// Returns the length (number of tensors) of the input argument `input_name` + /// found in the provided `op`. + /// + /// TFE_Op* + /// const char* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern int TFE_OpGetInputLength(IntPtr op, string input_name, IntPtr status); + + /// + /// Returns the length (number of tensors) of the output argument `output_name` + /// found in the provided `op`. + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TFE_OpGetOutputLength(IntPtr op, string input_name, IntPtr status); + + /// + /// + /// + /// TFE_Op* + /// TFE_TensorHandle** + /// int + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, IntPtr status); + /// /// /// diff --git a/test/TensorFlowNET.UnitTest/CApiTest.cs b/test/TensorFlowNET.UnitTest/CApiTest.cs index 2ae410da..07e19109 100644 --- a/test/TensorFlowNET.UnitTest/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiTest.cs @@ -22,6 +22,9 @@ namespace TensorFlowNET.UnitTest protected void EXPECT_NE(object expected, object actual, string msg = "") => Assert.AreNotEqual(expected, actual, msg); + protected void CHECK_NE(object expected, object actual, string msg = "") + => Assert.AreNotEqual(expected, actual, msg); + protected void EXPECT_GE(int expected, int actual, string msg = "") => Assert.IsTrue(expected >= actual, msg); @@ -106,6 +109,15 @@ namespace TensorFlowNET.UnitTest protected void TFE_DeleteContextOptions(IntPtr opts) => c_api.TFE_DeleteContextOptions(opts); + protected int TFE_OpGetInputLength(IntPtr op, string input_name, IntPtr status) + => c_api.TFE_OpGetInputLength(op, input_name, status); + + protected int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, IntPtr status) + => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); + + protected int TFE_OpGetOutputLength(IntPtr op, string input_name, IntPtr status) + => c_api.TFE_OpGetOutputLength(op, input_name, status); + protected void TFE_DeleteTensorHandle(IntPtr h) => c_api.TFE_DeleteTensorHandle(h); diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs new file mode 100644 index 00000000..789b4135 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs @@ -0,0 +1,64 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow; +using Tensorflow.Eager; +using Buffer = System.Buffer; + +namespace TensorFlowNET.UnitTest.Eager +{ + public partial class CApiEagerTest + { + /// + /// TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) + /// + [TestMethod] + public unsafe void OpGetInputAndOutputLengths() + { + var status = TF_NewStatus(); + var opts = TFE_NewContextOptions(); + var ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_DeleteContextOptions(opts); + + var input1 = TestMatrixTensorHandle(); + var input2 = TestMatrixTensorHandle(); + var identityOp = TFE_NewOp(ctx, "IdentityN", status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + // Try to retrieve lengths before building the attributes (should fail) + EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); + + var inputs = new IntPtr[] { input1, input2 }; + TFE_OpAddInputList(identityOp, inputs, 2, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + // Try to retrieve lengths before executing the op (should work) + EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + var retvals = new IntPtr[2]; + int num_retvals = 2; + TFE_Execute(identityOp, retvals, ref num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + // Try to retrieve lengths after executing the op (should work) + EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + + TF_DeleteStatus(status); + TFE_DeleteOp(identityOp); + TFE_DeleteTensorHandle(input1); + TFE_DeleteTensorHandle(input2); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(retvals[1]); + TFE_DeleteContext(ctx); + } + } +}