You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorTest.cs 8.1 kB

7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Threading;
  7. using Tensorflow;
  8. using static Tensorflow.Python;
  9. using static Tensorflow.Binding;
  10. namespace TensorFlowNET.UnitTest
  11. {
  12. [TestClass]
  13. public class TensorTest : CApiTest
  14. {
  15. [Ignore("Not for mult-thread")]
  16. public void TensorDeallocationThreadSafety()
  17. {
  18. var tensors = new Tensor[1000];
  19. foreach (var i in range(1000))
  20. {
  21. tensors[i] = new Tensor(new int[1000]);
  22. }
  23. SemaphoreSlim s = new SemaphoreSlim(0, 2);
  24. SemaphoreSlim s_done = new SemaphoreSlim(0, 2);
  25. var t1 = new Thread(() =>
  26. {
  27. s.Wait();
  28. foreach (var t in tensors)
  29. t.Dispose();
  30. s_done.Release();
  31. });
  32. var t2 = new Thread(() =>
  33. {
  34. s.Wait();
  35. foreach (var t in tensors)
  36. t.Dispose();
  37. s_done.Release();
  38. });
  39. t1.Start();
  40. t2.Start();
  41. s.Release(2);
  42. s_done.Wait();
  43. s_done.Wait();
  44. foreach (var t in tensors)
  45. Assert.IsTrue(t.IsDisposed);
  46. }
  47. [TestMethod]
  48. public unsafe void TensorFromFixed()
  49. {
  50. var array = new float[1000];
  51. var span = new Span<float>(array, 100, 500);
  52. fixed (float* ptr=&MemoryMarshal.GetReference(span))
  53. {
  54. using (var t = new Tensor((IntPtr)ptr, new long[] {span.Length}, tf.float32, 4*span.Length))
  55. {
  56. Assert.IsFalse(t.IsDisposed);
  57. Assert.IsFalse(t.IsMemoryOwner);
  58. Assert.AreEqual(2000, (int) t.bytesize);
  59. }
  60. }
  61. fixed (float* ptr = &array[0])
  62. {
  63. using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length))
  64. {
  65. Assert.IsFalse(t.IsDisposed);
  66. Assert.IsFalse(t.IsMemoryOwner);
  67. Assert.AreEqual(4000, (int)t.bytesize);
  68. }
  69. }
  70. }
  71. [TestMethod]
  72. public void AllocateTensor()
  73. {
  74. ulong num_bytes = 6 * sizeof(float);
  75. long[] dims = { 2, 3 };
  76. Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
  77. EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
  78. EXPECT_EQ(2, t.NDims);
  79. EXPECT_EQ((int)dims[0], t.shape[0]);
  80. EXPECT_EQ(num_bytes, t.bytesize);
  81. t.Dispose();
  82. }
  83. /// <summary>
  84. /// Port from c_api_test.cc
  85. /// `TEST(CAPI, MaybeMove)`
  86. /// </summary>
  87. [TestMethod]
  88. public void MaybeMove()
  89. {
  90. NDArray nd = np.array(2, 3);
  91. Tensor t = new Tensor(nd);
  92. Tensor o = t.MaybeMove();
  93. ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own.
  94. t.Dispose();
  95. }
  96. /// <summary>
  97. /// Port from c_api_test.cc
  98. /// `TEST(CAPI, Tensor)`
  99. /// </summary>
  100. [TestMethod]
  101. public void Tensor()
  102. {
  103. var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);
  104. var tensor = new Tensor(nd);
  105. var array = tensor.Data<float>();
  106. EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
  107. EXPECT_EQ(tensor.rank, nd.ndim);
  108. EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
  109. EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
  110. EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
  111. Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 }));
  112. }
  113. /// <summary>
  114. /// Port from tensorflow\c\c_api_test.cc
  115. /// `TEST(CAPI, SetShape)`
  116. /// </summary>
  117. [TestMethod]
  118. public void SetShape()
  119. {
  120. var s = new Status();
  121. var graph = new Graph();
  122. var feed = c_test_util.Placeholder(graph, s);
  123. var feed_out_0 = new TF_Output(feed, 0);
  124. // Fetch the shape, it should be completely unknown.
  125. int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  126. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  127. EXPECT_EQ(-1, num_dims);
  128. // Set the shape to be unknown, expect no change.
  129. c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
  130. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  131. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  132. EXPECT_EQ(-1, num_dims);
  133. // Set the shape to be 2 x Unknown
  134. long[] dims = { 2, -1 };
  135. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  136. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  137. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  138. EXPECT_EQ(2, num_dims);
  139. // Get the dimension vector appropriately.
  140. var returned_dims = new long[dims.Length];
  141. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  142. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  143. Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  144. // Set to a new valid shape: [2, 3]
  145. dims[1] = 3;
  146. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  147. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  148. // Fetch and see that the new value is returned.
  149. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  150. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  151. Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  152. // Try to set 'unknown' with unknown rank on the shape and see that
  153. // it doesn't change.
  154. c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
  155. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  156. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  157. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  158. EXPECT_EQ(2, num_dims);
  159. EXPECT_EQ(2, (int)returned_dims[0]);
  160. EXPECT_EQ(3, (int)returned_dims[1]);
  161. // Try to set 'unknown' with same rank on the shape and see that
  162. // it doesn't change.
  163. dims[0] = -1;
  164. dims[1] = -1;
  165. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
  166. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  167. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  168. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  169. EXPECT_EQ(2, num_dims);
  170. EXPECT_EQ(2, (int)returned_dims[0]);
  171. EXPECT_EQ(3, (int)returned_dims[1]);
  172. // Try to fetch a shape with the wrong num_dims
  173. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
  174. Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
  175. // Try to set an invalid shape (cannot change 2x3 to a 2x5).
  176. dims[1] = 5;
  177. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
  178. Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
  179. // Test for a scalar.
  180. var three = c_test_util.ScalarConst(3, graph, s);
  181. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  182. var three_out_0 = new TF_Output(three, 0);
  183. num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s);
  184. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  185. EXPECT_EQ(0, num_dims);
  186. c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s);
  187. //Assert.IsTrue(s.Code == TF_Code.TF_OK);
  188. // graph.Dispose();
  189. s.Dispose();
  190. }
  191. }
  192. }