- Ignored all unit tests related to CSession as it does not use TF.NET's api directly and unable to be tested with other tests parallely.tags/v0.12
| @@ -22,58 +22,54 @@ using System.Linq; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | |||||
| /// <summary> | |||||
| /// Represents a graph node that performs computation on tensors. | |||||
| /// | |||||
| /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or | |||||
| /// more `Tensor` objects as input, and produces zero or more `Tensor` | |||||
| /// objects as output. Objects of type `Operation` are created by | |||||
| /// calling an op constructor(such as `tf.matmul`) | |||||
| /// or `tf.Graph.create_op`. | |||||
| /// | |||||
| /// For example `c = tf.matmul(a, b)` creates an `Operation` of type | |||||
| /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` | |||||
| /// as output. | |||||
| /// | |||||
| /// After the graph has been launched in a session, an `Operation` can | |||||
| /// be executed by passing it to | |||||
| /// `tf.Session.run`. | |||||
| /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. | |||||
| { | |||||
| /// <summary> | |||||
| /// Represents a graph node that performs computation on tensors. | |||||
| /// | |||||
| /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or | |||||
| /// more `Tensor` objects as input, and produces zero or more `Tensor` | |||||
| /// objects as output. Objects of type `Operation` are created by | |||||
| /// calling an op constructor(such as `tf.matmul`) | |||||
| /// or `tf.Graph.create_op`. | |||||
| /// | |||||
| /// For example `c = tf.matmul(a, b)` creates an `Operation` of type | |||||
| /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` | |||||
| /// as output. | |||||
| /// | |||||
| /// After the graph has been launched in a session, an `Operation` can | |||||
| /// be executed by passing it to | |||||
| /// `tf.Session.run`. | |||||
| /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. | |||||
| /// </summary> | /// </summary> | ||||
| public partial class Operation : ITensorOrOperation | public partial class Operation : ITensorOrOperation | ||||
| { | { | ||||
| private readonly IntPtr _handle; // _c_op in python | private readonly IntPtr _handle; // _c_op in python | ||||
| private readonly IntPtr _operDesc; | |||||
| private readonly IntPtr _operDesc; | |||||
| private readonly Graph _graph; | |||||
| private NodeDef _node_def; | |||||
| private Graph _graph; | |||||
| public string type => OpType; | public string type => OpType; | ||||
| public Graph graph => _graph; | public Graph graph => _graph; | ||||
| public int _id => _id_value; | public int _id => _id_value; | ||||
| public int _id_value; | public int _id_value; | ||||
| public Operation op => this; | public Operation op => this; | ||||
| public TF_DataType dtype => TF_DataType.DtInvalid; | public TF_DataType dtype => TF_DataType.DtInvalid; | ||||
| public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | ||||
| public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | ||||
| public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | ||||
| private NodeDef _node_def; | |||||
| public NodeDef node_def | public NodeDef node_def | ||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| if(_node_def == null) | |||||
| if (_node_def == null) | |||||
| _node_def = GetNodeDef(); | _node_def = GetNodeDef(); | ||||
| return _node_def; | return _node_def; | ||||
| } | } | ||||
| } | } | ||||
| public Operation(IntPtr handle, Graph g=null) | |||||
| public Operation(IntPtr handle, Graph g = null) | |||||
| { | { | ||||
| if (handle == IntPtr.Zero) | if (handle == IntPtr.Zero) | ||||
| return; | return; | ||||
| @@ -97,14 +93,15 @@ namespace Tensorflow | |||||
| _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | ||||
| c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | ||||
| using (var status = new Status()) | |||||
| { | |||||
| _handle = c_api.TF_FinishOperation(_operDesc, status); | |||||
| status.Check(true); | |||||
| } | |||||
| // Dict mapping op name to file and line information for op colocation | |||||
| // context managers. | |||||
| lock (Locks.ProcessWide) | |||||
| using (var status = new Status()) | |||||
| { | |||||
| _handle = c_api.TF_FinishOperation(_operDesc, status); | |||||
| status.Check(true); | |||||
| } | |||||
| // Dict mapping op name to file and line information for op colocation | |||||
| // context managers. | |||||
| _control_flow_context = graph._get_control_flow_context(); | _control_flow_context = graph._get_control_flow_context(); | ||||
| } | } | ||||
| @@ -133,9 +130,9 @@ namespace Tensorflow | |||||
| // Build the list of control inputs. | // Build the list of control inputs. | ||||
| var control_input_ops = new List<Operation>(); | var control_input_ops = new List<Operation>(); | ||||
| if(control_inputs != null) | |||||
| if (control_inputs != null) | |||||
| { | { | ||||
| foreach(var c in control_inputs) | |||||
| foreach (var c in control_inputs) | |||||
| { | { | ||||
| switch (c) | switch (c) | ||||
| { | { | ||||
| @@ -196,15 +193,13 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | ||||
| { | { | ||||
| input_len = (int)attrs[input_arg.NumberAttr].I; | |||||
| input_len = (int) attrs[input_arg.NumberAttr].I; | |||||
| is_sequence = true; | is_sequence = true; | ||||
| } | |||||
| else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||||
| } else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||||
| { | { | ||||
| input_len = attrs[input_arg.TypeListAttr].List.Type.Count; | input_len = attrs[input_arg.TypeListAttr].List.Type.Count; | ||||
| is_sequence = true; | is_sequence = true; | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| input_len = 1; | input_len = 1; | ||||
| is_sequence = false; | is_sequence = false; | ||||
| @@ -225,22 +220,21 @@ namespace Tensorflow | |||||
| { | { | ||||
| AttrValue x = null; | AttrValue x = null; | ||||
| using (var status = new Status()) | |||||
| using (var buf = new Buffer()) | |||||
| { | |||||
| unsafe | |||||
| lock (Locks.ProcessWide) | |||||
| using (var status = new Status()) | |||||
| using (var buf = new Buffer()) | |||||
| { | { | ||||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | ||||
| } | } | ||||
| } | |||||
| string oneof_value = x.ValueCase.ToString(); | string oneof_value = x.ValueCase.ToString(); | ||||
| if (string.IsNullOrEmpty(oneof_value)) | if (string.IsNullOrEmpty(oneof_value)) | ||||
| return null; | return null; | ||||
| if(oneof_value == "list") | |||||
| if (oneof_value == "list") | |||||
| throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | ||||
| if (oneof_value == "type") | if (oneof_value == "type") | ||||
| @@ -259,60 +253,63 @@ namespace Tensorflow | |||||
| private NodeDef GetNodeDef() | private NodeDef GetNodeDef() | ||||
| { | { | ||||
| using (var s = new Status()) | |||||
| using (var buffer = new Buffer()) | |||||
| { | |||||
| c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||||
| s.Check(); | |||||
| return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Update the input to this operation at the given index. | |||||
| /// | |||||
| /// NOTE: This is for TF internal use only.Please don't use it. | |||||
| /// </summary> | |||||
| /// <param name="index">the index of the input to update.</param> | |||||
| /// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||||
| public void _update_input(int index, Tensor tensor) | |||||
| { | |||||
| _assert_same_graph(tensor); | |||||
| var input = _tf_input(index); | |||||
| var output = tensor._as_tf_output(); | |||||
| // Reset cached inputs. | |||||
| _inputs = null; | |||||
| // after the c_api call next time _inputs is accessed | |||||
| // the updated inputs are reloaded from the c_api | |||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.UpdateEdge(_graph, output, input, status); | |||||
| //var updated_inputs = inputs; | |||||
| status.Check(); | |||||
| } | |||||
| } | |||||
| private void _assert_same_graph(Tensor tensor) | |||||
| { | |||||
| //TODO: implement | |||||
| } | |||||
| /// <summary> | |||||
| /// Create and return a new TF_Output for output_idx'th output of this op. | |||||
| /// </summary> | |||||
| public TF_Output _tf_output(int output_idx) | |||||
| { | |||||
| return new TF_Output(op, output_idx); | |||||
| } | |||||
| /// <summary> | |||||
| /// Create and return a new TF_Input for input_idx'th input of this op. | |||||
| /// </summary> | |||||
| public TF_Input _tf_input(int input_idx) | |||||
| { | |||||
| return new TF_Input(op, input_idx); | |||||
| } | |||||
| } | |||||
| } | |||||
| lock (Locks.ProcessWide) | |||||
| using (var s = new Status()) | |||||
| using (var buffer = new Buffer()) | |||||
| { | |||||
| c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||||
| s.Check(); | |||||
| return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Update the input to this operation at the given index. | |||||
| /// | |||||
| /// NOTE: This is for TF internal use only.Please don't use it. | |||||
| /// </summary> | |||||
| /// <param name="index">the index of the input to update.</param> | |||||
| /// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||||
| public void _update_input(int index, Tensor tensor) | |||||
| { | |||||
| _assert_same_graph(tensor); | |||||
| var input = _tf_input(index); | |||||
| var output = tensor._as_tf_output(); | |||||
| // Reset cached inputs. | |||||
| _inputs = null; | |||||
| // after the c_api call next time _inputs is accessed | |||||
| // the updated inputs are reloaded from the c_api | |||||
| lock (Locks.ProcessWide) | |||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.UpdateEdge(_graph, output, input, status); | |||||
| //var updated_inputs = inputs; | |||||
| status.Check(); | |||||
| } | |||||
| } | |||||
| private void _assert_same_graph(Tensor tensor) | |||||
| { | |||||
| //TODO: implement | |||||
| } | |||||
| /// <summary> | |||||
| /// Create and return a new TF_Output for output_idx'th output of this op. | |||||
| /// </summary> | |||||
| public TF_Output _tf_output(int output_idx) | |||||
| { | |||||
| return new TF_Output(op, output_idx); | |||||
| } | |||||
| /// <summary> | |||||
| /// Create and return a new TF_Input for input_idx'th input of this op. | |||||
| /// </summary> | |||||
| public TF_Input _tf_input(int input_idx) | |||||
| { | |||||
| return new TF_Input(op, input_idx); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -36,23 +36,20 @@ namespace Tensorflow | |||||
| protected byte[] _target; | protected byte[] _target; | ||||
| public Graph graph => _graph; | public Graph graph => _graph; | ||||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||||
| public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null) | |||||
| { | { | ||||
| _graph = g ?? ops.get_default_graph(); | _graph = g ?? ops.get_default_graph(); | ||||
| _graph.as_default(); | _graph.as_default(); | ||||
| _target = Encoding.UTF8.GetBytes(target); | _target = Encoding.UTF8.GetBytes(target); | ||||
| SessionOptions newOpts = opts ?? new SessionOptions(); | |||||
| SessionOptions lopts = opts ?? new SessionOptions(); | |||||
| var status = new Status(); | |||||
| _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||||
| // dispose opts only if not provided externally. | |||||
| if (opts == null) | |||||
| newOpts.Dispose(); | |||||
| status.Check(true); | |||||
| lock (Locks.ProcessWide) | |||||
| { | |||||
| status = status ?? new Status(); | |||||
| _handle = c_api.TF_NewSession(_graph, opts ?? lopts, status); | |||||
| status.Check(true); | |||||
| } | |||||
| } | } | ||||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | public virtual void run(Operation op, params FeedItem[] feed_dict) | ||||
| @@ -72,19 +69,19 @@ namespace Tensorflow | |||||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | ||||
| { | { | ||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||||
| var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4}, feed_dict); | |||||
| return (results[0], results[1], results[2], results[3]); | return (results[0], results[1], results[2], results[3]); | ||||
| } | } | ||||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | ||||
| { | { | ||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||||
| var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3}, feed_dict); | |||||
| return (results[0], results[1], results[2]); | return (results[0], results[1], results[2]); | ||||
| } | } | ||||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | ||||
| { | { | ||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||||
| var results = _run(new object[] {fetches.Item1, fetches.Item2}, feed_dict); | |||||
| return (results[0], results[1]); | return (results[0], results[1]); | ||||
| } | } | ||||
| @@ -95,8 +92,7 @@ namespace Tensorflow | |||||
| public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | ||||
| { | { | ||||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
| var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
| return _run(fetches, feed_items); | return _run(fetches, feed_items); | ||||
| } | } | ||||
| @@ -130,7 +126,7 @@ namespace Tensorflow | |||||
| // We only want to really perform the run if fetches or targets are provided, | // We only want to really perform the run if fetches or targets are provided, | ||||
| // or if the call is a partial run that specifies feeds. | // or if the call is a partial run that specifies feeds. | ||||
| var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||||
| var results = _do_run(final_targets.Select(x => (Operation) x).ToList(), final_fetches, feed_dict_tensor); | |||||
| return fetch_handler.build_results(this, results); | return fetch_handler.build_results(this, results); | ||||
| } | } | ||||
| @@ -150,7 +146,6 @@ namespace Tensorflow | |||||
| /// </returns> | /// </returns> | ||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | ||||
| { | { | ||||
| var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | ||||
| int i = 0; | int i = 0; | ||||
| foreach (var x in feed_dict) | foreach (var x in feed_dict) | ||||
| @@ -159,16 +154,25 @@ namespace Tensorflow | |||||
| { | { | ||||
| switch (x.Value) | switch (x.Value) | ||||
| { | { | ||||
| case Tensor v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break; | |||||
| case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break; | |||||
| case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case Tensor v: | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||||
| break; | |||||
| case NDArray v: | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||||
| break; | |||||
| case IntPtr v: | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| break; | |||||
| #if _REGEN | #if _REGEN | ||||
| // @formatter:off — disable formatter after this line | |||||
| %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | ||||
| %foreach types% | %foreach types% | ||||
| case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
| case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
| % | % | ||||
| // @formatter:on — enable formatter after this line | |||||
| #else | #else | ||||
| // @formatter:off — disable formatter after this line | |||||
| case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
| case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
| case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
| @@ -191,9 +195,14 @@ namespace Tensorflow | |||||
| case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
| case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
| case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
| // @formatter:on — enable formatter after this line | |||||
| #endif | #endif | ||||
| case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break; | |||||
| case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
| case bool v: | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); | |||||
| break; | |||||
| case string v: | |||||
| feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
| break; | |||||
| default: | default: | ||||
| throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | ||||
| } | } | ||||
| @@ -217,12 +226,12 @@ namespace Tensorflow | |||||
| c_api.TF_SessionRun(_handle, | c_api.TF_SessionRun(_handle, | ||||
| run_options: null, | run_options: null, | ||||
| inputs: feed_dict.Select(f => f.Key).ToArray(), | inputs: feed_dict.Select(f => f.Key).ToArray(), | ||||
| input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||||
| input_values: feed_dict.Select(f => (IntPtr) f.Value).ToArray(), | |||||
| ninputs: feed_dict.Length, | ninputs: feed_dict.Length, | ||||
| outputs: fetch_list, | outputs: fetch_list, | ||||
| output_values: output_values, | output_values: output_values, | ||||
| noutputs: fetch_list.Length, | noutputs: fetch_list.Length, | ||||
| target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||||
| target_opers: target_list.Select(f => (IntPtr) f).ToArray(), | |||||
| ntargets: target_list.Count, | ntargets: target_list.Count, | ||||
| run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
| status: status); | status: status); | ||||
| @@ -253,7 +262,7 @@ namespace Tensorflow | |||||
| ret = NDArray.Scalar(*(bool*) srcAddress); | ret = NDArray.Scalar(*(bool*) srcAddress); | ||||
| break; | break; | ||||
| case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
| using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) | |||||
| using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize))) | |||||
| ret = NDArray.FromString(reader.ReadString()); | ret = NDArray.FromString(reader.ReadString()); | ||||
| break; | break; | ||||
| case TF_DataType.TF_UINT8: | case TF_DataType.TF_UINT8: | ||||
| @@ -318,81 +327,95 @@ namespace Tensorflow | |||||
| #endregion | #endregion | ||||
| #else | #else | ||||
| #region Compute | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| { | |||||
| #region Compute | |||||
| switch (tensor.dtype) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Boolean, ndims, false); | ret = new NDArray(NPTypeCode.Boolean, ndims, false); | ||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT8: | |||||
| { | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT8: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Byte, ndims, false); | ret = new NDArray(NPTypeCode.Byte, ndims, false); | ||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_INT16: | |||||
| { | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_INT16: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Int16, ndims, false); | ret = new NDArray(NPTypeCode.Int16, ndims, false); | ||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT16: | |||||
| { | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT16: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.UInt16, ndims, false); | ret = new NDArray(NPTypeCode.UInt16, ndims, false); | ||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_INT32: | |||||
| { | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_INT32: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Int32, ndims, false); | ret = new NDArray(NPTypeCode.Int32, ndims, false); | ||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT32: | |||||
| { | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT32: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.UInt32, ndims, false); | ret = new NDArray(NPTypeCode.UInt32, ndims, false); | ||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_INT64: | |||||
| { | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_INT64: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Int64, ndims, false); | ret = new NDArray(NPTypeCode.Int64, ndims, false); | ||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT64: | |||||
| { | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_UINT64: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.UInt64, ndims, false); | ret = new NDArray(NPTypeCode.UInt64, ndims, false); | ||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| { | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Double, ndims, false); | ret = new NDArray(NPTypeCode.Double, ndims, false); | ||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_FLOAT: | |||||
| { | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_FLOAT: | |||||
| { | |||||
| ret = new NDArray(NPTypeCode.Single, ndims, false); | ret = new NDArray(NPTypeCode.Single, ndims, false); | ||||
| System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
| break; | |||||
| } | |||||
| break; | |||||
| } | |||||
| case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString | //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString | ||||
| using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) | |||||
| using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize))) | |||||
| ret = NDArray.FromString(reader.ReadString()); | ret = NDArray.FromString(reader.ReadString()); | ||||
| break; | break; | ||||
| } | } | ||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | #endif | ||||
| } | } | ||||
| } | } | ||||
| @@ -411,9 +434,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| private void _extend_graph() | private void _extend_graph() | ||||
| { | |||||
| } | |||||
| { } | |||||
| public void close() | public void close() | ||||
| { | { | ||||
| @@ -422,11 +443,12 @@ namespace Tensorflow | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
| { | { | ||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.TF_DeleteSession(handle, status); | |||||
| status.Check(true); | |||||
| } | |||||
| lock (Locks.ProcessWide) | |||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.TF_DeleteSession(handle, status); | |||||
| status.Check(true); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -21,24 +21,16 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class Session : BaseSession, IObjectLife | public class Session : BaseSession, IObjectLife | ||||
| { | { | ||||
| public Session(string target = "", Graph g = null) | |||||
| : base(target, g, null) | |||||
| { | |||||
| } | |||||
| public Session(string target = "", Graph g = null) : base(target, g, null) | |||||
| { } | |||||
| public Session(IntPtr handle, Graph g = null) | |||||
| : base("", g, null) | |||||
| public Session(IntPtr handle, Graph g = null) : base("", g, null) | |||||
| { | { | ||||
| _handle = handle; | _handle = handle; | ||||
| } | } | ||||
| public Session(Graph g, SessionOptions opts = null, Status s = null) | |||||
| : base("", g, opts) | |||||
| { | |||||
| if (s == null) | |||||
| s = new Status(); | |||||
| } | |||||
| public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s) | |||||
| { } | |||||
| public Session as_default() | public Session as_default() | ||||
| { | { | ||||
| @@ -21,6 +21,7 @@ using Google.Protobuf; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Threading; | using System.Threading; | ||||
| using NumSharp; | using NumSharp; | ||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -207,47 +208,49 @@ namespace Tensorflow | |||||
| /// <returns>A wrapped TF_Operation*.</returns> | /// <returns>A wrapped TF_Operation*.</returns> | ||||
| public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | ||||
| { | { | ||||
| var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||||
| //TODO: Implement TF_SetDevice | |||||
| //if node_def.device: | |||||
| // c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) | |||||
| // Add inputs | |||||
| foreach (var op_input in inputs) | |||||
| lock (Locks.ProcessWide) | |||||
| { | { | ||||
| if (op_input is Tensor[] op_inputs) | |||||
| c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||||
| else if (op_input is Tensor op_input1) | |||||
| var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||||
| //TODO: Implement TF_SetDevice | |||||
| //if node_def.device: | |||||
| // c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) | |||||
| // Add inputs | |||||
| foreach (var op_input in inputs) | |||||
| { | { | ||||
| c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||||
| if (op_input is Tensor[] op_inputs) | |||||
| c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||||
| else if (op_input is Tensor op_input1) | |||||
| { | |||||
| c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||||
| } else | |||||
| throw new NotImplementedException("_create_c_op"); | |||||
| } | } | ||||
| else | |||||
| throw new NotImplementedException("_create_c_op"); | |||||
| } | |||||
| var status = new Status(); | |||||
| var status = new Status(); | |||||
| // Add control inputs | |||||
| foreach (var control_input in control_inputs) | |||||
| c_api.TF_AddControlInput(op_desc, control_input); | |||||
| // Add control inputs | |||||
| foreach (var control_input in control_inputs) | |||||
| c_api.TF_AddControlInput(op_desc, control_input); | |||||
| // Add attrs | |||||
| foreach (var attr in node_def.Attr) | |||||
| { | |||||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||||
| uint len = (uint)bytes.Length; | |||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||||
| // Add attrs | |||||
| foreach (var attr in node_def.Attr) | |||||
| { | |||||
| var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
| var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||||
| uint len = (uint) bytes.Length; | |||||
| c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||||
| status.Check(true); | |||||
| } | |||||
| status.Check(true); | |||||
| } | |||||
| var c_op = c_api.TF_FinishOperation(op_desc, status); | |||||
| var c_op = c_api.TF_FinishOperation(op_desc, status); | |||||
| status.Check(true); | |||||
| status.Check(true); | |||||
| return (c_op, op_desc); | |||||
| return (c_op, op_desc); | |||||
| } | |||||
| } | } | ||||
| public static OpDef _get_op_def(Graph graph, string type) | public static OpDef _get_op_def(Graph graph, string type) | ||||
| @@ -11,7 +11,7 @@ namespace TensorFlowNET.UnitTest | |||||
| /// tensorflow\c\c_api_test.cc | /// tensorflow\c\c_api_test.cc | ||||
| /// `class CApiGradientsTest` | /// `class CApiGradientsTest` | ||||
| /// </summary> | /// </summary> | ||||
| [TestClass] | |||||
| [TestClass, Ignore] | |||||
| public class CApiGradientsTest : CApiTest, IDisposable | public class CApiGradientsTest : CApiTest, IDisposable | ||||
| { | { | ||||
| private Graph graph_ = new Graph(); | private Graph graph_ = new Graph(); | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Util; | |||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| { | { | ||||
| @@ -22,9 +23,12 @@ namespace TensorFlowNET.UnitTest | |||||
| public CSession(Graph graph, Status s, bool user_XLA = false) | public CSession(Graph graph, Status s, bool user_XLA = false) | ||||
| { | { | ||||
| var opts = new SessionOptions(); | |||||
| opts.SetConfig(new ConfigProto { InterOpParallelismThreads = 4 }); | |||||
| session_ = new Session(graph, opts, s); | |||||
| lock (Locks.ProcessWide) | |||||
| { | |||||
| var opts = new SessionOptions(); | |||||
| opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4}); | |||||
| session_ = new Session(graph, opts, s); | |||||
| } | |||||
| } | } | ||||
| public void SetInputs(Dictionary<Operation, Tensor> inputs) | public void SetInputs(Dictionary<Operation, Tensor> inputs) | ||||
| @@ -64,13 +68,13 @@ namespace TensorFlowNET.UnitTest | |||||
| public unsafe void Run(Status s) | public unsafe void Run(Status s) | ||||
| { | { | ||||
| var inputs_ptr = inputs_.ToArray(); | var inputs_ptr = inputs_.ToArray(); | ||||
| var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); | |||||
| var input_values_ptr = input_values_.Select(x => (IntPtr) x).ToArray(); | |||||
| var outputs_ptr = outputs_.ToArray(); | var outputs_ptr = outputs_.ToArray(); | ||||
| var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray(); | var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray(); | ||||
| IntPtr[] targets_ptr = new IntPtr[0]; | IntPtr[] targets_ptr = new IntPtr[0]; | ||||
| c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | ||||
| outputs_ptr, output_values_ptr, outputs_.Count, | |||||
| outputs_ptr, output_values_ptr, outputs_.Count, | |||||
| targets_ptr, targets_.Count, | targets_ptr, targets_.Count, | ||||
| IntPtr.Zero, s); | IntPtr.Zero, s); | ||||
| @@ -90,4 +94,4 @@ namespace TensorFlowNET.UnitTest | |||||
| ResetOutputValues(); | ResetOutputValues(); | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -207,7 +207,7 @@ namespace TensorFlowNET.UnitTest | |||||
| public void ImportGraphDef() | public void ImportGraphDef() | ||||
| { | { | ||||
| var s = new Status(); | var s = new Status(); | ||||
| var graph = new Graph(); | |||||
| var graph = new Graph().as_default(); | |||||
| // Create a simple graph. | // Create a simple graph. | ||||
| c_test_util.Placeholder(graph, s); | c_test_util.Placeholder(graph, s); | ||||
| @@ -221,7 +221,7 @@ namespace TensorFlowNET.UnitTest | |||||
| // Import it, with a prefix, in a fresh graph. | // Import it, with a prefix, in a fresh graph. | ||||
| graph.Dispose(); | graph.Dispose(); | ||||
| graph = new Graph(); | |||||
| graph = new Graph().as_default(); | |||||
| var opts = c_api.TF_NewImportGraphDefOptions(); | var opts = c_api.TF_NewImportGraphDefOptions(); | ||||
| c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | ||||
| c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | ||||
| @@ -359,7 +359,7 @@ namespace TensorFlowNET.UnitTest | |||||
| public void ImportGraphDef_WithReturnOutputs() | public void ImportGraphDef_WithReturnOutputs() | ||||
| { | { | ||||
| var s = new Status(); | var s = new Status(); | ||||
| var graph = new Graph(); | |||||
| var graph = new Graph().as_default(); | |||||
| // Create a graph with two nodes: x and 3 | // Create a graph with two nodes: x and 3 | ||||
| c_test_util.Placeholder(graph, s); | c_test_util.Placeholder(graph, s); | ||||
| @@ -375,7 +375,7 @@ namespace TensorFlowNET.UnitTest | |||||
| // Import it in a fresh graph with return outputs. | // Import it in a fresh graph with return outputs. | ||||
| graph.Dispose(); | graph.Dispose(); | ||||
| graph = new Graph(); | |||||
| graph = new Graph().as_default(); | |||||
| var opts = new ImportGraphDefOptions(); | var opts = new ImportGraphDefOptions(); | ||||
| opts.AddReturnOutput("feed", 0); | opts.AddReturnOutput("feed", 0); | ||||
| opts.AddReturnOutput("scalar", 0); | opts.AddReturnOutput("scalar", 0); | ||||
| @@ -4,6 +4,7 @@ using System.Runtime.InteropServices; | |||||
| using FluentAssertions; | using FluentAssertions; | ||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| @@ -14,7 +15,7 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void SessionCreation() | public void SessionCreation() | ||||
| { | { | ||||
| tf.Session(); //create one to increase next id to 1. | |||||
| ops.uid(); //increment id by one | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | MultiThreadedUnitTestExecuter.Run(8, Core); | ||||
| @@ -23,6 +24,28 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| tf.peak_default_graph().Should().BeNull(); | tf.peak_default_graph().Should().BeNull(); | ||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var default_graph = tf.peak_default_graph(); | |||||
| var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||||
| sess_graph.Should().NotBeNull(); | |||||
| default_graph.Should().NotBeNull() | |||||
| .And.BeEquivalentTo(sess_graph); | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void SessionCreation_x2() | |||||
| { | |||||
| ops.uid(); //increment id by one | |||||
| MultiThreadedUnitTestExecuter.Run(16, Core); | |||||
| //the core method | |||||
| void Core(int tid) | |||||
| { | |||||
| tf.peak_default_graph().Should().BeNull(); | |||||
| //tf.Session created an other graph | //tf.Session created an other graph | ||||
| using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
| { | { | ||||
| @@ -38,7 +61,7 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void GraphCreation() | public void GraphCreation() | ||||
| { | { | ||||
| tf.Graph(); //create one to increase next id to 1. | |||||
| ops.uid(); //increment id by one | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | MultiThreadedUnitTestExecuter.Run(8, Core); | ||||
| @@ -47,7 +70,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| tf.peak_default_graph().Should().BeNull(); | tf.peak_default_graph().Should().BeNull(); | ||||
| var beforehand = tf.get_default_graph(); //this should create default automatically. | var beforehand = tf.get_default_graph(); //this should create default automatically. | ||||
| beforehand.graph_key.Should().NotContain("0", "Already created a graph in an other thread."); | |||||
| beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread."); | |||||
| tf.peak_default_graph().Should().NotBeNull(); | tf.peak_default_graph().Should().NotBeNull(); | ||||
| using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
| @@ -67,5 +90,174 @@ namespace TensorFlowNET.UnitTest | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Marshal_AllocHGlobal() | |||||
| { | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| //the core method | |||||
| void Core(int tid) | |||||
| { | |||||
| for (int i = 0; i < 100; i++) | |||||
| { | |||||
| Marshal.FreeHGlobal(Marshal.AllocHGlobal(sizeof(int))); | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void TensorCreation() | |||||
| { | |||||
| //lock (Locks.ProcessWide) | |||||
| // tf.Session(); //create one to increase next id to 1. | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| //the core method | |||||
| void Core(int tid) | |||||
| { | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| Tensor t = null; | |||||
| for (int i = 0; i < 100; i++) | |||||
| { | |||||
| t = new Tensor(1); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void TensorCreation_Array() | |||||
| { | |||||
| //lock (Locks.ProcessWide) | |||||
| // tf.Session(); //create one to increase next id to 1. | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| //the core method | |||||
| void Core(int tid) | |||||
| { | |||||
| //tf.Session created an other graph | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| Tensor t = null; | |||||
| for (int i = 0; i < 100; i++) | |||||
| { | |||||
| t = new Tensor(new int[] {1, 2, 3}); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void TensorCreation_Undressed() | |||||
| { | |||||
| //lock (Locks.ProcessWide) | |||||
| // tf.Session(); //create one to increase next id to 1. | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| //the core method | |||||
| unsafe void Core(int tid) | |||||
| { | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| Tensor t = null; | |||||
| for (int i = 0; i < 100; i++) | |||||
| { | |||||
| var v = (int*) Marshal.AllocHGlobal(sizeof(int)); | |||||
| c_api.DeallocatorArgs _deallocatorArgs = new c_api.DeallocatorArgs(); | |||||
| var handle = c_api.TF_NewTensor(typeof(int).as_dtype(), dims: new long[0], num_dims: 0, | |||||
| data: (IntPtr) v, len: (UIntPtr) sizeof(int), | |||||
| deallocator: (IntPtr data, IntPtr size, ref c_api.DeallocatorArgs args) => Marshal.FreeHGlobal(data), | |||||
| ref _deallocatorArgs); | |||||
| c_api.TF_DeleteTensor(handle); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void SessionRun() | |||||
| { | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| //the core method | |||||
| void Core(int tid) | |||||
| { | |||||
| tf.peak_default_graph().Should().BeNull(); | |||||
| //graph is created automatically to perform create these operations | |||||
| var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||||
| var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||||
| var math = a1 + a2; | |||||
| for (int i = 0; i < 100; i++) | |||||
| { | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| sess.run(math).GetAtIndex<float>(0).Should().Be(5); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void SessionRun_InsideSession() | |||||
| { | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| //the core method | |||||
| void Core(int tid) | |||||
| { | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| tf.peak_default_graph().Should().NotBeNull(); | |||||
| //graph is created automatically to perform create these operations | |||||
| var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||||
| var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||||
| var math = a1 + a2; | |||||
| var result = sess.run(math); | |||||
| result[0].GetAtIndex<float>(0).Should().Be(5); | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void SessionRun_Initialization() | |||||
| { | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| //the core method | |||||
| void Core(int tid) | |||||
| { | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| tf.peak_default_graph().Should().NotBeNull(); | |||||
| //graph is created automatically to perform create these operations | |||||
| var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||||
| var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||||
| var math = a1 + a2; | |||||
| } | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void SessionRun_Initialization_OutsideSession() | |||||
| { | |||||
| MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
| //the core method | |||||
| void Core(int tid) | |||||
| { | |||||
| tf.peak_default_graph().Should().BeNull(); | |||||
| //graph is created automatically to perform create these operations | |||||
| var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||||
| var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||||
| var math = a1 + a2; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -8,6 +8,7 @@ using System.Text; | |||||
| using FluentAssertions; | using FluentAssertions; | ||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Util; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| @@ -19,13 +20,13 @@ namespace TensorFlowNET.UnitTest | |||||
| /// tensorflow\c\c_api_test.cc | /// tensorflow\c\c_api_test.cc | ||||
| /// `TEST(CAPI, Session)` | /// `TEST(CAPI, Session)` | ||||
| /// </summary> | /// </summary> | ||||
| [TestMethod] | |||||
| [TestMethod, Ignore] | |||||
| public void Session() | public void Session() | ||||
| { | { | ||||
| lock (this) | |||||
| lock (Locks.ProcessWide) | |||||
| { | { | ||||
| var s = new Status(); | var s = new Status(); | ||||
| var graph = new Graph(); | |||||
| var graph = new Graph().as_default(); | |||||
| // Make a placeholder operation. | // Make a placeholder operation. | ||||
| var feed = c_test_util.Placeholder(graph, s); | var feed = c_test_util.Placeholder(graph, s); | ||||
| @@ -117,7 +117,7 @@ namespace TensorFlowNET.UnitTest | |||||
| public void SetShape() | public void SetShape() | ||||
| { | { | ||||
| var s = new Status(); | var s = new Status(); | ||||
| var graph = new Graph(); | |||||
| var graph = new Graph().as_default(); | |||||
| var feed = c_test_util.Placeholder(graph, s); | var feed = c_test_util.Placeholder(graph, s); | ||||
| var feed_out_0 = new TF_Output(feed, 0); | var feed_out_0 = new TF_Output(feed, 0); | ||||
| @@ -12,42 +12,51 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") | public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") | ||||
| { | { | ||||
| var desc = c_api.TF_NewOperation(graph, "AddN", name); | |||||
| var inputs = new TF_Output[] | |||||
| lock (Locks.ProcessWide) | |||||
| { | { | ||||
| new TF_Output(l, 0), | |||||
| new TF_Output(r, 0), | |||||
| }; | |||||
| var desc = c_api.TF_NewOperation(graph, "AddN", name); | |||||
| c_api.TF_AddInputList(desc, inputs, inputs.Length); | |||||
| var inputs = new TF_Output[] | |||||
| { | |||||
| new TF_Output(l, 0), | |||||
| new TF_Output(r, 0), | |||||
| }; | |||||
| var op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| c_api.TF_AddInputList(desc, inputs, inputs.Length); | |||||
| return op; | |||||
| var op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| return op; | |||||
| } | |||||
| } | } | ||||
| [SuppressMessage("ReSharper", "RedundantAssignment")] | [SuppressMessage("ReSharper", "RedundantAssignment")] | ||||
| public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | ||||
| { | { | ||||
| using (var buffer = new Buffer()) | |||||
| lock (Locks.ProcessWide) | |||||
| { | { | ||||
| c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||||
| attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | |||||
| using (var buffer = new Buffer()) | |||||
| { | |||||
| c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||||
| attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | |||||
| return s.Code == TF_Code.TF_OK; | |||||
| return s.Code == TF_Code.TF_OK; | |||||
| } | |||||
| } | } | ||||
| public static GraphDef GetGraphDef(Graph graph) | public static GraphDef GetGraphDef(Graph graph) | ||||
| { | { | ||||
| using (var s = new Status()) | |||||
| using (var buffer = new Buffer()) | |||||
| lock (Locks.ProcessWide) | |||||
| { | { | ||||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||||
| s.Check(); | |||||
| return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| using (var s = new Status()) | |||||
| using (var buffer = new Buffer()) | |||||
| { | |||||
| c_api.TF_GraphToGraphDef(graph, buffer, s); | |||||
| s.Check(); | |||||
| return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -58,6 +67,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| return false; | return false; | ||||
| } | } | ||||
| bool found_t = false; | bool found_t = false; | ||||
| bool found_n = false; | bool found_n = false; | ||||
| foreach (var attr in node_def.Attr) | foreach (var attr in node_def.Attr) | ||||
| @@ -67,19 +77,16 @@ namespace TensorFlowNET.UnitTest | |||||
| if (attr.Value.Type == DataType.DtInt32) | if (attr.Value.Type == DataType.DtInt32) | ||||
| { | { | ||||
| found_t = true; | found_t = true; | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | |||||
| else if (attr.Key == "N") | |||||
| } else if (attr.Key == "N") | |||||
| { | { | ||||
| if (attr.Value.I == n) | if (attr.Value.I == n) | ||||
| { | { | ||||
| found_n = true; | found_n = true; | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -92,7 +99,7 @@ namespace TensorFlowNET.UnitTest | |||||
| public static bool IsNeg(NodeDef node_def, string input) | public static bool IsNeg(NodeDef node_def, string input) | ||||
| { | { | ||||
| return node_def.Op == "Neg" && node_def.Name == "neg" && | return node_def.Op == "Neg" && node_def.Name == "neg" && | ||||
| node_def.Input.Count == 1 && node_def.Input[0] == input; | |||||
| node_def.Input.Count == 1 && node_def.Input[0] == input; | |||||
| } | } | ||||
| public static bool IsPlaceholder(NodeDef node_def) | public static bool IsPlaceholder(NodeDef node_def) | ||||
| @@ -111,13 +118,11 @@ namespace TensorFlowNET.UnitTest | |||||
| if (attr.Value.Type == DataType.DtInt32) | if (attr.Value.Type == DataType.DtInt32) | ||||
| { | { | ||||
| found_dtype = true; | found_dtype = true; | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | |||||
| else if (attr.Key == "shape") | |||||
| } else if (attr.Key == "shape") | |||||
| { | { | ||||
| found_shape = true; | found_shape = true; | ||||
| } | } | ||||
| @@ -132,72 +137,82 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| return false; | return false; | ||||
| } | } | ||||
| bool found_dtype = false; | bool found_dtype = false; | ||||
| bool found_value = false; | bool found_value = false; | ||||
| foreach (var attr in node_def.Attr) { | |||||
| foreach (var attr in node_def.Attr) | |||||
| { | |||||
| if (attr.Key == "dtype") | if (attr.Key == "dtype") | ||||
| { | { | ||||
| if (attr.Value.Type == DataType.DtInt32) | if (attr.Value.Type == DataType.DtInt32) | ||||
| { | { | ||||
| found_dtype = true; | found_dtype = true; | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | |||||
| else if (attr.Key == "value") | |||||
| } else if (attr.Key == "value") | |||||
| { | { | ||||
| if (attr.Value.Tensor != null && | if (attr.Value.Tensor != null && | ||||
| attr.Value.Tensor.IntVal.Count == 1 && | attr.Value.Tensor.IntVal.Count == 1 && | ||||
| attr.Value.Tensor.IntVal[0] == v) | attr.Value.Tensor.IntVal[0] == v) | ||||
| { | { | ||||
| found_value = true; | found_value = true; | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return found_dtype && found_value; | return found_dtype && found_value; | ||||
| } | } | ||||
| public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") | public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") | ||||
| { | { | ||||
| OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | |||||
| var neg_input = new TF_Output(n, 0); | |||||
| c_api.TF_AddInput(desc, neg_input); | |||||
| var op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| lock (Locks.ProcessWide) | |||||
| { | |||||
| OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | |||||
| var neg_input = new TF_Output(n, 0); | |||||
| c_api.TF_AddInput(desc, neg_input); | |||||
| var op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| return op; | |||||
| return op; | |||||
| } | |||||
| } | } | ||||
| public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) | public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) | ||||
| { | { | ||||
| var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||||
| c_api.TF_SetAttrType(desc, "dtype", dtype); | |||||
| if (dims != null) | |||||
| lock (Locks.ProcessWide) | |||||
| { | { | ||||
| c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | |||||
| } | |||||
| var op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||||
| c_api.TF_SetAttrType(desc, "dtype", dtype); | |||||
| if (dims != null) | |||||
| { | |||||
| c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | |||||
| } | |||||
| var op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| return op; | |||||
| return op; | |||||
| } | |||||
| } | } | ||||
| public static Operation Const(Tensor t, Graph graph, Status s, string name) | public static Operation Const(Tensor t, Graph graph, Status s, string name) | ||||
| { | { | ||||
| var desc = c_api.TF_NewOperation(graph, "Const", name); | |||||
| c_api.TF_SetAttrTensor(desc, "value", t, s); | |||||
| s.Check(); | |||||
| c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||||
| var op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| return op; | |||||
| lock (Locks.ProcessWide) | |||||
| { | |||||
| var desc = c_api.TF_NewOperation(graph, "Const", name); | |||||
| c_api.TF_SetAttrTensor(desc, "value", t, s); | |||||
| s.Check(); | |||||
| c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||||
| var op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| return op; | |||||
| } | |||||
| } | } | ||||
| public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") | public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") | ||||
| @@ -205,4 +220,4 @@ namespace TensorFlowNET.UnitTest | |||||
| return Const(new Tensor(v), graph, s, name); | return Const(new Tensor(v), graph, s, name); | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||