From f9ac0ffa1c003de9ae5767787e77400daa54e1f3 Mon Sep 17 00:00:00 2001 From: Brendan Mulcahy Date: Sun, 1 Dec 2019 18:36:40 -0500 Subject: [PATCH] Add tf.scan test cases --- .../functional_ops_test/ScanTestCase.cs | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 test/TensorFlowNET.UnitTest/functional_ops_test/ScanTestCase.cs diff --git a/test/TensorFlowNET.UnitTest/functional_ops_test/ScanTestCase.cs b/test/TensorFlowNET.UnitTest/functional_ops_test/ScanTestCase.cs new file mode 100644 index 00000000..b9614a33 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/functional_ops_test/ScanTestCase.cs @@ -0,0 +1,39 @@ +using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.functional_ops_test +{ + /// + /// https://www.tensorflow.org/api_docs/python/tf/scan + /// + [TestClass] + public class ScanTestCase + { + [TestMethod] + public void ScanForward() + { + var fn = new Func((a, x) => tf.add(a, x)); + + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6)); + var scan = functional_ops.scan(fn, input); + sess.run(scan, (input, np.array(1,2,3,4,5,6))).Should().Be(np.array(1,3,6,10,15,21)); + } + + [TestMethod] + public void ScanReverse() + { + var fn = new Func((a, x) => tf.add(a, x)); + + var sess = tf.Session().as_default(); + + var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6)); + var scan = functional_ops.scan(fn, input, reverse:true); + sess.run(scan, (input, np.array(1,2,3,4,5,6))).Should().Be(np.array(21,20,18,15,11,6)); + } + } +}