| @@ -20,6 +20,38 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_DeleteContextOptions(IntPtr options); | |||
| /// <summary> | |||
| /// Returns the length (number of tensors) of the input argument `input_name` | |||
| /// found in the provided `op`. | |||
| /// </summary> | |||
| /// <param name="op">TFE_Op*</param> | |||
| /// <param name="input_name">const char*</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TFE_OpGetInputLength(IntPtr op, string input_name, IntPtr status); | |||
| /// <summary> | |||
| /// Returns the length (number of tensors) of the output argument `output_name` | |||
| /// found in the provided `op`. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| /// <param name="input_name"></param> | |||
| /// <param name="status"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TFE_OpGetOutputLength(IntPtr op, string input_name, IntPtr status); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="op">TFE_Op*</param> | |||
| /// <param name="inputs">TFE_TensorHandle**</param> | |||
| /// <param name="num_inputs">int</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, IntPtr status); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| @@ -22,6 +22,9 @@ namespace TensorFlowNET.UnitTest | |||
| protected void EXPECT_NE(object expected, object actual, string msg = "") | |||
| => Assert.AreNotEqual(expected, actual, msg); | |||
| protected void CHECK_NE(object expected, object actual, string msg = "") | |||
| => Assert.AreNotEqual(expected, actual, msg); | |||
| protected void EXPECT_GE(int expected, int actual, string msg = "") | |||
| => Assert.IsTrue(expected >= actual, msg); | |||
| @@ -106,6 +109,15 @@ namespace TensorFlowNET.UnitTest | |||
| protected void TFE_DeleteContextOptions(IntPtr opts) | |||
| => c_api.TFE_DeleteContextOptions(opts); | |||
| protected int TFE_OpGetInputLength(IntPtr op, string input_name, IntPtr status) | |||
| => c_api.TFE_OpGetInputLength(op, input_name, status); | |||
| protected int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, IntPtr status) | |||
| => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | |||
| protected int TFE_OpGetOutputLength(IntPtr op, string input_name, IntPtr status) | |||
| => c_api.TFE_OpGetOutputLength(op, input_name, status); | |||
| protected void TFE_DeleteTensorHandle(IntPtr h) | |||
| => c_api.TFE_DeleteTensorHandle(h); | |||
| @@ -0,0 +1,64 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Tensorflow; | |||
| using Tensorflow.Eager; | |||
| using Buffer = System.Buffer; | |||
| namespace TensorFlowNET.UnitTest.Eager | |||
| { | |||
| public partial class CApiEagerTest | |||
| { | |||
| /// <summary> | |||
| /// TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) | |||
| /// </summary> | |||
| [TestMethod] | |||
| public unsafe void OpGetInputAndOutputLengths() | |||
| { | |||
| var status = TF_NewStatus(); | |||
| var opts = TFE_NewContextOptions(); | |||
| var ctx = TFE_NewContext(opts, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TFE_DeleteContextOptions(opts); | |||
| var input1 = TestMatrixTensorHandle(); | |||
| var input2 = TestMatrixTensorHandle(); | |||
| var identityOp = TFE_NewOp(ctx, "IdentityN", status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| // Try to retrieve lengths before building the attributes (should fail) | |||
| EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); | |||
| CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
| CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var inputs = new IntPtr[] { input1, input2 }; | |||
| TFE_OpAddInputList(identityOp, inputs, 2, status); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| // Try to retrieve lengths before executing the op (should work) | |||
| EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| var retvals = new IntPtr[2]; | |||
| int num_retvals = 2; | |||
| TFE_Execute(identityOp, retvals, ref num_retvals, status); | |||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| // Try to retrieve lengths after executing the op (should work) | |||
| EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||
| TF_DeleteStatus(status); | |||
| TFE_DeleteOp(identityOp); | |||
| TFE_DeleteTensorHandle(input1); | |||
| TFE_DeleteTensorHandle(input2); | |||
| TFE_DeleteTensorHandle(retvals[0]); | |||
| TFE_DeleteTensorHandle(retvals[1]); | |||
| TFE_DeleteContext(ctx); | |||
| } | |||
| } | |||
| } | |||