| @@ -10,7 +10,7 @@ namespace Tensorflow.Operations | |||||
| /// </summary> | /// </summary> | ||||
| public class CondContext : ControlFlowContext | public class CondContext : ControlFlowContext | ||||
| { | { | ||||
| private string _name; | |||||
| /// <summary> | /// <summary> | ||||
| /// The boolean tensor for the cond predicate | /// The boolean tensor for the cond predicate | ||||
| @@ -207,6 +207,9 @@ namespace Tensorflow.Operations | |||||
| _values.Add(real_val.name); | _values.Add(real_val.name); | ||||
| _external_values[real_val.name] = real_val; | _external_values[real_val.name] = real_val; | ||||
| } | } | ||||
| var (t0, t1) = control_flow_ops._SwitchRefOrTensor(real_val, _pred); | |||||
| real_val = new[] {t0, t1}[_branch]; | |||||
| _external_values[val.name] = real_val; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -37,7 +37,8 @@ namespace Tensorflow.Operations | |||||
| _context_stack = new Stack<IControlFlowContext>(); | _context_stack = new Stack<IControlFlowContext>(); | ||||
| } | } | ||||
| public string name { get; set; } | |||||
| public string name { get => _name; } | |||||
| protected string _name; | |||||
| public void __init__() | public void __init__() | ||||
| { | { | ||||
| @@ -279,12 +279,36 @@ namespace Tensorflow | |||||
| /// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | /// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | ||||
| public void _update_input(int index, Tensor tensor) | public void _update_input(int index, Tensor tensor) | ||||
| { | { | ||||
| throw new NotImplementedException("_update_input"); | |||||
| var input = _tf_input(index); | |||||
| var output = tensor._as_tf_output(); | |||||
| _assert_same_graph( tensor); | |||||
| // Reset cached inputs. | |||||
| _inputs=new InputList(new Tensor[]{ tensor }); // is this right? original code: self._inputs_val=None | |||||
| // TODO: implement below code dependencies | // TODO: implement below code dependencies | ||||
| //_assert_same_graph( tensor); | |||||
| //// Reset cached inputs. | |||||
| //_inputs_val = null; | |||||
| //c_api.UpdateEdge(_graph._c_graph, tensor._as_tf_output(), _tf_input(index)); | |||||
| //c_api.UpdateEdge(_graph._c_graph, output, input); | |||||
| } | |||||
| private void _assert_same_graph(Tensor tensor) | |||||
| { | |||||
| //TODO: implement | |||||
| } | |||||
| /// <summary> | |||||
| /// Create and return a new TF_Output for output_idx'th output of this op. | |||||
| /// </summary> | |||||
| public TF_Output _tf_output(int output_idx) | |||||
| { | |||||
| var tf_output = new TF_Output(op, output_idx); | |||||
| return tf_output; | |||||
| } | |||||
| /// <summary> | |||||
| /// Create and return a new TF_Input for input_idx'th input of this op. | |||||
| /// </summary> | |||||
| public TF_Input _tf_input(int input_idx) | |||||
| { | |||||
| var tf_input = new TF_Input(op, input_idx); | |||||
| return tf_input; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -3,6 +3,7 @@ using System; | |||||
| using System.Collections; | using System.Collections; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.ComponentModel; | using System.ComponentModel; | ||||
| using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -82,7 +83,10 @@ namespace Tensorflow | |||||
| } | } | ||||
| catch (Exception ex) | catch (Exception ex) | ||||
| { | { | ||||
| Console.WriteLine(ex.ToString()); | |||||
| Console.WriteLine(ex.ToString()); | |||||
| #if DEBUG | |||||
| Debugger.Break(); | |||||
| #endif | |||||
| return default(TOut); | return default(TOut); | ||||
| } | } | ||||
| finally | finally | ||||
| @@ -255,16 +255,17 @@ namespace Tensorflow | |||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| if(NDims == 0) | |||||
| { | |||||
| switch (dtype) | |||||
| { | |||||
| case TF_DataType.TF_INT32: | |||||
| return Data<int>()[0].ToString(); | |||||
| } | |||||
| } | |||||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||||
| // this can throw IndexOutOfRangeException | |||||
| //if(NDims == 0) | |||||
| //{ | |||||
| // switch (dtype) | |||||
| // { | |||||
| // case TF_DataType.TF_INT32: | |||||
| // return Data<int>()[0].ToString(); | |||||
| // } | |||||
| //} | |||||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| @@ -64,7 +64,6 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
| }); | }); | ||||
| } | } | ||||
| [Ignore("Switch op gets not inserted correctly in the graph")] | |||||
| [TestMethod] | [TestMethod] | ||||
| public void TestCond() | public void TestCond() | ||||
| { | { | ||||
| @@ -94,42 +93,12 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
| //self.assertEqual(op.outputs, new object[0]); | //self.assertEqual(op.outputs, new object[0]); | ||||
| var op_input = op.inputs[0].op; | var op_input = op.inputs[0].op; | ||||
| self.assertEqual(op_input.type, "Switch"); | self.assertEqual(op_input.type, "Switch"); | ||||
| self.assertEqual(op_input.inputs[0], x); | |||||
| self.assertEqual(op_input.inputs[0].name, x.name); | |||||
| self.assertEqual(op.graph, g); | self.assertEqual(op.graph, g); | ||||
| self.assertIsNotNone(op._get_control_flow_context()); | self.assertIsNotNone(op._get_control_flow_context()); | ||||
| self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text"); | |||||
| var cond_text = op._get_control_flow_context() as ControlFlowContext; | |||||
| self.assertEqual(cond_text.name, "cond/cond_text"); | |||||
| }); | }); | ||||
| /* | |||||
| @test_util.run_v1_only("b/120545219") | |||||
| def testCond(self): | |||||
| g = ops.Graph() | |||||
| with g.as_default(): | |||||
| x = test_ops.int_output() | |||||
| def true_fn(): | |||||
| ops._create_c_op(ops.get_default_graph(), | |||||
| ops._NodeDef("IntInput", "cond/myop"), [x], []) | |||||
| new_ops = g._add_new_tf_operations() | |||||
| self.assertEqual(len(new_ops), 1) | |||||
| return x | |||||
| control_flow_ops.cond(x < 10, true_fn, lambda: x) | |||||
| op = g.get_operation_by_name("cond/myop") | |||||
| self.assertIsNotNone(op) | |||||
| self.assertEqual(op.name, "cond/myop") | |||||
| self.assertEqual(op.type, "IntInput") | |||||
| self.assertEqual(op.outputs, []) | |||||
| op_input = op.inputs[0].op | |||||
| self.assertEqual(op_input.type, "Switch") | |||||
| self.assertEqual(op_input.inputs[0], x) | |||||
| self.assertEqual(op.graph, g) | |||||
| # pylint: disable=protected-access | |||||
| self.assertIsNotNone(op._get_control_flow_context()) | |||||
| self.assertEqual(op._get_control_flow_context().name, | |||||
| "cond/cond_text") | |||||
| # pylint: enable=protected-access | |||||
| */ | |||||
| } | } | ||||
| [Ignore("Todo: Port")] | [Ignore("Todo: Port")] | ||||