You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

OperationsTest.cs 8.5 kB

7 years ago
7 years ago
7 years ago
6 years ago
6 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using NumSharp;
  7. using Tensorflow;
  8. using Buffer = Tensorflow.Buffer;
  9. namespace TensorFlowNET.UnitTest
  10. {
  11. [TestClass]
  12. public class OperationsTest
  13. {
  14. /// <summary>
  15. /// Port from tensorflow\c\c_api_test.cc
  16. /// `TEST(CAPI, GetAllOpList)`
  17. /// </summary>
  18. [TestMethod]
  19. public void GetAllOpList()
  20. {
  21. var handle = c_api.TF_GetAllOpList();
  22. var buffer = new Buffer(handle);
  23. var op_list = OpList.Parser.ParseFrom(buffer);
  24. var _registered_ops = new Dictionary<string, OpDef>();
  25. foreach (var op_def in op_list.Op)
  26. _registered_ops[op_def.Name] = op_def;
  27. // r1.14 added NN op
  28. var op = _registered_ops.FirstOrDefault(x => x.Key == "NearestNeighbors");
  29. Assert.IsTrue(op_list.Op.Count > 1000);
  30. }
  31. [TestMethod]
  32. public void addInPlaceholder()
  33. {
  34. var a = tf.placeholder(tf.float32);
  35. var b = tf.placeholder(tf.float32);
  36. var c = tf.add(a, b);
  37. using(var sess = tf.Session())
  38. {
  39. var o = sess.run(c,
  40. new FeedItem(a, 3.0f),
  41. new FeedItem(b, 2.0f));
  42. Assert.AreEqual((float)o, 5.0f);
  43. }
  44. }
  45. [TestMethod]
  46. public void addInConstant()
  47. {
  48. var a = tf.constant(4.0f);
  49. var b = tf.constant(5.0f);
  50. var c = tf.add(a, b);
  51. using (var sess = tf.Session())
  52. {
  53. var o = sess.run(c);
  54. Assert.AreEqual((float)o, 9.0f);
  55. }
  56. }
  57. [TestMethod]
  58. public void addOpTests()
  59. {
  60. const int rows = 2; // to avoid broadcasting effect
  61. const int cols = 10;
  62. #region intTest
  63. const int firstIntVal = 2;
  64. const int secondIntVal = 3;
  65. var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray();
  66. var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray();
  67. var intResult = firstIntFeed.Sum() + secondIntFeed.Sum();
  68. var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
  69. var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
  70. var c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1));
  71. using (var sess = tf.Session())
  72. {
  73. var o = sess.run(c,
  74. new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
  75. new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
  76. Assert.AreEqual((int)o, intResult);
  77. }
  78. // Testing `operator +(Tensor x, Tensor y)`
  79. c = tf.reduce_sum(tf.reduce_sum(a + b, 1));
  80. using (var sess = tf.Session())
  81. {
  82. var o = sess.run(c,
  83. new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
  84. new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
  85. Assert.AreEqual((int)o, intResult);
  86. }
  87. // Testing `operator +(Tensor x, int y)`
  88. c = tf.reduce_sum(tf.reduce_sum(a + secondIntVal, 1));
  89. using (var sess = tf.Session())
  90. {
  91. var o = sess.run(c,
  92. new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
  93. Assert.AreEqual((int)o, intResult);
  94. }
  95. // Testing `operator +(int x, Tensor y)`
  96. c = tf.reduce_sum(tf.reduce_sum(secondIntVal + a, 1));
  97. using (var sess = tf.Session())
  98. {
  99. var o = sess.run(c,
  100. new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
  101. Assert.AreEqual((int)o, intResult);
  102. }
  103. #endregion
  104. #region floatTest
  105. const float firstFloatVal = 2.0f;
  106. const float secondFloatVal = 3.0f;
  107. var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray();
  108. var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray();
  109. var floatResult = firstFloatFeed.Sum() + secondFloatFeed.Sum();
  110. a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
  111. b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
  112. c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1));
  113. using (var sess = tf.Session())
  114. {
  115. var o = sess.run(c,
  116. new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
  117. new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
  118. Assert.AreEqual((float)o, floatResult);
  119. }
  120. // Testing `operator +(Tensor x, Tensor y)
  121. c = tf.reduce_sum(tf.reduce_sum(a + b, 1));
  122. using (var sess = tf.Session())
  123. {
  124. var o = sess.run(c,
  125. new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
  126. new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
  127. Assert.AreEqual((float)o, floatResult);
  128. }
  129. // Testing `operator +(Tensor x, float y)
  130. c = tf.reduce_sum(tf.reduce_sum(a + secondFloatVal, 1));
  131. using (var sess = tf.Session())
  132. {
  133. var o = sess.run(c,
  134. new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
  135. Assert.AreEqual((float)o, floatResult);
  136. }
  137. // Testing `operator +(float x, Tensor y)
  138. c = tf.reduce_sum(tf.reduce_sum(secondFloatVal + a, 1));
  139. using (var sess = tf.Session())
  140. {
  141. var o = sess.run(c,
  142. new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
  143. Assert.AreEqual((float)o, floatResult);
  144. }
  145. #endregion
  146. #region doubleTest
  147. const double firstDoubleVal = 2.0;
  148. const double secondDoubleVal = 3.0;
  149. var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray();
  150. var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray();
  151. var doubleResult = firstDoubleFeed.Sum() + secondDoubleFeed.Sum();
  152. a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
  153. b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
  154. c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1));
  155. using (var sess = tf.Session())
  156. {
  157. var o = sess.run(c,
  158. new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
  159. new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
  160. Assert.AreEqual((double)o, doubleResult);
  161. }
  162. // Testing `operator +(Tensor x, Tensor y)
  163. c = tf.reduce_sum(tf.reduce_sum(a + b, 1));
  164. using (var sess = tf.Session())
  165. {
  166. var o = sess.run(c,
  167. new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
  168. new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
  169. Assert.AreEqual((double)o, doubleResult);
  170. }
  171. // Testing `operator +(Tensor x, double y)
  172. c = tf.reduce_sum(tf.reduce_sum(a + secondFloatVal, 1));
  173. using (var sess = tf.Session())
  174. {
  175. var o = sess.run(c,
  176. new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
  177. Assert.AreEqual((double)o, doubleResult);
  178. }
  179. // Testing `operator +(double x, Tensor y)
  180. c = tf.reduce_sum(tf.reduce_sum(secondFloatVal + a, 1));
  181. using (var sess = tf.Session())
  182. {
  183. var o = sess.run(c,
  184. new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
  185. Assert.AreEqual((double)o, doubleResult);
  186. }
  187. #endregion
  188. }
  189. }
  190. }