From 518095db750e677a26887475563585cd6a12d1cc Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 13 Dec 2018 07:04:56 -0600 Subject: [PATCH] update tf.constant --- src/TensorFlowNET.Core/Graph.cs | 7 ++++--- src/TensorFlowNET.Core/TensorFlowNET.Core.csproj | 4 ++++ src/TensorFlowNET.Core/Tensorflow.cs | 9 ++++++--- src/TensorFlowNET.Core/tensor_util.cs | 15 +++++++++++++++ test/TensorFlowNET.Examples/HelloWorld.cs | 1 + 5 files changed, 30 insertions(+), 6 deletions(-) create mode 100644 src/TensorFlowNET.Core/tensor_util.cs diff --git a/src/TensorFlowNET.Core/Graph.cs b/src/TensorFlowNET.Core/Graph.cs index 696cf38a..3518ac7a 100644 --- a/src/TensorFlowNET.Core/Graph.cs +++ b/src/TensorFlowNET.Core/Graph.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Runtime.InteropServices; using System.Text; @@ -26,7 +27,7 @@ namespace TensorFlowNET.Core _nodes_by_name = new Dictionary(); } - public unsafe Operation create_op(object inputs, string op_type = "", string name = "") + public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = "") { if (String.IsNullOrEmpty(name)) { @@ -51,9 +52,9 @@ namespace TensorFlowNET.Core return ++_next_id_counter; } - public void get_operations() + public Operation[] get_operations() { - + return _nodes_by_name.Values.Select(x => x).ToArray(); } } } diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 43821acd..c95a170c 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -9,6 +9,10 @@ DEBUG;TRACE + + + + PreserveNewest diff --git a/src/TensorFlowNET.Core/Tensorflow.cs b/src/TensorFlowNET.Core/Tensorflow.cs index 1fc847c0..e83e8a3a 100644 --- a/src/TensorFlowNET.Core/Tensorflow.cs +++ b/src/TensorFlowNET.Core/Tensorflow.cs @@ -3,8 +3,6 @@ using System.Collections.Generic; using System.Runtime.InteropServices; using System.Text; - - namespace TensorFlowNET.Core { public static class Tensorflow @@ -14,7 +12,7 @@ namespace TensorFlowNET.Core public static unsafe Tensor constant(object value) { var g = ops.get_default_graph(); - g.create_op(value, "Const"); + g.create_op("Const", value, new TF_DataType[] { TF_DataType.TF_DOUBLE }); return new Tensor(); } @@ -29,6 +27,11 @@ namespace TensorFlowNET.Core public static string VERSION => Marshal.PtrToStringAnsi(c_api.TF_Version()); + public static Graph get_default_graph() + { + return ops.get_default_graph(); + } + public static Graph Graph() { Graph g = new Graph(c_api.TF_NewGraph()); diff --git a/src/TensorFlowNET.Core/tensor_util.cs b/src/TensorFlowNET.Core/tensor_util.cs new file mode 100644 index 00000000..c3cc2cf4 --- /dev/null +++ b/src/TensorFlowNET.Core/tensor_util.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using np = NumSharp.Core.NumPy; + +namespace TensorFlowNET.Core +{ + public static class tensor_util + { + public static void make_tensor_proto(object values, Type dtype = null) + { + var nparray = np.array(values as Array, dtype); + } + } +} diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs index dedcff7f..a32700e4 100644 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ b/test/TensorFlowNET.Examples/HelloWorld.cs @@ -18,6 +18,7 @@ namespace TensorFlowNET.Examples The value returned by the constructor represents the output of the Constant op.*/ + var graph = tf.get_default_graph(); var hello = tf.constant(4.0); //var hello = tf.constant("Hello, TensorFlow!");