You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorTest.cs 8.2 kB

7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using System.Text;
  8. using System.Threading;
  9. using System.Threading.Tasks;
  10. using Tensorflow;
  11. using static Tensorflow.Python;
  12. namespace TensorFlowNET.UnitTest
  13. {
  14. [TestClass]
  15. public class TensorTest : CApiTest
  16. {
  17. [TestMethod]
  18. public void TensorDeallocationThreadSafety()
  19. {
  20. var tensors = new Tensor[1000];
  21. foreach (var i in range(1000))
  22. {
  23. tensors[i] = new Tensor(new int[1000]);
  24. }
  25. SemaphoreSlim s = new SemaphoreSlim(0, 2);
  26. SemaphoreSlim s_done = new SemaphoreSlim(0, 2);
  27. var t1 = new Thread(() =>
  28. {
  29. s.Wait();
  30. foreach (var t in tensors)
  31. t.Dispose();
  32. s_done.Release();
  33. });
  34. var t2 = new Thread(() =>
  35. {
  36. s.Wait();
  37. foreach (var t in tensors)
  38. t.Dispose();
  39. s_done.Release();
  40. });
  41. t1.Start();
  42. t2.Start();
  43. s.Release(2);
  44. s_done.Wait();
  45. s_done.Wait();
  46. foreach (var t in tensors)
  47. Assert.IsTrue(t.IsDisposed);
  48. }
  49. [TestMethod]
  50. public unsafe void TensorFromFixed()
  51. {
  52. var array = new float[1000];
  53. var span = new Span<float>(array, 100, 500);
  54. fixed (float* ptr=&MemoryMarshal.GetReference(span))
  55. {
  56. using (var t = new Tensor((IntPtr)ptr, new long[] {span.Length}, tf.float32, 4*span.Length))
  57. {
  58. Assert.IsFalse(t.IsDisposed);
  59. Assert.IsFalse(t.IsMemoryOwner);
  60. Assert.AreEqual(2000, (int) t.bytesize);
  61. }
  62. }
  63. fixed (float* ptr = &array[0])
  64. {
  65. using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length))
  66. {
  67. Assert.IsFalse(t.IsDisposed);
  68. Assert.IsFalse(t.IsMemoryOwner);
  69. Assert.AreEqual(4000, (int)t.bytesize);
  70. }
  71. }
  72. }
  73. [TestMethod]
  74. public void AllocateTensor()
  75. {
  76. /*ulong num_bytes = 6 * sizeof(float);
  77. long[] dims = { 2, 3 };
  78. Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
  79. EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
  80. EXPECT_EQ(2, t.NDims);
  81. Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape));
  82. EXPECT_EQ(num_bytes, t.bytesize);
  83. t.Dispose();*/
  84. }
  85. /// <summary>
  86. /// Port from c_api_test.cc
  87. /// `TEST(CAPI, MaybeMove)`
  88. /// </summary>
  89. [TestMethod]
  90. public void MaybeMove()
  91. {
  92. NDArray nd = np.array(2, 3);
  93. Tensor t = new Tensor(nd);
  94. Tensor o = t.MaybeMove();
  95. ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own.
  96. t.Dispose();
  97. }
  98. /// <summary>
  99. /// Port from c_api_test.cc
  100. /// `TEST(CAPI, Tensor)`
  101. /// </summary>
  102. [TestMethod]
  103. public void Tensor()
  104. {
  105. var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);
  106. var tensor = new Tensor(nd);
  107. var array = tensor.Data<float>();
  108. EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
  109. EXPECT_EQ(tensor.rank, nd.ndim);
  110. EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
  111. EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
  112. EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
  113. Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 }));
  114. }
  115. /// <summary>
  116. /// Port from tensorflow\c\c_api_test.cc
  117. /// `TEST(CAPI, SetShape)`
  118. /// </summary>
  119. [TestMethod]
  120. public void SetShape()
  121. {
  122. var s = new Status();
  123. var graph = new Graph();
  124. var feed = c_test_util.Placeholder(graph, s);
  125. var feed_out_0 = new TF_Output(feed, 0);
  126. // Fetch the shape, it should be completely unknown.
  127. int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  128. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  129. EXPECT_EQ(-1, num_dims);
  130. // Set the shape to be unknown, expect no change.
  131. c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
  132. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  133. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  134. EXPECT_EQ(-1, num_dims);
  135. // Set the shape to be 2 x Unknown
  136. long[] dims = { 2, -1 };
  137. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  138. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  139. num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
  140. EXPECT_EQ(2, num_dims);
  141. // Get the dimension vector appropriately.
  142. var returned_dims = new long[dims.Length];
  143. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  144. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  145. Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  146. // Set to a new valid shape: [2, 3]
  147. dims[1] = 3;
  148. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
  149. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  150. // Fetch and see that the new value is returned.
  151. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  152. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  153. Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));
  154. // Try to set 'unknown' with unknown rank on the shape and see that
  155. // it doesn't change.
  156. c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
  157. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  158. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  159. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  160. EXPECT_EQ(2, num_dims);
  161. EXPECT_EQ(2, (int)returned_dims[0]);
  162. EXPECT_EQ(3, (int)returned_dims[1]);
  163. // Try to set 'unknown' with same rank on the shape and see that
  164. // it doesn't change.
  165. dims[0] = -1;
  166. dims[1] = -1;
  167. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
  168. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  169. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
  170. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  171. EXPECT_EQ(2, num_dims);
  172. EXPECT_EQ(2, (int)returned_dims[0]);
  173. EXPECT_EQ(3, (int)returned_dims[1]);
  174. // Try to fetch a shape with the wrong num_dims
  175. c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
  176. Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
  177. // Try to set an invalid shape (cannot change 2x3 to a 2x5).
  178. dims[1] = 5;
  179. c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
  180. Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);
  181. // Test for a scalar.
  182. var three = c_test_util.ScalarConst(3, graph, s);
  183. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  184. var three_out_0 = new TF_Output(three, 0);
  185. num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s);
  186. Assert.IsTrue(s.Code == TF_Code.TF_OK);
  187. EXPECT_EQ(0, num_dims);
  188. c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s);
  189. //Assert.IsTrue(s.Code == TF_Code.TF_OK);
  190. // graph.Dispose();
  191. s.Dispose();
  192. }
  193. }
  194. }