You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorTest.cs 2.2 kB

7 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp.Core;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Runtime.InteropServices;
  7. using System.Text;
  8. using Tensorflow;
  9. namespace TensorFlowNET.UnitTest
  10. {
  11. [TestClass]
  12. public class TensorTest
  13. {
  14. [TestMethod]
  15. public unsafe void TF_NewTensor()
  16. {
  17. var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);
  18. var data = Marshal.AllocHGlobal(sizeof(float) * nd.size);
  19. Marshal.Copy(nd.Data<float>(), 0, data, nd.size);
  20. var deallocator_called = Marshal.AllocHGlobal(sizeof(bool));
  21. Assert.AreEqual(*(bool*)deallocator_called, false);
  22. var handle = c_api.TF_NewTensor(TF_DataType.TF_FLOAT,
  23. nd.shape.Select(x => (long)x).ToArray(), // shape
  24. nd.ndim,
  25. data,
  26. (UIntPtr)(nd.size * sizeof(float)),
  27. (IntPtr values, IntPtr len, IntPtr closure) =>
  28. {
  29. // Free the original buffer and set flag
  30. Marshal.FreeHGlobal(data);
  31. *(bool*)closure = true;
  32. },
  33. deallocator_called);
  34. Assert.AreNotEqual(handle, IntPtr.Zero);
  35. var tensor = new Tensor(handle);
  36. Assert.AreEqual(tensor.dtype, TF_DataType.TF_FLOAT);
  37. Assert.AreEqual(tensor.ndim, nd.ndim);
  38. Assert.AreEqual(nd.shape[0], c_api.TF_Dim(handle, 0));
  39. Assert.AreEqual(nd.shape[1], c_api.TF_Dim(handle, 1));
  40. Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float));
  41. Assert.AreEqual(*(bool*)deallocator_called, true);
  42. // Column major order
  43. // https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg
  44. // matrix:[[1, 2, 3], [4, 5, 6]]
  45. // index: 0 2 4 1 3 5
  46. // result: 1 4 2 5 3 6
  47. var array = tensor.Data<float>();
  48. Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), array));
  49. c_api.TF_DeleteTensor(handle);
  50. Assert.AreEqual(*(bool *)deallocator_called, true);
  51. }
  52. }
  53. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。

Contributors (1)