diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
index d6fda591..9c2881e1 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
@@ -30,15 +30,15 @@ namespace Tensorflow
/// A list of control inputs for the op to be created.
private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops)
{
- var ret = new ITensorOrOperation[0];
+ var ret = new List();
- foreach(var controller in _control_dependencies_stack)
+ foreach (var controller in _control_dependencies_stack)
{
bool dominated = false;
// If any of the input_ops already depends on the inputs from controller,
// we say that the new op is dominated (by that input), and we therefore
// do not need to add control dependencies for this controller's inputs.
- foreach(var op in input_ops)
+ foreach (var op in input_ops)
{
if (controller.op_in_group(op))
{
@@ -48,12 +48,22 @@ namespace Tensorflow
}
if (!dominated)
- ret = controller.control_inputs.Where(x => !input_ops.Contains(x)).ToArray();
+ ret.AddRange(controller.control_inputs.Where(x => !input_ops.Contains(x)));
}
- return ret;
+ return ret.ToArray();
}
+ ///
+ /// Returns a context manager that specifies control dependencies.
+ ///
+ /// Use with the `with` keyword to specify that all operations constructed
+ /// within the context should have control dependencies on
+ /// `control_inputs`.
+ ///
+ public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
+ => control_dependencies(control_inputs == null ? null : control_inputs.OfType().ToArray());
+
///
/// Returns a context manager that specifies control dependencies.
///
@@ -61,7 +71,7 @@ namespace Tensorflow
/// within the context should have control dependencies on
/// `control_inputs`.
///
- public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
+ public _ControlDependenciesController control_dependencies(object[] control_inputs)
{
if (control_inputs == null)
return new _ControlDependenciesController(this, null);
@@ -69,9 +79,26 @@ namespace Tensorflow
var control_ops = new List();
foreach (var c in control_inputs)
{
- control_ops.Add(c);
+ switch (c)
+ {
+ // TODO: implement IndexedSlices
+ //case IndexedSlices islice:
+ // control_ops.Add(islice.op);
+ // break;
+ case Tensor t:
+ control_ops.Add(t.op);
+ break;
+ case Operation op:
+ control_ops.Add(op);
+ break;
+ default:
+ var t1 = _as_graph_element(c);
+ if (t1 == null)
+ throw new TypeError($"Control input must be Operation or Tensor:{c}");
+ control_ops.Add(t1.op);
+ break;
+ }
}
-
return new _ControlDependenciesController(this, control_ops);
}
@@ -103,6 +130,9 @@ namespace Tensorflow
_control_dependencies_stack.Dequeue();
}
+ ///
+ /// Record that the given op depends on all registered control dependencies.
+ ///
public void _record_op_seen_by_control_dependencies(Operation op)
{
foreach (var controller in _control_dependencies_stack)
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
index 883f2b64..2d099292 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
@@ -21,8 +21,14 @@ namespace Tensorflow
public OperationDescription NewOperation(string opType, string opName)
{
return c_api.TF_NewOperation(_handle, opType, opName);
- }
-
+ }
+
+ ///
+ /// Returns the `Operation` with the given `name`.
+ ///
+ /// This method may be called concurrently from multiple threads.
+ ///
+ /// The name of the `Operation` to return.
public Operation get_operation_by_name(string name)
=> as_graph_element(name, allow_tensor: false, allow_operation: true) as Operation;
diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
index 3887d2a1..8def1417 100644
--- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
+++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
@@ -17,8 +17,29 @@ namespace Tensorflow
private bool _new_stack;
private IControlFlowContext _old_control_flow_context;
- public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray();
-
+ public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray();
+
+ ///
+ /// Create a new `_ControlDependenciesController`.
+ ///
+ /// A `_ControlDependenciesController` is the context manager for
+ /// `with tf.control_dependencies()` blocks.These normally nest,
+ /// as described in the documentation for `control_dependencies()`.
+ ///
+ /// The `control_inputs` argument list control dependencies that must be
+ /// added to the current set of control dependencies.Because of
+ /// uniquification the set can be empty even if the caller passed a list of
+ /// ops.The special value `None` indicates that we want to start a new
+ /// empty set of control dependencies instead of extending the current set.
+ ///
+ /// In that case we also clear the current control flow context, which is an
+ /// additional mechanism to add control dependencies.
+ ///
+ /// The graph that this controller is managing.
+ /// List of ops to use as control inputs in addition
+ /// to the current control dependencies.None to indicate that
+ /// the dependencies should be cleared.
+ ///
public _ControlDependenciesController(Graph graph, List control_inputs)
{
_graph = graph;
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
index 5db34ce9..9ef89271 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
@@ -1,68 +1,79 @@
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Runtime.InteropServices;
-using System.Text;
-
-namespace Tensorflow
-{
- public partial class Operation
- {
- public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index));
- public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index));
- public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
- public int NumInputs => c_api.TF_OperationNumInputs(_handle);
- private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray();
-
- private InputList _inputs;
- public InputList inputs
- {
- get
- {
- if (_inputs == null)
- {
- var retval = new Tensor[NumInputs];
-
- for (int i = 0; i < NumInputs; i++)
- {
- var tf_outpus = Input(i);
- var op = new Operation(tf_outpus.oper);
- retval[i] = op.outputs[tf_outpus.index];
- }
-
- _inputs = new InputList(retval);
- }
-
- return _inputs;
- }
- }
-
- public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
-
- public Operation[] control_inputs
- {
- get
- {
- return GetControlInputs();
- }
- }
-
- public unsafe Operation[] GetControlInputs()
- {
- var control_inputs = new Operation[NumControlInputs];
-
- if (NumControlInputs > 0)
- {
- IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlInputs);
- c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs);
- for (int i = 0; i < NumControlInputs; i++)
- {
- var handle = control_input_handle + Marshal.SizeOf() * i;
- control_inputs[i] = new Operation(*(IntPtr*)handle);
- }
- }
-
- return control_inputs;
- }
- }
-}
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Text;
+
+namespace Tensorflow
+{
+
+ // from ops.py
+ public partial class Operation
+ {
+ public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index));
+ public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index));
+ public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
+ public int NumInputs => c_api.TF_OperationNumInputs(_handle);
+ private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray();
+
+ private InputList _inputs;
+ public InputList inputs
+ {
+ get
+ {
+ if (_inputs == null)
+ {
+ var retval = new Tensor[NumInputs];
+
+ for (int i = 0; i < NumInputs; i++)
+ {
+ var tf_outpus = Input(i);
+ var op = new Operation(tf_outpus.oper);
+ retval[i] = op.outputs[tf_outpus.index];
+ }
+
+ _inputs = new InputList(retval);
+ }
+
+ return _inputs;
+ }
+ }
+
+ public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
+
+ ///
+ /// The `Operation` objects on which this op has a control dependency.
+ ///
+ /// Before this op is executed, TensorFlow will ensure that the
+ /// operations in `self.control_inputs` have finished executing.This
+ /// mechanism can be used to run ops sequentially for performance
+ /// reasons, or to ensure that the side effects of an op are observed
+ /// in the correct order.
+ ///
+ public Operation[] control_inputs
+ {
+ get
+ {
+ return GetControlInputs();
+ }
+ }
+
+ public unsafe Operation[] GetControlInputs()
+ {
+ var control_inputs = new Operation[NumControlInputs];
+
+ if (NumControlInputs > 0)
+ {
+ IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlInputs);
+ c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs);
+ for (int i = 0; i < NumControlInputs; i++)
+ {
+ var handle = control_input_handle + Marshal.SizeOf() * i;
+ control_inputs[i] = new Operation(*(IntPtr*)handle);
+ }
+ }
+
+ return control_inputs;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index 65b2ab6a..f4d8727e 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -45,7 +45,6 @@ Bug memory leak issue when allocating Tensor.
-
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index ff41e261..add752ea 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -119,11 +119,14 @@ namespace Tensorflow
/// A context manager that specifies control dependencies for all
/// operations constructed within the context.
///
- public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
+ public static _ControlDependenciesController control_dependencies(object[] control_inputs)
{
return get_default_graph().control_dependencies(control_inputs);
}
+ public static _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
+ => control_dependencies(control_inputs == null ? null : control_inputs.OfType().ToArray());
+
///
/// Creates a TF_Operation.
///
diff --git a/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs
index 5146ae57..3be4e80e 100644
--- a/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs
+++ b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs
@@ -23,7 +23,7 @@ namespace TensorFlowNET.UnitTest
{
a = constant_op.constant(1.0);
b = constant_op.constant(1.0);
- with(g.control_dependencies(new ITensorOrOperation[] { a }), x =>
+ with(g.control_dependencies(new[] { a }), x =>
{
c = constant_op.constant(1.0);
d = array_ops.identity(b);
@@ -36,15 +36,15 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(0, e.op.control_inputs.Length);
}
- [Ignore("Part of this test is not compiling")]
+ [Ignore("Future is not supported yet")]
[TestMethod]
public void TestEager()
{
- Tensor a = null, b = null, c = null, d = null, e = null;
+ Tensor a = null, c = null, d = null, e = null;
+ object b = null;
var calls = 0;
Func future = () =>
{
-
calls += 1;
return constant_op.constant(2.0);
};
@@ -55,26 +55,26 @@ namespace TensorFlowNET.UnitTest
if (context.executing_eagerly())
{
// TODO: make this compile (see original Python code below)
- //a = constant_op.constant(1.0);
- //b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well.
- //with(ops.control_dependencies(new Operation[] {a, b}), ctrl =>
- //{
- // return c = constant_op.constant(3.0);
- //});
- //Assert.AreEqual(calls, 1);
+ a = constant_op.constant(1.0);
+ b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well.
+ with(ops.control_dependencies(new object[] { a, b }), ctrl =>
+ {
+ return c = constant_op.constant(3.0);
+ });
+ Assert.AreEqual(calls, 1);
}
else
{
- var graph = tf.Graph();
- with(graph.as_default(), g =>
+ var graph = tf.Graph().as_default();
+ with(graph, g =>
{
a = constant_op.constant(1.0);
- b = future();
- with(g.control_dependencies(new ITensorOrOperation[] { a, b }), ctrl =>
- {
- c = constant_op.constant(3.0);
- });
- Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b.op }));
+ var b1 = future();
+ with(g.control_dependencies(new [] { a, b}), ctrl =>
+ {
+ c = constant_op.constant(3.0);
+ });
+ Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b1.op }));
Assert.AreEqual(1, calls);
});
@@ -106,100 +106,107 @@ namespace TensorFlowNET.UnitTest
}
- // Note: {henon}, all tests below use the function _apply_op which is not really portable in C#, see original source below
- // but I think _apply_op(...) can just be replaced by g.create_op(...).
- /*
-def _apply_op(g, *args, **kwargs):
- op = g.create_op(*args, **kwargs)
- if len(op.outputs) == 1:
- return op.outputs[0]
- else:
- return op.outputs
- */
-
-
- [Ignore("")]
+ [Ignore("How to port the ConvertibleObj?")]
[TestMethod]
public void TestBasicWithConversion()
{
- var g = ops.get_default_graph();
+ var g = tf.Graph().as_default();
// Note: _apply_op can be replaced by g.create_op
var a = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
// TODO: ConvertibleObj, see original source below
/*
- def testBasicWithConversion(self):
- g = ops.Graph()
- a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ def testBasicWithConversion(self):
+ g = ops.Graph()
+ a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
- class ConvertibleObj(object):
+ class ConvertibleObj(object):
- def _as_graph_element(self):
- return a
+ def _as_graph_element(self):
+ return a
- with g.control_dependencies([ConvertibleObj()]):
- c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ with g.control_dependencies([ConvertibleObj()]):
+ c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
- self.assertEqual(c.op.control_inputs, [a.op])
+ self.assertEqual(c.op.control_inputs, [a.op])
*/
}
- [Ignore("Fails with message: Op type not registered 'FloatOutput' in binary running on ...")]
[TestMethod]
public void TestNested()
{
- var g = ops.get_default_graph();
- var a_1 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
- var a_2 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
- var a_3 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
- var a_4 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
+ var g = tf.Graph().as_default();
+ var a_1 = constant_op.constant(1.0);
+ var a_2 = constant_op.constant(3.0);
+ var a_3 = constant_op.constant(4.0);
+ var a_4 = constant_op.constant(5.0);
Operation b_1 = null, b_2 = null;
- with(g.control_dependencies(new ITensorOrOperation[] { a_1, a_2, a_3, a_4 }), ctrl =>
- {
- b_1 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
- });
- with(g.control_dependencies(new ITensorOrOperation[] { a_1 }), ctrl1 =>
- {
- with(g.control_dependencies(new ITensorOrOperation[] { a_2 }), ctrl2 =>
- {
- with(g.control_dependencies(new ITensorOrOperation[] { a_3 }), ctrl3 =>
- {
- with(g.control_dependencies(new ITensorOrOperation[] { a_4 }), ctrl4 =>
- {
- b_2 = g.create_op("FloatOutput", new Tensor[] { }, new[] { TF_DataType.TF_FLOAT });
- });
- });
- });
- });
- AssertItemsEqual(new[] {a_1.op, a_2.op, a_3.op, a_4.op}, b_1.op.control_inputs);
+ with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl =>
+ {
+ b_1 = constant_op.constant(6.0);
+ });
+ with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
+ {
+ with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
+ {
+ with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
+ {
+ with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
+ {
+ b_2 = constant_op.constant(7.0);
+ });
+ });
+ });
+ });
+ AssertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs);
AssertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
- /*
-def testNested(self):
-g = ops.Graph()
-a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
-a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
-a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
-a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
-
-with g.control_dependencies([a_1, a_2, a_3, a_4]):
- b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
-
-with g.control_dependencies([a_1]):
- with g.control_dependencies([a_2]):
- with g.control_dependencies([a_3]):
- with g.control_dependencies([a_4]):
- b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
-
-self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
- b_1.op.control_inputs)
-self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
- */
}
-
- [Ignore("will fail due to unsupported op 'FloatOutput'")]
+ [Ignore("Fails")]
[TestMethod]
public void TestClear()
{
+ var g = tf.Graph().as_default();
+ var a_1 = constant_op.constant(1.0);
+ var a_2 = constant_op.constant(3.0);
+ var a_3 = constant_op.constant(4.0);
+ var a_4 = constant_op.constant(5.0);
+ Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null;
+ with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
+ {
+ with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
+ {
+ with(g.control_dependencies(null), ctrl3 =>
+ {
+ with(g.control_dependencies(new[] { a_3 }), ctrl4 =>
+ {
+ with(g.control_dependencies(new[] { a_4 }), ctrl5 =>
+ {
+ // deps [a_3, a_4]
+ b_3_4 = constant_op.constant(7.0);
+ });
+ // deps = [a_3]
+ b_3 = constant_op.constant(8.0);
+ });
+ // deps back to None
+ b_none = constant_op.constant(9.0);
+ });
+ // deps back to [a_1, a_2]
+ b_1_2 = constant_op.constant(10.0);
+ });
+ // deps back to [a_1]
+ b_1 = constant_op.constant(11.0);
+ with(g.control_dependencies(null), ctrl6 =>
+ {
+ // deps are None again
+ b_none2 = constant_op.constant(12.0);
+ });
+ });
+ AssertItemsEqual(new[] {a_3.op, a_4.op}, b_3_4.op.control_inputs);
+ AssertItemsEqual(new[] {a_3.op}, b_3.op.control_inputs);
+ AssertItemsEqual(new object[0], b_none.op.control_inputs);
+ AssertItemsEqual(new[] {a_1.op, a_2.op}, b_1_2.op.control_inputs);
+ AssertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
+ AssertItemsEqual(new object[0], b_none2.op.control_inputs);
/*
def testClear(self):
g = ops.Graph()