Browse Source

TfLiteInterpreterInvoke

tags/TensorFlowOpLayer
Oceania2018 4 years ago
parent
commit
f19808edfc
2 changed files with 40 additions and 2 deletions
  1. +9
    -0
      src/TensorFlowNET.Core/APIs/c_api_lite.cs
  2. +31
    -2
      test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs

+ 9
- 0
src/TensorFlowNET.Core/APIs/c_api_lite.cs View File

@@ -78,5 +78,14 @@ namespace Tensorflow

[DllImport(TensorFlowLibName)]
public static extern TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor tensor, IntPtr input_data, int input_data_size);

[DllImport(TensorFlowLibName)]
public static extern TfLiteStatus TfLiteInterpreterInvoke(SafeTfLiteInterpreterHandle interpreter);

[DllImport(TensorFlowLibName)]
public static extern IntPtr TfLiteInterpreterGetOutputTensor(SafeTfLiteInterpreterHandle interpreter, int output_index);

[DllImport(TensorFlowLibName)]
public static extern TfLiteStatus TfLiteTensorCopyToBuffer(TfLiteTensor output_tensor, IntPtr output_data, int output_data_size);
}
}

+ 31
- 2
test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow.Native.UnitTest
}

[TestMethod]
public void SmokeTest()
public unsafe void SmokeTest()
{
var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add.bin");
var options = c_api_lite.TfLiteInterpreterOptionsCreate();
@@ -52,7 +52,36 @@ namespace Tensorflow.Native.UnitTest
Assert.AreEqual(0, input_params.zero_point);

var input = new[] { 1f, 3f };
// c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, input, 2 * sizeof(float));
fixed (float* addr = &input[0])
{
Assert.AreEqual(TfLiteStatus.kTfLiteOk,
c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(float)));
}

Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter));

var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0);
Assert.AreEqual(TF_DataType.TF_FLOAT, c_api_lite.TfLiteTensorType(output_tensor));
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(output_tensor));
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(output_tensor, 0));
Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(output_tensor));
Assert.IsNotNull(c_api_lite.TfLiteTensorData(output_tensor));
Assert.AreEqual("output", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(output_tensor)));

var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor);
Assert.AreEqual(0f, output_params.scale);
Assert.AreEqual(0, output_params.zero_point);

var output = new float[2];
fixed (float* addr = &output[0])
{
Assert.AreEqual(TfLiteStatus.kTfLiteOk,
c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(float)));
}
Assert.AreEqual(3f, output[0]);
Assert.AreEqual(9f, output[1]);

c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle());
}
}
}

Loading…
Cancel
Save