diff --git a/src/TensorFlowNET.Core/APIs/c_api_lite.cs b/src/TensorFlowNET.Core/APIs/c_api_lite.cs new file mode 100644 index 00000000..45ead620 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/c_api_lite.cs @@ -0,0 +1,82 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; +using Tensorflow.Lite; + +namespace Tensorflow +{ + public class c_api_lite + { + public const string TensorFlowLibName = "tensorflowlite_c"; + + public static string StringPiece(IntPtr handle) + { + return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); + } + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TfLiteVersion(); + + [DllImport(TensorFlowLibName)] + public static extern SafeTfLiteModelHandle TfLiteModelCreateFromFile(string model_path); + + [DllImport(TensorFlowLibName)] + public static extern void TfLiteModelDelete(IntPtr model); + + [DllImport(TensorFlowLibName)] + public static extern SafeTfLiteInterpreterOptionsHandle TfLiteInterpreterOptionsCreate(); + + [DllImport(TensorFlowLibName)] + public static extern void TfLiteInterpreterOptionsDelete(IntPtr options); + + [DllImport(TensorFlowLibName)] + public static extern void TfLiteInterpreterOptionsSetNumThreads(SafeTfLiteInterpreterOptionsHandle options, int num_threads); + + [DllImport(TensorFlowLibName)] + public static extern SafeTfLiteInterpreterHandle TfLiteInterpreterCreate(SafeTfLiteModelHandle model, SafeTfLiteInterpreterOptionsHandle optional_options); + + [DllImport(TensorFlowLibName)] + public static extern void TfLiteInterpreterDelete(IntPtr interpreter); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteStatus TfLiteInterpreterAllocateTensors(SafeTfLiteInterpreterHandle interpreter); + + [DllImport(TensorFlowLibName)] + public static extern int TfLiteInterpreterGetInputTensorCount(SafeTfLiteInterpreterHandle interpreter); + + [DllImport(TensorFlowLibName)] + public static extern int TfLiteInterpreterGetOutputTensorCount(SafeTfLiteInterpreterHandle interpreter); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteStatus TfLiteInterpreterResizeInputTensor(SafeTfLiteInterpreterHandle interpreter, + int input_index, int[] input_dims, int input_dims_size); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteTensor TfLiteInterpreterGetInputTensor(SafeTfLiteInterpreterHandle interpreter, int input_index); + + [DllImport(TensorFlowLibName)] + public static extern TF_DataType TfLiteTensorType(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern int TfLiteTensorNumDims(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern int TfLiteTensorDim(TfLiteTensor tensor, int dim_index); + + [DllImport(TensorFlowLibName)] + public static extern int TfLiteTensorByteSize(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TfLiteTensorData(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TfLiteTensorName(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteQuantizationParams TfLiteTensorQuantizationParams(TfLiteTensor tensor); + + [DllImport(TensorFlowLibName)] + public static extern TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor tensor, IntPtr input_data, int input_data_size); + } +} diff --git a/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterHandle.cs b/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterHandle.cs new file mode 100644 index 00000000..feb65711 --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterHandle.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Util; + +namespace Tensorflow.Lite +{ + public class SafeTfLiteInterpreterHandle : SafeTensorflowHandle + { + protected SafeTfLiteInterpreterHandle() + { + } + + public SafeTfLiteInterpreterHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api_lite.TfLiteInterpreterDelete(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterOptionsHandle.cs b/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterOptionsHandle.cs new file mode 100644 index 00000000..72893646 --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterOptionsHandle.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Util; + +namespace Tensorflow.Lite +{ + public class SafeTfLiteInterpreterOptionsHandle : SafeTensorflowHandle + { + protected SafeTfLiteInterpreterOptionsHandle() + { + } + + public SafeTfLiteInterpreterOptionsHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api_lite.TfLiteInterpreterOptionsDelete(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Lite/SafeTfLiteModelHandle.cs b/src/TensorFlowNET.Core/Lite/SafeTfLiteModelHandle.cs new file mode 100644 index 00000000..bdae1543 --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/SafeTfLiteModelHandle.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Util; + +namespace Tensorflow.Lite +{ + public class SafeTfLiteModelHandle : SafeTensorflowHandle + { + protected SafeTfLiteModelHandle() + { + } + + public SafeTfLiteModelHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api_lite.TfLiteModelDelete(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Lite/TfLiteQuantizationParams.cs b/src/TensorFlowNET.Core/Lite/TfLiteQuantizationParams.cs new file mode 100644 index 00000000..e564392c --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/TfLiteQuantizationParams.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Lite +{ + public struct TfLiteQuantizationParams + { + public float scale; + public int zero_point; + } +} diff --git a/src/TensorFlowNET.Core/Lite/TfLiteStatus.cs b/src/TensorFlowNET.Core/Lite/TfLiteStatus.cs new file mode 100644 index 00000000..06612125 --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/TfLiteStatus.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Lite +{ + public enum TfLiteStatus + { + kTfLiteOk = 0, + + // Generally referring to an error in the runtime (i.e. interpreter) + kTfLiteError = 1, + + // Generally referring to an error from a TfLiteDelegate itself. + kTfLiteDelegateError = 2, + + // Generally referring to an error in applying a delegate due to + // incompatibility between runtime and delegate, e.g., this error is returned + // when trying to apply a TfLite delegate onto a model graph that's already + // immutable. + kTfLiteApplicationError = 3, + + // Generally referring to serialized delegate data not being found. + // See tflite::delegates::Serialization. + kTfLiteDelegateDataNotFound = 4, + + // Generally referring to data-writing issues in delegate serialization. + // See tflite::delegates::Serialization. + kTfLiteDelegateDataWriteError = 5, + } +} diff --git a/src/TensorFlowNET.Core/Lite/TfLiteTensor.cs b/src/TensorFlowNET.Core/Lite/TfLiteTensor.cs new file mode 100644 index 00000000..5a43f58f --- /dev/null +++ b/src/TensorFlowNET.Core/Lite/TfLiteTensor.cs @@ -0,0 +1,21 @@ +using System; + +namespace Tensorflow.Lite +{ + public struct TfLiteTensor + { + IntPtr _handle; + + public TfLiteTensor(IntPtr handle) + => _handle = handle; + + public static implicit operator TfLiteTensor(IntPtr handle) + => new TfLiteTensor(handle); + + public static implicit operator IntPtr(TfLiteTensor tensor) + => tensor._handle; + + public override string ToString() + => $"TfLiteTensor 0x{_handle.ToString("x16")}"; + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs b/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs new file mode 100644 index 00000000..abf41af5 --- /dev/null +++ b/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs @@ -0,0 +1,58 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; +using Tensorflow.Lite; + +namespace Tensorflow.Native.UnitTest +{ + [TestClass] + public class TfLiteTest + { + [TestMethod] + public void TfLiteVersion() + { + var ver = c_api_lite.StringPiece(c_api_lite.TfLiteVersion()); + Assert.IsNotNull(ver); + } + + [TestMethod] + public void SmokeTest() + { + var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add.bin"); + var options = c_api_lite.TfLiteInterpreterOptionsCreate(); + c_api_lite.TfLiteInterpreterOptionsSetNumThreads(options, 2); + + var interpreter = c_api_lite.TfLiteInterpreterCreate(model, options); + + c_api_lite.TfLiteInterpreterOptionsDelete(options.DangerousGetHandle()); + c_api_lite.TfLiteModelDelete(model.DangerousGetHandle()); + + Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter)); + Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetInputTensorCount(interpreter)); + Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetOutputTensorCount(interpreter)); + + var input_dims = new int[] { 2 }; + Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, input_dims.Length)); + Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter)); + + var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0); + Assert.AreEqual(TF_DataType.TF_FLOAT, c_api_lite.TfLiteTensorType(input_tensor)); + Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor)); + Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0)); + Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(input_tensor)); + Assert.IsNotNull(c_api_lite.TfLiteTensorData(input_tensor)); + Assert.AreEqual("input", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(input_tensor))); + + var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor); + Assert.AreEqual(0f, input_params.scale); + Assert.AreEqual(0, input_params.zero_point); + + var input = new[] { 1f, 3f }; + // c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, input, 2 * sizeof(float)); + } + } +} diff --git a/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add.bin b/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add.bin new file mode 100644 index 00000000..b4c02350 Binary files /dev/null and b/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add.bin differ diff --git a/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj index 39678f9f..7c0b4288 100644 --- a/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj +++ b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj @@ -24,6 +24,16 @@ true + + + + + + + PreserveNewest + + +