| @@ -4,6 +4,8 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso | |||||
| TensorFlow.NET is a member project of SciSharp stack. | TensorFlow.NET is a member project of SciSharp stack. | ||||
|  | |||||
| ### How to use | ### How to use | ||||
| ```cs | ```cs | ||||
| using tf = TensorFlowNET.Core.Tensorflow; | using tf = TensorFlowNET.Core.Tensorflow; | ||||
| @@ -14,7 +16,7 @@ namespace TensorFlowNET.Examples | |||||
| { | { | ||||
| public void Run() | public void Run() | ||||
| { | { | ||||
| var hello = tf.constant("Hello, TensorFlow!"); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -14,10 +14,16 @@ namespace TensorFlowNET.Core | |||||
| public class Graph | public class Graph | ||||
| { | { | ||||
| public IntPtr handle; | public IntPtr handle; | ||||
| private Dictionary<int, Operation> _nodes_by_id; | |||||
| private Dictionary<string, Operation> _nodes_by_name; | |||||
| public int _version; | |||||
| private int _next_id_counter; | |||||
| public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
| { | { | ||||
| this.handle = graph; | this.handle = graph; | ||||
| _nodes_by_id = new Dictionary<int, Operation>(); | |||||
| _nodes_by_name = new Dictionary<string, Operation>(); | |||||
| } | } | ||||
| public unsafe Operation create_op(object inputs, string op_type = "", string name = "") | public unsafe Operation create_op(object inputs, string op_type = "", string name = "") | ||||
| @@ -28,8 +34,26 @@ namespace TensorFlowNET.Core | |||||
| } | } | ||||
| var op = new Operation(this, inputs); | var op = new Operation(this, inputs); | ||||
| op.name = name; | |||||
| return op; | return op; | ||||
| } | } | ||||
| public void _add_op(Operation op) | |||||
| { | |||||
| _nodes_by_id[op._id] = op; | |||||
| //_nodes_by_name[op.name] = op; | |||||
| _version = Math.Max(_version, op._id); | |||||
| } | |||||
| public int _next_id() | |||||
| { | |||||
| return ++_next_id_counter; | |||||
| } | |||||
| public void get_operations() | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -8,12 +8,17 @@ namespace TensorFlowNET.Core | |||||
| { | { | ||||
| private Graph _graph; | private Graph _graph; | ||||
| private IntPtr _c_op; | private IntPtr _c_op; | ||||
| public int _id => _id_value; | |||||
| private int _id_value; | |||||
| public string name; | |||||
| public Operation(Graph g, object inputs) | public Operation(Graph g, object inputs) | ||||
| { | { | ||||
| _graph = g; | _graph = g; | ||||
| _id_value = _graph._next_id(); | |||||
| _c_op = ops._create_c_op(g, inputs); | _c_op = ops._create_c_op(g, inputs); | ||||
| _graph._add_op(this); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -31,7 +31,7 @@ namespace TensorFlowNET.Core | |||||
| public static unsafe extern TF_Status TF_NewStatus(); | public static unsafe extern TF_Status TF_NewStatus(); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe TF_Tensor TF_NewTensor(TF_DataType dataType, IntPtr zeroDims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg); | |||||
| public static extern unsafe TF_Tensor TF_NewTensor(TF_DataType dataType, Int64 dims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status); | public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status); | ||||
| @@ -26,7 +26,7 @@ namespace TensorFlowNET.Core | |||||
| case double value: | case double value: | ||||
| var v = (double*)Marshal.AllocHGlobal(sizeof(double)); | var v = (double*)Marshal.AllocHGlobal(sizeof(double)); | ||||
| *v = value; | *v = value; | ||||
| tensor = c_api.TF_NewTensor(TF_DataType.TF_DOUBLE, IntPtr.Zero, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); | |||||
| tensor = c_api.TF_NewTensor(TF_DataType.TF_DOUBLE, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); | |||||
| c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.TF_DOUBLE); | c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.TF_DOUBLE); | ||||
| break; | break; | ||||
| } | } | ||||