| @@ -0,0 +1,19 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public static partial class tf | |||||
| { | |||||
| public static unsafe Tensor add(Tensor a, Tensor b) | |||||
| { | |||||
| return gen_math_ops.add(a, b); | |||||
| } | |||||
| public static unsafe Tensor multiply(Tensor x, Tensor y) | |||||
| { | |||||
| return gen_math_ops.mul(x, y); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -12,13 +12,6 @@ namespace Tensorflow | |||||
| public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | ||||
| { | { | ||||
| /*var g = ops.get_default_graph(); | |||||
| var op = new Operation(g, "Placeholder", "feed"); | |||||
| var tensor = new Tensor(op, 0, dtype); | |||||
| return tensor;*/ | |||||
| var keywords = new Dictionary<string, object>(); | var keywords = new Dictionary<string, object>(); | ||||
| keywords.Add("dtype", dtype); | keywords.Add("dtype", dtype); | ||||
| keywords.Add("shape", shape); | keywords.Add("shape", shape); | ||||
| @@ -31,8 +24,7 @@ namespace Tensorflow | |||||
| _attrs["dtype"] = _op.get_attr("dtype"); | _attrs["dtype"] = _op.get_attr("dtype"); | ||||
| _attrs["shape"] = _op.get_attr("shape"); | _attrs["shape"] = _op.get_attr("shape"); | ||||
| var tensor = new Tensor(_op, 0, dtype); | |||||
| return tensor; | |||||
| return new Tensor(_op, 0, dtype); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -105,8 +105,8 @@ namespace Tensorflow | |||||
| c_api.TF_SessionRun(_session, | c_api.TF_SessionRun(_session, | ||||
| run_options: null, | run_options: null, | ||||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | inputs: feed_dict.Select(f => f.Key).ToArray(), | ||||
| input_values: new IntPtr[] { }, | |||||
| ninputs: 0, | |||||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||||
| ninputs: feed_dict.Length, | |||||
| outputs: fetch_list, | outputs: fetch_list, | ||||
| output_values: output_values, | output_values: output_values, | ||||
| noutputs: fetch_list.Length, | noutputs: fetch_list.Length, | ||||
| @@ -115,6 +115,8 @@ namespace Tensorflow | |||||
| run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
| status: status); | status: status); | ||||
| status.Check(true); | |||||
| object[] result = new object[fetch_list.Length]; | object[] result = new object[fetch_list.Length]; | ||||
| for (int i = 0; i < fetch_list.Length; i++) | for (int i = 0; i < fetch_list.Length; i++) | ||||
| @@ -37,20 +37,9 @@ namespace Tensorflow | |||||
| switch (type) | switch (type) | ||||
| { | { | ||||
| case TF_DataType.TF_INT32: | |||||
| dtype = DataType.DtInt32; | |||||
| break; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| dtype = DataType.DtFloat; | |||||
| break; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| dtype = DataType.DtDouble; | |||||
| break; | |||||
| case TF_DataType.TF_STRING: | |||||
| dtype = DataType.DtString; | |||||
| break; | |||||
| default: | default: | ||||
| throw new Exception("Not Implemented"); | |||||
| Enum.TryParse(((int)type).ToString(), out dtype); | |||||
| break; | |||||
| } | } | ||||
| return dtype; | return dtype; | ||||
| @@ -10,6 +10,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static partial class tf | public static partial class tf | ||||
| { | { | ||||
| public static TF_DataType int16 = TF_DataType.TF_INT16; | |||||
| public static TF_DataType float32 = TF_DataType.TF_FLOAT; | public static TF_DataType float32 = TF_DataType.TF_FLOAT; | ||||
| public static TF_DataType chars = TF_DataType.TF_STRING; | public static TF_DataType chars = TF_DataType.TF_STRING; | ||||
| @@ -22,11 +23,6 @@ namespace Tensorflow | |||||
| return new RefVariable(data, dtype); | return new RefVariable(data, dtype); | ||||
| } | } | ||||
| public static unsafe Tensor add(Tensor a, Tensor b) | |||||
| { | |||||
| return gen_math_ops.add(a, b); | |||||
| } | |||||
| public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | ||||
| { | { | ||||
| return gen_array_ops.placeholder(dtype, shape); | return gen_array_ops.placeholder(dtype, shape); | ||||
| @@ -11,6 +11,8 @@ namespace TensorFlowNET.Examples | |||||
| /// </summary> | /// </summary> | ||||
| public class BasicOperations : IExample | public class BasicOperations : IExample | ||||
| { | { | ||||
| private Session sess; | |||||
| public void Run() | public void Run() | ||||
| { | { | ||||
| // Basic constant operations | // Basic constant operations | ||||
| @@ -18,14 +20,34 @@ namespace TensorFlowNET.Examples | |||||
| // of the Constant op. | // of the Constant op. | ||||
| var a = tf.constant(2); | var a = tf.constant(2); | ||||
| var b = tf.constant(3); | var b = tf.constant(3); | ||||
| var c = a * b; | |||||
| // Launch the default graph. | // Launch the default graph. | ||||
| using (var sess = tf.Session()) | |||||
| using (sess = tf.Session()) | |||||
| { | { | ||||
| Console.WriteLine("a=2, b=3"); | Console.WriteLine("a=2, b=3"); | ||||
| Console.WriteLine($"Addition with constants: {sess.run(a + b)}"); | Console.WriteLine($"Addition with constants: {sess.run(a + b)}"); | ||||
| Console.WriteLine($"Multiplication with constants: {sess.run(a * b)}"); | Console.WriteLine($"Multiplication with constants: {sess.run(a * b)}"); | ||||
| } | } | ||||
| // Basic Operations with variable as graph input | |||||
| // The value returned by the constructor represents the output | |||||
| // of the Variable op. (define as input when running session) | |||||
| // tf Graph input | |||||
| a = tf.placeholder(tf.int16); | |||||
| b = tf.placeholder(tf.int16); | |||||
| // Define some operations | |||||
| var add = tf.add(a, b); | |||||
| var mul = tf.multiply(a, b); | |||||
| // Launch the default graph. | |||||
| using(sess = tf.Session()) | |||||
| { | |||||
| // var feed_dict = new Dictionary<string, > | |||||
| // Run every operation with variable input | |||||
| // Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict: {a: 2, b: 3})}"); | |||||
| // Console.WriteLine($"Multiplication with variables: {}"); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -78,7 +78,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray(); | var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray(); | ||||
| IntPtr targets_ptr = IntPtr.Zero; | IntPtr targets_ptr = IntPtr.Zero; | ||||
| c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 1, | |||||
| c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | |||||
| outputs_ptr, output_values_ptr, outputs_.Count, | outputs_ptr, output_values_ptr, outputs_.Count, | ||||
| targets_ptr, targets_.Count, | targets_ptr, targets_.Count, | ||||
| IntPtr.Zero, s); | IntPtr.Zero, s); | ||||
| @@ -35,7 +35,8 @@ namespace TensorFlowNET.UnitTest | |||||
| feed_dict.Add(a, 3.0f); | feed_dict.Add(a, 3.0f); | ||||
| feed_dict.Add(b, 2.0f); | feed_dict.Add(b, 2.0f); | ||||
| //var o = sess.run(c, feed_dict); | |||||
| var o = sess.run(c, feed_dict); | |||||
| Assert.AreEqual(o, 5.0f); | |||||
| } | } | ||||
| } | } | ||||