diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 6dd8e25e..56b38846 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -119,13 +119,13 @@ namespace Tensorflow.Operations return null; } - /// - /// Notifies a scope about an operator added to an inner scope. - /// + /// + /// Notifies a scope about an operator added to an inner scope. + /// /// public virtual void AddInnerOp(Operation op) { - if (_outer_context != null) + if (_outer_context != null) _outer_context.AddInnerOp(op); } @@ -164,6 +164,12 @@ namespace Tensorflow.Operations var internal_control_inputs = op.control_inputs; } + public object to_proto() + { + throw new NotImplementedException(); + } + + public void Dispose() { } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs index 5bc34965..7fdd22f5 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs @@ -11,5 +11,6 @@ namespace Tensorflow HashSet values { get; } Tensor AddValue(Tensor val); void AddInnerOp(Operation resultOp); + object to_proto(); } } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs index a31819dc..d800679b 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -6,5 +6,14 @@ namespace Tensorflow.Operations { public class WhileContext : ControlFlowContext { + public static WhileContext from_proto(object proto) + { + throw new NotImplementedException(); + } + + public object to_proto() + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index dbb7a96e..11950b46 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -490,8 +490,13 @@ namespace Tensorflow } throw new NotImplementedException("ZerosLikeOutsideLoop"); - } - - + } + + // TODO + public static void while_loop(Func func, Func func1, Tensor[] tensors, int? i) + { + throw new NotImplementedException(); + } + } } diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index beb5e703..bfdfbd24 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -240,5 +240,11 @@ namespace Tensorflow if (_current_name_scope != null) _current_name_scope.Dispose(); } + + // TODO for Switch/Case + public static RefVariable get_variable(string embeddingMatrix, double[,] initializer, bool use_resource) + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs index 4f11a7a8..bd38fc77 100644 --- a/src/TensorFlowNET.Core/Variables/variables.py.cs +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -67,5 +67,10 @@ namespace Tensorflow else return gen_control_flow_ops.no_op(name: name); } + + public static Tensor global_variables_initializer() + { + throw new NotImplementedException(); + } } } diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 3761455e..fd24d7d1 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -95,9 +95,15 @@ namespace TensorFlowNET.UnitTest Assert.IsTrue(cond); } + + public void assertProtoEquals(object toProto, object o) + { + throw new NotImplementedException(); + } + #endregion - #region tensor evaluation + #region tensor evaluation and test session protected object _eval_helper(Tensor[] tensors) { @@ -166,6 +172,11 @@ namespace TensorFlowNET.UnitTest } + protected Session cached_session() + { + throw new NotImplementedException(); + } + //Returns a TensorFlow Session for use in executing tests. public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) { diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs index e86e133b..0e95fdc8 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs @@ -1,4 +1,5 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow; namespace TensorFlowNET.UnitTest.control_flow_ops_test @@ -14,24 +15,33 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test [TestMethod] public void testResourceReadInLoop() { - //def testResourceReadInLoop(self): - // embedding_matrix = variable_scope.get_variable( - // "embedding_matrix", initializer=[[2.0], [3.0]], use_resource=True) - // - // def cond(it, _): - // return it < 5 - // - // def body(it, cost): - // embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) - // cost += math_ops.reduce_sum(embedding) - // return it + 1, cost - // - // _, cost = control_flow_ops.while_loop( - // cond, body, [constant_op.constant(0), - // constant_op.constant(0.0)]) - // with self.cached_session(): - // self.evaluate(variables.global_variables_initializer()) - // self.assertAllEqual(10.0, self.evaluate(cost)) + + var embedding_matrix = variable_scope.get_variable( + "embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true); + + Tensor cond(Tensor it, Tensor _) + { + return it < 5; + } + + // TODO: below code doesn't compile + //(Tensor, Tensor) body(Tensor it, Tensor cost) + //{ + // var embedding = embedding_ops.embedding_lookup(embedding_matrix, new int[]{0}); + // cost += math_ops.reduce_sum(embedding); + // return (it + 1, cost); + //} + //var (_, cost1) = control_flow_ops.while_loop( + // cond, body, new[] + // { + // constant_op.constant(0), + // constant_op.constant(0.0) + // }); + //with(this.cached_session(), sess => + //{ + // self.evaluate(variables.global_variables_initializer()); + // self.assertAllEqual(10.0, self.evaluate(cost1)); + //}); } @@ -49,7 +59,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: true); } - private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource= false) + private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource = false) { //def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False): // embedding_matrix = variable_scope.get_variable( diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs new file mode 100644 index 00000000..78927092 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs @@ -0,0 +1,49 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using Tensorflow.Operations; + +namespace TensorFlowNET.UnitTest.control_flow_ops_test +{ + [TestClass] + public class WhileContextTestCase : PythonTest + { + private void _testWhileContextHelper(int? maximum_iterations = null) + { + // TODO: implement missing code dependencies + with(this.cached_session(), sess => + { + var i = constant_op.constant(0, name: "i"); + var c = new Func(x => gen_math_ops.less(x, 10, name: "c")); + var b = new Func(x => gen_math_ops.add(x, 1, name: "c")); + control_flow_ops.while_loop( + c, b, new[] { i }, maximum_iterations = maximum_iterations); + foreach (Operation op in sess.graph.get_operations()) + { + var control_flow_context = op._get_control_flow_context(); + if (control_flow_context != null) + self.assertProtoEquals(control_flow_context.to_proto(), + WhileContext.from_proto( + control_flow_context.to_proto()).to_proto()); + } + }); + } + + [Ignore("TODO")] + [TestMethod] + public void testWhileContext() + { + _testWhileContextHelper(); + } + + [Ignore("TODO")] + [TestMethod] + public void testWhileContextWithMaximumIterations() + { + _testWhileContextHelper(maximum_iterations: 10); + } + } +}