From 8ae2feb5fbca339407dfacb6038e6402ff5fe07f Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 12 Feb 2019 07:10:14 -0600 Subject: [PATCH] Malformed TF_STRING tensor; element 0 out of range --- src/TensorFlowNET.Core/Graphs/Graph.cs | 8 +- .../Sessions/BaseSession.cs | 5 +- .../Tensors/Tensor.Creation.cs | 111 ++++++++++++++++++ src/TensorFlowNET.Core/Tensors/Tensor.cs | 80 ------------- .../Tensors/c_api.tensor.cs | 4 +- src/TensorFlowNET.Core/Train/Saving/Saver.cs | 4 +- test/TensorFlowNET.UnitTest/ConstantTest.cs | 12 +- test/TensorFlowNET.UnitTest/TrainSaverTest.cs | 1 - 8 files changed, 131 insertions(+), 94 deletions(-) create mode 100644 src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 79e9fbe5..8fba13d5 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -76,12 +76,12 @@ namespace Tensorflow obj = temp_obj; // If obj appears to be a name... - if (obj is String str) + if (obj is string name) { - if(str.Contains(":") && allow_tensor) + if(name.Contains(":") && allow_tensor) { - string op_name = str.Split(':')[0]; - int out_n = int.Parse(str.Split(':')[1]); + string op_name = name.Split(':')[0]; + int out_n = int.Parse(name.Split(':')[1]); if (_nodes_by_name.ContainsKey(op_name)) return _nodes_by_name[op_name].outputs[out_n]; diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 8044e904..656ab344 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -67,7 +67,7 @@ namespace Tensorflow default: throw new NotImplementedException("_run subfeed"); } - feed_map[subfeed_t.name] = new Tuple(subfeed_t, subfeed.Value); + feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); } } @@ -178,7 +178,8 @@ namespace Tensorflow case TF_DataType.TF_STRING: var bytes = tensor.Data(); // wired, don't know why we have to start from offset 9. - var str = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9); + // length in the begin + var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); nd = np.array(str).reshape(); break; case TF_DataType.TF_INT16: diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs new file mode 100644 index 00000000..4755454e --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -0,0 +1,111 @@ +using NumSharp.Core; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using static Tensorflow.c_api; + +namespace Tensorflow +{ + public partial class Tensor + { + /// + /// if original buffer is free. + /// + private bool deallocator_called; + + public Tensor(IntPtr handle) + { + _handle = handle; + } + + public Tensor(NDArray nd) + { + _handle = Allocate(nd); + } + + private IntPtr Allocate(NDArray nd) + { + IntPtr dotHandle = IntPtr.Zero; + ulong size = 0; + + if (nd.dtype.Name != "String") + { + dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); + size = (ulong)(nd.size * nd.dtypesize); + } + + switch (nd.dtype.Name) + { + case "Int16": + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); + break; + case "Int32": + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); + break; + case "Single": + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); + break; + case "Double": + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); + break; + case "String": + /*var value = nd.Data()[0]; + var bytes = Encoding.UTF8.GetBytes(value); + dotHandle = Marshal.AllocHGlobal(bytes.Length + 1); + Marshal.Copy(bytes, 0, dotHandle, bytes.Length); + size = (ulong)bytes.Length;*/ + + var str = nd.Data()[0]; + ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); + //dotHandle = Marshal.AllocHGlobal((int)dst_len); + //size = c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status); + + var dataType1 = ToTFDataType(nd.dtype); + // shape + var dims1 = nd.shape.Select(x => (long)x).ToArray(); + + var tfHandle1 = c_api.TF_AllocateTensor(dataType1, + dims1, + nd.ndim, + dst_len); + + dotHandle = c_api.TF_TensorData(tfHandle1); + c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status); + return tfHandle1; + break; + default: + throw new NotImplementedException("Marshal.Copy failed."); + } + + var dataType = ToTFDataType(nd.dtype); + // shape + var dims = nd.shape.Select(x => (long)x).ToArray(); + // Free the original buffer and set flag + Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) => + { + Marshal.FreeHGlobal(dotHandle); + closure = true; + }; + + var tfHandle = c_api.TF_NewTensor(dataType, + dims, + nd.ndim, + dotHandle, + size, + deallocator, + ref deallocator_called); + + return tfHandle; + } + + public Tensor(Operation op, int value_index, TF_DataType dtype) + { + this.op = op; + this.value_index = value_index; + this._dtype = dtype; + _id = ops.uid(); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index d64086c1..96111cd9 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -95,86 +95,6 @@ namespace Tensorflow public int NDims => rank; - /// - /// if original buffer is free. - /// - private bool deallocator_called; - - public Tensor(IntPtr handle) - { - _handle = handle; - } - - public Tensor(NDArray nd) - { - _handle = Allocate(nd); - } - - private IntPtr Allocate(NDArray nd) - { - IntPtr dotHandle = IntPtr.Zero; - ulong size = 0; - - if (nd.dtype.Name != "String") - { - dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); - size = (ulong)(nd.size * nd.dtypesize); - } - - switch (nd.dtype.Name) - { - case "Int16": - Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); - break; - case "Int32": - Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); - break; - case "Single": - Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); - break; - case "Double": - Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); - break; - case "String": - var value = nd.Data()[0]; - var bytes = Encoding.UTF8.GetBytes(value); - dotHandle = Marshal.AllocHGlobal(bytes.Length + 1); - Marshal.Copy(bytes, 0, dotHandle, bytes.Length); - size = (ulong)bytes.Length; - break; - default: - throw new NotImplementedException("Marshal.Copy failed."); - } - - var dataType = ToTFDataType(nd.dtype); - // shape - var dims = nd.shape.Select(x => (long)x).ToArray(); - // Free the original buffer and set flag - Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) => - { - Marshal.FreeHGlobal(dotHandle); - closure = true; - }; - - var tfHandle = c_api.TF_NewTensor(dataType, - dims, - nd.ndim, - dotHandle, - size, - deallocator, - ref deallocator_called); - - return tfHandle; - } - - public Tensor(Operation op, int value_index, TF_DataType dtype) - { - this.op = op; - this.value_index = value_index; - this._dtype = dtype; - _id = ops.uid(); - } - public Operation[] Consumers => consumers(); public string Device => op.Device; diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index a06af129..78b1016f 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -120,7 +120,7 @@ namespace Tensorflow /// TF_Status* /// On success returns the size in bytes of the encoded string. [DllImport(TensorFlowLibName)] - public static extern ulong TF_StringEncode(string src, ulong src_len, string dst, ulong dst_len, IntPtr status); + public static extern ulong TF_StringEncode(string src, ulong src_len, IntPtr dst, ulong dst_len, IntPtr status); /// /// Decode a string encoded using TF_StringEncode. @@ -132,6 +132,6 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern ulong TF_StringDecode(string src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status); + public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status); } } diff --git a/src/TensorFlowNET.Core/Train/Saving/Saver.cs b/src/TensorFlowNET.Core/Train/Saving/Saver.cs index 90f0fcd0..7ec46172 100644 --- a/src/TensorFlowNET.Core/Train/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Train/Saving/Saver.cs @@ -138,12 +138,14 @@ namespace Tensorflow public string save(Session sess, string save_path, string global_step = "", + string latest_filename = "", string meta_graph_suffix = "meta", bool write_meta_graph = true, bool write_state = true, bool strip_default_attrs = false) { - string latest_filename = "checkpoint"; + if (string.IsNullOrEmpty(latest_filename)) + latest_filename = "checkpoint"; string model_checkpoint_path = ""; string checkpoint_file = ""; diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs index 8385f42c..b3e222a2 100644 --- a/test/TensorFlowNET.UnitTest/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -3,6 +3,7 @@ using NumSharp.Core; using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.InteropServices; using System.Text; using Tensorflow; @@ -104,11 +105,14 @@ namespace TensorFlowNET.UnitTest string str = "Hello, TensorFlow.NET!"; ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); Assert.AreEqual(dst_len, (ulong)23); - string dst = ""; - c_api.TF_StringEncode(str, (ulong)str.Length, dst, dst_len, status); + IntPtr dst = Marshal.AllocHGlobal((int)dst_len); + ulong encoded_len = c_api.TF_StringEncode(str, (ulong)str.Length, dst, dst_len, status); + Assert.AreEqual((ulong)23, encoded_len); Assert.AreEqual(status.Code, TF_Code.TF_OK); - - //c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); + string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); + Assert.AreEqual(encoded_str, str); + Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); + //c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); } /// diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index ef00dc91..9e86e8d5 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -45,7 +45,6 @@ namespace TensorFlowNET.UnitTest }); } - [TestMethod] public void Save2() { var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);