| @@ -32,5 +32,25 @@ namespace Tensorflow.UnitTest.TF_API | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape)); | Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape)); | ||||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape)); | Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape)); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void ConcatTest() | |||||
| { | |||||
| var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } }); | |||||
| var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } }); | |||||
| var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); | |||||
| var concatValue = tf.concat(new[] { a, b, c }, axis: 0); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); | |||||
| } | |||||
| [TestMethod] | |||||
| public void ConcatDoubleTest() | |||||
| {//double type has some error | |||||
| var a = tf.constant(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } }); | |||||
| var b = tf.constant(new[,] { { 5.0, 6.0 }, { 7.0, 8.0 } }); | |||||
| var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } }); | |||||
| var concatValue = tf.concat(new[] { a, b, c }, axis: 0); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||