diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
index 6dd8e25e..56b38846 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
@@ -119,13 +119,13 @@ namespace Tensorflow.Operations
return null;
}
- ///
- /// Notifies a scope about an operator added to an inner scope.
- ///
+ ///
+ /// Notifies a scope about an operator added to an inner scope.
+ ///
///
public virtual void AddInnerOp(Operation op)
{
- if (_outer_context != null)
+ if (_outer_context != null)
_outer_context.AddInnerOp(op);
}
@@ -164,6 +164,12 @@ namespace Tensorflow.Operations
var internal_control_inputs = op.control_inputs;
}
+ public object to_proto()
+ {
+ throw new NotImplementedException();
+ }
+
+
public void Dispose()
{
}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
index 5bc34965..7fdd22f5 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
@@ -11,5 +11,6 @@ namespace Tensorflow
HashSet values { get; }
Tensor AddValue(Tensor val);
void AddInnerOp(Operation resultOp);
+ object to_proto();
}
}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
index a31819dc..d800679b 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
@@ -6,5 +6,14 @@ namespace Tensorflow.Operations
{
public class WhileContext : ControlFlowContext
{
+ public static WhileContext from_proto(object proto)
+ {
+ throw new NotImplementedException();
+ }
+
+ public object to_proto()
+ {
+ throw new NotImplementedException();
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
index dbb7a96e..11950b46 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
@@ -490,8 +490,13 @@ namespace Tensorflow
}
throw new NotImplementedException("ZerosLikeOutsideLoop");
- }
-
-
+ }
+
+ // TODO
+ public static void while_loop(Func func, Func func1, Tensor[] tensors, int? i)
+ {
+ throw new NotImplementedException();
+ }
+
}
}
diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs
index beb5e703..bfdfbd24 100644
--- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs
+++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs
@@ -240,5 +240,11 @@ namespace Tensorflow
if (_current_name_scope != null)
_current_name_scope.Dispose();
}
+
+ // TODO for Switch/Case
+ public static RefVariable get_variable(string embeddingMatrix, double[,] initializer, bool use_resource)
+ {
+ throw new NotImplementedException();
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs
index 4f11a7a8..bd38fc77 100644
--- a/src/TensorFlowNET.Core/Variables/variables.py.cs
+++ b/src/TensorFlowNET.Core/Variables/variables.py.cs
@@ -67,5 +67,10 @@ namespace Tensorflow
else
return gen_control_flow_ops.no_op(name: name);
}
+
+ public static Tensor global_variables_initializer()
+ {
+ throw new NotImplementedException();
+ }
}
}
diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs
index 3761455e..fd24d7d1 100644
--- a/test/TensorFlowNET.UnitTest/PythonTest.cs
+++ b/test/TensorFlowNET.UnitTest/PythonTest.cs
@@ -95,9 +95,15 @@ namespace TensorFlowNET.UnitTest
Assert.IsTrue(cond);
}
+
+ public void assertProtoEquals(object toProto, object o)
+ {
+ throw new NotImplementedException();
+ }
+
#endregion
- #region tensor evaluation
+ #region tensor evaluation and test session
protected object _eval_helper(Tensor[] tensors)
{
@@ -166,6 +172,11 @@ namespace TensorFlowNET.UnitTest
}
+ protected Session cached_session()
+ {
+ throw new NotImplementedException();
+ }
+
//Returns a TensorFlow Session for use in executing tests.
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
{
diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs
index e86e133b..0e95fdc8 100644
--- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs
+++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs
@@ -1,4 +1,5 @@
-using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
namespace TensorFlowNET.UnitTest.control_flow_ops_test
@@ -14,24 +15,33 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
[TestMethod]
public void testResourceReadInLoop()
{
- //def testResourceReadInLoop(self):
- // embedding_matrix = variable_scope.get_variable(
- // "embedding_matrix", initializer=[[2.0], [3.0]], use_resource=True)
- //
- // def cond(it, _):
- // return it < 5
- //
- // def body(it, cost):
- // embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
- // cost += math_ops.reduce_sum(embedding)
- // return it + 1, cost
- //
- // _, cost = control_flow_ops.while_loop(
- // cond, body, [constant_op.constant(0),
- // constant_op.constant(0.0)])
- // with self.cached_session():
- // self.evaluate(variables.global_variables_initializer())
- // self.assertAllEqual(10.0, self.evaluate(cost))
+
+ var embedding_matrix = variable_scope.get_variable(
+ "embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true);
+
+ Tensor cond(Tensor it, Tensor _)
+ {
+ return it < 5;
+ }
+
+ // TODO: below code doesn't compile
+ //(Tensor, Tensor) body(Tensor it, Tensor cost)
+ //{
+ // var embedding = embedding_ops.embedding_lookup(embedding_matrix, new int[]{0});
+ // cost += math_ops.reduce_sum(embedding);
+ // return (it + 1, cost);
+ //}
+ //var (_, cost1) = control_flow_ops.while_loop(
+ // cond, body, new[]
+ // {
+ // constant_op.constant(0),
+ // constant_op.constant(0.0)
+ // });
+ //with(this.cached_session(), sess =>
+ //{
+ // self.evaluate(variables.global_variables_initializer());
+ // self.assertAllEqual(10.0, self.evaluate(cost1));
+ //});
}
@@ -49,7 +59,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: true);
}
- private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource= false)
+ private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource = false)
{
//def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
// embedding_matrix = variable_scope.get_variable(
diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs
new file mode 100644
index 00000000..78927092
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs
@@ -0,0 +1,49 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Tensorflow;
+using Tensorflow.Operations;
+
+namespace TensorFlowNET.UnitTest.control_flow_ops_test
+{
+ [TestClass]
+ public class WhileContextTestCase : PythonTest
+ {
+ private void _testWhileContextHelper(int? maximum_iterations = null)
+ {
+ // TODO: implement missing code dependencies
+ with(this.cached_session(), sess =>
+ {
+ var i = constant_op.constant(0, name: "i");
+ var c = new Func(x => gen_math_ops.less(x, 10, name: "c"));
+ var b = new Func(x => gen_math_ops.add(x, 1, name: "c"));
+ control_flow_ops.while_loop(
+ c, b, new[] { i }, maximum_iterations = maximum_iterations);
+ foreach (Operation op in sess.graph.get_operations())
+ {
+ var control_flow_context = op._get_control_flow_context();
+ if (control_flow_context != null)
+ self.assertProtoEquals(control_flow_context.to_proto(),
+ WhileContext.from_proto(
+ control_flow_context.to_proto()).to_proto());
+ }
+ });
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testWhileContext()
+ {
+ _testWhileContextHelper();
+ }
+
+ [Ignore("TODO")]
+ [TestMethod]
+ public void testWhileContextWithMaximumIterations()
+ {
+ _testWhileContextHelper(maximum_iterations: 10);
+ }
+ }
+}