| @@ -119,13 +119,13 @@ namespace Tensorflow.Operations | |||
| return null; | |||
| } | |||
| /// <summary> | |||
| /// Notifies a scope about an operator added to an inner scope. | |||
| /// </summary> | |||
| /// <summary> | |||
| /// Notifies a scope about an operator added to an inner scope. | |||
| /// </summary> | |||
| /// <param name="op"></param> | |||
| public virtual void AddInnerOp(Operation op) | |||
| { | |||
| if (_outer_context != null) | |||
| if (_outer_context != null) | |||
| _outer_context.AddInnerOp(op); | |||
| } | |||
| @@ -164,6 +164,12 @@ namespace Tensorflow.Operations | |||
| var internal_control_inputs = op.control_inputs; | |||
| } | |||
| public object to_proto() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| } | |||
| @@ -11,5 +11,6 @@ namespace Tensorflow | |||
| HashSet<string> values { get; } | |||
| Tensor AddValue(Tensor val); | |||
| void AddInnerOp(Operation resultOp); | |||
| object to_proto(); | |||
| } | |||
| } | |||
| @@ -6,5 +6,14 @@ namespace Tensorflow.Operations | |||
| { | |||
| public class WhileContext : ControlFlowContext | |||
| { | |||
| public static WhileContext from_proto(object proto) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public object to_proto() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| @@ -490,8 +490,13 @@ namespace Tensorflow | |||
| } | |||
| throw new NotImplementedException("ZerosLikeOutsideLoop"); | |||
| } | |||
| } | |||
| // TODO | |||
| public static void while_loop(Func<Tensor, Tensor> func, Func<Tensor, Tensor> func1, Tensor[] tensors, int? i) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| @@ -240,5 +240,11 @@ namespace Tensorflow | |||
| if (_current_name_scope != null) | |||
| _current_name_scope.Dispose(); | |||
| } | |||
| // TODO for Switch/Case | |||
| public static RefVariable get_variable(string embeddingMatrix, double[,] initializer, bool use_resource) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| @@ -67,5 +67,10 @@ namespace Tensorflow | |||
| else | |||
| return gen_control_flow_ops.no_op(name: name); | |||
| } | |||
| public static Tensor global_variables_initializer() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| @@ -95,9 +95,15 @@ namespace TensorFlowNET.UnitTest | |||
| Assert.IsTrue(cond); | |||
| } | |||
| public void assertProtoEquals(object toProto, object o) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| #endregion | |||
| #region tensor evaluation | |||
| #region tensor evaluation and test session | |||
| protected object _eval_helper(Tensor[] tensors) | |||
| { | |||
| @@ -166,6 +172,11 @@ namespace TensorFlowNET.UnitTest | |||
| } | |||
| protected Session cached_session() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| //Returns a TensorFlow Session for use in executing tests. | |||
| public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) | |||
| { | |||
| @@ -1,4 +1,5 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| @@ -14,24 +15,33 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| [TestMethod] | |||
| public void testResourceReadInLoop() | |||
| { | |||
| //def testResourceReadInLoop(self): | |||
| // embedding_matrix = variable_scope.get_variable( | |||
| // "embedding_matrix", initializer=[[2.0], [3.0]], use_resource=True) | |||
| // | |||
| // def cond(it, _): | |||
| // return it < 5 | |||
| // | |||
| // def body(it, cost): | |||
| // embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) | |||
| // cost += math_ops.reduce_sum(embedding) | |||
| // return it + 1, cost | |||
| // | |||
| // _, cost = control_flow_ops.while_loop( | |||
| // cond, body, [constant_op.constant(0), | |||
| // constant_op.constant(0.0)]) | |||
| // with self.cached_session(): | |||
| // self.evaluate(variables.global_variables_initializer()) | |||
| // self.assertAllEqual(10.0, self.evaluate(cost)) | |||
| var embedding_matrix = variable_scope.get_variable( | |||
| "embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true); | |||
| Tensor cond(Tensor it, Tensor _) | |||
| { | |||
| return it < 5; | |||
| } | |||
| // TODO: below code doesn't compile | |||
| //(Tensor, Tensor) body(Tensor it, Tensor cost) | |||
| //{ | |||
| // var embedding = embedding_ops.embedding_lookup(embedding_matrix, new int[]{0}); | |||
| // cost += math_ops.reduce_sum(embedding); | |||
| // return (it + 1, cost); | |||
| //} | |||
| //var (_, cost1) = control_flow_ops.while_loop( | |||
| // cond, body, new[] | |||
| // { | |||
| // constant_op.constant(0), | |||
| // constant_op.constant(0.0) | |||
| // }); | |||
| //with<Session>(this.cached_session(), sess => | |||
| //{ | |||
| // self.evaluate(variables.global_variables_initializer()); | |||
| // self.assertAllEqual(10.0, self.evaluate(cost1)); | |||
| //}); | |||
| } | |||
| @@ -49,7 +59,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: true); | |||
| } | |||
| private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource= false) | |||
| private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource = false) | |||
| { | |||
| //def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False): | |||
| // embedding_matrix = variable_scope.get_variable( | |||
| @@ -0,0 +1,49 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using Tensorflow.Operations; | |||
| namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| { | |||
| [TestClass] | |||
| public class WhileContextTestCase : PythonTest | |||
| { | |||
| private void _testWhileContextHelper(int? maximum_iterations = null) | |||
| { | |||
| // TODO: implement missing code dependencies | |||
| with<Session>(this.cached_session(), sess => | |||
| { | |||
| var i = constant_op.constant(0, name: "i"); | |||
| var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||
| var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||
| control_flow_ops.while_loop( | |||
| c, b, new[] { i }, maximum_iterations = maximum_iterations); | |||
| foreach (Operation op in sess.graph.get_operations()) | |||
| { | |||
| var control_flow_context = op._get_control_flow_context(); | |||
| if (control_flow_context != null) | |||
| self.assertProtoEquals(control_flow_context.to_proto(), | |||
| WhileContext.from_proto( | |||
| control_flow_context.to_proto()).to_proto()); | |||
| } | |||
| }); | |||
| } | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testWhileContext() | |||
| { | |||
| _testWhileContextHelper(); | |||
| } | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testWhileContextWithMaximumIterations() | |||
| { | |||
| _testWhileContextHelper(maximum_iterations: 10); | |||
| } | |||
| } | |||
| } | |||