| @@ -10,7 +10,7 @@ namespace Tensorflow.Operations | |||
| /// </summary> | |||
| public class CondContext : ControlFlowContext | |||
| { | |||
| private string _name; | |||
| /// <summary> | |||
| /// The boolean tensor for the cond predicate | |||
| @@ -207,6 +207,9 @@ namespace Tensorflow.Operations | |||
| _values.Add(real_val.name); | |||
| _external_values[real_val.name] = real_val; | |||
| } | |||
| var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred); | |||
| real_val = new[] {t0, t1}[_branch]; | |||
| _external_values[val.name] = real_val; | |||
| } | |||
| else | |||
| { | |||
| @@ -37,7 +37,8 @@ namespace Tensorflow.Operations | |||
| _context_stack = new Stack<IControlFlowContext>(); | |||
| } | |||
| public string name { get; set; } | |||
| public string name { get => _name; } | |||
| protected string _name; | |||
| public void __init__() | |||
| { | |||
| @@ -279,12 +279,36 @@ namespace Tensorflow | |||
| /// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||
| public void _update_input(int index, Tensor tensor) | |||
| { | |||
| throw new NotImplementedException("_update_input"); | |||
| var input = _tf_input(index); | |||
| var output = tensor._as_tf_output(); | |||
| _assert_same_graph( tensor); | |||
| // Reset cached inputs. | |||
| _inputs=new InputList(new Tensor[]{ tensor }); // is this right? original code: self._inputs_val=None | |||
| // TODO: implement below code dependencies | |||
| //_assert_same_graph( tensor); | |||
| //// Reset cached inputs. | |||
| //_inputs_val = null; | |||
| //c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index)); | |||
| //c_api.UpdateEdge(_graph._c_graph, output, input); | |||
| } | |||
| private void _assert_same_graph(Tensor tensor) | |||
| { | |||
| //TODO: implement | |||
| } | |||
| /// <summary> | |||
| /// Create and return a new TF_Output for output_idx'th output of this op. | |||
| /// </summary> | |||
| public TF_Output _tf_output(int output_idx) | |||
| { | |||
| var tf_output = new TF_Output(op, output_idx); | |||
| return tf_output; | |||
| } | |||
| /// <summary> | |||
| /// Create and return a new TF_Input for input_idx'th input of this op. | |||
| /// </summary> | |||
| public TF_Input _tf_input(int input_idx) | |||
| { | |||
| var tf_input = new TF_Input(op, input_idx); | |||
| return tf_input; | |||
| } | |||
| } | |||
| } | |||
| @@ -3,6 +3,7 @@ using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.ComponentModel; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Text; | |||
| @@ -82,7 +83,10 @@ namespace Tensorflow | |||
| } | |||
| catch (Exception ex) | |||
| { | |||
| Console.WriteLine(ex.ToString()); | |||
| Console.WriteLine(ex.ToString()); | |||
| #if DEBUG | |||
| Debugger.Break(); | |||
| #endif | |||
| return default(TOut); | |||
| } | |||
| finally | |||
| @@ -255,16 +255,17 @@ namespace Tensorflow | |||
| public override string ToString() | |||
| { | |||
| if(NDims == 0) | |||
| { | |||
| switch (dtype) | |||
| { | |||
| case TF_DataType.TF_INT32: | |||
| return Data<int>()[0].ToString(); | |||
| } | |||
| } | |||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||
| // this can throw IndexOutOfRangeException | |||
| //if(NDims == 0) | |||
| //{ | |||
| // switch (dtype) | |||
| // { | |||
| // case TF_DataType.TF_INT32: | |||
| // return Data<int>()[0].ToString(); | |||
| // } | |||
| //} | |||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | |||
| } | |||
| public void Dispose() | |||
| @@ -64,7 +64,6 @@ namespace TensorFlowNET.UnitTest.ops_test | |||
| }); | |||
| } | |||
| [Ignore("Switch op gets not inserted correctly in the graph")] | |||
| [TestMethod] | |||
| public void TestCond() | |||
| { | |||
| @@ -94,42 +93,12 @@ namespace TensorFlowNET.UnitTest.ops_test | |||
| //self.assertEqual(op.outputs, new object[0]); | |||
| var op_input = op.inputs[0].op; | |||
| self.assertEqual(op_input.type, "Switch"); | |||
| self.assertEqual(op_input.inputs[0], x); | |||
| self.assertEqual(op_input.inputs[0].name, x.name); | |||
| self.assertEqual(op.graph, g); | |||
| self.assertIsNotNone(op._get_control_flow_context()); | |||
| self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text"); | |||
| var cond_text = op._get_control_flow_context() as ControlFlowContext; | |||
| self.assertEqual(cond_text.name, "cond/cond_text"); | |||
| }); | |||
| /* | |||
| @test_util.run_v1_only("b/120545219") | |||
| def testCond(self): | |||
| g = ops.Graph() | |||
| with g.as_default(): | |||
| x = test_ops.int_output() | |||
| def true_fn(): | |||
| ops._create_c_op(ops.get_default_graph(), | |||
| ops._NodeDef("IntInput", "cond/myop"), [x], []) | |||
| new_ops = g._add_new_tf_operations() | |||
| self.assertEqual(len(new_ops), 1) | |||
| return x | |||
| control_flow_ops.cond(x < 10, true_fn, lambda: x) | |||
| op = g.get_operation_by_name("cond/myop") | |||
| self.assertIsNotNone(op) | |||
| self.assertEqual(op.name, "cond/myop") | |||
| self.assertEqual(op.type, "IntInput") | |||
| self.assertEqual(op.outputs, []) | |||
| op_input = op.inputs[0].op | |||
| self.assertEqual(op_input.type, "Switch") | |||
| self.assertEqual(op_input.inputs[0], x) | |||
| self.assertEqual(op.graph, g) | |||
| # pylint: disable=protected-access | |||
| self.assertIsNotNone(op._get_control_flow_context()) | |||
| self.assertEqual(op._get_control_flow_context().name, | |||
| "cond/cond_text") | |||
| # pylint: enable=protected-access | |||
| */ | |||
| } | |||
| [Ignore("Todo: Port")] | |||