diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index bbf240e3..fbf3dd00 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -27,7 +27,7 @@ namespace Tensorflow public static Tensor asin(Tensor x, string name = null) => gen_math_ops.asin(x, name); - public static Tensor add(Tensor a, Tensor b) + public static Tensor add(Tx a, Ty b) => gen_math_ops.add(a, b); /// @@ -251,7 +251,7 @@ namespace Tensorflow public static Tensor minimum(T1 x, T2 y, string name = null) => gen_math_ops.minimum(x, y, name: name); - public static Tensor multiply(Tensor x, Tensor y) + public static Tensor multiply(Tx x, Ty y) => gen_math_ops.mul(x, y); public static Tensor negative(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 9f4280d9..c81ab08c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -287,7 +287,7 @@ namespace Tensorflow // Reset cached inputs. _inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None // TODO: implement below code dependencies - //c_api.UpdateEdge(_graph._c_graph, output, input); + // c_api.TF_UpdateEdge(graph, output, input, status); } private void _assert_same_graph(Tensor tensor) diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 5e58df45..56477442 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -80,7 +80,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor add(Tensor x, Tensor y, string name = null) + public static Tensor add(Tx x, Ty y, string name = null) { var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); @@ -300,7 +300,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor mul(Tensor x, Tensor y, string name = null) + public static Tensor mul(Tx x, Ty y, string name = null) { var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs index de391679..e35923e5 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; using Tensorflow; namespace TensorFlowNET.UnitTest.control_flow_ops_test @@ -18,25 +19,54 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test var x = tf.constant(2); var y = tf.constant(5); var z = control_flow_ops.cond(tf.less(x, y), - () => tf.multiply(x, tf.constant(17)), - () => tf.add(y, tf.constant(23))); + () => tf.multiply(x, 17), + () => tf.add(y, 23)); int result = z.eval(sess); assertEquals(result, 34); }); } - [Ignore("Todo")] [TestMethod] public void testCondFalse() { - // def testCondFalse(self): - // x = constant_op.constant(2) - // y = constant_op.constant(1) - // z = control_flow_ops.cond( - // math_ops.less( - // x, - // y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23)) - // self.assertEquals(self.evaluate(z), 24) + /* python + * import tensorflow as tf + from tensorflow.python.framework import ops + + def if_true(): + return tf.math.multiply(x, 17) + def if_false(): + return tf.math.add(y, 23) + + with tf.Session() as sess: + x = tf.constant(2) + y = tf.constant(1) + pred = tf.math.less(x,y) + z = tf.cond(pred, if_true, if_false) + result = z.eval() + + print(result == 24) */ + + with(tf.Session(), sess => + { + var x = tf.constant(2); + var y = tf.constant(1); + var pred = tf.less(x, y); + + Func if_true = delegate + { + return tf.multiply(x, 17); + }; + + Func if_false = delegate + { + return tf.add(y, 23); + }; + + var z = control_flow_ops.cond(pred, if_true, if_false); + int result = z.eval(sess); + assertEquals(result, 24); + }); } [Ignore("Todo")]