diff --git a/src/TensorFlowNET.Core/APIs/c_api_lite.cs b/src/TensorFlowNET.Core/APIs/c_api_lite.cs
index 52373988..5a437d26 100644
--- a/src/TensorFlowNET.Core/APIs/c_api_lite.cs
+++ b/src/TensorFlowNET.Core/APIs/c_api_lite.cs
@@ -56,7 +56,7 @@ namespace Tensorflow
public static extern TfLiteTensor TfLiteInterpreterGetInputTensor(SafeTfLiteInterpreterHandle interpreter, int input_index);
[DllImport(TensorFlowLibName)]
- public static extern TF_DataType TfLiteTensorType(TfLiteTensor tensor);
+ public static extern TfLiteDataType TfLiteTensorType(TfLiteTensor tensor);
[DllImport(TensorFlowLibName)]
public static extern int TfLiteTensorNumDims(TfLiteTensor tensor);
diff --git a/src/TensorFlowNET.Core/Lite/TfLiteDataType.cs b/src/TensorFlowNET.Core/Lite/TfLiteDataType.cs
new file mode 100644
index 00000000..7b3aa102
--- /dev/null
+++ b/src/TensorFlowNET.Core/Lite/TfLiteDataType.cs
@@ -0,0 +1,27 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Lite
+{
+ public enum TfLiteDataType
+ {
+ kTfLiteNoType = 0,
+ kTfLiteFloat32 = 1,
+ kTfLiteInt32 = 2,
+ kTfLiteUInt8 = 3,
+ kTfLiteInt64 = 4,
+ kTfLiteString = 5,
+ kTfLiteBool = 6,
+ kTfLiteInt16 = 7,
+ kTfLiteComplex64 = 8,
+ kTfLiteInt8 = 9,
+ kTfLiteFloat16 = 10,
+ kTfLiteFloat64 = 11,
+ kTfLiteComplex128 = 12,
+ kTfLiteUInt64 = 13,
+ kTfLiteResource = 14,
+ kTfLiteVariant = 15,
+ kTfLiteUInt32 = 16,
+ }
+}
diff --git a/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs b/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs
index a1c95ec0..e1665557 100644
--- a/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs
@@ -40,7 +40,7 @@ namespace Tensorflow.Native.UnitTest
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));
var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0);
- Assert.AreEqual(TF_DataType.TF_FLOAT, c_api_lite.TfLiteTensorType(input_tensor));
+ Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(input_tensor));
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor));
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0));
Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(input_tensor));
@@ -61,7 +61,7 @@ namespace Tensorflow.Native.UnitTest
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter));
var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0);
- Assert.AreEqual(TF_DataType.TF_FLOAT, c_api_lite.TfLiteTensorType(output_tensor));
+ Assert.AreEqual(TfLiteDataType.kTfLiteFloat32, c_api_lite.TfLiteTensorType(output_tensor));
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(output_tensor));
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(output_tensor, 0));
Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(output_tensor));
@@ -83,5 +83,56 @@ namespace Tensorflow.Native.UnitTest
c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle());
}
+
+ [TestMethod]
+ public unsafe void QuantizationParamsTest()
+ {
+ var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add_quantized.bin");
+ var interpreter = c_api_lite.TfLiteInterpreterCreate(model, new SafeTfLiteInterpreterOptionsHandle(IntPtr.Zero));
+ c_api_lite.TfLiteModelDelete(model.DangerousGetHandle());
+ var input_dims = new[] { 2 };
+ Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, 1));
+ Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));
+
+ var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0);
+ Assert.IsNotNull(input_tensor);
+
+ Assert.AreEqual(TfLiteDataType.kTfLiteUInt8, c_api_lite.TfLiteTensorType(input_tensor));
+ Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor));
+ Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0));
+
+ var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor);
+ Assert.AreEqual((0.003922f, 0), (input_params.scale, input_params.zero_point));
+
+ var input = new byte[] { 1, 3 };
+ fixed (byte* addr = &input[0])
+ {
+ Assert.AreEqual(TfLiteStatus.kTfLiteOk,
+ c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, new IntPtr(addr), 2 * sizeof(byte)));
+ }
+ Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterInvoke(interpreter));
+
+ var output_tensor = c_api_lite.TfLiteInterpreterGetOutputTensor(interpreter, 0);
+ Assert.IsNotNull(output_tensor);
+
+ var output_params = c_api_lite.TfLiteTensorQuantizationParams(output_tensor);
+ Assert.AreEqual((0.003922f, 0), (output_params.scale, output_params.zero_point));
+
+ var output = new byte[2];
+ fixed (byte* addr = &output[0])
+ {
+ Assert.AreEqual(TfLiteStatus.kTfLiteOk,
+ c_api_lite.TfLiteTensorCopyToBuffer(output_tensor, new IntPtr(addr), 2 * sizeof(byte)));
+ }
+ Assert.AreEqual(3f, output[0]);
+ Assert.AreEqual(9f, output[1]);
+
+ var dequantizedOutput0 = output_params.scale * (output[0] - output_params.zero_point);
+ var dequantizedOutput1 = output_params.scale * (output[1] - output_params.zero_point);
+ Assert.AreEqual(dequantizedOutput0, 0.011766f);
+ Assert.AreEqual(dequantizedOutput1, 0.035298f);
+
+ c_api_lite.TfLiteInterpreterDelete(interpreter.DangerousGetHandle());
+ }
}
}
diff --git a/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add_quantized.bin b/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add_quantized.bin
new file mode 100644
index 00000000..07d48b93
Binary files /dev/null and b/test/TensorFlowNET.Native.UnitTest/Lite/testdata/add_quantized.bin differ
diff --git a/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj
index 7c0b4288..3d093e1a 100644
--- a/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj
+++ b/test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj
@@ -26,12 +26,16 @@
+
PreserveNewest
+
+ PreserveNewest
+