diff --git a/TensorFlow.NET.sln.DotSettings b/TensorFlow.NET.sln.DotSettings new file mode 100644 index 00000000..aba8725c --- /dev/null +++ b/TensorFlow.NET.sln.DotSettings @@ -0,0 +1,2 @@ + + True \ No newline at end of file diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index adf0b86f..56672173 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -54,6 +54,15 @@ namespace Tensorflow public struct DeallocatorArgs { + internal static unsafe c_api.DeallocatorArgs* EmptyPtr; + internal static unsafe IntPtr Empty; + + static unsafe DeallocatorArgs() + { + Empty = new IntPtr(EmptyPtr = (DeallocatorArgs*) Marshal.AllocHGlobal(Marshal.SizeOf())); + *EmptyPtr = new DeallocatorArgs() {gc_handle = IntPtr.Zero, deallocator_called = false}; + } + public bool deallocator_called; public IntPtr gc_handle; } diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs index cee941ed..1648cb70 100644 --- a/src/TensorFlowNET.Core/APIs/tf.graph.cs +++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs @@ -29,7 +29,19 @@ namespace Tensorflow return ops.get_default_graph(); } - public Graph Graph() + /// + /// Equivalent to but does not create a new graph if it there is none. + /// + public Graph peak_default_graph() + { + return ops.default_graph_stack.peak_controller(); + } + + /// + /// Creates a new graph. + /// + ///Has no interaction with graph defaulting. Equivalent to new Graph(); + public Graph Graph() => new Graph(); } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 3b6d0eea..5aa0d044 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -61,7 +61,7 @@ namespace Tensorflow string grad_scope = scope; // Get a uid for this call to gradients that can be used to help // cluster ops for compilation. - var gradient_uid = ops.get_default_graph().unique_name("uid"); + var gradient_uid = curr_graph.unique_name("uid"); ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); @@ -80,7 +80,7 @@ namespace Tensorflow var to_ops = ys.Select(x => x.op).ToList(); var from_ops = xs.Select(x => x.op).ToList(); var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); - (var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); + var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs); foreach (var (y, grad_y) in zip(ys, grad_ys)) _SetGrad(grads, y, grad_y); @@ -168,7 +168,7 @@ namespace Tensorflow { if (in_grad != null) { - if (in_grad is Tensor && + if (!(in_grad is null) && in_grad.Tag == null && // maybe a IndexedSlice t_in.dtype != TF_DataType.TF_RESOURCE) { diff --git a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs index 66419b3e..3dc77859 100644 --- a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs +++ b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs @@ -21,11 +21,10 @@ using static Tensorflow.Binding; namespace Tensorflow { - /// /// Serves as a stack for determining current default graph. /// - public class DefaultGraphStack + public class DefaultGraphStack { private readonly List _stack = new List(); @@ -40,7 +39,7 @@ namespace Tensorflow public Graph get_controller() { - if (_stack.Count(x => x.IsDefault) == 0) + if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0) _stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true}); for (var i = _stack.Count - 1; i >= 0; i--) { @@ -52,6 +51,20 @@ namespace Tensorflow throw new TensorflowException("Unable to find a default graph"); } + public Graph peak_controller() + { + if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0) + return null; + for (var i = _stack.Count - 1; i >= 0; i--) + { + var x = _stack[i]; + if (x.IsDefault) + return x.Graph; + } + + return null; + } + public bool remove(Graph g) { if (_stack.Count == 0) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index 0e28dd9a..3cda8074 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -54,19 +54,21 @@ namespace Tensorflow var handle = return_oper_handle.node + tf_op_size * i; return_opers[i] = new Operation(*(IntPtr*)handle); } - } - + } + return return_opers; } public Operation OperationByName(string operName) { var handle = c_api.TF_GraphOperationByName(_handle, operName); - if(graph_key != tf.get_default_graph().graph_key) - { - Console.WriteLine($"Current graph is not default graph."); - // throw new ValueError($"Current graph is not default graph."); + var defaultKey = tf.get_default_graph().graph_key; + if (graph_key != defaultKey) + { + //Console.WriteLine($"Current graph is not default graph."); + throw new ValueError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}"); } + return new Operation(handle, g: this); } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 5fff9ade..856e3677 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -22,58 +22,54 @@ using System.Linq; using Tensorflow.Util; namespace Tensorflow -{ - - /// - /// Represents a graph node that performs computation on tensors. - /// - /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or - /// more `Tensor` objects as input, and produces zero or more `Tensor` - /// objects as output. Objects of type `Operation` are created by - /// calling an op constructor(such as `tf.matmul`) - /// or `tf.Graph.create_op`. - /// - /// For example `c = tf.matmul(a, b)` creates an `Operation` of type - /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` - /// as output. - /// - /// After the graph has been launched in a session, an `Operation` can - /// be executed by passing it to - /// `tf.Session.run`. - /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. +{ + /// + /// Represents a graph node that performs computation on tensors. + /// + /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or + /// more `Tensor` objects as input, and produces zero or more `Tensor` + /// objects as output. Objects of type `Operation` are created by + /// calling an op constructor(such as `tf.matmul`) + /// or `tf.Graph.create_op`. + /// + /// For example `c = tf.matmul(a, b)` creates an `Operation` of type + /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` + /// as output. + /// + /// After the graph has been launched in a session, an `Operation` can + /// be executed by passing it to + /// `tf.Session.run`. + /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. /// public partial class Operation : ITensorOrOperation { private readonly IntPtr _handle; // _c_op in python - private readonly IntPtr _operDesc; + private readonly IntPtr _operDesc; + private readonly Graph _graph; + private NodeDef _node_def; - private Graph _graph; public string type => OpType; - public Graph graph => _graph; public int _id => _id_value; public int _id_value; public Operation op => this; - public TF_DataType dtype => TF_DataType.DtInvalid; - public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); - private NodeDef _node_def; public NodeDef node_def { get { - if(_node_def == null) + if (_node_def == null) _node_def = GetNodeDef(); return _node_def; } } - public Operation(IntPtr handle, Graph g=null) + public Operation(IntPtr handle, Graph g = null) { if (handle == IntPtr.Zero) return; @@ -97,14 +93,15 @@ namespace Tensorflow _operDesc = c_api.TF_NewOperation(g, opType, oper_name); c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); - using (var status = new Status()) - { - _handle = c_api.TF_FinishOperation(_operDesc, status); - status.Check(true); - } - - // Dict mapping op name to file and line information for op colocation - // context managers. + lock (Locks.ProcessWide) + using (var status = new Status()) + { + _handle = c_api.TF_FinishOperation(_operDesc, status); + status.Check(true); + } + + // Dict mapping op name to file and line information for op colocation + // context managers. _control_flow_context = graph._get_control_flow_context(); } @@ -133,9 +130,9 @@ namespace Tensorflow // Build the list of control inputs. var control_input_ops = new List(); - if(control_inputs != null) + if (control_inputs != null) { - foreach(var c in control_inputs) + foreach (var c in control_inputs) { switch (c) { @@ -196,15 +193,13 @@ namespace Tensorflow { if (!string.IsNullOrEmpty(input_arg.NumberAttr)) { - input_len = (int)attrs[input_arg.NumberAttr].I; + input_len = (int) attrs[input_arg.NumberAttr].I; is_sequence = true; - } - else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) + } else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) { input_len = attrs[input_arg.TypeListAttr].List.Type.Count; is_sequence = true; - } - else + } else { input_len = 1; is_sequence = false; @@ -225,22 +220,21 @@ namespace Tensorflow { AttrValue x = null; - using (var status = new Status()) - using (var buf = new Buffer()) - { - unsafe + lock (Locks.ProcessWide) + using (var status = new Status()) + using (var buf = new Buffer()) { c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); status.Check(true); + x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); } - } string oneof_value = x.ValueCase.ToString(); if (string.IsNullOrEmpty(oneof_value)) return null; - if(oneof_value == "list") + if (oneof_value == "list") throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); if (oneof_value == "type") @@ -259,60 +253,63 @@ namespace Tensorflow private NodeDef GetNodeDef() { - using (var s = new Status()) - using (var buffer = new Buffer()) - { - c_api.TF_OperationToNodeDef(_handle, buffer, s); - s.Check(); - return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); - } - } - - /// - /// Update the input to this operation at the given index. - /// - /// NOTE: This is for TF internal use only.Please don't use it. - /// - /// the index of the input to update. - /// the Tensor to be used as the input at the given index. - public void _update_input(int index, Tensor tensor) - { - _assert_same_graph(tensor); - - var input = _tf_input(index); - var output = tensor._as_tf_output(); - - // Reset cached inputs. - _inputs = null; - // after the c_api call next time _inputs is accessed - // the updated inputs are reloaded from the c_api - using (var status = new Status()) - { - c_api.UpdateEdge(_graph, output, input, status); - //var updated_inputs = inputs; - status.Check(); - } - } - - private void _assert_same_graph(Tensor tensor) - { - //TODO: implement - } - - /// - /// Create and return a new TF_Output for output_idx'th output of this op. - /// - public TF_Output _tf_output(int output_idx) - { - return new TF_Output(op, output_idx); - } - - /// - /// Create and return a new TF_Input for input_idx'th input of this op. - /// - public TF_Input _tf_input(int input_idx) - { - return new TF_Input(op, input_idx); - } - } -} + lock (Locks.ProcessWide) + using (var s = new Status()) + using (var buffer = new Buffer()) + { + c_api.TF_OperationToNodeDef(_handle, buffer, s); + s.Check(); + + return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } + } + + /// + /// Update the input to this operation at the given index. + /// + /// NOTE: This is for TF internal use only.Please don't use it. + /// + /// the index of the input to update. + /// the Tensor to be used as the input at the given index. + public void _update_input(int index, Tensor tensor) + { + _assert_same_graph(tensor); + + var input = _tf_input(index); + var output = tensor._as_tf_output(); + + // Reset cached inputs. + _inputs = null; + // after the c_api call next time _inputs is accessed + // the updated inputs are reloaded from the c_api + lock (Locks.ProcessWide) + using (var status = new Status()) + { + c_api.UpdateEdge(_graph, output, input, status); + //var updated_inputs = inputs; + status.Check(); + } + } + + private void _assert_same_graph(Tensor tensor) + { + //TODO: implement + } + + /// + /// Create and return a new TF_Output for output_idx'th output of this op. + /// + public TF_Output _tf_output(int output_idx) + { + return new TF_Output(op, output_idx); + } + + /// + /// Create and return a new TF_Input for input_idx'th input of this op. + /// + public TF_Input _tf_input(int input_idx) + { + return new TF_Input(op, input_idx); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index e47002ef..6e91be02 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -7730,7 +7730,7 @@ namespace Tensorflow.Operations /// /// /// RFC 4180 format is expected for the CSV records. - /// (https://tools.ietf.org/html/rfc4180) + /// (https://tools.ietensorflow.org/html/rfc4180) /// Note that we allow leading and trailing spaces with int or float field. /// public static Tensor[] decode_c_s_v (Tensor records, Tensor[] record_defaults, string field_delim = null, bool? use_quote_delim = null, string na_value = null, int[] select_cols = null, string name = "DecodeCSV") diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index c3368120..4066c1df 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -36,23 +36,20 @@ namespace Tensorflow protected byte[] _target; public Graph graph => _graph; - public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) + public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null) { - _graph = g is null ? ops.get_default_graph() : g; + _graph = g ?? ops.get_default_graph(); _graph.as_default(); - _target = UTF8Encoding.UTF8.GetBytes(target); + _target = Encoding.UTF8.GetBytes(target); - SessionOptions newOpts = opts ?? new SessionOptions(); + SessionOptions lopts = opts ?? new SessionOptions(); - var status = new Status(); - - _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); - - // dispose opts only if not provided externally. - if (opts == null) - newOpts.Dispose(); - - status.Check(true); + lock (Locks.ProcessWide) + { + status = status ?? new Status(); + _handle = c_api.TF_NewSession(_graph, opts ?? lopts, status); + status.Check(true); + } } public virtual void run(Operation op, params FeedItem[] feed_dict) @@ -72,19 +69,19 @@ namespace Tensorflow public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) { - var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); + var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4}, feed_dict); return (results[0], results[1], results[2], results[3]); } public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) { - var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); + var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3}, feed_dict); return (results[0], results[1], results[2]); } public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) { - var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); + var results = _run(new object[] {fetches.Item1, fetches.Item2}, feed_dict); return (results[0], results[1]); } @@ -95,8 +92,7 @@ namespace Tensorflow public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) { - var feed_items = feed_dict == null ? new FeedItem[0] : - feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); + var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); return _run(fetches, feed_items); } @@ -130,7 +126,7 @@ namespace Tensorflow // We only want to really perform the run if fetches or targets are provided, // or if the call is a partial run that specifies feeds. - var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); + var results = _do_run(final_targets.Select(x => (Operation) x).ToList(), final_fetches, feed_dict_tensor); return fetch_handler.build_results(this, results); } @@ -150,9 +146,7 @@ namespace Tensorflow /// private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) { - var feeds = new KeyValuePair[feed_dict.Count]; - var ignoreDispose = new bool[feed_dict.Count]; int i = 0; foreach (var x in feed_dict) { @@ -160,15 +154,25 @@ namespace Tensorflow { switch (x.Value) { - case Tensor v: ignoreDispose[i] = true; feeds[i++] = new KeyValuePair(tensor._as_tf_output(), v); break; - case NDArray v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break; + case Tensor v: + feeds[i++] = new KeyValuePair(tensor._as_tf_output(), v); + break; + case NDArray v: + feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); + break; + case IntPtr v: + feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + break; #if _REGEN + // @formatter:off — disable formatter after this line %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] %foreach types% case #1 v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; case #1[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; % + // @formatter:on — enable formatter after this line #else + // @formatter:off — disable formatter after this line case sbyte v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; case sbyte[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; case byte v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; @@ -191,10 +195,14 @@ namespace Tensorflow case double[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; case Complex v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; case Complex[] v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + // @formatter:on — enable formatter after this line #endif - case bool v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break; - case string v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; - case IntPtr v: feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); break; + case bool v: + feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); + break; + case string v: + feeds[i++] = new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + break; default: throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? ""}"); } @@ -203,18 +211,7 @@ namespace Tensorflow var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); //var targets = target_list; - try - { - return _call_tf_sessionrun(feeds, fetches, target_list); - } finally - { - for (var idx = 0; idx < feeds.Length; idx++) - { - if (ignoreDispose[idx]) - continue; - feeds[idx].Value.Dispose(); - } - } + return _call_tf_sessionrun(feeds, fetches, target_list); } private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) @@ -229,12 +226,12 @@ namespace Tensorflow c_api.TF_SessionRun(_handle, run_options: null, inputs: feed_dict.Select(f => f.Key).ToArray(), - input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), + input_values: feed_dict.Select(f => (IntPtr) f.Value).ToArray(), ninputs: feed_dict.Length, outputs: fetch_list, output_values: output_values, noutputs: fetch_list.Length, - target_opers: target_list.Select(f => (IntPtr)f).ToArray(), + target_opers: target_list.Select(f => (IntPtr) f).ToArray(), ntargets: target_list.Count, run_metadata: IntPtr.Zero, status: status); @@ -265,7 +262,7 @@ namespace Tensorflow ret = NDArray.Scalar(*(bool*) srcAddress); break; case TF_DataType.TF_STRING: - using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) + using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize))) ret = NDArray.FromString(reader.ReadString()); break; case TF_DataType.TF_UINT8: @@ -330,81 +327,95 @@ namespace Tensorflow #endregion #else - #region Compute - switch (tensor.dtype) - { - case TF_DataType.TF_BOOL: - { + #region Compute + + switch (tensor.dtype) + { + case TF_DataType.TF_BOOL: + { ret = new NDArray(NPTypeCode.Boolean, ndims, false); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - case TF_DataType.TF_UINT8: - { + break; + } + + case TF_DataType.TF_UINT8: + { ret = new NDArray(NPTypeCode.Byte, ndims, false); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - case TF_DataType.TF_INT16: - { + break; + } + + case TF_DataType.TF_INT16: + { ret = new NDArray(NPTypeCode.Int16, ndims, false); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - case TF_DataType.TF_UINT16: - { + break; + } + + case TF_DataType.TF_UINT16: + { ret = new NDArray(NPTypeCode.UInt16, ndims, false); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - case TF_DataType.TF_INT32: - { + break; + } + + case TF_DataType.TF_INT32: + { ret = new NDArray(NPTypeCode.Int32, ndims, false); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - case TF_DataType.TF_UINT32: - { + break; + } + + case TF_DataType.TF_UINT32: + { ret = new NDArray(NPTypeCode.UInt32, ndims, false); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - case TF_DataType.TF_INT64: - { + break; + } + + case TF_DataType.TF_INT64: + { ret = new NDArray(NPTypeCode.Int64, ndims, false); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - case TF_DataType.TF_UINT64: - { + break; + } + + case TF_DataType.TF_UINT64: + { ret = new NDArray(NPTypeCode.UInt64, ndims, false); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - case TF_DataType.TF_DOUBLE: - { + break; + } + + case TF_DataType.TF_DOUBLE: + { ret = new NDArray(NPTypeCode.Double, ndims, false); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - case TF_DataType.TF_FLOAT: - { + break; + } + + case TF_DataType.TF_FLOAT: + { ret = new NDArray(NPTypeCode.Single, ndims, false); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } + break; + } + case TF_DataType.TF_STRING: { throw new NotImplementedException(); //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString - using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) + using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize))) ret = NDArray.FromString(reader.ReadString()); break; } - default: - throw new NotSupportedException(); - } - #endregion + + default: + throw new NotSupportedException(); + } + + #endregion + #endif } } @@ -423,9 +434,7 @@ namespace Tensorflow } private void _extend_graph() - { - - } + { } public void close() { @@ -434,11 +443,12 @@ namespace Tensorflow protected override void DisposeUnmanagedResources(IntPtr handle) { - using (var status = new Status()) - { - c_api.TF_DeleteSession(handle, status); - status.Check(true); - } + lock (Locks.ProcessWide) + using (var status = new Status()) + { + c_api.TF_DeleteSession(handle, status); + status.Check(true); + } } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index ec2e443f..90c66afa 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -21,28 +21,20 @@ namespace Tensorflow { public class Session : BaseSession, IObjectLife { - public Session(string target = "", Graph g = null) - : base(target, g, null) - { - - } + public Session(string target = "", Graph g = null) : base(target, g, null) + { } - public Session(IntPtr handle, Graph g = null) - : base("", g, null) + public Session(IntPtr handle, Graph g = null) : base("", g, null) { _handle = handle; } - public Session(Graph g, SessionOptions opts = null, Status s = null) - : base("", g, opts) - { - if (s == null) - s = new Status(); - } + public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s) + { } public Session as_default() { - tf.defaultSession = this; + tf._defaultSessionFactory.Value = this; return this; } diff --git a/src/TensorFlowNET.Core/Tensors/AllocationType.cs b/src/TensorFlowNET.Core/Tensors/AllocationType.cs new file mode 100644 index 00000000..9f5c8bad --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/AllocationType.cs @@ -0,0 +1,27 @@ +namespace Tensorflow +{ + /// + /// Used internally to + /// + public enum AllocationType + { + None = 0, + /// + /// Allocation was done by passing in a pointer, might be also holding reference to a C# object. + /// + FromPointer = 1, + /// + /// Allocation was done by calling c_api.TF_AllocateTensor or TF decided it has to copy data during c_api.TF_NewTensor.

+ /// Deallocation is handled solely by Tensorflow. + ///
+ Tensorflow = 2, + /// + /// Allocation was done by Marshal.AllocateHGlobal + /// + Marshal = 3, + /// + /// Allocation was done by GCHandle.Alloc + /// + GCHandle = 4, + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 5fa60eff..34edcb4f 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -28,42 +28,37 @@ using static Tensorflow.c_api; namespace Tensorflow { + [SuppressMessage("ReSharper", "InvokeAsExtensionMethod")] public partial class Tensor { /// - /// true if unmanaged buffer has been freed. + /// When Tensor was created from an object that is managed by C#'s GC - this will hold reference to prevent it from being collected. /// - private bool _deallocator_called => _deallocatorArgs.deallocator_called; + protected object AllocationReferenceHolder; /// - /// true if the Tensor was created from a managed array + /// The handle that was used to allocate this tensor, dependent on . /// - private bool _isPinnedArray => _deallocatorArgs.gc_handle != IntPtr.Zero; + protected object AllocationHandle; /// - /// True only if the Tensor object was created in a way that the Tensor object itself allocated memory or pinned a managed object. - /// False if the Tensor was created from a pointer + /// True if this Tensor holds data allocated by C#. /// - public bool IsMemoryOwner { get; private set; } + public bool IsMemoryOwner => AllocationType >= AllocationType.Marshal; /// - /// This holds values that are used by the unmanaged deallocator callback + /// The allocation method used to create this Tensor. /// - private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; - - // note: they must be assigned to a static variable in order to work as unmanaged callbacks - private static readonly Deallocator _hGlobalDeallocator = FreeHGlobalMemory; - private static readonly Deallocator _gcHandleDeallocator = FreeGCHandle; - private static readonly Deallocator _nothingDeallocator = FreeNothing; + public AllocationType AllocationType { get; protected set; } /// - /// Create a Tensor object from an existing TF handle + /// Create a Tensor object from an existing TF handle /// - /// + /// Handle to a object. public Tensor(IntPtr handle) { _handle = handle; - IsMemoryOwner = false; + //no need to set AllocationType = AllocationType.None; } /// @@ -71,430 +66,412 @@ namespace Tensorflow /// Note: the caller is responsible for freeing the memory. Calling Dispose on this object will dispose the TensorFlow tensor /// but not the memory itself! /// - /// Pointer to unmanaged, fixed or pinned memory which the caller owns + /// Pointer to unmanaged, fixed or pinned memory which the caller owns /// Tensor shape /// TF data type /// Size of the tensor in memory - public Tensor(IntPtr ptr, long[] shape, TF_DataType dType, int num_bytes) + public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes) { - _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); - IsMemoryOwner = false; + unsafe + { + _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (UIntPtr) num_bytes); + AllocationType = TF_TensorData(_handle) == data_ptr ? AllocationType.FromPointer : AllocationType.Tensorflow; + } + } + + /// + /// Create a new Tensor from the given unmanaged memory pointer (which must be allocated, fixed or pinned by the caller) + /// Note: the caller is responsible for freeing the memory. Calling Dispose on this object will dispose the TensorFlow tensor + /// but not the memory itself! + /// + /// Pointer to unmanaged, fixed or pinned memory which the caller owns + /// Tensor shape + /// TF data type + /// Size of the tensor in memory + public unsafe Tensor(void* data_ptr, long[] shape, TF_DataType dType, int num_bytes) + { + _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (UIntPtr) num_bytes); + AllocationType = TF_TensorData(_handle).ToPointer() == data_ptr ? AllocationType.FromPointer : AllocationType.Tensorflow; } #if _REGEN - %types=["sbyte", "bool", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] + %types = ["sbyte", "bool", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] %foreach types% /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(#1[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(#1)), new long[]{data.Length}, data, Marshal.SizeOf<#1>()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(#1)), new long[] {data.Length}, data, #(#1=="Complex"|"Marshal.SizeOf()"|"sizeof(#(str(#1)))")); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(#1[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(#1)), shape, data, Marshal.SizeOf<#1>()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(#1)), shape, data, #(#1=="Complex"|"Marshal.SizeOf()"|"sizeof(#(str(#1)))")); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(#1 value, TF_DataType? dType = null) { - var v = (#1*)Marshal.AllocHGlobal(sizeof(#1)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(#1)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(#1), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(#1)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(#1)); + *(#1*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } % #else - - /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(sbyte[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(sbyte)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(sbyte)), new long[] {data.Length}, data, sizeof(sbyte)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(sbyte[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(sbyte)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(sbyte)), shape, data, sizeof(sbyte)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(sbyte value, TF_DataType? dType = null) { - var v = (sbyte*)Marshal.AllocHGlobal(sizeof(sbyte)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(sbyte)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(sbyte), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(sbyte)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(sbyte)); + *(sbyte*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(bool[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(bool)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(bool)), new long[] {data.Length}, data, sizeof(bool)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(bool[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(bool)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(bool)), shape, data, sizeof(bool)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(bool value, TF_DataType? dType = null) { - var v = (bool*)Marshal.AllocHGlobal(sizeof(bool)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(bool)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(bool), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(bool)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(bool)); + *(bool*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(byte[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(byte)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(byte)), new long[] {data.Length}, data, sizeof(byte)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(byte[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(byte)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(byte)), shape, data, sizeof(byte)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(byte value, TF_DataType? dType = null) { - var v = (byte*)Marshal.AllocHGlobal(sizeof(byte)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(byte)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(byte), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(byte)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(byte)); + *(byte*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(short[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(short)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(short)), new long[] {data.Length}, data, sizeof(short)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(short[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(short)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(short)), shape, data, sizeof(short)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(short value, TF_DataType? dType = null) { - var v = (short*)Marshal.AllocHGlobal(sizeof(short)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(short)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(short), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(short)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(short)); + *(short*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(ushort[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ushort)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ushort)), new long[] {data.Length}, data, sizeof(ushort)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(ushort[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ushort)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ushort)), shape, data, sizeof(ushort)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(ushort value, TF_DataType? dType = null) { - var v = (ushort*)Marshal.AllocHGlobal(sizeof(ushort)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(ushort)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(ushort), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(ushort)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(ushort)); + *(ushort*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(int[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(int)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(int)), new long[] {data.Length}, data, sizeof(int)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(int[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(int)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(int)), shape, data, sizeof(int)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(int value, TF_DataType? dType = null) { - var v = (int*)Marshal.AllocHGlobal(sizeof(int)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(int)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(int), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(int)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(int)); + *(int*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(uint[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(uint)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(uint)), new long[] {data.Length}, data, sizeof(uint)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(uint[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(uint)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(uint)), shape, data, sizeof(uint)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(uint value, TF_DataType? dType = null) { - var v = (uint*)Marshal.AllocHGlobal(sizeof(uint)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(uint)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(uint), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(uint)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(uint)); + *(uint*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(long[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(long)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(long)), new long[] {data.Length}, data, sizeof(long)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(long[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(long)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(long)), shape, data, sizeof(long)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(long value, TF_DataType? dType = null) { - var v = (long*)Marshal.AllocHGlobal(sizeof(long)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(long)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(long), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(long)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(long)); + *(long*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(ulong[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ulong)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ulong)), new long[] {data.Length}, data, sizeof(ulong)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(ulong[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ulong)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ulong)), shape, data, sizeof(ulong)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(ulong value, TF_DataType? dType = null) { - var v = (ulong*)Marshal.AllocHGlobal(sizeof(ulong)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(ulong)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(ulong), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(ulong)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(ulong)); + *(ulong*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(float[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(float)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(float)), new long[] {data.Length}, data, sizeof(float)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(float[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(float)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(float)), shape, data, sizeof(float)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(float value, TF_DataType? dType = null) { - var v = (float*)Marshal.AllocHGlobal(sizeof(float)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(float)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(float), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(float)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(float)); + *(float*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(double[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(double)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(double)), new long[] {data.Length}, data, sizeof(double)); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(double[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(double)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(double)), shape, data, sizeof(double)); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(double value, TF_DataType? dType = null) { - var v = (double*)Marshal.AllocHGlobal(sizeof(double)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(double)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(double)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(double)); + *(double*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } /// - /// Create a 1d Tensor from the given linear array and shape + /// Create a 1d Tensor from the given linear array and shape /// public Tensor(Complex[] data, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(Complex)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(Complex)), new long[] {data.Length}, data, Marshal.SizeOf()); } /// - /// Create a N-dimensional Tensor from the given array + /// Create a N-dimensional Tensor from the given array /// public Tensor(Complex[] data, long[] shape, TF_DataType? dType = null) { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(Complex)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; + _handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(Complex)), shape, data, Marshal.SizeOf()); } /// - /// Create a scalar Tensor from the given value + /// Create a scalar Tensor from the given value /// public unsafe Tensor(Complex value, TF_DataType? dType = null) { - var v = (Complex*)Marshal.AllocHGlobal(sizeof(Complex)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; + _handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(Complex)); + *(Complex*) TF_TensorData(_handle) = value; + AllocationType = AllocationType.Tensorflow; } #endif /// - /// Create a string Tensor from the given string + /// Create a string Tensor from the given string /// public unsafe Tensor(string str) { var status = new Status(); var buffer = Encoding.UTF8.GetBytes(str); - var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); + var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); fixed (byte* src = buffer) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); + c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); _handle = handle; status.Check(true); } public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) { + if (tensorDType == null) + tensorDType = nd.dtype.as_dtype(); + // todo: handle nd of type "String" here too if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) { if (nd.Unsafe.Storage.Shape.IsContiguous) { - var bytesLength = (UIntPtr)nd.size; + var bytesLength = (UIntPtr) nd.size; var size = c_api.TF_StringEncodedSize(bytesLength); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); @@ -504,13 +481,12 @@ namespace Tensorflow status.Check(true); _handle = handle; - IsMemoryOwner = false; - } - else + } else { var buffer = nd.ToArray(); var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); @@ -521,7 +497,6 @@ namespace Tensorflow status.Check(true); _handle = handle; - IsMemoryOwner = false; } return; @@ -532,27 +507,27 @@ namespace Tensorflow private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) { - if (nd.dtype.Name == "String") + if (nd.typecode == NPTypeCode.String) throw new NotImplementedException("Support for NDArray of type string not implemented yet"); - IArraySlice arraySlice; - if (nd.Unsafe.Storage.Shape.IsContiguous == false) - { - // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. - arraySlice = nd.CloneData(); - } - else + + var arraySlice = nd.Unsafe.Storage.Shape.IsContiguous ? nd.GetData() : nd.CloneData(); + + var handle = TF_NewTensor( + given_dtype ?? nd.dtype.as_dtype(), + dims: nd.shape.Select(i => (long) i).ToArray(), + num_dims: nd.ndim, + data: arraySlice.Address, + len: (UIntPtr) (nd.size * nd.dtypesize)); + + //if TF decided not to perform copy, hold reference for given NDArray. + if (TF_TensorData(handle).ToPointer() == arraySlice.Address) { - // the memory is contiguous - arraySlice = nd.GetData(); - } - this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it - var ptr = new IntPtr(arraySlice.Address); - int num_bytes = (nd.size * nd.dtypesize); - var dtype = given_dtype ?? nd.dtype.as_dtype(); - var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); - IsMemoryOwner = false; - return handle; + AllocationType = AllocationType.FromPointer; + AllocationReferenceHolder = arraySlice; + } else + AllocationType = AllocationType.Tensorflow; + return handle; } public unsafe Tensor(byte[][] buffer, long[] shape) @@ -560,11 +535,13 @@ namespace Tensorflow int size = 0; foreach (var b in buffer) { - size += (int)TF_StringEncodedSize((UIntPtr)b.Length); + size += (int) TF_StringEncodedSize((UIntPtr) b.Length); } + int totalSize = size + buffer.Length * 8; ulong offset = 0; - IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); + IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr) totalSize); + AllocationType = AllocationType.Tensorflow; // Clear offset table IntPtr pOffset = TF_TensorData(handle); @@ -572,15 +549,15 @@ namespace Tensorflow IntPtr dstLimit = pOffset + totalSize; for (int i = 0; i < buffer.Length; i++) { - Marshal.WriteInt64(pOffset, (long)offset); + Marshal.WriteInt64(pOffset, (long) offset); using (var status = new Status()) { fixed (byte* src = &buffer[i][0]) { - var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); + var written = TF_StringEncode(src, (UIntPtr) buffer[i].Length, (sbyte*) dst, (UIntPtr) (dstLimit.ToInt64() - dst.ToInt64()), status); status.Check(true); pOffset += 8; - dst += (int)written; + dst += (int) written; offset += written; } } @@ -612,24 +589,26 @@ namespace Tensorflow /// [MethodImpl(MethodImplOptions.AggressiveInlining)] [SuppressMessage("ReSharper", "LocalVariableHidesMember")] - protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) + protected unsafe IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data, int element_size) { if (dt == TF_DataType.TF_STRING && data is byte[] buffer) { - var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); + var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + AllocationType = AllocationType.Tensorflow; IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); var status = new Status(); fixed (byte* src = buffer) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); + c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); status.Check(true); return handle; } - return CreateTensorWithoutCopying(dt, shape, data, 0, data.Length, element_size); + + return CreateTensorFromArray(dt, shape, data, 0, data.Length, element_size); } /// @@ -647,67 +626,34 @@ namespace Tensorflow /// specified dimensions. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int start, int count, int element_size) + protected IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data, int start, int count, int element_size) { if (start < 0 || start > data.Length - count) throw new ArgumentException($"Array length {data.Length} does not match the given shape {new Shape(shape.Cast().ToArray())}"); // get a handle to the pinned array which we will pass on to the tensor computation engine to use var gcHandle = GCHandle.Alloc(data, GCHandleType.Pinned); - _deallocatorArgs = new DeallocatorArgs() { gc_handle = GCHandle.ToIntPtr(gcHandle) }; - if (shape == null || shape.Length == 0) - return TF_NewTensor(dt, new long[0], 0, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr)(count * element_size), _gcHandleDeallocator, ref _deallocatorArgs); - else - return TF_NewTensor(dt, shape, shape.Length, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr)(count * element_size), _gcHandleDeallocator, ref _deallocatorArgs); - } - - [MonoPInvokeCallback(typeof(Deallocator))] - internal static void FreeHGlobalMemory(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) - { - if (args.deallocator_called || dataPtr == IntPtr.Zero) - return; + var pinnedAddr = gcHandle.AddrOfPinnedObject(); - // NumSharp will dispose - Marshal.FreeHGlobal(dataPtr); - args.deallocator_called = true; - } + //call NewTensor + IntPtr handle; + if (shape == null || shape.Length == 0) + handle = TF_NewTensor(dt, new long[0], 0, pinnedAddr + start * element_size, (UIntPtr) (count * element_size)); + else + handle = TF_NewTensor(dt, shape, shape.Length, pinnedAddr + start * element_size, (UIntPtr) (count * element_size)); - [MonoPInvokeCallback(typeof(Deallocator))] - internal static void FreeGCHandle(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) - { - if (args.deallocator_called || args.gc_handle == IntPtr.Zero) - return; - // note: since the ptr given to tensorflow is just the addr of the pinned object we can not directly free it! we need to free the gcHandle instead - GCHandle.FromIntPtr(args.gc_handle).Free(); - args.deallocator_called = true; - } + //Figure if TF decided to clone or not. + if (c_api.TF_TensorData(handle) == pinnedAddr) + { + AllocationType = AllocationType.GCHandle; + AllocationHandle = gcHandle; + } else + { + AllocationType = AllocationType.Tensorflow; + gcHandle.Free(); + } - [MonoPInvokeCallback(typeof(Deallocator))] - internal static void FreeNothing(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) - { - args.deallocator_called = true; + return handle; } - } - - /// - /// This attribute can be applied to callback functions that will be invoked - /// from unmanaged code to managed code. - /// - /// - /// - /// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))] - /// internal static void MyFreeFunc (IntPtr data, IntPtr length){..} - /// - /// - public sealed class MonoPInvokeCallbackAttribute : Attribute - { - /// - /// Use this constructor to annotate the type of the callback function that - /// will be invoked from unmanaged code. - /// - /// T. - public MonoPInvokeCallbackAttribute(Type t) { } - } - -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 75cba69e..aa2dc6d5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -555,9 +555,35 @@ namespace Tensorflow return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; } + /// + /// Dispose any managed resources. + /// + /// Equivalent to what you would perform inside + protected override void DisposeManagedResources() + { + AllocationReferenceHolder = null; + } + + [SuppressMessage("ReSharper", "ConvertIfStatementToSwitchStatement")] protected override void DisposeUnmanagedResources(IntPtr handle) { c_api.TF_DeleteTensor(handle); + + if (AllocationHandle == null) + return; + + if (AllocationType == AllocationType.GCHandle) + { + ((GCHandle) AllocationHandle).Free(); + AllocationHandle = null; + AllocationType = AllocationType.None; + } else if (AllocationType == AllocationType.Marshal) + { + Marshal.FreeHGlobal((IntPtr) AllocationHandle); + AllocationHandle = null; + AllocationType = AllocationType.None; + } else + throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType})."); } public bool IsDisposed => _disposed; diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index 6b20b34f..be5f3932 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; namespace Tensorflow @@ -77,6 +78,51 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref DeallocatorArgs deallocator_arg); + /// + /// Return a new tensor that holds the bytes data[0,len-1] + /// + /// + /// + /// + /// + /// num_bytes, ex: 6 * sizeof(float) + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, IntPtr deallocator_arg); + + /// + /// Return a new tensor that holds the bytes data[0,len-1] + /// + /// + /// + /// + /// + /// num_bytes, ex: 6 * sizeof(float) + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len) + { + return TF_NewTensor(dataType, dims, num_dims, data, len, EmptyDeallocator, DeallocatorArgs.Empty); + } + /// + /// Return a new tensor that holds the bytes data[0,len-1] + /// + /// + /// + /// + /// + /// num_bytes, ex: 6 * sizeof(float) + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, void* data, UIntPtr len) + { + return TF_NewTensor(dataType, dims, num_dims, new IntPtr(data), len); + } + /// /// Return the number of dimensions that the tensor has. /// @@ -159,5 +205,32 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, IntPtr status); + + + public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator; + + [MonoPInvokeCallback(typeof(c_api.Deallocator))] + private static void FreeNothingDeallocator(IntPtr dataPtr, IntPtr len, ref c_api.DeallocatorArgs args) + { } + + /// + /// This attribute can be applied to callback functions that will be invoked + /// from unmanaged code to managed code. + /// + /// + /// + /// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))] + /// internal static void MyFreeFunc (IntPtr data, IntPtr length){..} + /// + /// + public sealed class MonoPInvokeCallbackAttribute : Attribute + { + /// + /// Use this constructor to annotate the type of the callback function that + /// will be invoked from unmanaged code. + /// + /// T. + public MonoPInvokeCallbackAttribute(Type t) { } + } } } diff --git a/src/TensorFlowNET.Core/Util/Locks.cs b/src/TensorFlowNET.Core/Util/Locks.cs new file mode 100644 index 00000000..3b54ee2c --- /dev/null +++ b/src/TensorFlowNET.Core/Util/Locks.cs @@ -0,0 +1,21 @@ +using System.Threading; + +namespace Tensorflow.Util +{ + /// + /// Provides a set of locks on different shared levels. + /// + public static class Locks + { + private static readonly ThreadLocal _lockpool = new ThreadLocal(() => new object()); + + /// + /// A seperate lock for every requesting thread. + /// + /// This property is thread-safe. + public static object ThreadWide => _lockpool.Value; + + + public static readonly object ProcessWide = new object(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 1dc8eb56..4708730b 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -19,13 +19,19 @@ using System.Collections.Generic; using System.Runtime.InteropServices; using Google.Protobuf; using System.Linq; +using System.Threading; using NumSharp; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow { public partial class ops { + private static readonly ThreadLocal _defaultGraphFactory = new ThreadLocal(() => new DefaultGraphStack()); + + public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value; + public static int tensor_id(Tensor tensor) { return tensor.Id; @@ -72,8 +78,6 @@ namespace Tensorflow return get_default_graph().get_collection_ref(key); } - public static DefaultGraphStack default_graph_stack = new DefaultGraphStack(); - /// /// Returns the default graph for the current thread. /// @@ -93,6 +97,7 @@ namespace Tensorflow //return _default_graph_stack.get_default() return default_graph_stack.get_controller(); } + public static Graph set_default_graph(Graph graph) { //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! @@ -203,47 +208,49 @@ namespace Tensorflow /// A wrapped TF_Operation*. public static (IntPtr, IntPtr) _create_c_op(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) { - var op_desc = graph.NewOperation(node_def.Op, node_def.Name); - - //TODO: Implement TF_SetDevice - //if node_def.device: - // c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) - // Add inputs - foreach (var op_input in inputs) + lock (Locks.ProcessWide) { - if (op_input is Tensor[] op_inputs) - c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); - else if (op_input is Tensor op_input1) + var op_desc = graph.NewOperation(node_def.Op, node_def.Name); + + //TODO: Implement TF_SetDevice + //if node_def.device: + // c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) + // Add inputs + foreach (var op_input in inputs) { - c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); + if (op_input is Tensor[] op_inputs) + c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); + else if (op_input is Tensor op_input1) + { + c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); + } else + throw new NotImplementedException("_create_c_op"); } - else - throw new NotImplementedException("_create_c_op"); - } - var status = new Status(); + var status = new Status(); - // Add control inputs - foreach (var control_input in control_inputs) - c_api.TF_AddControlInput(op_desc, control_input); + // Add control inputs + foreach (var control_input in control_inputs) + c_api.TF_AddControlInput(op_desc, control_input); - // Add attrs - foreach (var attr in node_def.Attr) - { - var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. - var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak - Marshal.Copy(bytes, 0, proto, bytes.Length); - uint len = (uint)bytes.Length; - c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); + // Add attrs + foreach (var attr in node_def.Attr) + { + var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. + var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak + Marshal.Copy(bytes, 0, proto, bytes.Length); + uint len = (uint) bytes.Length; + c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); - status.Check(true); - } + status.Check(true); + } - var c_op = c_api.TF_FinishOperation(op_desc, status); + var c_op = c_api.TF_FinishOperation(op_desc, status); - status.Check(true); + status.Check(true); - return (c_op, op_desc); + return (c_op, op_desc); + } } public static OpDef _get_op_def(Graph graph, string type) @@ -311,7 +318,7 @@ namespace Tensorflow /// public static int uid() { - return uid_number++; + return Interlocked.Increment(ref uid_number); } public static void colocate_with(bool ignore_existing = false) @@ -386,8 +393,6 @@ namespace Tensorflow /// The default `Session` being used in the current thread. public static Session get_default_session() { - if (tf.defaultSession == null) - tf.defaultSession = tf.Session(); return tf.defaultSession; } diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index ca903844..bdb2f537 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -14,12 +14,15 @@ limitations under the License. ******************************************************************************/ +using System.Threading; using Tensorflow.Eager; namespace Tensorflow { public partial class tensorflow : IObjectLife { + protected internal readonly ThreadLocal _defaultSessionFactory; + public TF_DataType @byte = TF_DataType.TF_UINT8; public TF_DataType @sbyte = TF_DataType.TF_INT8; public TF_DataType int16 = TF_DataType.TF_INT16; @@ -34,7 +37,13 @@ namespace Tensorflow public Context context = new Context(new ContextOptions(), new Status()); - public Session defaultSession; + + public tensorflow() + { + _defaultSessionFactory = new ThreadLocal(Session); + } + + public Session defaultSession => _defaultSessionFactory.Value; public RefVariable Variable(T data, bool trainable = true, diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs b/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs index 0414d68d..b6d0502f 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs @@ -89,7 +89,7 @@ namespace TensorFlowNET.Examples Directory.CreateDirectory(dir); // get model file - string url = "https://storage.googleapis.com/download.tf.org/models/inception5h.zip"; + string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"; Utility.Web.Download(url, dir, "inception5h.zip"); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs b/test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs index 704067fc..006a614c 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs @@ -93,7 +93,7 @@ namespace TensorFlowNET.Examples Directory.CreateDirectory(dir); // get model file - string url = "https://storage.googleapis.com/download.tf.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz"; + string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz"; Utility.Web.Download(url, dir, $"{pbFile}.tar.gz"); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs index 7f2d81f4..b9a848ce 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs @@ -33,7 +33,7 @@ namespace TensorFlowNET.Examples /// and simply train a new classification layer on top. Transfer learning is a technique that shortcuts much of this /// by taking a piece of a model that has already been trained on a related task and reusing it in a new model. /// - /// https://www.tf.org/hub/tutorials/image_retraining + /// https://www.tensorflow.org/hub/tutorials/image_retraining /// public class RetrainImageClassifier : IExample { @@ -168,7 +168,7 @@ namespace TensorFlowNET.Examples /// weights, and then sets up all the gradients for the backward pass. /// /// The set up for the softmax and fully-connected layers is based on: - /// https://www.tf.org/tutorials/mnist/beginners/index.html + /// https://www.tensorflow.org/tutorials/mnist/beginners/index.html /// /// /// diff --git a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs index 58609c17..007b5624 100644 --- a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs @@ -11,7 +11,7 @@ namespace TensorFlowNET.UnitTest /// tensorflow\c\c_api_test.cc /// `class CApiGradientsTest` /// - [TestClass] + [TestClass, Ignore] public class CApiGradientsTest : CApiTest, IDisposable { private Graph graph_ = new Graph(); diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs index ae57b075..fa293288 100644 --- a/test/TensorFlowNET.UnitTest/CSession.cs +++ b/test/TensorFlowNET.UnitTest/CSession.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow; +using Tensorflow.Util; namespace TensorFlowNET.UnitTest { @@ -22,9 +23,12 @@ namespace TensorFlowNET.UnitTest public CSession(Graph graph, Status s, bool user_XLA = false) { - var opts = new SessionOptions(); - opts.SetConfig(new ConfigProto { InterOpParallelismThreads = 4 }); - session_ = new Session(graph, opts, s); + lock (Locks.ProcessWide) + { + var opts = new SessionOptions(); + opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4}); + session_ = new Session(graph, opts, s); + } } public void SetInputs(Dictionary inputs) @@ -64,13 +68,13 @@ namespace TensorFlowNET.UnitTest public unsafe void Run(Status s) { var inputs_ptr = inputs_.ToArray(); - var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); + var input_values_ptr = input_values_.Select(x => (IntPtr) x).ToArray(); var outputs_ptr = outputs_.ToArray(); var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray(); IntPtr[] targets_ptr = new IntPtr[0]; c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, - outputs_ptr, output_values_ptr, outputs_.Count, + outputs_ptr, output_values_ptr, outputs_.Count, targets_ptr, targets_.Count, IntPtr.Zero, s); @@ -90,4 +94,4 @@ namespace TensorFlowNET.UnitTest ResetOutputValues(); } } -} +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 443191dd..1b474f71 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -207,7 +207,7 @@ namespace TensorFlowNET.UnitTest public void ImportGraphDef() { var s = new Status(); - var graph = new Graph(); + var graph = new Graph().as_default(); // Create a simple graph. c_test_util.Placeholder(graph, s); @@ -221,7 +221,7 @@ namespace TensorFlowNET.UnitTest // Import it, with a prefix, in a fresh graph. graph.Dispose(); - graph = new Graph(); + graph = new Graph().as_default(); var opts = c_api.TF_NewImportGraphDefOptions(); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); @@ -359,7 +359,7 @@ namespace TensorFlowNET.UnitTest public void ImportGraphDef_WithReturnOutputs() { var s = new Status(); - var graph = new Graph(); + var graph = new Graph().as_default(); // Create a graph with two nodes: x and 3 c_test_util.Placeholder(graph, s); @@ -375,7 +375,7 @@ namespace TensorFlowNET.UnitTest // Import it in a fresh graph with return outputs. graph.Dispose(); - graph = new Graph(); + graph = new Graph().as_default(); var opts = new ImportGraphDefOptions(); opts.AddReturnOutput("feed", 0); opts.AddReturnOutput("scalar", 0); diff --git a/test/TensorFlowNET.UnitTest/MultithreadingTests.cs b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs new file mode 100644 index 00000000..e1cb95ff --- /dev/null +++ b/test/TensorFlowNET.UnitTest/MultithreadingTests.cs @@ -0,0 +1,263 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using FluentAssertions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using Tensorflow.Util; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class MultithreadingTests + { + [TestMethod] + public void SessionCreation() + { + ops.uid(); //increment id by one + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + + using (var sess = tf.Session()) + { + var default_graph = tf.peak_default_graph(); + var sess_graph = sess.GetPrivate("_graph"); + sess_graph.Should().NotBeNull(); + default_graph.Should().NotBeNull() + .And.BeEquivalentTo(sess_graph); + } + } + } + + [TestMethod] + public void SessionCreation_x2() + { + ops.uid(); //increment id by one + + MultiThreadedUnitTestExecuter.Run(16, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + //tf.Session created an other graph + using (var sess = tf.Session()) + { + var default_graph = tf.peak_default_graph(); + var sess_graph = sess.GetPrivate("_graph"); + sess_graph.Should().NotBeNull(); + default_graph.Should().NotBeNull() + .And.BeEquivalentTo(sess_graph); + } + } + } + + [TestMethod] + public void GraphCreation() + { + ops.uid(); //increment id by one + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + var beforehand = tf.get_default_graph(); //this should create default automatically. + beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread."); + tf.peak_default_graph().Should().NotBeNull(); + + using (var sess = tf.Session()) + { + var default_graph = tf.peak_default_graph(); + var sess_graph = sess.GetPrivate("_graph"); + sess_graph.Should().NotBeNull(); + default_graph.Should().NotBeNull() + .And.BeEquivalentTo(sess_graph) + .And.BeEquivalentTo(beforehand); + + Console.WriteLine($"{tid}-{default_graph.graph_key}"); + + //var result = sess.run(new object[] {g, a}); + //var actualDeriv = result[0].GetData()[0]; + //var actual = result[1].GetData()[0]; + } + } + } + + + [TestMethod] + public void Marshal_AllocHGlobal() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + for (int i = 0; i < 100; i++) + { + Marshal.FreeHGlobal(Marshal.AllocHGlobal(sizeof(int))); + } + } + } + + [TestMethod] + public void TensorCreation() + { + //lock (Locks.ProcessWide) + // tf.Session(); //create one to increase next id to 1. + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + using (var sess = tf.Session()) + { + Tensor t = null; + for (int i = 0; i < 100; i++) + { + t = new Tensor(1); + } + } + } + } + + [TestMethod] + public void TensorCreation_Array() + { + //lock (Locks.ProcessWide) + // tf.Session(); //create one to increase next id to 1. + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + //tf.Session created an other graph + using (var sess = tf.Session()) + { + Tensor t = null; + for (int i = 0; i < 100; i++) + { + t = new Tensor(new int[] {1, 2, 3}); + } + } + } + } + + [TestMethod] + public void TensorCreation_Undressed() + { + //lock (Locks.ProcessWide) + // tf.Session(); //create one to increase next id to 1. + + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + unsafe void Core(int tid) + { + using (var sess = tf.Session()) + { + Tensor t = null; + for (int i = 0; i < 100; i++) + { + var v = (int*) Marshal.AllocHGlobal(sizeof(int)); + c_api.DeallocatorArgs _deallocatorArgs = new c_api.DeallocatorArgs(); + var handle = c_api.TF_NewTensor(typeof(int).as_dtype(), dims: new long[0], num_dims: 0, + data: (IntPtr) v, len: (UIntPtr) sizeof(int), + deallocator: (IntPtr data, IntPtr size, ref c_api.DeallocatorArgs args) => Marshal.FreeHGlobal(data), + ref _deallocatorArgs); + c_api.TF_DeleteTensor(handle); + } + } + } + } + + [TestMethod] + public void SessionRun() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] {2f}, shape: new[] {1}); + var a2 = tf.constant(new[] {3f}, shape: new[] {1}); + var math = a1 + a2; + for (int i = 0; i < 100; i++) + { + using (var sess = tf.Session()) + { + sess.run(math).GetAtIndex(0).Should().Be(5); + } + } + } + } + + [TestMethod] + public void SessionRun_InsideSession() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + using (var sess = tf.Session()) + { + tf.peak_default_graph().Should().NotBeNull(); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] {2f}, shape: new[] {1}); + var a2 = tf.constant(new[] {3f}, shape: new[] {1}); + var math = a1 + a2; + + var result = sess.run(math); + result[0].GetAtIndex(0).Should().Be(5); + } + } + } + + [TestMethod] + public void SessionRun_Initialization() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + using (var sess = tf.Session()) + { + tf.peak_default_graph().Should().NotBeNull(); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] {2f}, shape: new[] {1}); + var a2 = tf.constant(new[] {3f}, shape: new[] {1}); + var math = a1 + a2; + } + } + } + + [TestMethod] + public void SessionRun_Initialization_OutsideSession() + { + MultiThreadedUnitTestExecuter.Run(8, Core); + + //the core method + void Core(int tid) + { + tf.peak_default_graph().Should().BeNull(); + //graph is created automatically to perform create these operations + var a1 = tf.constant(new[] {2f}, shape: new[] {1}); + var a2 = tf.constant(new[] {3f}, shape: new[] {1}); + var math = a1 + a2; + } + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index 45005a59..d2295166 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -8,6 +8,7 @@ using System.Text; using FluentAssertions; using Google.Protobuf; using Tensorflow; +using Tensorflow.Util; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest @@ -19,13 +20,13 @@ namespace TensorFlowNET.UnitTest /// tensorflow\c\c_api_test.cc /// `TEST(CAPI, Session)` /// - [TestMethod] + [TestMethod, Ignore] public void Session() { - lock (this) + lock (Locks.ProcessWide) { var s = new Status(); - var graph = new Graph(); + var graph = new Graph().as_default(); // Make a placeholder operation. var feed = c_test_util.Placeholder(graph, s); @@ -93,7 +94,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var result = c.eval(sess); - Assert.AreEqual(6, result.Data()[0]); + Assert.AreEqual(6, result.GetAtIndex(0)); } } } diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index 549ab909..702bb2ae 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -10,6 +10,8 @@ false Open.snk + + latest diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings new file mode 100644 index 00000000..6cbf8796 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings @@ -0,0 +1,2 @@ + + True \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 11557f14..01ebda07 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -4,6 +4,7 @@ using System; using System.Linq; using System.Runtime.InteropServices; using System.Threading; +using FluentAssertions; using Tensorflow; using static Tensorflow.Binding; @@ -12,77 +13,63 @@ namespace TensorFlowNET.UnitTest [TestClass] public class TensorTest : CApiTest { - [Ignore("Not for mult-thread")] - public void TensorDeallocationThreadSafety() - { - var tensors = new Tensor[1000]; - foreach (var i in range(1000)) - { - tensors[i] = new Tensor(new int[1000]); - } - SemaphoreSlim s = new SemaphoreSlim(0, 2); - SemaphoreSlim s_done = new SemaphoreSlim(0, 2); - - var t1 = new Thread(() => - { - s.Wait(); - foreach (var t in tensors) - t.Dispose(); - s_done.Release(); - }); - - var t2 = new Thread(() => - { - s.Wait(); - foreach (var t in tensors) - t.Dispose(); - s_done.Release(); - }); - - t1.Start(); - t2.Start(); - s.Release(2); - s_done.Wait(); - s_done.Wait(); - - foreach (var t in tensors) - Assert.IsTrue(t.IsDisposed); - } - [TestMethod] public unsafe void TensorFromFixed() { var array = new float[1000]; var span = new Span(array, 100, 500); - fixed (float* ptr=&MemoryMarshal.GetReference(span)) + fixed (float* ptr = &MemoryMarshal.GetReference(span)) { - using (var t = new Tensor((IntPtr)ptr, new long[] {span.Length}, tf.float32, 4*span.Length)) + using (var t = new Tensor((IntPtr) ptr, new long[] {span.Length}, tf.float32, 4 * span.Length)) { Assert.IsFalse(t.IsDisposed); - Assert.IsFalse(t.IsMemoryOwner); Assert.AreEqual(2000, (int) t.bytesize); } } + fixed (float* ptr = &array[0]) { - using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length)) + using (var t = new Tensor((IntPtr) ptr, new long[] {array.Length}, tf.float32, 4 * array.Length)) { Assert.IsFalse(t.IsDisposed); - Assert.IsFalse(t.IsMemoryOwner); - Assert.AreEqual(4000, (int)t.bytesize); + Assert.AreEqual(4000, (int) t.bytesize); } } } + [TestMethod] + public unsafe void TensorFromArray() + { + var array = new float[1000]; + using (var t = new Tensor(array, new long[] {array.Length}, tf.float32)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(1000 * sizeof(float), (int) t.bytesize); + } + + using (var t = new Tensor(new float[] {1}, new long[] {1}, tf.float32)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(1 * sizeof(float), (int) t.bytesize); + } + + using (var t = new Tensor(new float[] {1}, null, tf.float32)) + { + Assert.IsFalse(t.IsDisposed); + Assert.AreEqual(1 * sizeof(float), (int) t.bytesize); + t.shape.Should().BeEmpty(); + } + } + [TestMethod] public void AllocateTensor() { ulong num_bytes = 6 * sizeof(float); - long[] dims = { 2, 3 }; + long[] dims = {2, 3}; Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); EXPECT_EQ(2, t.NDims); - EXPECT_EQ((int)dims[0], t.shape[0]); + EXPECT_EQ((int) dims[0], t.shape[0]); EXPECT_EQ(num_bytes, t.bytesize); t.Dispose(); } @@ -98,7 +85,7 @@ namespace TensorFlowNET.UnitTest NDArray nd = np.array(2, 3); Tensor t = new Tensor(nd); Tensor o = t.MaybeMove(); - ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own. + ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own. t.Dispose(); } @@ -116,10 +103,10 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); EXPECT_EQ(tensor.rank, nd.ndim); - EXPECT_EQ((int)tensor.shape[0], nd.shape[0]); - EXPECT_EQ((int)tensor.shape[1], nd.shape[1]); - EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float)); - Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), new float[] { 1, 2, 3, 4, 5, 6 })); + EXPECT_EQ((int) tensor.shape[0], nd.shape[0]); + EXPECT_EQ((int) tensor.shape[1], nd.shape[1]); + EXPECT_EQ(tensor.bytesize, (ulong) nd.size * sizeof(float)); + Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), new float[] {1, 2, 3, 4, 5, 6})); } /// @@ -130,7 +117,7 @@ namespace TensorFlowNET.UnitTest public void SetShape() { var s = new Status(); - var graph = new Graph(); + var graph = new Graph().as_default(); var feed = c_test_util.Placeholder(graph, s); var feed_out_0 = new TF_Output(feed, 0); @@ -148,7 +135,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(-1, num_dims); // Set the shape to be 2 x Unknown - long[] dims = { 2, -1 }; + long[] dims = {2, -1}; c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); @@ -177,8 +164,8 @@ namespace TensorFlowNET.UnitTest c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); EXPECT_EQ(2, num_dims); - EXPECT_EQ(2, (int)returned_dims[0]); - EXPECT_EQ(3, (int)returned_dims[1]); + EXPECT_EQ(2, (int) returned_dims[0]); + EXPECT_EQ(3, (int) returned_dims[1]); // Try to set 'unknown' with same rank on the shape and see that // it doesn't change. @@ -189,8 +176,8 @@ namespace TensorFlowNET.UnitTest c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); EXPECT_EQ(2, num_dims); - EXPECT_EQ(2, (int)returned_dims[0]); - EXPECT_EQ(3, (int)returned_dims[1]); + EXPECT_EQ(2, (int) returned_dims[0]); + EXPECT_EQ(3, (int) returned_dims[1]); // Try to fetch a shape with the wrong num_dims c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); @@ -216,4 +203,4 @@ namespace TensorFlowNET.UnitTest s.Dispose(); } } -} +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs b/test/TensorFlowNET.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs new file mode 100644 index 00000000..ac4dee69 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs @@ -0,0 +1,173 @@ +using System; +using System.Diagnostics; +using System.Threading; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace TensorFlowNET.UnitTest +{ + public delegate void MultiThreadedTestDelegate(int threadid); + + /// + /// Creates a synchronized eco-system of running code. + /// + public class MultiThreadedUnitTestExecuter : IDisposable + { + public int ThreadCount { get; } + public Thread[] Threads { get; } + public Exception[] Exceptions { get; } + private readonly SemaphoreSlim barrier_threadstarted; + private readonly ManualResetEventSlim barrier_corestart; + private readonly SemaphoreSlim done_barrier2; + + public Action PostRun { get; set; } + + #region Static + + [DebuggerHidden] + public static void Run(int threadCount, MultiThreadedTestDelegate workload) + { + if (workload == null) throw new ArgumentNullException(nameof(workload)); + if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount)); + new MultiThreadedUnitTestExecuter(threadCount).Run(workload); + } + + [DebuggerHidden] + public static void Run(int threadCount, params MultiThreadedTestDelegate[] workloads) + { + if (workloads == null) throw new ArgumentNullException(nameof(workloads)); + if (workloads.Length == 0) throw new ArgumentException("Value cannot be an empty collection.", nameof(workloads)); + if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount)); + new MultiThreadedUnitTestExecuter(threadCount).Run(workloads); + } + + [DebuggerHidden] + public static void Run(int threadCount, MultiThreadedTestDelegate workload, Action postRun) + { + if (workload == null) throw new ArgumentNullException(nameof(workload)); + if (postRun == null) throw new ArgumentNullException(nameof(postRun)); + if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount)); + new MultiThreadedUnitTestExecuter(threadCount) {PostRun = postRun}.Run(workload); + } + + #endregion + + + /// Initializes a new instance of the class. + public MultiThreadedUnitTestExecuter(int threadCount) + { + if (threadCount <= 0) + throw new ArgumentOutOfRangeException(nameof(threadCount)); + ThreadCount = threadCount; + Threads = new Thread[ThreadCount]; + Exceptions = new Exception[ThreadCount]; + done_barrier2 = new SemaphoreSlim(0, threadCount); + barrier_corestart = new ManualResetEventSlim(); + barrier_threadstarted = new SemaphoreSlim(0, threadCount); + } + + [DebuggerHidden] + public void Run(params MultiThreadedTestDelegate[] workloads) + { + if (workloads == null) + throw new ArgumentNullException(nameof(workloads)); + if (workloads.Length != 1 && workloads.Length % ThreadCount != 0) + throw new InvalidOperationException($"Run method must accept either 1 workload or n-threads workloads. Got {workloads.Length} workloads."); + + if (ThreadCount == 1) + { + Exception ex = null; + new Thread(() => + { + try + { + workloads[0](0); + } catch (Exception e) + { + if (Debugger.IsAttached) + throw; + ex = e; + } finally + { + done_barrier2.Release(1); + } + }).Start(); + + done_barrier2.Wait(); + + if (ex != null) + throw new Exception($"Thread 0 has failed: ", ex); + + PostRun?.Invoke(this); + + return; + } + + //thread core + Exception ThreadCore(MultiThreadedTestDelegate core, int threadid) + { + barrier_threadstarted.Release(1); + barrier_corestart.Wait(); + //workload + try + { + core(threadid); + } catch (Exception e) + { + if (Debugger.IsAttached) + throw; + return e; + } finally + { + done_barrier2.Release(1); + } + + return null; + } + + //initialize all threads + if (workloads.Length == 1) + { + var workload = workloads[0]; + for (int i = 0; i < ThreadCount; i++) + { + var i_local = i; + Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local)); + } + } else + { + for (int i = 0; i < ThreadCount; i++) + { + var i_local = i; + var workload = workloads[i_local % workloads.Length]; + Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local)); + } + } + + //run all threads + for (int i = 0; i < ThreadCount; i++) Threads[i].Start(); + //wait for threads to be started and ready + for (int i = 0; i < ThreadCount; i++) barrier_threadstarted.Wait(); + + //signal threads to start + barrier_corestart.Set(); + + //wait for threads to finish + for (int i = 0; i < ThreadCount; i++) done_barrier2.Wait(); + + //handle fails + for (int i = 0; i < ThreadCount; i++) + if (Exceptions[i] != null) + throw new Exception($"Thread {i} has failed: ", Exceptions[i]); + + //checks after ended + PostRun?.Invoke(this); + } + + public void Dispose() + { + barrier_threadstarted.Dispose(); + barrier_corestart.Dispose(); + done_barrier2.Dispose(); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs b/test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs new file mode 100644 index 00000000..acb8c41e --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs @@ -0,0 +1,914 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.VisualStudio.TestTools.UnitTesting +{ + using System; + using System.Collections.Generic; + //using System.Diagnostics; + //using System.Diagnostics.CodeAnalysis; + using System.Globalization; + using System.Reflection; + + /// + /// This class represents the live NON public INTERNAL object in the system + /// + internal class PrivateObject + { + #region Data + + // bind everything + private const BindingFlags BindToEveryThing = BindingFlags.Default | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Public; + + private static BindingFlags constructorFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.CreateInstance | BindingFlags.NonPublic; + + private object target; // automatically initialized to null + private Type originalType; // automatically initialized to null + + //private Dictionary> methodCache; // automatically initialized to null + + #endregion + + #region Constructors + + ///// + ///// Initializes a new instance of the class that contains + ///// the already existing object of the private class + ///// + ///// object that serves as starting point to reach the private members + ///// the derefrencing string using . that points to the object to be retrived as in m_X.m_Y.m_Z + //[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")] + //public PrivateObject(object obj, string memberToAccess) + //{ + // Helper.CheckParameterNotNull(obj, "obj", string.Empty); + // ValidateAccessString(memberToAccess); + + // PrivateObject temp = obj as PrivateObject; + // if (temp == null) + // { + // temp = new PrivateObject(obj); + // } + + // // Split The access string + // string[] arr = memberToAccess.Split(new char[] { '.' }); + + // for (int i = 0; i < arr.Length; i++) + // { + // object next = temp.InvokeHelper(arr[i], BindToEveryThing | BindingFlags.Instance | BindingFlags.GetField | BindingFlags.GetProperty, null, CultureInfo.InvariantCulture); + // temp = new PrivateObject(next); + // } + + // this.target = temp.target; + // this.originalType = temp.originalType; + //} + + ///// + ///// Initializes a new instance of the class that wraps the + ///// specified type. + ///// + ///// Name of the assembly + ///// fully qualified name + ///// Argmenets to pass to the constructor + //public PrivateObject(string assemblyName, string typeName, params object[] args) + // : this(assemblyName, typeName, null, args) + //{ + //} + + ///// + ///// Initializes a new instance of the class that wraps the + ///// specified type. + ///// + ///// Name of the assembly + ///// fully qualified name + ///// An array of objects representing the number, order, and type of the parameters for the constructor to get + ///// Argmenets to pass to the constructor + //public PrivateObject(string assemblyName, string typeName, Type[] parameterTypes, object[] args) + // : this(Type.GetType(string.Format(CultureInfo.InvariantCulture, "{0}, {1}", typeName, assemblyName), false), parameterTypes, args) + //{ + // Helper.CheckParameterNotNull(assemblyName, "assemblyName", string.Empty); + // Helper.CheckParameterNotNull(typeName, "typeName", string.Empty); + //} + + ///// + ///// Initializes a new instance of the class that wraps the + ///// specified type. + ///// + ///// type of the object to create + ///// Argmenets to pass to the constructor + //public PrivateObject(Type type, params object[] args) + // : this(type, null, args) + //{ + // Helper.CheckParameterNotNull(type, "type", string.Empty); + //} + + ///// + ///// Initializes a new instance of the class that wraps the + ///// specified type. + ///// + ///// type of the object to create + ///// An array of objects representing the number, order, and type of the parameters for the constructor to get + ///// Argmenets to pass to the constructor + //public PrivateObject(Type type, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(type, "type", string.Empty); + // object o; + // if (parameterTypes != null) + // { + // ConstructorInfo ci = type.GetConstructor(BindToEveryThing, null, parameterTypes, null); + // if (ci == null) + // { + // throw new ArgumentException(FrameworkMessages.PrivateAccessorConstructorNotFound); + // } + + // try + // { + // o = ci.Invoke(args); + // } + // catch (TargetInvocationException e) + // { + // Debug.Assert(e.InnerException != null, "Inner exception should not be null."); + // if (e.InnerException != null) + // { + // throw e.InnerException; + // } + + // throw; + // } + // } + // else + // { + // o = Activator.CreateInstance(type, constructorFlags, null, args, null); + // } + + // this.ConstructFrom(o); + //} + + /// + /// Initializes a new instance of the class that wraps + /// the given object. + /// + /// object to wrap + //[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")] + public PrivateObject(object obj) + { + Helper.CheckParameterNotNull(obj, "obj", string.Empty); + this.ConstructFrom(obj); + } + + /// + /// Initializes a new instance of the class that wraps + /// the given object. + /// + /// object to wrap + /// PrivateType object + //[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an an object, so 'obj' seems reasonable")] + public PrivateObject(object obj, PrivateType type) + { + Helper.CheckParameterNotNull(type, "type", string.Empty); + this.target = obj; + this.originalType = type.ReferencedType; + } + + #endregion + + ///// + ///// Gets or sets the target + ///// + //public object Target + //{ + // get + // { + // return this.target; + // } + + // set + // { + // Helper.CheckParameterNotNull(value, "Target", string.Empty); + // this.target = value; + // this.originalType = value.GetType(); + // } + //} + + ///// + ///// Gets the type of underlying object + ///// + //public Type RealType + //{ + // get + // { + // return this.originalType; + // } + //} + + //private Dictionary> GenericMethodCache + //{ + // get + // { + // if (this.methodCache == null) + // { + // this.BuildGenericMethodCacheForType(this.originalType); + // } + + // Debug.Assert(this.methodCache != null, "Invalid method cache for type."); + + // return this.methodCache; + // } + //} + + /// + /// returns the hash code of the target object + /// + /// int representing hashcode of the target object + public override int GetHashCode() + { + //Debug.Assert(this.target != null, "target should not be null."); + return this.target.GetHashCode(); + } + + /// + /// Equals + /// + /// Object with whom to compare + /// returns true if the objects are equal. + public override bool Equals(object obj) + { + if (this != obj) + { + //Debug.Assert(this.target != null, "target should not be null."); + if (typeof(PrivateObject) == obj?.GetType()) + { + return this.target.Equals(((PrivateObject) obj).target); + } else + { + return false; + } + } + + return true; + } + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// Arguments to pass to the member to invoke. + ///// Result of method call + //public object Invoke(string name, params object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.Invoke(name, null, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// Result of method call + //public object Invoke(string name, Type[] parameterTypes, object[] args) + //{ + // return this.Invoke(name, parameterTypes, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// An array of types corresponding to the types of the generic arguments. + ///// Result of method call + //public object Invoke(string name, Type[] parameterTypes, object[] args, Type[] typeArguments) + //{ + // return this.Invoke(name, BindToEveryThing, parameterTypes, args, CultureInfo.InvariantCulture, typeArguments); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// Arguments to pass to the member to invoke. + ///// Culture info + ///// Result of method call + //public object Invoke(string name, object[] args, CultureInfo culture) + //{ + // return this.Invoke(name, null, args, culture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// Culture info + ///// Result of method call + //public object Invoke(string name, Type[] parameterTypes, object[] args, CultureInfo culture) + //{ + // return this.Invoke(name, BindToEveryThing, parameterTypes, args, culture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// Arguments to pass to the member to invoke. + ///// Result of method call + //public object Invoke(string name, BindingFlags bindingFlags, params object[] args) + //{ + // return this.Invoke(name, bindingFlags, null, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// Result of method call + //public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) + //{ + // return this.Invoke(name, bindingFlags, parameterTypes, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// Arguments to pass to the member to invoke. + ///// Culture info + ///// Result of method call + //public object Invoke(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) + //{ + // return this.Invoke(name, bindingFlags, null, args, culture); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// Culture info + ///// Result of method call + //public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture) + //{ + // return this.Invoke(name, bindingFlags, parameterTypes, args, culture, null); + //} + + ///// + ///// Invokes the specified method + ///// + ///// Name of the method + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// An array of objects representing the number, order, and type of the parameters for the method to get. + ///// Arguments to pass to the member to invoke. + ///// Culture info + ///// An array of types corresponding to the types of the generic arguments. + ///// Result of method call + //public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture, Type[] typeArguments) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // if (parameterTypes != null) + // { + // bindingFlags |= BindToEveryThing | BindingFlags.Instance; + + // // Fix up the parameter types + // MethodInfo member = this.originalType.GetMethod(name, bindingFlags, null, parameterTypes, null); + + // // If the method was not found and type arguments were provided for generic paramaters, + // // attempt to look up a generic method. + // if ((member == null) && (typeArguments != null)) + // { + // // This method may contain generic parameters...if so, the previous call to + // // GetMethod() will fail because it doesn't fully support generic parameters. + + // // Look in the method cache to see if there is a generic method + // // on the incoming type that contains the correct signature. + // member = this.GetGenericMethodFromCache(name, parameterTypes, typeArguments, bindingFlags, null); + // } + + // if (member == null) + // { + // throw new ArgumentException( + // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // try + // { + // if (member.IsGenericMethodDefinition) + // { + // MethodInfo constructed = member.MakeGenericMethod(typeArguments); + // return constructed.Invoke(this.target, bindingFlags, null, args, culture); + // } + // else + // { + // return member.Invoke(this.target, bindingFlags, null, args, culture); + // } + // } + // catch (TargetInvocationException e) + // { + // Debug.Assert(e.InnerException != null, "Inner exception should not be null."); + // if (e.InnerException != null) + // { + // throw e.InnerException; + // } + + // throw; + // } + // } + // else + // { + // return this.InvokeHelper(name, bindingFlags | BindingFlags.InvokeMethod, args, culture); + // } + //} + + ///// + ///// Gets the array element using array of subsrcipts for each dimension + ///// + ///// Name of the member + ///// the indices of array + ///// An arrya of elements. + //public object GetArrayElement(string name, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.GetArrayElement(name, BindToEveryThing, indices); + //} + + ///// + ///// Sets the array element using array of subsrcipts for each dimension + ///// + ///// Name of the member + ///// Value to set + ///// the indices of array + //public void SetArrayElement(string name, object value, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.SetArrayElement(name, BindToEveryThing, value, indices); + //} + + ///// + ///// Gets the array element using array of subsrcipts for each dimension + ///// + ///// Name of the member + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// the indices of array + ///// An arrya of elements. + //public object GetArrayElement(string name, BindingFlags bindingFlags, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); + // return arr.GetValue(indices); + //} + + ///// + ///// Sets the array element using array of subsrcipts for each dimension + ///// + ///// Name of the member + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// Value to set + ///// the indices of array + //public void SetArrayElement(string name, BindingFlags bindingFlags, object value, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); + // arr.SetValue(value, indices); + //} + + ///// + ///// Get the field + ///// + ///// Name of the field + ///// The field. + //public object GetField(string name) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.GetField(name, BindToEveryThing); + //} + + ///// + ///// Sets the field + ///// + ///// Name of the field + ///// value to set + //public void SetField(string name, object value) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.SetField(name, BindToEveryThing, value); + //} + + ///// + ///// Gets the field + ///// + ///// Name of the field + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// The field. + //public object GetField(string name, BindingFlags bindingFlags) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); + //} + + ///// + ///// Sets the field + ///// + ///// Name of the field + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// value to set + //public void SetField(string name, BindingFlags bindingFlags, object value) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.InvokeHelper(name, BindingFlags.SetField | bindingFlags, new object[] { value }, CultureInfo.InvariantCulture); + //} + + /// + /// Get the field or property + /// + /// Name of the field or property + /// The field or property. + public object GetFieldOrProperty(string name) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + return this.GetFieldOrProperty(name, BindToEveryThing); + } + + /// + /// Sets the field or property + /// + /// Name of the field or property + /// value to set + public void SetFieldOrProperty(string name, object value) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + this.SetFieldOrProperty(name, BindToEveryThing, value); + } + + /// + /// Gets the field or property + /// + /// Name of the field or property + /// A bitmask comprised of one or more that specify how the search is conducted. + /// The field or property. + public object GetFieldOrProperty(string name, BindingFlags bindingFlags) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + return this.InvokeHelper(name, BindingFlags.GetField | BindingFlags.GetProperty | bindingFlags, null, CultureInfo.InvariantCulture); + } + + /// + /// Sets the field or property + /// + /// Name of the field or property + /// A bitmask comprised of one or more that specify how the search is conducted. + /// value to set + public void SetFieldOrProperty(string name, BindingFlags bindingFlags, object value) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + this.InvokeHelper(name, BindingFlags.SetField | BindingFlags.SetProperty | bindingFlags, new object[] {value}, CultureInfo.InvariantCulture); + } + + ///// + ///// Gets the property + ///// + ///// Name of the property + ///// Arguments to pass to the member to invoke. + ///// The property. + //public object GetProperty(string name, params object[] args) + //{ + // return this.GetProperty(name, null, args); + //} + + ///// + ///// Gets the property + ///// + ///// Name of the property + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + ///// The property. + //public object GetProperty(string name, Type[] parameterTypes, object[] args) + //{ + // return this.GetProperty(name, BindToEveryThing, parameterTypes, args); + //} + + ///// + ///// Set the property + ///// + ///// Name of the property + ///// value to set + ///// Arguments to pass to the member to invoke. + //public void SetProperty(string name, object value, params object[] args) + //{ + // this.SetProperty(name, null, value, args); + //} + + ///// + ///// Set the property + ///// + ///// Name of the property + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// value to set + ///// Arguments to pass to the member to invoke. + //public void SetProperty(string name, Type[] parameterTypes, object value, object[] args) + //{ + // this.SetProperty(name, BindToEveryThing, value, parameterTypes, args); + //} + + ///// + ///// Gets the property + ///// + ///// Name of the property + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// Arguments to pass to the member to invoke. + ///// The property. + //public object GetProperty(string name, BindingFlags bindingFlags, params object[] args) + //{ + // return this.GetProperty(name, bindingFlags, null, args); + //} + + ///// + ///// Gets the property + ///// + ///// Name of the property + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + ///// The property. + //public object GetProperty(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // if (parameterTypes != null) + // { + // PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null); + // if (pi == null) + // { + // throw new ArgumentException( + // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // return pi.GetValue(this.target, args); + // } + // else + // { + // return this.InvokeHelper(name, bindingFlags | BindingFlags.GetProperty, args, null); + // } + //} + + ///// + ///// Sets the property + ///// + ///// Name of the property + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// value to set + ///// Arguments to pass to the member to invoke. + //public void SetProperty(string name, BindingFlags bindingFlags, object value, params object[] args) + //{ + // this.SetProperty(name, bindingFlags, value, null, args); + //} + + ///// + ///// Sets the property + ///// + ///// Name of the property + ///// A bitmask comprised of one or more that specify how the search is conducted. + ///// value to set + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + //public void SetProperty(string name, BindingFlags bindingFlags, object value, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + + // if (parameterTypes != null) + // { + // PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null); + // if (pi == null) + // { + // throw new ArgumentException( + // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // pi.SetValue(this.target, value, args); + // } + // else + // { + // object[] pass = new object[(args?.Length ?? 0) + 1]; + // pass[0] = value; + // args?.CopyTo(pass, 1); + // this.InvokeHelper(name, bindingFlags | BindingFlags.SetProperty, pass, null); + // } + //} + + #region Private Helpers + + ///// + ///// Validate access string + ///// + ///// access string + //private static void ValidateAccessString(string access) + //{ + // Helper.CheckParameterNotNull(access, "access", string.Empty); + // if (access.Length == 0) + // { + // throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax); + // } + + // string[] arr = access.Split('.'); + // foreach (string str in arr) + // { + // if ((str.Length == 0) || (str.IndexOfAny(new char[] { ' ', '\t', '\n' }) != -1)) + // { + // throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax); + // } + // } + //} + + /// + /// Invokes the memeber + /// + /// Name of the member + /// Additional attributes + /// Arguments for the invocation + /// Culture + /// Result of the invocation + private object InvokeHelper(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + //Debug.Assert(this.target != null, "Internal Error: Null reference is returned for internal object"); + + // Invoke the actual Method + try + { + return this.originalType.InvokeMember(name, bindingFlags, null, this.target, args, culture); + } catch (TargetInvocationException e) + { + //Debug.Assert(e.InnerException != null, "Inner exception should not be null."); + if (e.InnerException != null) + { + throw e.InnerException; + } + + throw; + } + } + + private void ConstructFrom(object obj) + { + Helper.CheckParameterNotNull(obj, "obj", string.Empty); + this.target = obj; + this.originalType = obj.GetType(); + } + + //private void BuildGenericMethodCacheForType(Type t) + //{ + // Debug.Assert(t != null, "type should not be null."); + // this.methodCache = new Dictionary>(); + + // MethodInfo[] members = t.GetMethods(BindToEveryThing); + // LinkedList listByName; // automatically initialized to null + + // foreach (MethodInfo member in members) + // { + // if (member.IsGenericMethod || member.IsGenericMethodDefinition) + // { + // if (!this.GenericMethodCache.TryGetValue(member.Name, out listByName)) + // { + // listByName = new LinkedList(); + // this.GenericMethodCache.Add(member.Name, listByName); + // } + + // Debug.Assert(listByName != null, "list should not be null."); + // listByName.AddLast(member); + // } + // } + //} + + ///// + ///// Extracts the most appropriate generic method signature from the current private type. + ///// + ///// The name of the method in which to search the signature cache. + ///// An array of types corresponding to the types of the parameters in which to search. + ///// An array of types corresponding to the types of the generic arguments. + ///// to further filter the method signatures. + ///// Modifiers for parameters. + ///// A methodinfo instance. + //private MethodInfo GetGenericMethodFromCache(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers) + //{ + // Debug.Assert(!string.IsNullOrEmpty(methodName), "Invalid method name."); + // Debug.Assert(parameterTypes != null, "Invalid parameter type array."); + // Debug.Assert(typeArguments != null, "Invalid type arguments array."); + + // // Build a preliminary list of method candidates that contain roughly the same signature. + // var methodCandidates = this.GetMethodCandidates(methodName, parameterTypes, typeArguments, bindingFlags, modifiers); + + // // Search of ambiguous methods (methods with the same signature). + // MethodInfo[] finalCandidates = new MethodInfo[methodCandidates.Count]; + // methodCandidates.CopyTo(finalCandidates, 0); + + // if ((parameterTypes != null) && (parameterTypes.Length == 0)) + // { + // for (int i = 0; i < finalCandidates.Length; i++) + // { + // MethodInfo methodInfo = finalCandidates[i]; + + // if (!RuntimeTypeHelper.CompareMethodSigAndName(methodInfo, finalCandidates[0])) + // { + // throw new AmbiguousMatchException(); + // } + // } + + // // All the methods have the exact same name and sig so return the most derived one. + // return RuntimeTypeHelper.FindMostDerivedNewSlotMeth(finalCandidates, finalCandidates.Length) as MethodInfo; + // } + + // // Now that we have a preliminary list of candidates, select the most appropriate one. + // return RuntimeTypeHelper.SelectMethod(bindingFlags, finalCandidates, parameterTypes, modifiers) as MethodInfo; + //} + + //private LinkedList GetMethodCandidates(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers) + //{ + // Debug.Assert(!string.IsNullOrEmpty(methodName), "methodName should not be null."); + // Debug.Assert(parameterTypes != null, "parameterTypes should not be null."); + // Debug.Assert(typeArguments != null, "typeArguments should not be null."); + + // LinkedList methodCandidates = new LinkedList(); + // LinkedList methods = null; + + // if (!this.GenericMethodCache.TryGetValue(methodName, out methods)) + // { + // return methodCandidates; + // } + + // Debug.Assert(methods != null, "methods should not be null."); + + // foreach (MethodInfo candidate in methods) + // { + // bool paramMatch = true; + // ParameterInfo[] candidateParams = null; + // Type[] genericArgs = candidate.GetGenericArguments(); + // Type sourceParameterType = null; + + // if (genericArgs.Length != typeArguments.Length) + // { + // continue; + // } + + // // Since we can't just get the correct MethodInfo from Reflection, + // // we will just match the number of parameters, their order, and their type + // var methodCandidate = candidate; + // candidateParams = methodCandidate.GetParameters(); + + // if (candidateParams.Length != parameterTypes.Length) + // { + // continue; + // } + + // // Exact binding + // if ((bindingFlags & BindingFlags.ExactBinding) != 0) + // { + // int i = 0; + + // foreach (ParameterInfo candidateParam in candidateParams) + // { + // sourceParameterType = parameterTypes[i++]; + + // if (candidateParam.ParameterType.ContainsGenericParameters) + // { + // // Since we have a generic parameter here, just make sure the IsArray matches. + // if (candidateParam.ParameterType.IsArray != sourceParameterType.IsArray) + // { + // paramMatch = false; + // break; + // } + // } + // else + // { + // if (candidateParam.ParameterType != sourceParameterType) + // { + // paramMatch = false; + // break; + // } + // } + // } + + // if (paramMatch) + // { + // methodCandidates.AddLast(methodCandidate); + // continue; + // } + // } + // else + // { + // methodCandidates.AddLast(methodCandidate); + // } + // } + + // return methodCandidates; + //} + + #endregion + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs b/test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs new file mode 100644 index 00000000..f40cc727 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs @@ -0,0 +1,314 @@ +// +// Copyright (c) 2019 cactuaroid All Rights Reserved +// +// +// Released under the MIT license +// https://github.com/cactuaroid/PrivateObjectExtensions +// + +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; +using System.Reflection; + +namespace System +{ + /// + /// Extension methods for PrivateObject + /// + public static class PrivateObjectExtensions + { + private static readonly BindingFlags Static = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Static; + private static readonly BindingFlags Instance = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Instance; + + /// + /// Get from private (and any other) field/property. + /// If the real type of specified object doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The object to get from + /// The name of the field/property + /// The object got from the field/property + /// 'name' is not found. + /// Arguments contain null. + public static object GetPrivate(this object obj, string name) + { + if (obj == null) { throw new ArgumentNullException("obj"); } + + return GetPrivate(obj, name, obj.GetType(), null); + } + + /// + /// Get from private (and any other) field/property. + /// If the real type of specified object doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The type of the field/property + /// The object to get from + /// The name of the field/property + /// The object got from the field/property + /// 'name' is not found. + /// Arguments contain null. + public static T GetPrivate(this object obj, string name) + { + if (obj == null) { throw new ArgumentNullException("obj"); } + + return (T)GetPrivate(obj, name, obj.GetType(), typeof(T)); + } + + /// + /// Get from private (and any other) field/property with assuming the specified object as specified type. + /// If the specified type doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The object to get from + /// The name of the field/property + /// The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored. + /// The object got from the field/property + /// 'name' is not found. + /// 'objType' is not assignable from 'obj'. + /// Arguments contain null. + public static object GetPrivate(this object obj, string name, Type objType) + { + return GetPrivate(obj, name, objType, null); + } + + /// + /// Get from private (and any other) field/property with assuming the specified object as specified type. + /// If the specified type doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The type of the field/property + /// The object to get from + /// The name of the field/property + /// The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored. + /// The object got from the field/property + /// 'name' is not found. + /// 'objType' is not assignable from 'obj'. + /// Arguments contain null. + public static T GetPrivate(this object obj, string name, Type objType) + { + return (T)GetPrivate(obj, name, objType, typeof(T)); + } + + private static object GetPrivate(object obj, string name, Type objType, Type memberType) + { + if (obj == null) { throw new ArgumentNullException("obj"); } + if (name == null) { throw new ArgumentNullException("name"); } + if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } + if (objType == null) { throw new ArgumentNullException("objType"); } + if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); } + + bool memberTypeMatching(Type actualType) => actualType == memberType; + + if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType)) + { + return new PrivateObject(obj, new PrivateType(ownerType)).GetFieldOrProperty(name); + } + else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType)) + { + return new PrivateType(ownerType).GetStaticFieldOrProperty(name); + } + + throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found."); + } + + /// + /// Get from private (and any other) static field/property. + /// + /// The type to get from + /// The name of the static field/property + /// The object got from the static field/property + /// 'name' is not found. + /// Arguments contain null. + public static object GetPrivate(this Type type, string name) + { + return GetPrivate(type, name, null); + } + + /// + /// Get from private (and any other) static field/property. + /// + /// The type of the field/property + /// The type to get from + /// The name of the static field/property + /// The object got from the static field/property + /// 'name' is not found. + /// Arguments contain null. + public static T GetPrivate(this Type type, string name) + { + return (T)GetPrivate(type, name, typeof(T)); + } + + private static object GetPrivate(this Type type, string name, Type memberType) + { + if (type == null) { throw new ArgumentNullException("type"); } + if (name == null) { throw new ArgumentNullException("name"); } + if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } + + bool memberTypeMatching(Type actualType) => actualType == memberType; + + if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static)) + { + return new PrivateType(type).GetStaticFieldOrProperty(name); + } + + throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found."); + } + + /// + /// Set to private (and any other) field/property. + /// If the real type of specified object doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The object to set to + /// The name of the field/property + /// The value to set for 'name' + /// 'name' is not found. + /// Arguments contain null. + public static void SetPrivate(this object obj, string name, T value) + { + if (obj == null) { throw new ArgumentNullException("obj"); } + + SetPrivate(obj, name, value, obj.GetType()); + } + + /// + /// Set to private (and any other) field/property with assuming the specified object as specified type. + /// If the specified type doesn't contain the specified field/property, + /// base types are searched automatically. + /// + /// The object to set to + /// The name of the field/property + /// The value to set for 'name' + /// The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored. + /// 'name' is not found. + /// 'objType' is not assignable from 'obj'. + /// Arguments contain null. + public static void SetPrivate(this object obj, string name, T value, Type objType) + { + if (obj == null) { throw new ArgumentNullException("obj"); } + if (name == null) { throw new ArgumentNullException("name"); } + if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } + if (value == null) { throw new ArgumentNullException("value"); } + if (objType == null) { throw new ArgumentNullException("objType"); } + if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); } + + if (TrySetPrivate(obj, name, value, objType)) { return; } + + // retry for the case of getter only property + if (TrySetPrivate(obj, GetBackingFieldName(name), value, objType)) { return; } + + throw new ArgumentException($"{typeof(T)} {name} is not found."); + } + + private static bool TrySetPrivate(object obj, string name, T value, Type objType) + { + var memberType = typeof(T); + bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType); + + try + { + if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType)) + { + new PrivateObject(obj, new PrivateType(ownerType)).SetFieldOrProperty(name, value); + return true; + } + else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType)) + { + new PrivateType(ownerType).SetStaticFieldOrProperty(name, value); + return true; + } + } + catch(MissingMethodException) + { + // When getter only property name is given, the property is found but fails to set. + return false; + } + + return false; + } + + /// + /// Set to private (and any other) static field/property. + /// + /// The type to set to + /// The name of the field/property + /// The value to set for 'name' + /// 'name' is not found. + /// Arguments contain null. + public static void SetPrivate(this Type type, string name, T value) + { + if (type == null) { throw new ArgumentNullException("type"); } + if (name == null) { throw new ArgumentNullException("name"); } + if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } + + if (TrySetPrivate(type, name, value)) { return; } + + // retry for the case of getter only property + if (TrySetPrivate(type, GetBackingFieldName(name), value)) { return; } + + throw new ArgumentException($"{typeof(T)} {name} is not found."); + } + + private static bool TrySetPrivate(this Type type, string name, T value) + { + var memberType = typeof(T); + bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType); + + try + { + if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static)) + { + new PrivateType(type).SetStaticFieldOrProperty(name, value); + return true; + } + } + catch (MissingMethodException) + { + // When getter only property name is given, the property is found but fails to set. + return false; + } + + return false; + } + + private static string GetBackingFieldName(string propertyName) + => $"<{propertyName}>k__BackingField"; // generated backing field name + + private static bool TryFindFieldOrPropertyOwnerType(Type objType, string name, Type memberType, Func memberTypeMatching, BindingFlags bindingFlag, out Type ownerType) + { + ownerType = FindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, bindingFlag); + + return (ownerType != null); + } + + private static Type FindFieldOrPropertyOwnerType(Type objectType, string name, Type memberType, Func memberTypeMatching, BindingFlags bindingFlags) + { + if (objectType == null) { return null; } + + if (objectType.ContainsFieldOrProperty(name, memberType, memberTypeMatching, bindingFlags)) + { + return objectType; + } + + return FindFieldOrPropertyOwnerType(objectType.BaseType, name, memberType, memberTypeMatching, bindingFlags); + } + + private static bool ContainsFieldOrProperty(this Type objectType, string name, Type memberType, Func memberTypeMatching, BindingFlags bindingFlags) + { + var fields = objectType + .GetFields(bindingFlags) + .Select((x) => new { Type = x.FieldType, Member = x as MemberInfo }); + + var properties = objectType + .GetProperties(bindingFlags) + .Select((x) => new { Type = x.PropertyType, Member = x as MemberInfo }); + + var members = fields.Concat(properties); + + return members.Any((actual) => + (memberType == null || memberTypeMatching.Invoke(actual.Type)) + && actual.Member.Name == name); + } + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Utilities/PrivateType.cs b/test/TensorFlowNET.UnitTest/Utilities/PrivateType.cs new file mode 100644 index 00000000..a2d0b3c3 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Utilities/PrivateType.cs @@ -0,0 +1,572 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.VisualStudio.TestTools.UnitTesting +{ + using System; + //using System.Diagnostics; + using System.Globalization; + using System.Reflection; + + /// + /// This class represents a private class for the Private Accessor functionality. + /// + internal class PrivateType + { + /// + /// Binds to everything + /// + private const BindingFlags BindToEveryThing = BindingFlags.Default + | BindingFlags.NonPublic | BindingFlags.Instance + | BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy; + + /// + /// The wrapped type. + /// + private Type type; + + ///// + ///// Initializes a new instance of the class that contains the private type. + ///// + ///// Assembly name + ///// fully qualified name of the + //public PrivateType(string assemblyName, string typeName) + //{ + // Helper.CheckParameterNotNullOrEmpty(assemblyName, "assemblyName", string.Empty); + // Helper.CheckParameterNotNullOrEmpty(typeName, "typeName", string.Empty); + // Assembly asm = Assembly.Load(assemblyName); + + // this.type = asm.GetType(typeName, true); + //} + + /// + /// Initializes a new instance of the class that contains + /// the private type from the type object + /// + /// The wrapped Type to create. + public PrivateType(Type type) + { + if (type == null) + { + throw new ArgumentNullException("type"); + } + + this.type = type; + } + + /// + /// Gets the referenced type + /// + public Type ReferencedType => this.type; + + ///// + ///// Invokes static member + ///// + ///// Name of the member to InvokeHelper + ///// Arguements to the invoction + ///// Result of invocation + //public object InvokeStatic(string name, params object[] args) + //{ + // return this.InvokeStatic(name, null, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes static member + ///// + ///// Name of the member to InvokeHelper + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invoction + ///// Result of invocation + //public object InvokeStatic(string name, Type[] parameterTypes, object[] args) + //{ + // return this.InvokeStatic(name, parameterTypes, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes static member + ///// + ///// Name of the member to InvokeHelper + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invoction + ///// An array of types corresponding to the types of the generic arguments. + ///// Result of invocation + //public object InvokeStatic(string name, Type[] parameterTypes, object[] args, Type[] typeArguments) + //{ + // return this.InvokeStatic(name, BindToEveryThing, parameterTypes, args, CultureInfo.InvariantCulture, typeArguments); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Arguements to the invocation + ///// Culture + ///// Result of invocation + //public object InvokeStatic(string name, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, null, args, culture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Culture info + ///// Result of invocation + //public object InvokeStatic(string name, Type[] parameterTypes, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, BindingFlags.InvokeMethod, parameterTypes, args, culture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// Arguements to the invocation + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, params object[] args) + //{ + // return this.InvokeStatic(name, bindingFlags, null, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) + //{ + // return this.InvokeStatic(name, bindingFlags, parameterTypes, args, CultureInfo.InvariantCulture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// Arguements to the invocation + ///// Culture + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, bindingFlags, null, args, culture); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// /// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Culture + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture) + //{ + // return this.InvokeStatic(name, bindingFlags, parameterTypes, args, culture, null); + //} + + ///// + ///// Invokes the static method + ///// + ///// Name of the member + ///// Additional invocation attributes + ///// /// An array of objects representing the number, order, and type of the parameters for the method to invoke + ///// Arguements to the invocation + ///// Culture + ///// An array of types corresponding to the types of the generic arguments. + ///// Result of invocation + //public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture, Type[] typeArguments) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // if (parameterTypes != null) + // { + // MethodInfo member = this.type.GetMethod(name, bindingFlags | BindToEveryThing | BindingFlags.Static, null, parameterTypes, null); + // if (member == null) + // { + // throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // try + // { + // if (member.IsGenericMethodDefinition) + // { + // MethodInfo constructed = member.MakeGenericMethod(typeArguments); + // return constructed.Invoke(null, bindingFlags, null, args, culture); + // } + // else + // { + // return member.Invoke(null, bindingFlags, null, args, culture); + // } + // } + // catch (TargetInvocationException e) + // { + // Debug.Assert(e.InnerException != null, "Inner Exception should not be null."); + // if (e.InnerException != null) + // { + // throw e.InnerException; + // } + + // throw; + // } + // } + // else + // { + // return this.InvokeHelperStatic(name, bindingFlags | BindingFlags.InvokeMethod, args, culture); + // } + //} + + ///// + ///// Gets the element in static array + ///// + ///// Name of the array + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to get. For instance, to access a[10][11] the indices would be {10,11} + ///// + ///// element at the specified location + //public object GetStaticArrayElement(string name, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.GetStaticArrayElement(name, BindToEveryThing, indices); + //} + + ///// + ///// Sets the memeber of the static array + ///// + ///// Name of the array + ///// value to set + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to set. For instance, to access a[10][11] the array would be {10,11} + ///// + //public void SetStaticArrayElement(string name, object value, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.SetStaticArrayElement(name, BindToEveryThing, value, indices); + //} + + ///// + ///// Gets the element in satatic array + ///// + ///// Name of the array + ///// Additional InvokeHelper attributes + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to get. For instance, to access a[10][11] the array would be {10,11} + ///// + ///// element at the spcified location + //public object GetStaticArrayElement(string name, BindingFlags bindingFlags, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // Array arr = (Array)this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | bindingFlags, null, CultureInfo.InvariantCulture); + // return arr.GetValue(indices); + //} + + ///// + ///// Sets the memeber of the static array + ///// + ///// Name of the array + ///// Additional InvokeHelper attributes + ///// value to set + ///// + ///// A one-dimensional array of 32-bit integers that represent the indexes specifying + ///// the position of the element to set. For instance, to access a[10][11] the array would be {10,11} + ///// + //public void SetStaticArrayElement(string name, BindingFlags bindingFlags, object value, params int[] indices) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // Array arr = (Array)this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture); + // arr.SetValue(value, indices); + //} + + ///// + ///// Gets the static field + ///// + ///// Name of the field + ///// The static field. + //public object GetStaticField(string name) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.GetStaticField(name, BindToEveryThing); + //} + + ///// + ///// Sets the static field + ///// + ///// Name of the field + ///// Arguement to the invocation + //public void SetStaticField(string name, object value) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.SetStaticField(name, BindToEveryThing, value); + //} + + ///// + ///// Gets the static field using specified InvokeHelper attributes + ///// + ///// Name of the field + ///// Additional invocation attributes + ///// The static field. + //public object GetStaticField(string name, BindingFlags bindingFlags) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // return this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture); + //} + + ///// + ///// Sets the static field using binding attributes + ///// + ///// Name of the field + ///// Additional InvokeHelper attributes + ///// Arguement to the invocation + //public void SetStaticField(string name, BindingFlags bindingFlags, object value) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // this.InvokeHelperStatic(name, BindingFlags.SetField | bindingFlags | BindingFlags.Static, new[] { value }, CultureInfo.InvariantCulture); + //} + + /// + /// Gets the static field or property + /// + /// Name of the field or property + /// The static field or property. + public object GetStaticFieldOrProperty(string name) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + return this.GetStaticFieldOrProperty(name, BindToEveryThing); + } + + /// + /// Sets the static field or property + /// + /// Name of the field or property + /// Value to be set to field or property + public void SetStaticFieldOrProperty(string name, object value) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + this.SetStaticFieldOrProperty(name, BindToEveryThing, value); + } + + /// + /// Gets the static field or property using specified InvokeHelper attributes + /// + /// Name of the field or property + /// Additional invocation attributes + /// The static field or property. + public object GetStaticFieldOrProperty(string name, BindingFlags bindingFlags) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + return this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture); + } + + /// + /// Sets the static field or property using binding attributes + /// + /// Name of the field or property + /// Additional invocation attributes + /// Value to be set to field or property + public void SetStaticFieldOrProperty(string name, BindingFlags bindingFlags, object value) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + this.InvokeHelperStatic(name, BindingFlags.SetField | BindingFlags.SetProperty | bindingFlags | BindingFlags.Static, new[] {value}, CultureInfo.InvariantCulture); + } + + ///// + ///// Gets the static property + ///// + ///// Name of the field or property + ///// Arguements to the invocation + ///// The static property. + //public object GetStaticProperty(string name, params object[] args) + //{ + // return this.GetStaticProperty(name, BindToEveryThing, args); + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Value to be set to field or property + ///// Arguments to pass to the member to invoke. + //public void SetStaticProperty(string name, object value, params object[] args) + //{ + // this.SetStaticProperty(name, BindToEveryThing, value, null, args); + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Value to be set to field or property + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + //public void SetStaticProperty(string name, object value, Type[] parameterTypes, object[] args) + //{ + // this.SetStaticProperty(name, BindingFlags.SetProperty, value, parameterTypes, args); + //} + + ///// + ///// Gets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// Arguments to pass to the member to invoke. + ///// The static property. + //public object GetStaticProperty(string name, BindingFlags bindingFlags, params object[] args) + //{ + // return this.GetStaticProperty(name, BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, args); + //} + + ///// + ///// Gets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + ///// The static property. + //public object GetStaticProperty(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + // if (parameterTypes != null) + // { + // PropertyInfo pi = this.type.GetProperty(name, bindingFlags | BindingFlags.Static, null, null, parameterTypes, null); + // if (pi == null) + // { + // throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // return pi.GetValue(null, args); + // } + // else + // { + // return this.InvokeHelperStatic(name, bindingFlags | BindingFlags.GetProperty, args, null); + // } + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// Value to be set to field or property + ///// Optional index values for indexed properties. The indexes of indexed properties are zero-based. This value should be null for non-indexed properties. + //public void SetStaticProperty(string name, BindingFlags bindingFlags, object value, params object[] args) + //{ + // this.SetStaticProperty(name, bindingFlags, value, null, args); + //} + + ///// + ///// Sets the static property + ///// + ///// Name of the property + ///// Additional invocation attributes. + ///// Value to be set to field or property + ///// An array of objects representing the number, order, and type of the parameters for the indexed property. + ///// Arguments to pass to the member to invoke. + //public void SetStaticProperty(string name, BindingFlags bindingFlags, object value, Type[] parameterTypes, object[] args) + //{ + // Helper.CheckParameterNotNull(name, "name", string.Empty); + + // if (parameterTypes != null) + // { + // PropertyInfo pi = this.type.GetProperty(name, bindingFlags | BindingFlags.Static, null, null, parameterTypes, null); + // if (pi == null) + // { + // throw new ArgumentException( + // string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); + // } + + // pi.SetValue(null, value, args); + // } + // else + // { + // object[] pass = new object[(args?.Length ?? 0) + 1]; + // pass[0] = value; + // args?.CopyTo(pass, 1); + // this.InvokeHelperStatic(name, bindingFlags | BindingFlags.SetProperty, pass, null); + // } + //} + + /// + /// Invokes the static method + /// + /// Name of the member + /// Additional invocation attributes + /// Arguements to the invocation + /// Culture + /// Result of invocation + private object InvokeHelperStatic(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) + { + Helper.CheckParameterNotNull(name, "name", string.Empty); + try + { + return this.type.InvokeMember(name, bindingFlags | BindToEveryThing | BindingFlags.Static, null, null, args, culture); + } catch (TargetInvocationException e) + { + //Debug.Assert(e.InnerException != null, "Inner Exception should not be null."); + if (e.InnerException != null) + { + throw e.InnerException; + } + + throw; + } + } + } + + /// + /// The helper. + /// + internal static class Helper + { + /// + /// The check parameter not null. + /// + /// + /// The parameter. + /// + /// + /// The parameter name. + /// + /// + /// The message. + /// + /// Throws argument null exception when parameter is null. + internal static void CheckParameterNotNull(object param, string parameterName, string message) + { + if (param == null) + { + throw new ArgumentNullException(parameterName, message); + } + } + + /// + /// The check parameter not null or empty. + /// + /// + /// The parameter. + /// + /// + /// The parameter name. + /// + /// + /// The message. + /// + /// Throws ArgumentException when parameter is null. + //internal static void CheckParameterNotNullOrEmpty(string param, string parameterName, string message) + //{ + // if (string.IsNullOrEmpty(param)) + // { + // throw new ArgumentException(message, parameterName); + // } + //} + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 4d9d1059..e1a91560 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -29,7 +29,7 @@ namespace TensorFlowNET.UnitTest } /// - /// https://www.tf.org/api_docs/python/tf/variable_scope + /// https://www.tensorflow.org/api_docs/python/tf/variable_scope /// how to create a new variable /// [TestMethod] diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index 627d7c2f..988afa17 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -12,42 +12,51 @@ namespace TensorFlowNET.UnitTest { public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") { - var desc = c_api.TF_NewOperation(graph, "AddN", name); - - var inputs = new TF_Output[] + lock (Locks.ProcessWide) { - new TF_Output(l, 0), - new TF_Output(r, 0), - }; + var desc = c_api.TF_NewOperation(graph, "AddN", name); - c_api.TF_AddInputList(desc, inputs, inputs.Length); + var inputs = new TF_Output[] + { + new TF_Output(l, 0), + new TF_Output(r, 0), + }; - var op = c_api.TF_FinishOperation(desc, s); - s.Check(); + c_api.TF_AddInputList(desc, inputs, inputs.Length); - return op; + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); + + return op; + } } [SuppressMessage("ReSharper", "RedundantAssignment")] public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) { - using (var buffer = new Buffer()) + lock (Locks.ProcessWide) { - c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); - attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); - } + using (var buffer = new Buffer()) + { + c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); + attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } - return s.Code == TF_Code.TF_OK; + return s.Code == TF_Code.TF_OK; + } } public static GraphDef GetGraphDef(Graph graph) { - using (var s = new Status()) - using (var buffer = new Buffer()) + lock (Locks.ProcessWide) { - c_api.TF_GraphToGraphDef(graph, buffer, s); - s.Check(); - return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + using (var s = new Status()) + using (var buffer = new Buffer()) + { + c_api.TF_GraphToGraphDef(graph, buffer, s); + s.Check(); + return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } } } @@ -58,6 +67,7 @@ namespace TensorFlowNET.UnitTest { return false; } + bool found_t = false; bool found_n = false; foreach (var attr in node_def.Attr) @@ -67,19 +77,16 @@ namespace TensorFlowNET.UnitTest if (attr.Value.Type == DataType.DtInt32) { found_t = true; - } - else + } else { return false; } - } - else if (attr.Key == "N") + } else if (attr.Key == "N") { if (attr.Value.I == n) { found_n = true; - } - else + } else { return false; } @@ -92,7 +99,7 @@ namespace TensorFlowNET.UnitTest public static bool IsNeg(NodeDef node_def, string input) { return node_def.Op == "Neg" && node_def.Name == "neg" && - node_def.Input.Count == 1 && node_def.Input[0] == input; + node_def.Input.Count == 1 && node_def.Input[0] == input; } public static bool IsPlaceholder(NodeDef node_def) @@ -111,13 +118,11 @@ namespace TensorFlowNET.UnitTest if (attr.Value.Type == DataType.DtInt32) { found_dtype = true; - } - else + } else { return false; } - } - else if (attr.Key == "shape") + } else if (attr.Key == "shape") { found_shape = true; } @@ -132,72 +137,82 @@ namespace TensorFlowNET.UnitTest { return false; } + bool found_dtype = false; bool found_value = false; - foreach (var attr in node_def.Attr) { + foreach (var attr in node_def.Attr) + { if (attr.Key == "dtype") { if (attr.Value.Type == DataType.DtInt32) { found_dtype = true; - } - else + } else { return false; } - } - else if (attr.Key == "value") + } else if (attr.Key == "value") { if (attr.Value.Tensor != null && attr.Value.Tensor.IntVal.Count == 1 && attr.Value.Tensor.IntVal[0] == v) { found_value = true; - } - else + } else { return false; } } } + return found_dtype && found_value; } public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") { - OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); - var neg_input = new TF_Output(n, 0); - c_api.TF_AddInput(desc, neg_input); - var op = c_api.TF_FinishOperation(desc, s); - s.Check(); + lock (Locks.ProcessWide) + { + OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); + var neg_input = new TF_Output(n, 0); + c_api.TF_AddInput(desc, neg_input); + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); - return op; + return op; + } } public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) { - var desc = c_api.TF_NewOperation(graph, "Placeholder", name); - c_api.TF_SetAttrType(desc, "dtype", dtype); - if (dims != null) + lock (Locks.ProcessWide) { - c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); - } - var op = c_api.TF_FinishOperation(desc, s); - s.Check(); + var desc = c_api.TF_NewOperation(graph, "Placeholder", name); + c_api.TF_SetAttrType(desc, "dtype", dtype); + if (dims != null) + { + c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); + } + + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); - return op; + return op; + } } public static Operation Const(Tensor t, Graph graph, Status s, string name) { - var desc = c_api.TF_NewOperation(graph, "Const", name); - c_api.TF_SetAttrTensor(desc, "value", t, s); - s.Check(); - c_api.TF_SetAttrType(desc, "dtype", t.dtype); - var op = c_api.TF_FinishOperation(desc, s); - s.Check(); - - return op; + lock (Locks.ProcessWide) + { + var desc = c_api.TF_NewOperation(graph, "Const", name); + c_api.TF_SetAttrTensor(desc, "value", t, s); + s.Check(); + c_api.TF_SetAttrType(desc, "dtype", t.dtype); + var op = c_api.TF_FinishOperation(desc, s); + s.Check(); + + return op; + } } public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") @@ -205,4 +220,4 @@ namespace TensorFlowNET.UnitTest return Const(new Tensor(v), graph, s, name); } } -} +} \ No newline at end of file