diff --git a/README.md b/README.md index 8fa2127d..eaf5c30c 100644 --- a/README.md +++ b/README.md @@ -7,17 +7,37 @@ TensorFlow.NET is a member project of SciSharp stack. ![tensors_flowing](docs/assets/tensors_flowing.gif) ### How to use +Download the pre-compiled dll [here](tensorflowlib) and place it in the bin folder. + +```cs +// import tensorflow.net +using using Tensorflow; +``` + ```cs -using TensorFlowNET.Core; +// Create a Constant op +var a = tf.constant(4.0f); +var b = tf.constant(5.0f); +var c = tf.add(a, b); -namespace TensorFlowNET.Examples +using (var sess = tf.Session()) { - public class HelloWorld : IExample - { - public void Run() - { - var hello = tf.constant("Hello, TensorFlow!"); - } - } + var o = sess.run(c); +} +``` + +```cs +// Create a placeholder op +var a = tf.placeholder(tf.float32); +var b = tf.placeholder(tf.float32); +var c = tf.add(a, b); + +using(var sess = tf.Session()) +{ + var feed_dict = new Dictionary(); + feed_dict.Add(a, 3.0f); + feed_dict.Add(b, 2.0f); + + var o = sess.run(c, feed_dict); } ``` \ No newline at end of file diff --git a/src/TensorFlowNET.Core/BaseSession.cs b/src/TensorFlowNET.Core/BaseSession.cs index 4e7c983f..ffb77c12 100644 --- a/src/TensorFlowNET.Core/BaseSession.cs +++ b/src/TensorFlowNET.Core/BaseSession.cs @@ -66,16 +66,16 @@ namespace Tensorflow var status = new Status(); c_api.TF_SessionRun(_session, - run_options: null, + run_options: IntPtr.Zero, inputs: new TF_Output[] { }, input_values: new IntPtr[] { }, - ninputs: 1, + ninputs: 0, outputs: new TF_Output[] { }, output_values: new IntPtr[] { }, noutputs: 1, target_opers: new IntPtr[] { }, ntargets: 1, - run_metadata: null, + run_metadata: IntPtr.Zero, status: status.Handle); return null; diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index 48490328..abd50bf3 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -67,11 +67,11 @@ namespace Tensorflow public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status); [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SessionRun(TF_Session session, TF_Buffer* run_options, + public static extern unsafe void TF_SessionRun(TF_Session session, IntPtr run_options, TF_Output[] inputs, TF_Tensor[] input_values, int ninputs, TF_Output[] outputs, TF_Tensor[] output_values, int noutputs, TF_Operation[] target_opers, int ntargets, - TF_Buffer* run_metadata, + IntPtr run_metadata, TF_Status status); [DllImport(TensorFlowLibName)] diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index da3287b8..395410d1 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -21,9 +21,12 @@ namespace Tensorflow var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); // Add inputs - foreach(var op_input in inputs) + if(inputs != null) { - c_api.TF_AddInput(op_desc, op_input._as_tf_output()); + foreach (var op_input in inputs) + { + c_api.TF_AddInput(op_desc, op_input._as_tf_output()); + } } var status = new Status(); diff --git a/src/TensorFlowNET.Core/tensor_util.cs b/src/TensorFlowNET.Core/tensor_util.cs index ec9b23bb..1c6da7bc 100644 --- a/src/TensorFlowNET.Core/tensor_util.cs +++ b/src/TensorFlowNET.Core/tensor_util.cs @@ -16,6 +16,15 @@ namespace Tensorflow switch (values) { + case float val: + nparray = np.array(new float[] { val }, np.float32); + tensor_proto = new tensor_pb2.TensorProto + { + Dtype = DataType.DtFloat, + TensorShape = tensor_shape.as_shape().as_proto() + }; + tensor_proto.FloatVal.Add(val); + break; case double val: nparray = np.array(new double[] { val }, np.float64); tensor_proto = new tensor_pb2.TensorProto @@ -25,7 +34,6 @@ namespace Tensorflow }; tensor_proto.DoubleVal.Add(val); break; - case string val: nparray = np.array(new string[] { val }, np.chars); tensor_proto = new tensor_pb2.TensorProto diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 47d7c8d1..f781cb5d 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -12,7 +12,13 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void constant() { - tf.constant(4.0); + var a = tf.constant(4.0f); + var b = tf.constant(5.0f); + var c = tf.add(a, b); + using (var sess = tf.Session()) + { + var o = sess.run(c); + } } [TestMethod] @@ -31,10 +37,10 @@ namespace TensorFlowNET.UnitTest using(var sess = tf.Session()) { var feed_dict = new Dictionary(); - feed_dict.Add(a, 3); - feed_dict.Add(b, 2); + feed_dict.Add(a, 3.0f); + feed_dict.Add(b, 2.0f); - sess.run(c, feed_dict); + var o = sess.run(c, feed_dict); } } }