Browse Source

tflite native api.

tags/TensorFlowOpLayer
Oceania2018 4 years ago
parent
commit
66bb81b121
10 changed files with 292 additions and 0 deletions
  1. +82
    -0
      src/TensorFlowNET.Core/APIs/c_api_lite.cs
  2. +26
    -0
      src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterHandle.cs
  3. +26
    -0
      src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterOptionsHandle.cs
  4. +26
    -0
      src/TensorFlowNET.Core/Lite/SafeTfLiteModelHandle.cs
  5. +12
    -0
      src/TensorFlowNET.Core/Lite/TfLiteQuantizationParams.cs
  6. +31
    -0
      src/TensorFlowNET.Core/Lite/TfLiteStatus.cs
  7. +21
    -0
      src/TensorFlowNET.Core/Lite/TfLiteTensor.cs
  8. +58
    -0
      test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs
  9. BIN
      test/TensorFlowNET.Native.UnitTest/Lite/testdata/add.bin
  10. +10
    -0
      test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj

+ 82
- 0
src/TensorFlowNET.Core/APIs/c_api_lite.cs View File

@@ -0,0 +1,82 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow.Lite;

namespace Tensorflow
{
public class c_api_lite
{
public const string TensorFlowLibName = "tensorflowlite_c";

public static string StringPiece(IntPtr handle)
{
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}

[DllImport(TensorFlowLibName)]
public static extern IntPtr TfLiteVersion();

[DllImport(TensorFlowLibName)]
public static extern SafeTfLiteModelHandle TfLiteModelCreateFromFile(string model_path);

[DllImport(TensorFlowLibName)]
public static extern void TfLiteModelDelete(IntPtr model);

[DllImport(TensorFlowLibName)]
public static extern SafeTfLiteInterpreterOptionsHandle TfLiteInterpreterOptionsCreate();

[DllImport(TensorFlowLibName)]
public static extern void TfLiteInterpreterOptionsDelete(IntPtr options);

[DllImport(TensorFlowLibName)]
public static extern void TfLiteInterpreterOptionsSetNumThreads(SafeTfLiteInterpreterOptionsHandle options, int num_threads);

[DllImport(TensorFlowLibName)]
public static extern SafeTfLiteInterpreterHandle TfLiteInterpreterCreate(SafeTfLiteModelHandle model, SafeTfLiteInterpreterOptionsHandle optional_options);

[DllImport(TensorFlowLibName)]
public static extern void TfLiteInterpreterDelete(IntPtr interpreter);

[DllImport(TensorFlowLibName)]
public static extern TfLiteStatus TfLiteInterpreterAllocateTensors(SafeTfLiteInterpreterHandle interpreter);

[DllImport(TensorFlowLibName)]
public static extern int TfLiteInterpreterGetInputTensorCount(SafeTfLiteInterpreterHandle interpreter);

[DllImport(TensorFlowLibName)]
public static extern int TfLiteInterpreterGetOutputTensorCount(SafeTfLiteInterpreterHandle interpreter);

[DllImport(TensorFlowLibName)]
public static extern TfLiteStatus TfLiteInterpreterResizeInputTensor(SafeTfLiteInterpreterHandle interpreter,
int input_index, int[] input_dims, int input_dims_size);

[DllImport(TensorFlowLibName)]
public static extern TfLiteTensor TfLiteInterpreterGetInputTensor(SafeTfLiteInterpreterHandle interpreter, int input_index);

[DllImport(TensorFlowLibName)]
public static extern TF_DataType TfLiteTensorType(TfLiteTensor tensor);

[DllImport(TensorFlowLibName)]
public static extern int TfLiteTensorNumDims(TfLiteTensor tensor);

[DllImport(TensorFlowLibName)]
public static extern int TfLiteTensorDim(TfLiteTensor tensor, int dim_index);

[DllImport(TensorFlowLibName)]
public static extern int TfLiteTensorByteSize(TfLiteTensor tensor);

[DllImport(TensorFlowLibName)]
public static extern IntPtr TfLiteTensorData(TfLiteTensor tensor);

[DllImport(TensorFlowLibName)]
public static extern IntPtr TfLiteTensorName(TfLiteTensor tensor);

[DllImport(TensorFlowLibName)]
public static extern TfLiteQuantizationParams TfLiteTensorQuantizationParams(TfLiteTensor tensor);

[DllImport(TensorFlowLibName)]
public static extern TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor tensor, IntPtr input_data, int input_data_size);
}
}

+ 26
- 0
src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterHandle.cs View File

@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Util;

namespace Tensorflow.Lite
{
public class SafeTfLiteInterpreterHandle : SafeTensorflowHandle
{
protected SafeTfLiteInterpreterHandle()
{
}

public SafeTfLiteInterpreterHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api_lite.TfLiteInterpreterDelete(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 26
- 0
src/TensorFlowNET.Core/Lite/SafeTfLiteInterpreterOptionsHandle.cs View File

@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Util;

namespace Tensorflow.Lite
{
public class SafeTfLiteInterpreterOptionsHandle : SafeTensorflowHandle
{
protected SafeTfLiteInterpreterOptionsHandle()
{
}

public SafeTfLiteInterpreterOptionsHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api_lite.TfLiteInterpreterOptionsDelete(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 26
- 0
src/TensorFlowNET.Core/Lite/SafeTfLiteModelHandle.cs View File

@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Util;

namespace Tensorflow.Lite
{
public class SafeTfLiteModelHandle : SafeTensorflowHandle
{
protected SafeTfLiteModelHandle()
{
}

public SafeTfLiteModelHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api_lite.TfLiteModelDelete(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 12
- 0
src/TensorFlowNET.Core/Lite/TfLiteQuantizationParams.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Lite
{
public struct TfLiteQuantizationParams
{
public float scale;
public int zero_point;
}
}

+ 31
- 0
src/TensorFlowNET.Core/Lite/TfLiteStatus.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Lite
{
public enum TfLiteStatus
{
kTfLiteOk = 0,

// Generally referring to an error in the runtime (i.e. interpreter)
kTfLiteError = 1,

// Generally referring to an error from a TfLiteDelegate itself.
kTfLiteDelegateError = 2,

// Generally referring to an error in applying a delegate due to
// incompatibility between runtime and delegate, e.g., this error is returned
// when trying to apply a TfLite delegate onto a model graph that's already
// immutable.
kTfLiteApplicationError = 3,

// Generally referring to serialized delegate data not being found.
// See tflite::delegates::Serialization.
kTfLiteDelegateDataNotFound = 4,

// Generally referring to data-writing issues in delegate serialization.
// See tflite::delegates::Serialization.
kTfLiteDelegateDataWriteError = 5,
}
}

+ 21
- 0
src/TensorFlowNET.Core/Lite/TfLiteTensor.cs View File

@@ -0,0 +1,21 @@
using System;

namespace Tensorflow.Lite
{
public struct TfLiteTensor
{
IntPtr _handle;

public TfLiteTensor(IntPtr handle)
=> _handle = handle;

public static implicit operator TfLiteTensor(IntPtr handle)
=> new TfLiteTensor(handle);

public static implicit operator IntPtr(TfLiteTensor tensor)
=> tensor._handle;

public override string ToString()
=> $"TfLiteTensor 0x{_handle.ToString("x16")}";
}
}

+ 58
- 0
test/TensorFlowNET.Native.UnitTest/Lite/TfLiteTest.cs View File

@@ -0,0 +1,58 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using Tensorflow.Lite;

namespace Tensorflow.Native.UnitTest
{
[TestClass]
public class TfLiteTest
{
[TestMethod]
public void TfLiteVersion()
{
var ver = c_api_lite.StringPiece(c_api_lite.TfLiteVersion());
Assert.IsNotNull(ver);
}

[TestMethod]
public void SmokeTest()
{
var model = c_api_lite.TfLiteModelCreateFromFile("Lite/testdata/add.bin");
var options = c_api_lite.TfLiteInterpreterOptionsCreate();
c_api_lite.TfLiteInterpreterOptionsSetNumThreads(options, 2);

var interpreter = c_api_lite.TfLiteInterpreterCreate(model, options);

c_api_lite.TfLiteInterpreterOptionsDelete(options.DangerousGetHandle());
c_api_lite.TfLiteModelDelete(model.DangerousGetHandle());

Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));
Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetInputTensorCount(interpreter));
Assert.AreEqual(1, c_api_lite.TfLiteInterpreterGetOutputTensorCount(interpreter));

var input_dims = new int[] { 2 };
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterResizeInputTensor(interpreter, 0, input_dims, input_dims.Length));
Assert.AreEqual(TfLiteStatus.kTfLiteOk, c_api_lite.TfLiteInterpreterAllocateTensors(interpreter));

var input_tensor = c_api_lite.TfLiteInterpreterGetInputTensor(interpreter, 0);
Assert.AreEqual(TF_DataType.TF_FLOAT, c_api_lite.TfLiteTensorType(input_tensor));
Assert.AreEqual(1, c_api_lite.TfLiteTensorNumDims(input_tensor));
Assert.AreEqual(2, c_api_lite.TfLiteTensorDim(input_tensor, 0));
Assert.AreEqual(sizeof(float) * 2, c_api_lite.TfLiteTensorByteSize(input_tensor));
Assert.IsNotNull(c_api_lite.TfLiteTensorData(input_tensor));
Assert.AreEqual("input", c_api_lite.StringPiece(c_api_lite.TfLiteTensorName(input_tensor)));

var input_params = c_api_lite.TfLiteTensorQuantizationParams(input_tensor);
Assert.AreEqual(0f, input_params.scale);
Assert.AreEqual(0, input_params.zero_point);

var input = new[] { 1f, 3f };
// c_api_lite.TfLiteTensorCopyFromBuffer(input_tensor, input, 2 * sizeof(float));
}
}
}

BIN
test/TensorFlowNET.Native.UnitTest/Lite/testdata/add.bin View File


+ 10
- 0
test/TensorFlowNET.Native.UnitTest/Tensorflow.Native.UnitTest.csproj View File

@@ -24,6 +24,16 @@
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
<None Remove="Lite\testdata\add.bin" />
</ItemGroup>

<ItemGroup>
<Content Include="Lite\testdata\add.bin">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
</ItemGroup>

<ItemGroup>
<PackageReference Include="FluentAssertions" Version="5.10.3" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0-release-20210626-04" />


Loading…
Cancel
Save