From 4ca080e56586472be9cc84e440945053b9d9ae7e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 30 Jul 2019 17:54:05 -0500 Subject: [PATCH] Hello World works. --- src/TensorFlowNET.Core/Sessions/BaseSession.cs | 2 +- src/TensorFlowNET.Core/TensorFlowNET.Core.csproj | 1 - src/TensorFlowNET.Core/Tensors/dtypes.cs | 9 +++++---- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 2 +- src/TensorFlowNET.Core/Tensors/tf.constant.cs | 8 ++++++++ test/TensorFlowNET.Examples/HelloWorld.cs | 5 +++-- 6 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 8942a0fb..13246e8f 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -302,7 +302,7 @@ namespace Tensorflow // wired, don't know why we have to start from offset 9. // length in the begin var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); - nd = np.array(str).reshape(); + nd = np.array(str); break; case TF_DataType.TF_UINT8: var _bytes = new byte[tensor.size]; diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 3d9f6eee..95908a9e 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -63,7 +63,6 @@ Docs: https://tensorflownet.readthedocs.io - diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 16b09d05..807dc6f5 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -51,12 +51,13 @@ namespace Tensorflow } // "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex" - public static TF_DataType as_dtype(Type type) + public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null) { - TF_DataType dtype = TF_DataType.DtInvalid; - switch (type.Name) { + case "Char": + dtype = dtype ?? TF_DataType.TF_UINT8; + break; case "SByte": dtype = TF_DataType.TF_INT8; break; @@ -100,7 +101,7 @@ namespace Tensorflow throw new Exception("as_dtype Not Implemented"); } - return dtype; + return dtype.Value; } public static DataType as_datatype_enum(this TF_DataType type) diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 7dbc2213..f7089a8e 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -226,7 +226,7 @@ namespace Tensorflow } } - var numpy_dtype = dtypes.as_dtype(nparray.dtype); + var numpy_dtype = dtypes.as_dtype(nparray.dtype, dtype: dtype); if (numpy_dtype == TF_DataType.DtInvalid) throw new TypeError($"Unrecognized data type: {nparray.dtype}"); diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs index 61ef232b..ddb450e2 100644 --- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -33,6 +33,14 @@ namespace Tensorflow verify_shape: verify_shape, allow_broadcast: false); + public static Tensor constant(string value, + string name = "Const") => constant_op._constant_impl(value, + tf.@string, + new int[] { 1 }, + name, + verify_shape: false, + allow_broadcast: false); + public static Tensor constant(float value, int shape, string name = "Const") => constant_op._constant_impl(value, diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs index e9c91336..ca5e669b 100644 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ b/test/TensorFlowNET.Examples/HelloWorld.cs @@ -29,8 +29,9 @@ namespace TensorFlowNET.Examples { // Run the op var result = sess.run(hello); - Console.WriteLine(result.ToString()); - return result.ToString().Equals(str); + string result_string = string.Join("", result.GetData()); + Console.WriteLine(result_string); + return result_string.Equals(str); }); }