| @@ -0,0 +1,40 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace TensorFlowNET.Core | |||||
| { | |||||
| public class BaseSession | |||||
| { | |||||
| private Graph _graph; | |||||
| private bool _opened; | |||||
| private bool _closed; | |||||
| private int _current_version; | |||||
| private byte[] _target; | |||||
| private IntPtr _session; | |||||
| public BaseSession(string target = "", Graph graph = null) | |||||
| { | |||||
| if(graph is null) | |||||
| { | |||||
| _graph = ops.get_default_graph(); | |||||
| } | |||||
| else | |||||
| { | |||||
| _graph = graph; | |||||
| } | |||||
| _target = UTF8Encoding.UTF8.GetBytes(target); | |||||
| var opts = c_api.TF_NewSessionOptions(); | |||||
| var status = new Status(); | |||||
| _session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle); | |||||
| c_api.TF_DeleteSessionOptions(opts); | |||||
| } | |||||
| public virtual byte[] run(Tensor fetches) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -16,7 +16,8 @@ namespace TensorFlowNET.Core | |||||
| /// </summary> | /// </summary> | ||||
| public class Graph | public class Graph | ||||
| { | { | ||||
| public IntPtr handle; | |||||
| private IntPtr _c_graph; | |||||
| public IntPtr Handle => _c_graph; | |||||
| private Dictionary<int, Operation> _nodes_by_id; | private Dictionary<int, Operation> _nodes_by_id; | ||||
| private Dictionary<string, Operation> _nodes_by_name; | private Dictionary<string, Operation> _nodes_by_name; | ||||
| private Dictionary<string, int> _names_in_use; | private Dictionary<string, int> _names_in_use; | ||||
| @@ -25,7 +26,7 @@ namespace TensorFlowNET.Core | |||||
| public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
| { | { | ||||
| this.handle = graph; | |||||
| this._c_graph = graph; | |||||
| _nodes_by_id = new Dictionary<int, Operation>(); | _nodes_by_id = new Dictionary<int, Operation>(); | ||||
| _nodes_by_name = new Dictionary<string, Operation>(); | _nodes_by_name = new Dictionary<string, Operation>(); | ||||
| _names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
| @@ -4,7 +4,13 @@ using System.Text; | |||||
| namespace TensorFlowNET.Core | namespace TensorFlowNET.Core | ||||
| { | { | ||||
| public class Session | |||||
| public class Session : BaseSession | |||||
| { | { | ||||
| public override byte[] run(Tensor fetches) | |||||
| { | |||||
| var ret = base.run(fetches); | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -9,6 +9,8 @@ using TF_OperationDescription = System.IntPtr; | |||||
| using TF_Operation = System.IntPtr; | using TF_Operation = System.IntPtr; | ||||
| using TF_Status = System.IntPtr; | using TF_Status = System.IntPtr; | ||||
| using TF_Tensor = System.IntPtr; | using TF_Tensor = System.IntPtr; | ||||
| using TF_Session = System.IntPtr; | |||||
| using TF_SessionOptions = System.IntPtr; | |||||
| using TF_DataType = Tensorflow.DataType; | using TF_DataType = Tensorflow.DataType; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| @@ -20,6 +22,9 @@ namespace TensorFlowNET.Core | |||||
| { | { | ||||
| public const string TensorFlowLibName = "tensorflow"; | public const string TensorFlowLibName = "tensorflow"; | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static unsafe extern void TF_DeleteSessionOptions(TF_SessionOptions opts); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status); | public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status); | ||||
| @@ -53,6 +58,12 @@ namespace TensorFlowNET.Core | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern TF_SessionOptions TF_NewSessionOptions(); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_Version(); | public static unsafe extern IntPtr TF_Version(); | ||||
| } | } | ||||
| @@ -20,7 +20,7 @@ namespace TensorFlowNET.Core | |||||
| public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs) | public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs) | ||||
| { | { | ||||
| var op_desc = c_api.TF_NewOperation(graph.handle, node_def.Op, node_def.Name); | |||||
| var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); | |||||
| var status = new Status(); | var status = new Status(); | ||||
| foreach (var attr in node_def.Attr) | foreach (var attr in node_def.Attr) | ||||
| @@ -22,6 +22,9 @@ namespace TensorFlowNET.Examples | |||||
| // Start tf session | // Start tf session | ||||
| var sess = tf.Session(); | var sess = tf.Session(); | ||||
| // Run the op | |||||
| sess.run(hello); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||