| @@ -4,7 +4,7 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class BaseSession | |||||
| public class BaseSession : IDisposable | |||||
| { | { | ||||
| private Graph _graph; | private Graph _graph; | ||||
| private bool _opened; | private bool _opened; | ||||
| @@ -32,18 +32,23 @@ namespace Tensorflow | |||||
| c_api.TF_DeleteSessionOptions(opts); | c_api.TF_DeleteSessionOptions(opts); | ||||
| } | } | ||||
| public virtual byte[] run(Tensor fetches) | |||||
| public void Dispose() | |||||
| { | { | ||||
| return _run(fetches); | |||||
| } | } | ||||
| private unsafe byte[] _run(Tensor fetches) | |||||
| public virtual byte[] run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| { | |||||
| return _run(fetches, feed_dict); | |||||
| } | |||||
| private unsafe byte[] _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| { | { | ||||
| var status = new Status(); | var status = new Status(); | ||||
| c_api.TF_SessionRun(_session, | c_api.TF_SessionRun(_session, | ||||
| run_options: null, | run_options: null, | ||||
| inputs: new TF_Input[] { }, | |||||
| inputs: new TF_Output[] { }, | |||||
| input_values: new IntPtr[] { }, | input_values: new IntPtr[] { }, | ||||
| ninputs: 1, | ninputs: 1, | ||||
| outputs: new TF_Output[] { }, | outputs: new TF_Output[] { }, | ||||
| @@ -31,7 +31,7 @@ namespace Tensorflow | |||||
| _names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
| } | } | ||||
| public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, | |||||
| public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | |||||
| TF_DataType[] input_types = null, string name = "", | TF_DataType[] input_types = null, string name = "", | ||||
| Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | ||||
| { | { | ||||
| @@ -43,9 +43,13 @@ namespace Tensorflow | |||||
| name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name); | ||||
| var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | ||||
| var op = new Operation(node_def, this, | |||||
| var op = new Operation(node_def, | |||||
| this, | |||||
| inputs: inputs, | inputs: inputs, | ||||
| output_types: dtypes, | output_types: dtypes, | ||||
| control_inputs: new object[] { }, | |||||
| input_types: input_types, | |||||
| original_op: null, | |||||
| op_def: op_def); | op_def: op_def); | ||||
| return op; | return op; | ||||
| @@ -73,6 +77,7 @@ namespace Tensorflow | |||||
| else | else | ||||
| { | { | ||||
| _names_in_use[name_key] = 1; | _names_in_use[name_key] = 1; | ||||
| return name; | |||||
| } | } | ||||
| @@ -47,17 +47,11 @@ namespace Tensorflow | |||||
| } | } | ||||
| var attrs = new Dictionary<string, object>(); | var attrs = new Dictionary<string, object>(); | ||||
| // Perform input type inference | |||||
| var inputs = new List<Tensor>(); | var inputs = new List<Tensor>(); | ||||
| var input_types = new List<DataType>(); | var input_types = new List<DataType>(); | ||||
| foreach (var attr in op_def.Attr) | |||||
| { | |||||
| if (keywords.ContainsKey(attr.Name)) | |||||
| { | |||||
| attrs[attr.Name] = keywords[attr.Name]; | |||||
| } | |||||
| } | |||||
| foreach (var input_arg in op_def.InputArg) | foreach (var input_arg in op_def.InputArg) | ||||
| { | { | ||||
| var input_name = input_arg.Name; | var input_name = input_arg.Name; | ||||
| @@ -70,18 +64,38 @@ namespace Tensorflow | |||||
| { | { | ||||
| attrs[input_arg.TypeAttr] = DataType.DtFloat; | attrs[input_arg.TypeAttr] = DataType.DtFloat; | ||||
| } | } | ||||
| if (input_arg.IsRef) | |||||
| { | |||||
| } | |||||
| else | |||||
| { | |||||
| input_types.Add((keywords[input_name] as Tensor).dtype); | |||||
| } | |||||
| } | } | ||||
| // Process remaining attrs | |||||
| foreach (var attr in op_def.Attr) | |||||
| { | |||||
| if (keywords.ContainsKey(attr.Name)) | |||||
| { | |||||
| attrs[attr.Name] = keywords[attr.Name]; | |||||
| } | |||||
| } | |||||
| // Convert attr values to AttrValue protos. | |||||
| var attr_protos = new Dictionary<string, AttrValue>(); | var attr_protos = new Dictionary<string, AttrValue>(); | ||||
| foreach (var attr_def in op_def.Attr) | foreach (var attr_def in op_def.Attr) | ||||
| { | { | ||||
| var key = attr_def.Name; | var key = attr_def.Name; | ||||
| var value = attrs[key]; | |||||
| var attr_value = new AttrValue(); | var attr_value = new AttrValue(); | ||||
| switch (attr_def.Type) | switch (attr_def.Type) | ||||
| { | { | ||||
| case "type": | case "type": | ||||
| attr_value.Type = (DataType)keywords["dtype"]; | |||||
| attr_value.Type = _MakeType(value, attr_def); | |||||
| break; | break; | ||||
| case "shape": | case "shape": | ||||
| attr_value.Shape = new TensorShapeProto(); | attr_value.Shape = new TensorShapeProto(); | ||||
| @@ -91,6 +105,7 @@ namespace Tensorflow | |||||
| attr_protos[key] = attr_value; | attr_protos[key] = attr_value; | ||||
| } | } | ||||
| // Determine output types (possibly using attrs) | |||||
| var output_types = new List<DataType>(); | var output_types = new List<DataType>(); | ||||
| foreach (var arg in op_def.OutputArg) | foreach (var arg in op_def.OutputArg) | ||||
| @@ -105,6 +120,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| // Add Op to graph | |||||
| var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | ||||
| name: scope, | name: scope, | ||||
| input_types: input_types.ToArray(), | input_types: input_types.ToArray(), | ||||
| @@ -113,5 +129,10 @@ namespace Tensorflow | |||||
| return op; | return op; | ||||
| } | } | ||||
| public DataType _MakeType(Object v, AttrDef attr_def) | |||||
| { | |||||
| return DataType.DtFloat; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -8,7 +8,7 @@ namespace Tensorflow | |||||
| public class Operation | public class Operation | ||||
| { | { | ||||
| private Graph _graph; | private Graph _graph; | ||||
| private IntPtr _c_op; | |||||
| public IntPtr _c_op; | |||||
| public int _id => _id_value; | public int _id => _id_value; | ||||
| private int _id_value; | private int _id_value; | ||||
| public string name; | public string name; | ||||
| @@ -27,7 +27,7 @@ namespace Tensorflow | |||||
| c_api.TF_FinishOperation(desc, status.Handle); | c_api.TF_FinishOperation(desc, status.Handle); | ||||
| } | } | ||||
| public Operation(NodeDef node_def, Graph g, object inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
| public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
| { | { | ||||
| _graph = g; | _graph = g; | ||||
| @@ -38,7 +38,7 @@ namespace Tensorflow | |||||
| _outputs = new Tensor[num_outputs]; | _outputs = new Tensor[num_outputs]; | ||||
| for (int i = 0; i < num_outputs; i++) | for (int i = 0; i < num_outputs; i++) | ||||
| { | { | ||||
| _outputs[i] = new Tensor(this, i, TF_DataType.DtDouble); | |||||
| _outputs[i] = new Tensor(this, i, TF_DataType.DtFloat); | |||||
| } | } | ||||
| _graph._add_op(this); | _graph._add_op(this); | ||||
| @@ -6,11 +6,5 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class Session : BaseSession | public class Session : BaseSession | ||||
| { | { | ||||
| public override byte[] run(Tensor fetches) | |||||
| { | |||||
| var ret = base.run(fetches); | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -6,9 +6,12 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class Tensor | public class Tensor | ||||
| { | { | ||||
| private Operation _op; | |||||
| private int _value_index; | |||||
| private readonly Operation _op; | |||||
| public Operation op => _op; | |||||
| private readonly int _value_index; | |||||
| public int value_index => _value_index; | |||||
| private DataType _dtype; | private DataType _dtype; | ||||
| public DataType dtype => _dtype; | |||||
| public Tensor(Operation op, int value_index, DataType dtype) | public Tensor(Operation op, int value_index, DataType dtype) | ||||
| { | { | ||||
| @@ -16,5 +19,10 @@ namespace Tensorflow | |||||
| _value_index = value_index; | _value_index = value_index; | ||||
| _dtype = dtype; | _dtype = dtype; | ||||
| } | } | ||||
| public TF_Output _as_tf_output() | |||||
| { | |||||
| return c_api_util.tf_output(_op._c_op, _value_index); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -19,6 +19,14 @@ namespace Tensorflow | |||||
| { | { | ||||
| public const string TensorFlowLibName = "tensorflow"; | public const string TensorFlowLibName = "tensorflow"; | ||||
| /// <summary> | |||||
| /// For inputs that take a single tensor. | |||||
| /// </summary> | |||||
| /// <param name="desc"></param> | |||||
| /// <param name="input"></param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern void TF_DeleteSessionOptions(TF_SessionOptions opts); | public static unsafe extern void TF_DeleteSessionOptions(TF_SessionOptions opts); | ||||
| @@ -60,11 +68,11 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_SessionRun(TF_Session session, TF_Buffer* run_options, | public static extern unsafe void TF_SessionRun(TF_Session session, TF_Buffer* run_options, | ||||
| TF_Input[] inputs, TF_Tensor[] input_values, | |||||
| int ninputs, TF_Output[] outputs, | |||||
| TF_Tensor[] output_values, int noutputs, | |||||
| TF_Output[] inputs, TF_Tensor[] input_values, int ninputs, | |||||
| TF_Output[] outputs, TF_Tensor[] output_values, int noutputs, | |||||
| TF_Operation[] target_opers, int ntargets, | TF_Operation[] target_opers, int ntargets, | ||||
| TF_Buffer* run_metadata, TF_Status status); | |||||
| TF_Buffer* run_metadata, | |||||
| TF_Status status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | ||||
| @@ -0,0 +1,18 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class c_api_util | |||||
| { | |||||
| public static TF_Output tf_output(IntPtr c_op, int index) | |||||
| { | |||||
| var ret = new TF_Output(); | |||||
| ret.oper = c_op; | |||||
| ret.index = index; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -16,21 +16,35 @@ namespace Tensorflow | |||||
| return tf.Graph(); | return tf.Graph(); | ||||
| } | } | ||||
| public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs) | |||||
| public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) | |||||
| { | { | ||||
| var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); | var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); | ||||
| // Add inputs | |||||
| foreach(var op_input in inputs) | |||||
| { | |||||
| c_api.TF_AddInput(op_desc, op_input._as_tf_output()); | |||||
| } | |||||
| var status = new Status(); | var status = new Status(); | ||||
| // Add control inputs | |||||
| // Add attrs | |||||
| foreach (var attr in node_def.Attr) | foreach (var attr in node_def.Attr) | ||||
| { | { | ||||
| var bytes = attr.Value.ToByteArray(); | var bytes = attr.Value.ToByteArray(); | ||||
| var proto = Marshal.AllocHGlobal(bytes.Length); | var proto = Marshal.AllocHGlobal(bytes.Length); | ||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | Marshal.Copy(bytes, 0, proto, bytes.Length); | ||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); | c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); | ||||
| if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); | |||||
| } | } | ||||
| var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); | var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); | ||||
| if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message); | |||||
| return c_op; | return c_op; | ||||
| } | } | ||||
| @@ -17,7 +17,9 @@ namespace Tensorflow | |||||
| var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); | var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); | ||||
| return null; | |||||
| var tensor = new Tensor(_op, 0, DataType.DtFloat); | |||||
| return tensor; | |||||
| } | } | ||||
| private static OpDefLibrary _InitOpDefLibrary() | private static OpDefLibrary _InitOpDefLibrary() | ||||
| @@ -28,8 +28,14 @@ namespace TensorFlowNET.UnitTest | |||||
| var b = tf.placeholder(tf.float32); | var b = tf.placeholder(tf.float32); | ||||
| var c = tf.add(a, b); | var c = tf.add(a, b); | ||||
| //sess.run(adder_node, { a: 3, b: 4.5}) | |||||
| //sess.run(adder_node, {a: [1,3], b: [2, 4]}) | |||||
| using(var sess = tf.Session()) | |||||
| { | |||||
| var feed_dict = new Dictionary<Tensor, object>(); | |||||
| feed_dict.Add(a, 3); | |||||
| feed_dict.Add(b, 2); | |||||
| sess.run(c, feed_dict); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||