diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index c6dd3fd1..c3368120 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -21,6 +21,9 @@ using System.Collections.Generic; using System.Linq; using System.Numerics; using System.Text; +using Google.Protobuf; +using NumSharp.Backends; +using Tensorflow.Util; namespace Tensorflow { @@ -246,111 +249,167 @@ namespace Tensorflow return result; } - private unsafe NDArray fetchValue(IntPtr output) + private static unsafe NDArray fetchValue(IntPtr output) { - var tensor = new Tensor(output); - NDArray nd = null; - Type type = tensor.dtype.as_numpy_dtype(); - var ndims = tensor.shape; - var offset = (byte*) c_api.TF_TensorData(output); - - if(ndims.Length == 0) + NDArray ret; + using (var tensor = new Tensor(output)) { - switch (tensor.dtype) + var ndims = tensor.shape; + var srcAddress = c_api.TF_TensorData(output).ToInt64(); + + if (ndims.Length == 0) { - case TF_DataType.TF_BOOL: - nd = NDArray.Scalar(*(bool*)offset); - break; - case TF_DataType.TF_STRING: - var bytes = tensor.BufferToArray(); - // wired, don't know why we have to start from offset 9. - // length in the begin - var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); - nd = NDArray.FromString(str); - break; - case TF_DataType.TF_UINT8: - nd = NDArray.Scalar(*(byte*)offset); - break; - case TF_DataType.TF_INT16: - nd = NDArray.Scalar(*(short*)offset); - break; - case TF_DataType.TF_INT32: - nd = NDArray.Scalar(*(int*)offset); - break; - case TF_DataType.TF_INT64: - nd = NDArray.Scalar(*(long*)offset); - break; - case TF_DataType.TF_FLOAT: - nd = NDArray.Scalar(*(float*)offset); - break; - case TF_DataType.TF_DOUBLE: - nd = NDArray.Scalar(*(double*)offset); - break; - default: - throw new NotImplementedException("can't fetch output"); - } - } - else - { - switch (tensor.dtype) + switch (tensor.dtype) + { + case TF_DataType.TF_BOOL: + ret = NDArray.Scalar(*(bool*) srcAddress); + break; + case TF_DataType.TF_STRING: + using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) + ret = NDArray.FromString(reader.ReadString()); + break; + case TF_DataType.TF_UINT8: + ret = NDArray.Scalar(*(byte*) srcAddress); + break; + case TF_DataType.TF_INT16: + ret = NDArray.Scalar(*(short*) srcAddress); + break; + case TF_DataType.TF_INT32: + ret = NDArray.Scalar(*(int*) srcAddress); + break; + case TF_DataType.TF_INT64: + ret = NDArray.Scalar(*(long*) srcAddress); + break; + case TF_DataType.TF_UINT16: + ret = NDArray.Scalar(*(ushort*) srcAddress); + break; + case TF_DataType.TF_UINT32: + ret = NDArray.Scalar(*(uint*) srcAddress); + break; + case TF_DataType.TF_UINT64: + ret = NDArray.Scalar(*(ulong*) srcAddress); + break; + case TF_DataType.TF_FLOAT: + ret = NDArray.Scalar(*(float*) srcAddress); + break; + case TF_DataType.TF_DOUBLE: + ret = NDArray.Scalar(*(double*) srcAddress); + break; + default: + throw new NotImplementedException("can't fetch output"); + } + } else { - case TF_DataType.TF_BOOL: - var bools = new bool[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(bools).reshape(ndims); - break; - case TF_DataType.TF_STRING: - var bytes = tensor.BufferToArray(); - // wired, don't know why we have to start from offset 9. - // length in the begin - var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); - nd = np.array(str); - break; - case TF_DataType.TF_UINT8: - var _bytes = new byte[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(_bytes).reshape(ndims); - break; - case TF_DataType.TF_INT16: - var shorts = new short[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(shorts).reshape(ndims); - break; - case TF_DataType.TF_INT32: - var ints = new int[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(ints).reshape(ndims); - break; - case TF_DataType.TF_INT64: - var longs = new long[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(longs).reshape(ndims); - break; - case TF_DataType.TF_FLOAT: - var floats = new float[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(floats).reshape(ndims); - break; - case TF_DataType.TF_DOUBLE: - var doubles = new double[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(doubles).reshape(ndims); - break; - default: - throw new NotImplementedException("can't fetch output"); + //var size = (long) tensor.size; + //var itemsize = (long) tensor.itemsize; + var bytesize = (long) tensor.bytesize; + var src = (void*) srcAddress; + +#if _REGEN + #region Compute + switch (tensor.dtype) + { + %foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")% + case TF_DataType.#3: + { + ret = new NDArray(NPTypeCode.#1, ndims, false); + System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize); + break; + } + % + case TF_DataType.TF_STRING: + { + //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString + using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) + ret = NDArray.FromString(reader.ReadString()); + break; + } + default: + throw new NotSupportedException(); + } + #endregion +#else + + #region Compute + switch (tensor.dtype) + { + case TF_DataType.TF_BOOL: + { + ret = new NDArray(NPTypeCode.Boolean, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + case TF_DataType.TF_UINT8: + { + ret = new NDArray(NPTypeCode.Byte, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + case TF_DataType.TF_INT16: + { + ret = new NDArray(NPTypeCode.Int16, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + case TF_DataType.TF_UINT16: + { + ret = new NDArray(NPTypeCode.UInt16, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + case TF_DataType.TF_INT32: + { + ret = new NDArray(NPTypeCode.Int32, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + case TF_DataType.TF_UINT32: + { + ret = new NDArray(NPTypeCode.UInt32, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + case TF_DataType.TF_INT64: + { + ret = new NDArray(NPTypeCode.Int64, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + case TF_DataType.TF_UINT64: + { + ret = new NDArray(NPTypeCode.UInt64, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + case TF_DataType.TF_DOUBLE: + { + ret = new NDArray(NPTypeCode.Double, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + case TF_DataType.TF_FLOAT: + { + ret = new NDArray(NPTypeCode.Single, ndims, false); + System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); + break; + } + case TF_DataType.TF_STRING: + { + throw new NotImplementedException(); + //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString + using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) + ret = NDArray.FromString(reader.ReadString()); + break; + } + default: + throw new NotSupportedException(); + } + #endregion +#endif } } - - tensor.Dispose(); - return nd; + return ret; } /// diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index 62d7c63d..45005a59 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -2,6 +2,11 @@ using NumSharp; using System; using System.Collections.Generic; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using FluentAssertions; +using Google.Protobuf; using Tensorflow; using static Tensorflow.Binding; @@ -17,74 +22,113 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void Session() { - var s = new Status(); - var graph = new Graph(); + lock (this) + { + var s = new Status(); + var graph = new Graph(); - // Make a placeholder operation. - var feed = c_test_util.Placeholder(graph, s); + // Make a placeholder operation. + var feed = c_test_util.Placeholder(graph, s); - // Make a constant operation with the scalar "2". - var two = c_test_util.ScalarConst(2, graph, s); + // Make a constant operation with the scalar "2". + var two = c_test_util.ScalarConst(2, graph, s); - // Add operation. - var add = c_test_util.Add(feed, two, graph, s); + // Add operation. + var add = c_test_util.Add(feed, two, graph, s); - var csession = new CSession(graph, s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); + var csession = new CSession(graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); - // Run the graph. - var inputs = new Dictionary(); - inputs.Add(feed, new Tensor(3)); - csession.SetInputs(inputs); + // Run the graph. + var inputs = new Dictionary(); + inputs.Add(feed, new Tensor(3)); + csession.SetInputs(inputs); - var outputs = new TF_Output[] { new TF_Output(add, 0) }; - csession.SetOutputs(outputs); + var outputs = new TF_Output[] {new TF_Output(add, 0)}; + csession.SetOutputs(outputs); - csession.Run(s); - Tensor outTensor = csession.output_tensor(0); - EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); - EXPECT_EQ(0, outTensor.NDims); - ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - var output_contents = outTensor.ToArray(); - EXPECT_EQ(3 + 2, output_contents[0]); + csession.Run(s); + Tensor outTensor = csession.output_tensor(0); + EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); + EXPECT_EQ(0, outTensor.NDims); + ASSERT_EQ((ulong) sizeof(uint), outTensor.bytesize); + var output_contents = outTensor.ToArray(); + EXPECT_EQ(3 + 2, output_contents[0]); - // Add another operation to the graph. - var neg = c_test_util.Neg(add, graph, s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); + // Add another operation to the graph. + var neg = c_test_util.Neg(add, graph, s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); - // Run up to the new operation. - inputs = new Dictionary(); - inputs.Add(feed, new Tensor(7)); - csession.SetInputs(inputs); - outputs = new TF_Output[] { new TF_Output(neg, 0) }; - csession.SetOutputs(outputs); - csession.Run(s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); + // Run up to the new operation. + inputs = new Dictionary(); + inputs.Add(feed, new Tensor(7)); + csession.SetInputs(inputs); + outputs = new TF_Output[] {new TF_Output(neg, 0)}; + csession.SetOutputs(outputs); + csession.Run(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); - outTensor = csession.output_tensor(0); - ASSERT_TRUE(outTensor != IntPtr.Zero); - EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); - EXPECT_EQ(0, outTensor.NDims); // scalar - ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - output_contents = outTensor.ToArray(); - EXPECT_EQ(-(7 + 2), output_contents[0]); + outTensor = csession.output_tensor(0); + ASSERT_TRUE(outTensor != IntPtr.Zero); + EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); + EXPECT_EQ(0, outTensor.NDims); // scalar + ASSERT_EQ((ulong) sizeof(uint), outTensor.bytesize); + output_contents = outTensor.ToArray(); + EXPECT_EQ(-(7 + 2), output_contents[0]); - // Clean up - csession.CloseAndDelete(s); - ASSERT_EQ(TF_Code.TF_OK, s.Code); + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_Code.TF_OK, s.Code); + } } [TestMethod] public void EvalTensor() { - var a = constant_op.constant(np.array(3.0).reshape(1, 1)); - var b = constant_op.constant(np.array(2.0).reshape(1, 1)); - var c = math_ops.matmul(a, b, name: "matmul"); - using (var sess = tf.Session()) + lock (this) + { + var a = constant_op.constant(np.array(3.0).reshape(1, 1)); + var b = constant_op.constant(np.array(2.0).reshape(1, 1)); + var c = math_ops.matmul(a, b, name: "matmul"); + using (var sess = tf.Session()) + { + var result = c.eval(sess); + Assert.AreEqual(6, result.Data()[0]); + } + } + } + + [TestMethod] + public void Eval_SmallString_Scalar() + { + lock (this) + { + var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING); + var c = tf.strings.substr(a, 4, 8); + using (var sess = tf.Session()) + { + var result = (string) c.eval(sess); + Console.WriteLine(result); + result.Should().Be("heythere"); + } + } + } + + [TestMethod] + public void Eval_LargeString_Scalar() + { + lock (this) { - var result = c.eval(sess); - Assert.AreEqual(6, result.Data()[0]); + const int size = 30_000; + var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING); + var c = tf.strings.substr(a, 0, size - 5000); + using (var sess = tf.Session()) + { + var result = (string) c.eval(sess); + Console.WriteLine((string) result); + result.Should().HaveLength(size - 5000).And.ContainAll("a"); + } } } } -} +} \ No newline at end of file