| @@ -76,12 +76,12 @@ namespace Tensorflow | |||||
| obj = temp_obj; | obj = temp_obj; | ||||
| // If obj appears to be a name... | // If obj appears to be a name... | ||||
| if (obj is String str) | |||||
| if (obj is string name) | |||||
| { | { | ||||
| if(str.Contains(":") && allow_tensor) | |||||
| if(name.Contains(":") && allow_tensor) | |||||
| { | { | ||||
| string op_name = str.Split(':')[0]; | |||||
| int out_n = int.Parse(str.Split(':')[1]); | |||||
| string op_name = name.Split(':')[0]; | |||||
| int out_n = int.Parse(name.Split(':')[1]); | |||||
| if (_nodes_by_name.ContainsKey(op_name)) | if (_nodes_by_name.ContainsKey(op_name)) | ||||
| return _nodes_by_name[op_name].outputs[out_n]; | return _nodes_by_name[op_name].outputs[out_n]; | ||||
| @@ -67,7 +67,7 @@ namespace Tensorflow | |||||
| default: | default: | ||||
| throw new NotImplementedException("_run subfeed"); | throw new NotImplementedException("_run subfeed"); | ||||
| } | } | ||||
| feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value); | |||||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||||
| } | } | ||||
| } | } | ||||
| @@ -178,7 +178,8 @@ namespace Tensorflow | |||||
| case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
| var bytes = tensor.Data(); | var bytes = tensor.Data(); | ||||
| // wired, don't know why we have to start from offset 9. | // wired, don't know why we have to start from offset 9. | ||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9); | |||||
| // length in the begin | |||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||||
| nd = np.array(str).reshape(); | nd = np.array(str).reshape(); | ||||
| break; | break; | ||||
| case TF_DataType.TF_INT16: | case TF_DataType.TF_INT16: | ||||
| @@ -0,0 +1,111 @@ | |||||
| using NumSharp.Core; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| using static Tensorflow.c_api; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class Tensor | |||||
| { | |||||
| /// <summary> | |||||
| /// if original buffer is free. | |||||
| /// </summary> | |||||
| private bool deallocator_called; | |||||
| public Tensor(IntPtr handle) | |||||
| { | |||||
| _handle = handle; | |||||
| } | |||||
| public Tensor(NDArray nd) | |||||
| { | |||||
| _handle = Allocate(nd); | |||||
| } | |||||
| private IntPtr Allocate(NDArray nd) | |||||
| { | |||||
| IntPtr dotHandle = IntPtr.Zero; | |||||
| ulong size = 0; | |||||
| if (nd.dtype.Name != "String") | |||||
| { | |||||
| dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); | |||||
| size = (ulong)(nd.size * nd.dtypesize); | |||||
| } | |||||
| switch (nd.dtype.Name) | |||||
| { | |||||
| case "Int16": | |||||
| Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Int32": | |||||
| Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Single": | |||||
| Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Double": | |||||
| Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "String": | |||||
| /*var value = nd.Data<string>()[0]; | |||||
| var bytes = Encoding.UTF8.GetBytes(value); | |||||
| dotHandle = Marshal.AllocHGlobal(bytes.Length + 1); | |||||
| Marshal.Copy(bytes, 0, dotHandle, bytes.Length); | |||||
| size = (ulong)bytes.Length;*/ | |||||
| var str = nd.Data<string>()[0]; | |||||
| ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); | |||||
| //dotHandle = Marshal.AllocHGlobal((int)dst_len); | |||||
| //size = c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status); | |||||
| var dataType1 = ToTFDataType(nd.dtype); | |||||
| // shape | |||||
| var dims1 = nd.shape.Select(x => (long)x).ToArray(); | |||||
| var tfHandle1 = c_api.TF_AllocateTensor(dataType1, | |||||
| dims1, | |||||
| nd.ndim, | |||||
| dst_len); | |||||
| dotHandle = c_api.TF_TensorData(tfHandle1); | |||||
| c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status); | |||||
| return tfHandle1; | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("Marshal.Copy failed."); | |||||
| } | |||||
| var dataType = ToTFDataType(nd.dtype); | |||||
| // shape | |||||
| var dims = nd.shape.Select(x => (long)x).ToArray(); | |||||
| // Free the original buffer and set flag | |||||
| Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) => | |||||
| { | |||||
| Marshal.FreeHGlobal(dotHandle); | |||||
| closure = true; | |||||
| }; | |||||
| var tfHandle = c_api.TF_NewTensor(dataType, | |||||
| dims, | |||||
| nd.ndim, | |||||
| dotHandle, | |||||
| size, | |||||
| deallocator, | |||||
| ref deallocator_called); | |||||
| return tfHandle; | |||||
| } | |||||
| public Tensor(Operation op, int value_index, TF_DataType dtype) | |||||
| { | |||||
| this.op = op; | |||||
| this.value_index = value_index; | |||||
| this._dtype = dtype; | |||||
| _id = ops.uid(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -95,86 +95,6 @@ namespace Tensorflow | |||||
| public int NDims => rank; | public int NDims => rank; | ||||
| /// <summary> | |||||
| /// if original buffer is free. | |||||
| /// </summary> | |||||
| private bool deallocator_called; | |||||
| public Tensor(IntPtr handle) | |||||
| { | |||||
| _handle = handle; | |||||
| } | |||||
| public Tensor(NDArray nd) | |||||
| { | |||||
| _handle = Allocate(nd); | |||||
| } | |||||
| private IntPtr Allocate(NDArray nd) | |||||
| { | |||||
| IntPtr dotHandle = IntPtr.Zero; | |||||
| ulong size = 0; | |||||
| if (nd.dtype.Name != "String") | |||||
| { | |||||
| dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); | |||||
| size = (ulong)(nd.size * nd.dtypesize); | |||||
| } | |||||
| switch (nd.dtype.Name) | |||||
| { | |||||
| case "Int16": | |||||
| Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Int32": | |||||
| Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Single": | |||||
| Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Double": | |||||
| Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "String": | |||||
| var value = nd.Data<string>()[0]; | |||||
| var bytes = Encoding.UTF8.GetBytes(value); | |||||
| dotHandle = Marshal.AllocHGlobal(bytes.Length + 1); | |||||
| Marshal.Copy(bytes, 0, dotHandle, bytes.Length); | |||||
| size = (ulong)bytes.Length; | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("Marshal.Copy failed."); | |||||
| } | |||||
| var dataType = ToTFDataType(nd.dtype); | |||||
| // shape | |||||
| var dims = nd.shape.Select(x => (long)x).ToArray(); | |||||
| // Free the original buffer and set flag | |||||
| Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) => | |||||
| { | |||||
| Marshal.FreeHGlobal(dotHandle); | |||||
| closure = true; | |||||
| }; | |||||
| var tfHandle = c_api.TF_NewTensor(dataType, | |||||
| dims, | |||||
| nd.ndim, | |||||
| dotHandle, | |||||
| size, | |||||
| deallocator, | |||||
| ref deallocator_called); | |||||
| return tfHandle; | |||||
| } | |||||
| public Tensor(Operation op, int value_index, TF_DataType dtype) | |||||
| { | |||||
| this.op = op; | |||||
| this.value_index = value_index; | |||||
| this._dtype = dtype; | |||||
| _id = ops.uid(); | |||||
| } | |||||
| public Operation[] Consumers => consumers(); | public Operation[] Consumers => consumers(); | ||||
| public string Device => op.Device; | public string Device => op.Device; | ||||
| @@ -120,7 +120,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns>On success returns the size in bytes of the encoded string.</returns> | /// <returns>On success returns the size in bytes of the encoded string.</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern ulong TF_StringEncode(string src, ulong src_len, string dst, ulong dst_len, IntPtr status); | |||||
| public static extern ulong TF_StringEncode(string src, ulong src_len, IntPtr dst, ulong dst_len, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Decode a string encoded using TF_StringEncode. | /// Decode a string encoded using TF_StringEncode. | ||||
| @@ -132,6 +132,6 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern ulong TF_StringDecode(string src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status); | |||||
| public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status); | |||||
| } | } | ||||
| } | } | ||||
| @@ -138,12 +138,14 @@ namespace Tensorflow | |||||
| public string save(Session sess, | public string save(Session sess, | ||||
| string save_path, | string save_path, | ||||
| string global_step = "", | string global_step = "", | ||||
| string latest_filename = "", | |||||
| string meta_graph_suffix = "meta", | string meta_graph_suffix = "meta", | ||||
| bool write_meta_graph = true, | bool write_meta_graph = true, | ||||
| bool write_state = true, | bool write_state = true, | ||||
| bool strip_default_attrs = false) | bool strip_default_attrs = false) | ||||
| { | { | ||||
| string latest_filename = "checkpoint"; | |||||
| if (string.IsNullOrEmpty(latest_filename)) | |||||
| latest_filename = "checkpoint"; | |||||
| string model_checkpoint_path = ""; | string model_checkpoint_path = ""; | ||||
| string checkpoint_file = ""; | string checkpoint_file = ""; | ||||
| @@ -3,6 +3,7 @@ using NumSharp.Core; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| @@ -104,11 +105,14 @@ namespace TensorFlowNET.UnitTest | |||||
| string str = "Hello, TensorFlow.NET!"; | string str = "Hello, TensorFlow.NET!"; | ||||
| ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); | ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); | ||||
| Assert.AreEqual(dst_len, (ulong)23); | Assert.AreEqual(dst_len, (ulong)23); | ||||
| string dst = ""; | |||||
| c_api.TF_StringEncode(str, (ulong)str.Length, dst, dst_len, status); | |||||
| IntPtr dst = Marshal.AllocHGlobal((int)dst_len); | |||||
| ulong encoded_len = c_api.TF_StringEncode(str, (ulong)str.Length, dst, dst_len, status); | |||||
| Assert.AreEqual((ulong)23, encoded_len); | |||||
| Assert.AreEqual(status.Code, TF_Code.TF_OK); | Assert.AreEqual(status.Code, TF_Code.TF_OK); | ||||
| //c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); | |||||
| string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); | |||||
| Assert.AreEqual(encoded_str, str); | |||||
| Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); | |||||
| //c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -45,7 +45,6 @@ namespace TensorFlowNET.UnitTest | |||||
| }); | }); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Save2() | public void Save2() | ||||
| { | { | ||||
| var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | ||||