Browse Source

TestTFE_OpGetInputAndOutputLengths

tags/v0.20
Oceania2018 5 years ago
parent
commit
da41466dbe
3 changed files with 108 additions and 0 deletions
  1. +32
    -0
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  2. +12
    -0
      test/TensorFlowNET.UnitTest/CApiTest.cs
  3. +64
    -0
      test/TensorFlowNET.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs

+ 32
- 0
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -20,6 +20,38 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteContextOptions(IntPtr options);

/// <summary>
/// Returns the length (number of tensors) of the input argument `input_name`
/// found in the provided `op`.
/// </summary>
/// <param name="op">TFE_Op*</param>
/// <param name="input_name">const char*</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern int TFE_OpGetInputLength(IntPtr op, string input_name, IntPtr status);

/// <summary>
/// Returns the length (number of tensors) of the output argument `output_name`
/// found in the provided `op`.
/// </summary>
/// <param name="op"></param>
/// <param name="input_name"></param>
/// <param name="status"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TFE_OpGetOutputLength(IntPtr op, string input_name, IntPtr status);

/// <summary>
///
/// </summary>
/// <param name="op">TFE_Op*</param>
/// <param name="inputs">TFE_TensorHandle**</param>
/// <param name="num_inputs">int</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, IntPtr status);

/// <summary>
///
/// </summary>


+ 12
- 0
test/TensorFlowNET.UnitTest/CApiTest.cs View File

@@ -22,6 +22,9 @@ namespace TensorFlowNET.UnitTest
protected void EXPECT_NE(object expected, object actual, string msg = "")
=> Assert.AreNotEqual(expected, actual, msg);

protected void CHECK_NE(object expected, object actual, string msg = "")
=> Assert.AreNotEqual(expected, actual, msg);

protected void EXPECT_GE(int expected, int actual, string msg = "")
=> Assert.IsTrue(expected >= actual, msg);

@@ -106,6 +109,15 @@ namespace TensorFlowNET.UnitTest
protected void TFE_DeleteContextOptions(IntPtr opts)
=> c_api.TFE_DeleteContextOptions(opts);

protected int TFE_OpGetInputLength(IntPtr op, string input_name, IntPtr status)
=> c_api.TFE_OpGetInputLength(op, input_name, status);

protected int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, IntPtr status)
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status);

protected int TFE_OpGetOutputLength(IntPtr op, string input_name, IntPtr status)
=> c_api.TFE_OpGetOutputLength(op, input_name, status);

protected void TFE_DeleteTensorHandle(IntPtr h)
=> c_api.TFE_DeleteTensorHandle(h);



+ 64
- 0
test/TensorFlowNET.UnitTest/Eager/CApi.Eager.OpGetInputAndOutputLengths.cs View File

@@ -0,0 +1,64 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.Eager;
using Buffer = System.Buffer;

namespace TensorFlowNET.UnitTest.Eager
{
public partial class CApiEagerTest
{
/// <summary>
/// TEST(CAPI, TestTFE_OpGetInputAndOutputLengths)
/// </summary>
[TestMethod]
public unsafe void OpGetInputAndOutputLengths()
{
var status = TF_NewStatus();
var opts = TFE_NewContextOptions();
var ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
TFE_DeleteContextOptions(opts);

var input1 = TestMatrixTensorHandle();
var input2 = TestMatrixTensorHandle();
var identityOp = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

// Try to retrieve lengths before building the attributes (should fail)
EXPECT_EQ(-1, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(-1, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_NE(TF_OK, TF_GetCode(status), TF_Message(status));

var inputs = new IntPtr[] { input1, input2 };
TFE_OpAddInputList(identityOp, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

// Try to retrieve lengths before executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

var retvals = new IntPtr[2];
int num_retvals = 2;
TFE_Execute(identityOp, retvals, ref num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

// Try to retrieve lengths after executing the op (should work)
EXPECT_EQ(2, TFE_OpGetInputLength(identityOp, "input", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
EXPECT_EQ(2, TFE_OpGetOutputLength(identityOp, "output", status));
CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));

TF_DeleteStatus(status);
TFE_DeleteOp(identityOp);
TFE_DeleteTensorHandle(input1);
TFE_DeleteTensorHandle(input2);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(retvals[1]);
TFE_DeleteContext(ctx);
}
}
}

Loading…
Cancel
Save