diff --git a/src/TensorFlowNET.Core/BaseSession.cs b/src/TensorFlowNET.Core/BaseSession.cs index 3be70a9d..c2d4565e 100644 --- a/src/TensorFlowNET.Core/BaseSession.cs +++ b/src/TensorFlowNET.Core/BaseSession.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow { - public class BaseSession + public class BaseSession : IDisposable { private Graph _graph; private bool _opened; @@ -32,18 +32,23 @@ namespace Tensorflow c_api.TF_DeleteSessionOptions(opts); } - public virtual byte[] run(Tensor fetches) + public void Dispose() { - return _run(fetches); + } - private unsafe byte[] _run(Tensor fetches) + public virtual byte[] run(Tensor fetches, Dictionary feed_dict = null) + { + return _run(fetches, feed_dict); + } + + private unsafe byte[] _run(Tensor fetches, Dictionary feed_dict = null) { var status = new Status(); c_api.TF_SessionRun(_session, run_options: null, - inputs: new TF_Input[] { }, + inputs: new TF_Output[] { }, input_values: new IntPtr[] { }, ninputs: 1, outputs: new TF_Output[] { }, diff --git a/src/TensorFlowNET.Core/Graph.cs b/src/TensorFlowNET.Core/Graph.cs index 93180a61..bc559543 100644 --- a/src/TensorFlowNET.Core/Graph.cs +++ b/src/TensorFlowNET.Core/Graph.cs @@ -31,7 +31,7 @@ namespace Tensorflow _names_in_use = new Dictionary(); } - public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, + public unsafe Operation create_op(string op_type, List inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = "", Dictionary attrs = null, OpDef op_def = null) { @@ -43,9 +43,13 @@ namespace Tensorflow name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); - var op = new Operation(node_def, this, + var op = new Operation(node_def, + this, inputs: inputs, output_types: dtypes, + control_inputs: new object[] { }, + input_types: input_types, + original_op: null, op_def: op_def); return op; @@ -73,6 +77,7 @@ namespace Tensorflow else { _names_in_use[name_key] = 1; + return name; } diff --git a/src/TensorFlowNET.Core/OpDefLibrary.cs b/src/TensorFlowNET.Core/OpDefLibrary.cs index 7d0d6e12..b927e8be 100644 --- a/src/TensorFlowNET.Core/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/OpDefLibrary.cs @@ -47,17 +47,11 @@ namespace Tensorflow } var attrs = new Dictionary(); + + // Perform input type inference var inputs = new List(); var input_types = new List(); - - foreach (var attr in op_def.Attr) - { - if (keywords.ContainsKey(attr.Name)) - { - attrs[attr.Name] = keywords[attr.Name]; - } - } - + foreach (var input_arg in op_def.InputArg) { var input_name = input_arg.Name; @@ -70,18 +64,38 @@ namespace Tensorflow { attrs[input_arg.TypeAttr] = DataType.DtFloat; } + + if (input_arg.IsRef) + { + + } + else + { + input_types.Add((keywords[input_name] as Tensor).dtype); + } } + // Process remaining attrs + foreach (var attr in op_def.Attr) + { + if (keywords.ContainsKey(attr.Name)) + { + attrs[attr.Name] = keywords[attr.Name]; + } + } + + // Convert attr values to AttrValue protos. var attr_protos = new Dictionary(); foreach (var attr_def in op_def.Attr) { var key = attr_def.Name; + var value = attrs[key]; var attr_value = new AttrValue(); switch (attr_def.Type) { case "type": - attr_value.Type = (DataType)keywords["dtype"]; + attr_value.Type = _MakeType(value, attr_def); break; case "shape": attr_value.Shape = new TensorShapeProto(); @@ -91,6 +105,7 @@ namespace Tensorflow attr_protos[key] = attr_value; } + // Determine output types (possibly using attrs) var output_types = new List(); foreach (var arg in op_def.OutputArg) @@ -105,6 +120,7 @@ namespace Tensorflow } } + // Add Op to graph var op = g.create_op(op_type_name, inputs, output_types.ToArray(), name: scope, input_types: input_types.ToArray(), @@ -113,5 +129,10 @@ namespace Tensorflow return op; } + + public DataType _MakeType(Object v, AttrDef attr_def) + { + return DataType.DtFloat; + } } } diff --git a/src/TensorFlowNET.Core/Operation.cs b/src/TensorFlowNET.Core/Operation.cs index 2d416944..d2bbca5c 100644 --- a/src/TensorFlowNET.Core/Operation.cs +++ b/src/TensorFlowNET.Core/Operation.cs @@ -8,7 +8,7 @@ namespace Tensorflow public class Operation { private Graph _graph; - private IntPtr _c_op; + public IntPtr _c_op; public int _id => _id_value; private int _id_value; public string name; @@ -27,7 +27,7 @@ namespace Tensorflow c_api.TF_FinishOperation(desc, status.Handle); } - public Operation(NodeDef node_def, Graph g, object inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) + public Operation(NodeDef node_def, Graph g, List inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) { _graph = g; @@ -38,7 +38,7 @@ namespace Tensorflow _outputs = new Tensor[num_outputs]; for (int i = 0; i < num_outputs; i++) { - _outputs[i] = new Tensor(this, i, TF_DataType.DtDouble); + _outputs[i] = new Tensor(this, i, TF_DataType.DtFloat); } _graph._add_op(this); diff --git a/src/TensorFlowNET.Core/Session.cs b/src/TensorFlowNET.Core/Session.cs index 4d0b9ba9..20c31f1f 100644 --- a/src/TensorFlowNET.Core/Session.cs +++ b/src/TensorFlowNET.Core/Session.cs @@ -6,11 +6,5 @@ namespace Tensorflow { public class Session : BaseSession { - public override byte[] run(Tensor fetches) - { - var ret = base.run(fetches); - - return ret; - } } } diff --git a/src/TensorFlowNET.Core/Tensor.cs b/src/TensorFlowNET.Core/Tensor.cs index 0d16fb41..7c01cf71 100644 --- a/src/TensorFlowNET.Core/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensor.cs @@ -6,9 +6,12 @@ namespace Tensorflow { public class Tensor { - private Operation _op; - private int _value_index; + private readonly Operation _op; + public Operation op => _op; + private readonly int _value_index; + public int value_index => _value_index; private DataType _dtype; + public DataType dtype => _dtype; public Tensor(Operation op, int value_index, DataType dtype) { @@ -16,5 +19,10 @@ namespace Tensorflow _value_index = value_index; _dtype = dtype; } + + public TF_Output _as_tf_output() + { + return c_api_util.tf_output(_op._c_op, _value_index); + } } } diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index b9834340..48490328 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -19,6 +19,14 @@ namespace Tensorflow { public const string TensorFlowLibName = "tensorflow"; + /// + /// For inputs that take a single tensor. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input); + [DllImport(TensorFlowLibName)] public static unsafe extern void TF_DeleteSessionOptions(TF_SessionOptions opts); @@ -60,11 +68,11 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern unsafe void TF_SessionRun(TF_Session session, TF_Buffer* run_options, - TF_Input[] inputs, TF_Tensor[] input_values, - int ninputs, TF_Output[] outputs, - TF_Tensor[] output_values, int noutputs, + TF_Output[] inputs, TF_Tensor[] input_values, int ninputs, + TF_Output[] outputs, TF_Tensor[] output_values, int noutputs, TF_Operation[] target_opers, int ntargets, - TF_Buffer* run_metadata, TF_Status status); + TF_Buffer* run_metadata, + TF_Status status); [DllImport(TensorFlowLibName)] public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); diff --git a/src/TensorFlowNET.Core/c_api_util.cs b/src/TensorFlowNET.Core/c_api_util.cs new file mode 100644 index 00000000..f6d54062 --- /dev/null +++ b/src/TensorFlowNET.Core/c_api_util.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class c_api_util + { + public static TF_Output tf_output(IntPtr c_op, int index) + { + var ret = new TF_Output(); + ret.oper = c_op; + ret.index = index; + + return ret; + } + } +} diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index c24ed135..da3287b8 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -16,21 +16,35 @@ namespace Tensorflow return tf.Graph(); } - public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs) + public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs) { var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); + + // Add inputs + foreach(var op_input in inputs) + { + c_api.TF_AddInput(op_desc, op_input._as_tf_output()); + } + var status = new Status(); + // Add control inputs + + // Add attrs foreach (var attr in node_def.Attr) { var bytes = attr.Value.ToByteArray(); var proto = Marshal.AllocHGlobal(bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length); c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); + + if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); } var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); + if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message); + return c_op; } diff --git a/src/TensorFlowNET.Core/ops/gen_math_ops.cs b/src/TensorFlowNET.Core/ops/gen_math_ops.cs index eece48e0..ab5172ba 100644 --- a/src/TensorFlowNET.Core/ops/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/ops/gen_math_ops.cs @@ -17,7 +17,9 @@ namespace Tensorflow var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); - return null; + var tensor = new Tensor(_op, 0, DataType.DtFloat); + + return tensor; } private static OpDefLibrary _InitOpDefLibrary() diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 40e902a0..47d7c8d1 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -28,8 +28,14 @@ namespace TensorFlowNET.UnitTest var b = tf.placeholder(tf.float32); var c = tf.add(a, b); - //sess.run(adder_node, { a: 3, b: 4.5}) - //sess.run(adder_node, {a: [1,3], b: [2, 4]}) + using(var sess = tf.Session()) + { + var feed_dict = new Dictionary(); + feed_dict.Add(a, 3); + feed_dict.Add(b, 2); + + sess.run(c, feed_dict); + } } } }