diff --git a/src/TensorFlowNET.Core/APIs/c_api_lite.cs b/src/TensorFlowNET.Core/APIs/c_api_lite.cs index 45ead620..52373988 100644 --- a/src/TensorFlowNET.Core/APIs/c_api_lite.cs +++ b/src/TensorFlowNET.Core/APIs/c_api_lite.cs @@ -78,5 +78,14 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor tensor, IntPtr input_data, int input_data_size); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteStatus TfLiteInterpreterInvoke(SafeTfLiteInterpreterHandle interpreter); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TfLiteInterpreterGetOutputTensor(SafeTfLiteInterpreterHandle interpreter, int output_index); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteStatus TfLiteTensorCopyToBuffer(TfLiteTensor output_tensor, IntPtr output_data, int output_data_size); } } diff --git a/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs b/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs index abf41af5..a1c95ec0 100644 --- a/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs @@ -20,7 +20,7 @@ namespace Tensorflow.Native.UnitTest } [TestMethod] - public void SmokeTest() + public unsafe void SmokeTest() { var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add.bin"); var options = c_api_lite.TfLiteInterpreterOptionsCreate(); @@ -52,7 +52,36 @@ namespace Tensorflow.Native.UnitTest Assert.AreEqual(0, input_params.zero_point); var input = new[] { 1f, 3f }; - // c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, input, 2 * sizeof(float)); + fixed (float* addr = &input[0]) + { + Assert.AreEqual(TfLiteStatus.kTfLiteOk, + c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(float))); + } + + Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter)); + + var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0); + Assert.AreEqual(TF_DataType.TF_FLOAT, c_api_lite.TfLiteTensorType(output_tensor)); + Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(output_tensor)); + Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(output_tensor, 0)); + Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(output_tensor)); + Assert.IsNotNull(c_api_lite.TfLiteTensorData(output_tensor)); + Assert.AreEqual("output", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(output_tensor))); + + var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor); + Assert.AreEqual(0f, output_params.scale); + Assert.AreEqual(0, output_params.zero_point); + + var output = new float[2]; + fixed (float* addr = &output[0]) + { + Assert.AreEqual(TfLiteStatus.kTfLiteOk, + c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(float))); + } + Assert.AreEqual(3f, output[0]); + Assert.AreEqual(9f, output[1]); + + c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle()); } } }