| @@ -9,21 +9,19 @@ namespace Tensorflow | |||||
| { | { | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| public IntPtr Handle => _handle; | public IntPtr Handle => _handle; | ||||
| //public TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle); | |||||
| public unsafe Buffer() | |||||
| { | |||||
| _handle = Marshal.AllocHGlobal(sizeof(TF_Buffer)); | |||||
| } | |||||
| private TF_Buffer buffer; | |||||
| public byte[] GetBuffer() | |||||
| { | |||||
| var buffer = Marshal.PtrToStructure<TF_Buffer>(_handle); | |||||
| public byte[] Data; | |||||
| var data = Marshal.AllocHGlobal(buffer.length); | |||||
| //var bytes = c_api.TF_GetBuffer(buffer.data); | |||||
| public int Length => (int)buffer.length; | |||||
| return null; | |||||
| public unsafe Buffer(IntPtr handle) | |||||
| { | |||||
| _handle = handle; | |||||
| buffer = Marshal.PtrToStructure<TF_Buffer>(_handle); | |||||
| Data = new byte[buffer.length]; | |||||
| Marshal.Copy(buffer.data, Data, 0, (int)buffer.length); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -9,7 +9,7 @@ namespace Tensorflow | |||||
| public struct TF_Buffer | public struct TF_Buffer | ||||
| { | { | ||||
| public IntPtr data; | public IntPtr data; | ||||
| public int length; | |||||
| public ulong length; | |||||
| public IntPtr data_deallocator; | public IntPtr data_deallocator; | ||||
| } | } | ||||
| } | } | ||||
| @@ -8,6 +8,6 @@ namespace Tensorflow | |||||
| public static partial class c_api | public static partial class c_api | ||||
| { | { | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern string TF_GetBuffer(IntPtr buffer); | |||||
| public static extern IntPtr TF_GetBuffer(TF_Buffer buffer); | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,8 +15,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class Graph | public class Graph | ||||
| { | { | ||||
| private IntPtr _c_graph; | |||||
| public IntPtr Handle => _c_graph; | |||||
| private IntPtr _handle; | |||||
| private Dictionary<int, Operation> _nodes_by_id; | private Dictionary<int, Operation> _nodes_by_id; | ||||
| private Dictionary<string, Operation> _nodes_by_name; | private Dictionary<string, Operation> _nodes_by_name; | ||||
| private Dictionary<string, int> _names_in_use; | private Dictionary<string, int> _names_in_use; | ||||
| @@ -28,7 +27,7 @@ namespace Tensorflow | |||||
| public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
| { | { | ||||
| this._c_graph = graph; | |||||
| _handle = graph; | |||||
| _nodes_by_id = new Dictionary<int, Operation>(); | _nodes_by_id = new Dictionary<int, Operation>(); | ||||
| _nodes_by_name = new Dictionary<string, Operation>(); | _nodes_by_name = new Dictionary<string, Operation>(); | ||||
| _names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
| @@ -171,5 +170,10 @@ namespace Tensorflow | |||||
| { | { | ||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
| } | } | ||||
| public static implicit operator IntPtr(Graph graph) | |||||
| { | |||||
| return graph._handle; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -10,6 +10,39 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | ||||
| /// <summary> | |||||
| /// Returns the shape of the Tensor referenced by `output` in `graph` | |||||
| /// into `dims`. `dims` must be an array large enough to hold `num_dims` | |||||
| /// entries (e.g., the return value of TF_GraphGetTensorNumDims). | |||||
| /// </summary> | |||||
| /// <param name="graph"></param> | |||||
| /// <param name="output"></param> | |||||
| /// <param name="dims"></param> | |||||
| /// <param name="num_dims"></param> | |||||
| /// <param name="status"></param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); | |||||
| /// <summary> | |||||
| /// Sets the shape of the Tensor referenced by `output` in `graph` to | |||||
| /// the shape described by `dims` and `num_dims`. | |||||
| /// </summary> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); | |||||
| /// <summary> | |||||
| /// Returns the number of dimensions of the Tensor referenced by `output` | |||||
| /// in `graph`. | |||||
| /// | |||||
| /// If the number of dimensions in the shape is unknown, returns -1. | |||||
| /// </summary> | |||||
| /// <param name="graph"></param> | |||||
| /// <param name="output"></param> | |||||
| /// <param name="status"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_NewGraph(); | public static unsafe extern IntPtr TF_NewGraph(); | ||||
| } | } | ||||
| @@ -28,9 +28,6 @@ namespace Tensorflow | |||||
| { | { | ||||
| var op_def = _ops[op_type_name]; | var op_def = _ops[op_type_name]; | ||||
| var status = new Status(); | |||||
| var buffer = new Buffer(); | |||||
| var g = ops.get_default_graph(); | var g = ops.get_default_graph(); | ||||
| if (String.IsNullOrEmpty(name)) | if (String.IsNullOrEmpty(name)) | ||||
| @@ -1,12 +1,13 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using TF_DataType = Tensorflow.DataType; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class Operation | public class Operation | ||||
| { | { | ||||
| public IntPtr Handle { get; } | |||||
| private Graph _graph; | private Graph _graph; | ||||
| public Graph graph => _graph; | public Graph graph => _graph; | ||||
| public IntPtr _c_op; | public IntPtr _c_op; | ||||
| @@ -17,15 +18,20 @@ namespace Tensorflow | |||||
| public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
| public Tensor[] inputs; | public Tensor[] inputs; | ||||
| public Operation(IntPtr handle) | |||||
| { | |||||
| Handle = handle; | |||||
| } | |||||
| public Operation(Graph g, string opType, string oper_name) | public Operation(Graph g, string opType, string oper_name) | ||||
| { | { | ||||
| _graph = g; | _graph = g; | ||||
| var status = new Status(); | var status = new Status(); | ||||
| var desc = c_api.TF_NewOperation(g.Handle, opType, oper_name); | |||||
| var desc = c_api.TF_NewOperation(g, opType, oper_name); | |||||
| c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); | c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); | ||||
| c_api.TF_FinishOperation(desc, status.Handle); | |||||
| c_api.TF_FinishOperation(desc, status); | |||||
| } | } | ||||
| public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | ||||
| @@ -7,30 +7,37 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static partial class c_api | public static partial class c_api | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Get the OpList of all OpDefs defined in this address space. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static unsafe extern IntPtr TF_GetAllOpList(); | |||||
| /// <summary> | /// <summary> | ||||
| /// For inputs that take a single tensor. | /// For inputs that take a single tensor. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="desc"></param> | /// <param name="desc"></param> | ||||
| /// <param name="input"></param> | /// <param name="input"></param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input); | |||||
| public static unsafe extern void TF_AddInput(IntPtr desc, TF_Output input); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_FinishOperation(TF_OperationDescription desc, IntPtr status); | |||||
| public static unsafe extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern TF_OperationDescription TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||||
| public static unsafe extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe int TF_OperationNumOutputs(IntPtr oper); | public static extern unsafe int TF_OperationNumOutputs(IntPtr oper); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); | |||||
| public static extern unsafe void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, IntPtr value, IntPtr status); | |||||
| public static extern unsafe void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | |||||
| public static extern unsafe void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value); | |||||
| } | } | ||||
| } | } | ||||
| @@ -24,7 +24,7 @@ namespace Tensorflow | |||||
| public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) | public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs) | ||||
| { | { | ||||
| var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); | |||||
| var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name); | |||||
| // Add inputs | // Add inputs | ||||
| if(inputs != null) | if(inputs != null) | ||||
| @@ -45,12 +45,12 @@ namespace Tensorflow | |||||
| var bytes = attr.Value.ToByteArray(); | var bytes = attr.Value.ToByteArray(); | ||||
| var proto = Marshal.AllocHGlobal(bytes.Length); | var proto = Marshal.AllocHGlobal(bytes.Length); | ||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | Marshal.Copy(bytes, 0, proto, bytes.Length); | ||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); | |||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status); | |||||
| if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); | if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); | ||||
| } | } | ||||
| var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); | |||||
| var c_op = c_api.TF_FinishOperation(op_desc, status); | |||||
| if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message); | if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message); | ||||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||||
| _target = UTF8Encoding.UTF8.GetBytes(target); | _target = UTF8Encoding.UTF8.GetBytes(target); | ||||
| var opts = c_api.TF_NewSessionOptions(); | var opts = c_api.TF_NewSessionOptions(); | ||||
| var status = new Status(); | var status = new Status(); | ||||
| _session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle); | |||||
| _session = c_api.TF_NewSession(_graph, opts, status); | |||||
| c_api.TF_DeleteSessionOptions(opts); | c_api.TF_DeleteSessionOptions(opts); | ||||
| } | } | ||||
| @@ -40,30 +40,30 @@ namespace Tensorflow | |||||
| } | } | ||||
| public virtual object run(Tensor fetches, FeedDict feed_dict = null) | |||||
| public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| { | { | ||||
| var result = _run(fetches, feed_dict); | var result = _run(fetches, feed_dict); | ||||
| return result; | return result; | ||||
| } | } | ||||
| private unsafe object _run(Tensor fetches, FeedDict feed_dict = null) | |||||
| private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null) | |||||
| { | { | ||||
| var feed_dict_tensor = new FeedDict(); | |||||
| var feed_dict_tensor = new Dictionary<Tensor, object>(); | |||||
| if (feed_dict != null) | if (feed_dict != null) | ||||
| { | { | ||||
| NDArray np_val = null; | NDArray np_val = null; | ||||
| foreach (FeedValue feed in feed_dict) | |||||
| foreach (var feed in feed_dict) | |||||
| { | { | ||||
| switch (feed.feed_val) | |||||
| switch (feed.Value) | |||||
| { | { | ||||
| case float value: | case float value: | ||||
| np_val = np.asarray(value); | np_val = np.asarray(value); | ||||
| break; | break; | ||||
| } | } | ||||
| feed_dict_tensor[feed.feed] = np_val; | |||||
| feed_dict_tensor[feed.Key] = np_val; | |||||
| } | } | ||||
| } | } | ||||
| @@ -85,9 +85,9 @@ namespace Tensorflow | |||||
| return fetch_handler.build_results(null, results); | return fetch_handler.build_results(null, results); | ||||
| } | } | ||||
| private object[] _do_run(List<Tensor> fetch_list, FeedDict feed_dict) | |||||
| private object[] _do_run(List<Tensor> fetch_list, Dictionary<Tensor, object> feed_dict) | |||||
| { | { | ||||
| var feeds = feed_dict.items().Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); | |||||
| var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); | |||||
| var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | ||||
| return _call_tf_sessionrun(feeds, fetches); | return _call_tf_sessionrun(feeds, fetches); | ||||
| @@ -113,7 +113,7 @@ namespace Tensorflow | |||||
| target_opers: new IntPtr[] { }, | target_opers: new IntPtr[] { }, | ||||
| ntargets: 0, | ntargets: 0, | ||||
| run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
| status: status.Handle); | |||||
| status: status); | |||||
| var result = output_values.Select(x => c_api.TF_TensorData(x)) | var result = output_values.Select(x => c_api.TF_TensorData(x)) | ||||
| .Select(x => (object)*(float*)x) | .Select(x => (object)*(float*)x) | ||||
| @@ -1,59 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class FeedDict : IEnumerable | |||||
| { | |||||
| private Dictionary<Tensor, object> feed_dict; | |||||
| public FeedDict() | |||||
| { | |||||
| feed_dict = new Dictionary<Tensor, object>(); | |||||
| } | |||||
| public object this[Tensor feed] | |||||
| { | |||||
| get | |||||
| { | |||||
| return feed_dict[feed]; | |||||
| } | |||||
| set | |||||
| { | |||||
| feed_dict[feed] = value; | |||||
| } | |||||
| } | |||||
| public FeedDict Add(Tensor feed, object value) | |||||
| { | |||||
| feed_dict.Add(feed, value); | |||||
| return this; | |||||
| } | |||||
| public IEnumerator GetEnumerator() | |||||
| { | |||||
| foreach (KeyValuePair<Tensor, object> feed in feed_dict) | |||||
| { | |||||
| yield return new FeedValue | |||||
| { | |||||
| feed = feed.Key, | |||||
| feed_val = feed.Value | |||||
| }; | |||||
| } | |||||
| } | |||||
| public Dictionary<Tensor, object> items() | |||||
| { | |||||
| return feed_dict; | |||||
| } | |||||
| } | |||||
| public struct FeedValue | |||||
| { | |||||
| public Tensor feed { get; set; } | |||||
| public object feed_val { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||||
| private List<Tensor> _final_fetches = new List<Tensor>(); | private List<Tensor> _final_fetches = new List<Tensor>(); | ||||
| private List<object> _targets = new List<object>(); | private List<object> _targets = new List<object>(); | ||||
| public _FetchHandler(Graph graph, Tensor fetches, FeedDict feeds = null, object feed_handles = null) | |||||
| public _FetchHandler(Graph graph, Tensor fetches, Dictionary<Tensor, object> feeds = null, object feed_handles = null) | |||||
| { | { | ||||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | _fetch_mapper = new _FetchMapper().for_fetch(fetches); | ||||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
| @@ -4,10 +4,13 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class Status : IDisposable | |||||
| /// <summary> | |||||
| /// TF_Status holds error information. It either has an OK code, or | |||||
| /// else an error code with an associated error message. | |||||
| /// </summary> | |||||
| public class Status | |||||
| { | { | ||||
| private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
| public IntPtr Handle => _handle; | |||||
| /// <summary> | /// <summary> | ||||
| /// Error message | /// Error message | ||||
| @@ -29,6 +32,23 @@ namespace Tensorflow | |||||
| c_api.TF_SetStatus(_handle, code, msg); | c_api.TF_SetStatus(_handle, code, msg); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Check status | |||||
| /// Throw exception with error message if code != TF_OK | |||||
| /// </summary> | |||||
| public void Check() | |||||
| { | |||||
| if(Code != TF_Code.TF_OK) | |||||
| { | |||||
| throw new Exception(Message); | |||||
| } | |||||
| } | |||||
| public static implicit operator IntPtr(Status status) | |||||
| { | |||||
| return status._handle; | |||||
| } | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| c_api.TF_DeleteStatus(_handle); | c_api.TF_DeleteStatus(_handle); | ||||
| @@ -13,6 +13,8 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class Tensor | public class Tensor | ||||
| { | { | ||||
| public IntPtr Handle { get; } | |||||
| public Graph graph => op.graph; | public Graph graph => op.graph; | ||||
| public Operation op { get; } | public Operation op { get; } | ||||
| @@ -21,7 +23,6 @@ namespace Tensorflow | |||||
| public int value_index { get; } | public int value_index { get; } | ||||
| public TF_DataType dtype { get; } | public TF_DataType dtype { get; } | ||||
| public IntPtr handle { get; } | |||||
| public ulong bytesize { get; } | public ulong bytesize { get; } | ||||
| public ulong dataTypeSize { get;} | public ulong dataTypeSize { get;} | ||||
| public ulong size => bytesize / dataTypeSize; | public ulong size => bytesize / dataTypeSize; | ||||
| @@ -45,7 +46,7 @@ namespace Tensorflow | |||||
| public Tensor(IntPtr handle) | public Tensor(IntPtr handle) | ||||
| { | { | ||||
| this.handle = handle; | |||||
| Handle = handle; | |||||
| dtype = c_api.TF_TensorType(handle); | dtype = c_api.TF_TensorType(handle); | ||||
| rank = c_api.TF_NumDims(handle); | rank = c_api.TF_NumDims(handle); | ||||
| bytesize = c_api.TF_TensorByteSize(handle); | bytesize = c_api.TF_TensorByteSize(handle); | ||||
| @@ -59,33 +60,52 @@ namespace Tensorflow | |||||
| public Tensor(NDArray nd) | public Tensor(NDArray nd) | ||||
| { | { | ||||
| var data = Marshal.AllocHGlobal(sizeof(float) * nd.size); | |||||
| Marshal.Copy(nd.Data<float>(), 0, data, nd.size); | |||||
| var dataType = ToTFDataType(nd.dtype); | |||||
| Handle = Allocate(nd); | |||||
| dtype = c_api.TF_TensorType(Handle); | |||||
| rank = c_api.TF_NumDims(Handle); | |||||
| bytesize = c_api.TF_TensorByteSize(Handle); | |||||
| buffer = c_api.TF_TensorData(Handle); | |||||
| dataTypeSize = c_api.TF_DataTypeSize(dtype); | |||||
| shape = new long[rank]; | |||||
| for (int i = 0; i < rank; i++) | |||||
| shape[i] = c_api.TF_Dim(Handle, i); | |||||
| } | |||||
| private IntPtr Allocate(NDArray nd) | |||||
| { | |||||
| var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); | |||||
| switch (nd.dtype.Name) | |||||
| { | |||||
| case "Int32": | |||||
| Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Single": | |||||
| Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| case "Double": | |||||
| Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException("Marshal.Copy failed."); | |||||
| } | |||||
| var handle = c_api.TF_NewTensor(dataType, | |||||
| var dataType = ToTFDataType(nd.dtype); | |||||
| var tfHandle = c_api.TF_NewTensor(dataType, | |||||
| nd.shape.Select(x => (long)x).ToArray(), // shape | nd.shape.Select(x => (long)x).ToArray(), // shape | ||||
| nd.ndim, | nd.ndim, | ||||
| data, | |||||
| (UIntPtr)(nd.size * sizeof(float)), | |||||
| dotHandle, | |||||
| (UIntPtr)(nd.size * nd.dtypesize), | |||||
| (IntPtr values, IntPtr len, ref bool closure) => | (IntPtr values, IntPtr len, ref bool closure) => | ||||
| { | { | ||||
| // Free the original buffer and set flag | // Free the original buffer and set flag | ||||
| Marshal.FreeHGlobal(data); | |||||
| Marshal.FreeHGlobal(dotHandle); | |||||
| closure = true; | closure = true; | ||||
| }, | }, | ||||
| ref deallocator_called); | ref deallocator_called); | ||||
| this.handle = handle; | |||||
| dtype = c_api.TF_TensorType(handle); | |||||
| rank = c_api.TF_NumDims(handle); | |||||
| bytesize = c_api.TF_TensorByteSize(handle); | |||||
| buffer = c_api.TF_TensorData(handle); | |||||
| dataTypeSize = c_api.TF_DataTypeSize(dtype); | |||||
| shape = new long[rank]; | |||||
| for (int i = 0; i < rank; i++) | |||||
| shape[i] = c_api.TF_Dim(handle, i); | |||||
| return tfHandle; | |||||
| } | } | ||||
| public Tensor(Operation op, int value_index, TF_DataType dtype) | public Tensor(Operation op, int value_index, TF_DataType dtype) | ||||
| @@ -129,11 +149,20 @@ namespace Tensorflow | |||||
| { | { | ||||
| switch (type.Name) | switch (type.Name) | ||||
| { | { | ||||
| case "Int32": | |||||
| return TF_DataType.TF_INT32; | |||||
| case "Single": | case "Single": | ||||
| return TF_DataType.TF_FLOAT; | return TF_DataType.TF_FLOAT; | ||||
| case "Double": | |||||
| return TF_DataType.TF_DOUBLE; | |||||
| } | } | ||||
| return TF_DataType.DtInvalid; | return TF_DataType.DtInvalid; | ||||
| } | } | ||||
| public static implicit operator IntPtr(Tensor tensor) | |||||
| { | |||||
| return tensor.Handle; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -10,12 +10,22 @@ namespace Tensorflow | |||||
| /// | /// | ||||
| /// The API leans towards simplicity and uniformity instead of convenience | /// The API leans towards simplicity and uniformity instead of convenience | ||||
| /// since most usage will be by language specific wrappers. | /// since most usage will be by language specific wrappers. | ||||
| /// | |||||
| /// The params type mapping between .net and c_api | |||||
| /// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op) | |||||
| /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) | |||||
| /// struct => struct (TF_Output output) => (TF_Output output) | |||||
| /// const char* => string | |||||
| /// int32_t => int | |||||
| /// int64_t* => long[] | |||||
| /// size_t* => unlong[] | |||||
| /// void* => IntPtr | |||||
| /// </summary> | /// </summary> | ||||
| public static partial class c_api | public static partial class c_api | ||||
| { | { | ||||
| public const string TensorFlowLibName = "tensorflow"; | public const string TensorFlowLibName = "tensorflow"; | ||||
| public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocatorData); | |||||
| public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocator); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_Version(); | public static unsafe extern IntPtr TF_Version(); | ||||
| @@ -3,12 +3,21 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Buffer = Tensorflow.Buffer; | |||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| { | { | ||||
| [TestClass] | [TestClass] | ||||
| public class OperationsTest | public class OperationsTest | ||||
| { | { | ||||
| [TestMethod] | |||||
| public void GetAllOpList() | |||||
| { | |||||
| var handle = c_api.TF_GetAllOpList(); | |||||
| var buffer = new Buffer(handle); | |||||
| Assert.IsTrue(buffer.Length == buffer.Data.Length); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void addInPlaceholder() | public void addInPlaceholder() | ||||
| { | { | ||||
| @@ -18,9 +27,9 @@ namespace TensorFlowNET.UnitTest | |||||
| using(var sess = tf.Session()) | using(var sess = tf.Session()) | ||||
| { | { | ||||
| var feed_dict = new FeedDict() | |||||
| .Add(a, 3.0f) | |||||
| .Add(b, 2.0f); | |||||
| var feed_dict = new Dictionary<Tensor, object>(); | |||||
| feed_dict.Add(a, 3.0f); | |||||
| feed_dict.Add(b, 2.0f); | |||||
| var o = sess.run(c, feed_dict); | var o = sess.run(c, feed_dict); | ||||
| } | } | ||||
| @@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest | |||||
| public class TensorTest | public class TensorTest | ||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public unsafe void NewTensor() | |||||
| public void NewTensor() | |||||
| { | { | ||||
| var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | ||||
| @@ -27,5 +27,69 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float)); | Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float)); | ||||
| Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array)); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Port from tensorflow\c\c_api_test.cc | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void SetShape() | |||||
| { | |||||
| var s = new Status(); | |||||
| var graph = tf.get_default_graph(); | |||||
| var desc = c_api.TF_NewOperation(graph, "Placeholder", ""); | |||||
| c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_FLOAT); | |||||
| //if (!dims.empty()) | |||||
| { | |||||
| //TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); | |||||
| } | |||||
| var op = c_api.TF_FinishOperation(desc, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.IsNotNull(op); | |||||
| // Fetch the shape, it should be completely unknown. | |||||
| var feed_out_0 = new TF_Output { oper = op, index = 0 }; | |||||
| int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.AreEqual(-1, num_dims); | |||||
| // Set the shape to be unknown, expect no change. | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, new int[0], -1, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||||
| Assert.AreEqual(-1, num_dims); | |||||
| // Set the shape to be 2 x Unknown | |||||
| var dims = new int[] { 2, -1 }; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||||
| Assert.AreEqual(2, num_dims); | |||||
| // Get the dimension vector appropriately. | |||||
| var returned_dims = new int[dims.Length]; | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||||
| // Set to a new valid shape: [2, 3] | |||||
| dims[1] = 3; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||||
| //Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| // Fetch and see that the new value is returned. | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||||
| //Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| //Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||||
| // Test for a scalar. | |||||
| var three = c_test_util.ScalarConst(3, graph, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| var three_out_0 = new TF_Output { oper = three.Handle }; | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | |||||
| Assert.AreEqual(0, num_dims); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,37 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| namespace TensorFlowNET.UnitTest | |||||
| { | |||||
| public static class c_test_util | |||||
| { | |||||
| public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref IntPtr op) | |||||
| { | |||||
| var desc = c_api.TF_NewOperation(graph, "Const", name); | |||||
| c_api.TF_SetAttrTensor(desc, "value", t.Handle, s); | |||||
| s.Check(); | |||||
| c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||||
| op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| if(op == null) | |||||
| { | |||||
| throw new Exception("c_api.TF_FinishOperation failed."); | |||||
| } | |||||
| } | |||||
| public static Operation Const(Tensor t, Graph graph, Status s, string name) | |||||
| { | |||||
| IntPtr op = IntPtr.Zero; | |||||
| ConstHelper(t, graph, s, name, ref op); | |||||
| return new Operation(op); | |||||
| } | |||||
| public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const") | |||||
| { | |||||
| return Const(new Tensor(v), graph, s, name); | |||||
| } | |||||
| } | |||||
| } | |||||