You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NDArray.Creation.cs 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow.Eager;
  5. using static Tensorflow.Binding;
  6. namespace Tensorflow.NumPy
  7. {
  8. public partial class NDArray
  9. {
  10. public NDArray(bool value) => Init(value);
  11. public NDArray(byte value) => Init(value);
  12. public NDArray(short value) => Init(value);
  13. public NDArray(int value) => Init(value);
  14. public NDArray(long value) => Init(value);
  15. public NDArray(float value) => Init(value);
  16. public NDArray(double value) => Init(value);
  17. public NDArray(Array value, Shape? shape = null) => Init(value, shape);
  18. public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype);
  19. public NDArray(Tensor value, Shape? shape = null) => Init(value, shape);
  20. public NDArray(byte[] bytes, TF_DataType dtype) => Init(bytes, dtype);
  21. public static NDArray Scalar<T>(T value) where T : unmanaged
  22. => value switch
  23. {
  24. bool val => new NDArray(val),
  25. byte val => new NDArray(val),
  26. int val => new NDArray(val),
  27. long val => new NDArray(val),
  28. float val => new NDArray(val),
  29. double val => new NDArray(val),
  30. _ => throw new NotImplementedException("")
  31. };
  32. void Init<T>(T value) where T : unmanaged
  33. {
  34. _tensor = value switch
  35. {
  36. bool val => new Tensor(val),
  37. byte val => new Tensor(val),
  38. int val => new Tensor(val),
  39. long val => new Tensor(val),
  40. float val => new Tensor(val),
  41. double val => new Tensor(val),
  42. _ => throw new NotImplementedException("")
  43. };
  44. _tensor.SetReferencedByNDArray();
  45. }
  46. void Init(Array value, Shape? shape = null)
  47. {
  48. _tensor = new Tensor(value, shape ?? value.GetShape());
  49. _tensor.SetReferencedByNDArray();
  50. }
  51. void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
  52. {
  53. _tensor = new Tensor(shape, dtype: dtype);
  54. _tensor.SetReferencedByNDArray();
  55. }
  56. void Init(Tensor value, Shape? shape = null)
  57. {
  58. // created tensor in graph mode
  59. if (value.TensorDataPointer == IntPtr.Zero)
  60. value = tf.defaultSession.eval(value);
  61. _tensor = new Tensor(value.TensorDataPointer, shape ?? value.shape, value.dtype);
  62. _tensor.SetReferencedByNDArray();
  63. }
  64. void Init(byte[] bytes, TF_DataType dtype)
  65. {
  66. }
  67. }
  68. }