| @@ -20,6 +20,38 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_DeleteContextOptions(IntPtr options); | public static extern void TFE_DeleteContextOptions(IntPtr options); | ||||
| /// <summary> | |||||
| /// Returns the length (number of tensors) of the input argument `input_name` | |||||
| /// found in the provided `op`. | |||||
| /// </summary> | |||||
| /// <param name="op">TFE_Op*</param> | |||||
| /// <param name="input_name">const char*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TFE_OpGetInputLength(IntPtr op, string input_name, IntPtr status); | |||||
| /// <summary> | |||||
| /// Returns the length (number of tensors) of the output argument `output_name` | |||||
| /// found in the provided `op`. | |||||
| /// </summary> | |||||
| /// <param name="op"></param> | |||||
| /// <param name="input_name"></param> | |||||
| /// <param name="status"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TFE_OpGetOutputLength(IntPtr op, string input_name, IntPtr status); | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="op">TFE_Op*</param> | |||||
| /// <param name="inputs">TFE_TensorHandle**</param> | |||||
| /// <param name="num_inputs">int</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -22,6 +22,9 @@ namespace TensorFlowNET.UnitTest | |||||
| protected void EXPECT_NE(object expected, object actual, string msg = "") | protected void EXPECT_NE(object expected, object actual, string msg = "") | ||||
| => Assert.AreNotEqual(expected, actual, msg); | => Assert.AreNotEqual(expected, actual, msg); | ||||
| protected void CHECK_NE(object expected, object actual, string msg = "") | |||||
| => Assert.AreNotEqual(expected, actual, msg); | |||||
| protected void EXPECT_GE(int expected, int actual, string msg = "") | protected void EXPECT_GE(int expected, int actual, string msg = "") | ||||
| => Assert.IsTrue(expected >= actual, msg); | => Assert.IsTrue(expected >= actual, msg); | ||||
| @@ -106,6 +109,15 @@ namespace TensorFlowNET.UnitTest | |||||
| protected void TFE_DeleteContextOptions(IntPtr opts) | protected void TFE_DeleteContextOptions(IntPtr opts) | ||||
| => c_api.TFE_DeleteContextOptions(opts); | => c_api.TFE_DeleteContextOptions(opts); | ||||
| protected int TFE_OpGetInputLength(IntPtr op, string input_name, IntPtr status) | |||||
| => c_api.TFE_OpGetInputLength(op, input_name, status); | |||||
| protected int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, IntPtr status) | |||||
| => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | |||||
| protected int TFE_OpGetOutputLength(IntPtr op, string input_name, IntPtr status) | |||||
| => c_api.TFE_OpGetOutputLength(op, input_name, status); | |||||
| protected void TFE_DeleteTensorHandle(IntPtr h) | protected void TFE_DeleteTensorHandle(IntPtr h) | ||||
| => c_api.TFE_DeleteTensorHandle(h); | => c_api.TFE_DeleteTensorHandle(h); | ||||
| @@ -0,0 +1,64 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Eager; | |||||
| using Buffer = System.Buffer; | |||||
| namespace TensorFlowNET.UnitTest.Eager | |||||
| { | |||||
| public partial class CApiEagerTest | |||||
| { | |||||
| /// <summary> | |||||
| /// TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public unsafe void OpGetInputAndOutputLengths() | |||||
| { | |||||
| var status = TF_NewStatus(); | |||||
| var opts = TFE_NewContextOptions(); | |||||
| var ctx = TFE_NewContext(opts, status); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TFE_DeleteContextOptions(opts); | |||||
| var input1 = TestMatrixTensorHandle(); | |||||
| var input2 = TestMatrixTensorHandle(); | |||||
| var identityOp = TFE_NewOp(ctx, "IdentityN", status); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| // Try to retrieve lengths before building the attributes (should fail) | |||||
| EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status)); | |||||
| CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status)); | |||||
| CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| var inputs = new IntPtr[] { input1, input2 }; | |||||
| TFE_OpAddInputList(identityOp, inputs, 2, status); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| // Try to retrieve lengths before executing the op (should work) | |||||
| EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| var retvals = new IntPtr[2]; | |||||
| int num_retvals = 2; | |||||
| TFE_Execute(identityOp, retvals, ref num_retvals, status); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| // Try to retrieve lengths after executing the op (should work) | |||||
| EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status)); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status)); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TF_DeleteStatus(status); | |||||
| TFE_DeleteOp(identityOp); | |||||
| TFE_DeleteTensorHandle(input1); | |||||
| TFE_DeleteTensorHandle(input2); | |||||
| TFE_DeleteTensorHandle(retvals[0]); | |||||
| TFE_DeleteTensorHandle(retvals[1]); | |||||
| TFE_DeleteContext(ctx); | |||||
| } | |||||
| } | |||||
| } | |||||