diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index e3d7f6ae..6017d510 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -25,9 +25,11 @@ https://tensorflownet.readthedocs.io * Eager Mode is added finally. * tf.keras is partially working. * tf.data is added. -* autograph works partially. +* Autograph works partially. +* Improve memory usage. -TensorFlow .NET v0.3x is focused on making more Keras API works +TensorFlow .NET v0.3x is focused on making more Keras API works. +Keras API is a separate package released as TensorFlow.Keras. 0.33.0.0 LICENSE true @@ -83,7 +85,7 @@ TensorFlow .NET v0.3x is focused on making more Keras API works - + diff --git a/src/TensorFlowNET.Core/Tensors/TF_TString_Type.cs b/src/TensorFlowNET.Core/Tensors/TF_TString_Type.cs new file mode 100644 index 00000000..233b16e5 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/TF_TString_Type.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public enum TF_TString_Type + { + TF_TSTR_SMALL = 0, + TF_TSTR_LARGE = 1, + TF_TSTR_OFFSET = 2, + TF_TSTR_VIEW = 3 + } +} diff --git a/src/TensorFlowNET.Core/Tensors/TStringHandle.cs b/src/TensorFlowNET.Core/Tensors/TStringHandle.cs new file mode 100644 index 00000000..13077ec3 --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/TStringHandle.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Util; + +namespace Tensorflow +{ + public class TStringHandle : SafeTensorflowHandle + { + protected override bool ReleaseHandle() + { + c_api.TF_StringDealloc(handle); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs index e331dc1a..abe07c75 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs @@ -8,7 +8,29 @@ namespace Tensorflow { public partial class Tensor { - public unsafe IntPtr StringTensor(string[] strings, TensorShape shape) + const ulong TF_TSRING_SIZE = 24; + + public IntPtr StringTensor25(string[] strings, TensorShape shape) + { + var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, + shape.dims.Select(x => (long)x).ToArray(), + shape.ndim, + (ulong)shape.size * TF_TSRING_SIZE); + + var data = c_api.TF_TensorData(handle); + var tstr = c_api.TF_StringInit(handle); + // AllocationHandle = tstr; + // AllocationType = AllocationType.Tensorflow; + for (int i = 0; i< strings.Length; i++) + { + c_api.TF_StringCopy(tstr, strings[i], strings[i].Length); + tstr += (int)TF_TSRING_SIZE; + } + // c_api.TF_StringDealloc(tstr); + return handle; + } + + public IntPtr StringTensor(string[] strings, TensorShape shape) { // convert string array to byte[][] var buffer = new byte[strings.Length][]; @@ -61,11 +83,27 @@ namespace Tensorflow return handle; } + public string[] StringData25() + { + string[] strings = new string[c_api.TF_Dim(_handle, 0)]; + var tstrings = TensorDataPointer; + for (int i = 0; i< strings.Length; i++) + { + var tstringData = c_api.TF_StringGetDataPointer(tstrings); + /*var size = c_api.TF_StringGetSize(tstrings); + var capacity = c_api.TF_StringGetCapacity(tstrings); + var type = c_api.TF_StringGetType(tstrings);*/ + strings[i] = c_api.StringPiece(tstringData); + tstrings += (int)TF_TSRING_SIZE; + } + return strings; + } + /// /// Extracts string array from current Tensor. /// /// When != TF_DataType.TF_STRING - public unsafe string[] StringData() + public string[] StringData() { var buffer = StringBytes(); diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index d5efb75d..0fd2527e 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -181,6 +181,30 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern unsafe ulong TF_StringEncode(byte* src, ulong src_len, byte* dst, ulong dst_len, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_StringInit(IntPtr t); + + [DllImport(TensorFlowLibName)] + public static extern void TF_StringCopy(IntPtr dst, string text, long size); + + [DllImport(TensorFlowLibName)] + public static extern void TF_StringAssignView(IntPtr dst, IntPtr text, long size); + + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_StringGetDataPointer(IntPtr tst); + + [DllImport(TensorFlowLibName)] + public static extern TF_TString_Type TF_StringGetType(IntPtr tst); + + [DllImport(TensorFlowLibName)] + public static extern ulong TF_StringGetSize(IntPtr tst); + + [DllImport(TensorFlowLibName)] + public static extern ulong TF_StringGetCapacity(IntPtr tst); + + [DllImport(TensorFlowLibName)] + public static extern void TF_StringDealloc(IntPtr tst); + /// /// Decode a string encoded using TF_StringEncode. /// diff --git a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs index b7a208e4..7f1591e9 100644 --- a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs @@ -107,6 +107,32 @@ namespace Tensorflow.Native.UnitTest.Tensors Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), new float[] { 1, 2, 3, 4, 5, 6 })); } + /// + /// Port from c_api_test.cc + /// `TEST_F(CApiAttributesTest, StringTensor)` + /// + [TestMethod, Ignore("Waiting for PR https://github.com/tensorflow/tensorflow/pull/46804")] + public void StringTensor() + { + string text = "Hello world!."; + + var tensor = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, + null, + 0, + 1 * 24); + var tstr = c_api.TF_StringInit(tensor); + var data = c_api.TF_StringGetDataPointer(tstr); + c_api.TF_StringCopy(tstr, text, text.Length); + + Assert.AreEqual((ulong)text.Length, c_api.TF_StringGetSize(tstr)); + Assert.AreEqual(text, c_api.StringPiece(data)); + Assert.AreEqual((ulong)text.Length, c_api.TF_TensorByteSize(tensor)); + Assert.AreEqual(0, c_api.TF_NumDims(tensor)); + + TF_DeleteTensor(tensor); + c_api.TF_StringDealloc(tstr); + } + /// /// Port from tensorflow\c\c_api_test.cc /// `TEST(CAPI, SetShape)`