| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using NumSharp; | |||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -45,5 +46,21 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
| var r = tf.while_loop(c, b, i); | var r = tf.while_loop(c, b, i); | ||||
| Assert.AreEqual(10, (int)r); | Assert.AreEqual(10, (int)r); | ||||
| } | } | ||||
| [TestMethod, Ignore] | |||||
| public void ScanFunctionGraphMode() | |||||
| { | |||||
| tf.compat.v1.disable_eager_execution(); | |||||
| Func<Tensor, Tensor, Tensor> fn = (prev, current) => tf.add(prev, current); | |||||
| var input = tf.placeholder(TF_DataType.TF_FLOAT, new TensorShape(6)); | |||||
| var scan = tf.scan(fn, input); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| sess.run(tf.global_variables_initializer()); | |||||
| var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6))); | |||||
| Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray<float>()); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||