diff --git a/src/TensorFlowNET.Core/Eager/Context.cs b/src/TensorFlowNET.Core/Eager/Context.cs new file mode 100644 index 00000000..b922c6ae --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/Context.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace TensorFlowNET.Core.Eager +{ + public class Context + { + public static int GRAPH_MODE = 0; + public static int EAGER_MODE = 1; + + public int default_execution_mode; + + + } +} diff --git a/src/TensorFlowNET.Core/Operation.cs b/src/TensorFlowNET.Core/Operation.cs index 9b7d9ff2..76d3c815 100644 --- a/src/TensorFlowNET.Core/Operation.cs +++ b/src/TensorFlowNET.Core/Operation.cs @@ -17,6 +17,17 @@ namespace TensorFlowNET.Core public Tensor[] outputs => _outputs; public Tensor[] inputs; + public Operation(Graph g, string opType, string oper_name) + { + _graph = g; + + var status = new Status(); + + var desc = c_api.TF_NewOperation(g.Handle, opType, oper_name); + c_api.TF_SetAttrType(desc, "dtype", DataType.DtInt32); + c_api.TF_FinishOperation(desc, status.Handle); + } + public Operation(NodeDef node_def, Graph g, object inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) { _graph = g; diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 83bbfc29..60307d04 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -5,6 +5,7 @@ using System.Text; using TF_DataType = Tensorflow.DataType; using attr_value_pb2 = Tensorflow; using Tensorflow; +using TensorFlowNET.Core.Eager; namespace TensorFlowNET.Core { @@ -12,11 +13,25 @@ namespace TensorFlowNET.Core { public static DataType float32 = DataType.DtFloat; + public static Context context = new Context(); + + public static Graph g = new Graph(c_api.TF_NewGraph()); + public delegate void Deallocator(IntPtr data, IntPtr size, IntPtr deallocatorData); + public static unsafe Tensor add(Tensor a, Tensor b) + { + return null; + } + public static unsafe Tensor placeholder(DataType dtype, TensorShape shape = null) { - return gen_array_ops.placeholder(dtype, shape); + var g = ops.get_default_graph(); + var op = new Operation(g, "Placeholder", "feed"); + + var tensor = new Tensor(op, 0, dtype); + //return gen_array_ops.placeholder(dtype, shape); + return tensor; } public static unsafe Tensor constant(object value) @@ -38,6 +53,11 @@ namespace TensorFlowNET.Core return const_tensor; } + public static void enable_eager_execution() + { + context.default_execution_mode = Context.EAGER_MODE; + } + public static Deallocator FreeTensorDataDelegate = FreeTensorData; [MonoPInvokeCallback(typeof(Deallocator))] @@ -55,7 +75,6 @@ namespace TensorFlowNET.Core public static Graph Graph() { - Graph g = new Graph(c_api.TF_NewGraph()); return g; } diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index ce58f81a..cb6f3ddd 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -18,7 +18,18 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void placeholder() { - var x = tf.placeholder(tf.float32, shape: new TensorShape(1024, 1024)); + var x = tf.placeholder(tf.float32); + } + + [TestMethod] + public void add() + { + var a = tf.placeholder(tf.float32); + var b = tf.placeholder(tf.float32); + var c = tf.add(a, b); + + //sess.run(adder_node, { a: 3, b: 4.5}) + //sess.run(adder_node, {a: [1,3], b: [2, 4]}) } } }