| @@ -526,8 +526,19 @@ namespace Tensorflow | |||||
| var type = data.GetType(); | var type = data.GetType(); | ||||
| switch (data) | switch (data) | ||||
| { | { | ||||
| case Shape shape: | |||||
| case TensorShape: | |||||
| case Shape: | |||||
| return TF_DataType.TF_INT64; | return TF_DataType.TF_INT64; | ||||
| case Axis: | |||||
| return TF_DataType.TF_INT32; | |||||
| case NDArray nd: | |||||
| return nd.dtype; | |||||
| case Tensor tensor: | |||||
| return tensor.dtype; | |||||
| case Tensor[] tensor: | |||||
| return tensor[0].dtype; | |||||
| case ResourceVariable variable: | |||||
| return variable.dtype; | |||||
| default: | default: | ||||
| return type.as_tf_dtype(); | return type.as_tf_dtype(); | ||||
| } | } | ||||
| @@ -142,7 +142,7 @@ namespace Tensorflow.Contexts | |||||
| bool has_graph_arg = !tf.Context.executing_eagerly(); | bool has_graph_arg = !tf.Context.executing_eagerly(); | ||||
| foreach (var el in flatten_args) | foreach (var el in flatten_args) | ||||
| { | { | ||||
| if (el is Tensor tensor && !tensor.IsEagerTensor) | |||||
| if (el is Tensor tensor && tensor.IsCreatedInGraphMode) | |||||
| { | { | ||||
| has_graph_arg = true; | has_graph_arg = true; | ||||
| break; | break; | ||||
| @@ -50,9 +50,6 @@ namespace Tensorflow.Eager | |||||
| public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype) | public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype) | ||||
| => NewEagerTensorHandle(_handle); | => NewEagerTensorHandle(_handle); | ||||
| internal unsafe EagerTensor(string value) : base(value) | |||||
| => NewEagerTensorHandle(_handle); | |||||
| internal unsafe EagerTensor(Array array, Shape shape) : base(array, shape) | internal unsafe EagerTensor(Array array, Shape shape) : base(array, shape) | ||||
| => NewEagerTensorHandle(_handle); | => NewEagerTensorHandle(_handle); | ||||
| @@ -141,7 +141,7 @@ namespace Tensorflow.Functions | |||||
| src_graph: _func_graph); | src_graph: _func_graph); | ||||
| var captures_from_forward = backwards_graph.external_captures | var captures_from_forward = backwards_graph.external_captures | ||||
| .Where(x => !x.IsEagerTensor && x.graph == _func_graph) | |||||
| .Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph) | |||||
| .ToArray(); | .ToArray(); | ||||
| foreach(var capture in captures_from_forward) | foreach(var capture in captures_from_forward) | ||||
| { | { | ||||
| @@ -8,20 +8,47 @@ namespace Tensorflow.NumPy | |||||
| { | { | ||||
| public partial class NDArray | public partial class NDArray | ||||
| { | { | ||||
| public NDArray(bool value) => _tensor = new EagerTensor(value); | |||||
| public NDArray(byte value) => _tensor = new EagerTensor(value); | |||||
| public NDArray(short value) => _tensor = new EagerTensor(value); | |||||
| public NDArray(int value) => _tensor = new EagerTensor(value); | |||||
| public NDArray(long value) => _tensor = new EagerTensor(value); | |||||
| public NDArray(float value) => _tensor = new EagerTensor(value); | |||||
| public NDArray(double value) => _tensor = new EagerTensor(value); | |||||
| public NDArray(bool value) => Init(value); | |||||
| public NDArray(byte value) => Init(value); | |||||
| public NDArray(short value) => Init(value); | |||||
| public NDArray(int value) => Init(value); | |||||
| public NDArray(long value) => Init(value); | |||||
| public NDArray(float value) => Init(value); | |||||
| public NDArray(double value) => Init(value); | |||||
| public NDArray(Array value, Shape? shape = null) => Init(value, shape); | |||||
| public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype); | |||||
| public NDArray(Tensor value, Shape? shape = null) => Init(value, shape); | |||||
| public NDArray(Array value, Shape? shape = null) => _tensor = new EagerTensor(value, shape); | |||||
| public static NDArray Scalar<T>(T value) where T : unmanaged | |||||
| => value switch | |||||
| { | |||||
| bool val => new NDArray(val), | |||||
| byte val => new NDArray(val), | |||||
| int val => new NDArray(val), | |||||
| float val => new NDArray(val), | |||||
| double val => new NDArray(val), | |||||
| _ => throw new NotImplementedException("") | |||||
| }; | |||||
| void Init<T>(T value) where T : unmanaged | |||||
| { | |||||
| _tensor = new EagerTensor(value); | |||||
| _tensor.SetReferencedByNDArray(); | |||||
| } | |||||
| void Init(Array value, Shape? shape = null) | |||||
| { | |||||
| _tensor = new EagerTensor(value, shape ?? value.GetShape()); | |||||
| _tensor.SetReferencedByNDArray(); | |||||
| } | |||||
| public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | |||||
| => _tensor = new EagerTensor(shape, dtype: dtype); | |||||
| void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | |||||
| { | |||||
| _tensor = new EagerTensor(shape, dtype: dtype); | |||||
| _tensor.SetReferencedByNDArray(); | |||||
| } | |||||
| public NDArray(Tensor value, Shape? shape = null) | |||||
| void Init(Tensor value, Shape? shape = null) | |||||
| { | { | ||||
| if (shape is not null) | if (shape is not null) | ||||
| _tensor = tf.reshape(value, shape); | _tensor = tf.reshape(value, shape); | ||||
| @@ -30,18 +57,8 @@ namespace Tensorflow.NumPy | |||||
| if (_tensor.TensorDataPointer == IntPtr.Zero) | if (_tensor.TensorDataPointer == IntPtr.Zero) | ||||
| _tensor = tf.get_default_session().eval(_tensor); | _tensor = tf.get_default_session().eval(_tensor); | ||||
| } | |||||
| public static NDArray Scalar<T>(T value) where T : unmanaged | |||||
| { | |||||
| return value switch | |||||
| { | |||||
| bool val => new NDArray(val), | |||||
| int val => new NDArray(val), | |||||
| float val => new NDArray(val), | |||||
| double val => new NDArray(val), | |||||
| _ => throw new NotImplementedException("") | |||||
| }; | |||||
| _tensor.SetReferencedByNDArray(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -21,6 +21,7 @@ using System.Linq; | |||||
| using System.Numerics; | using System.Numerics; | ||||
| using System.Text; | using System.Text; | ||||
| using static Tensorflow.c_api; | using static Tensorflow.c_api; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -31,7 +32,7 @@ namespace Tensorflow | |||||
| public Tensor() | public Tensor() | ||||
| { | { | ||||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -41,60 +42,7 @@ namespace Tensorflow | |||||
| public Tensor(IntPtr handle) | public Tensor(IntPtr handle) | ||||
| { | { | ||||
| _handle = handle; | _handle = handle; | ||||
| //no need to set AllocationType = AllocationType.None; | |||||
| #if TRACK_TENSOR_LIFE | |||||
| print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||||
| #endif | |||||
| } | |||||
| unsafe internal Tensor(Shape shape, TF_DataType dtype) | |||||
| => _handle = TF_NewTensor(shape, dtype, null); | |||||
| internal Tensor(Array array, Shape? shape = null) | |||||
| => InitTensor(array, shape); | |||||
| unsafe void InitTensor(Array array, Shape? shape = null) | |||||
| { | |||||
| shape = shape ?? array.GetShape(); | |||||
| var dtype = array.GetType().GetElementType().as_tf_dtype(); | |||||
| switch (array) | |||||
| { | |||||
| case bool[] val: | |||||
| fixed (void* addr = &val[0]) | |||||
| _handle = TF_NewTensor(shape, dtype, addr); | |||||
| break; | |||||
| case int[] val: | |||||
| fixed (void* addr = &val[0]) | |||||
| _handle = TF_NewTensor(shape, dtype, addr); | |||||
| break; | |||||
| case int[,] val: | |||||
| fixed (void* addr = &val[0, 0]) | |||||
| _handle = TF_NewTensor(shape, dtype, addr); | |||||
| break; | |||||
| case long[] val: | |||||
| fixed (void* addr = &val[0]) | |||||
| _handle = TF_NewTensor(shape, dtype, addr); | |||||
| break; | |||||
| case float[] val: | |||||
| fixed (void* addr = &val[0]) | |||||
| _handle = TF_NewTensor(shape, dtype, addr); | |||||
| break; | |||||
| case float[,] val: | |||||
| fixed (void* addr = &val[0, 0]) | |||||
| _handle = TF_NewTensor(shape, dtype, addr); | |||||
| break; | |||||
| case double[] val: | |||||
| fixed (void* addr = &val[0]) | |||||
| _handle = TF_NewTensor(shape, dtype, addr); | |||||
| break; | |||||
| case double[,] val: | |||||
| fixed (void* addr = &val[0, 0]) | |||||
| _handle = TF_NewTensor(shape, dtype, addr); | |||||
| break; | |||||
| default: | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -109,22 +57,26 @@ namespace Tensorflow | |||||
| public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes) | public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes) | ||||
| { | { | ||||
| _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (ulong)num_bytes); | _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (ulong)num_bytes); | ||||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| public unsafe Tensor(NDArray nd) | public unsafe Tensor(NDArray nd) | ||||
| => _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | |||||
| { | |||||
| _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | |||||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | |||||
| #region scala | #region scala | ||||
| public Tensor(bool value) => _handle = TF_NewTensor(value); | |||||
| public Tensor(byte value) => _handle = TF_NewTensor(value); | |||||
| public Tensor(sbyte value) => _handle = TF_NewTensor(value); | |||||
| public Tensor(short value) => _handle = TF_NewTensor(value); | |||||
| public Tensor(int value) => _handle = TF_NewTensor(value); | |||||
| public Tensor(uint value) => _handle = TF_NewTensor(value); | |||||
| public Tensor(long value) => _handle = TF_NewTensor(value); | |||||
| public Tensor(ulong value) => _handle = TF_NewTensor(value); | |||||
| public Tensor(float value) => _handle = TF_NewTensor(value); | |||||
| public Tensor(double value) => _handle = TF_NewTensor(value); | |||||
| public Tensor(bool value) => InitTensor(value); | |||||
| public Tensor(byte value) => InitTensor(value); | |||||
| public Tensor(sbyte value) => InitTensor(value); | |||||
| public Tensor(short value) => InitTensor(value); | |||||
| public Tensor(int value) => InitTensor(value); | |||||
| public Tensor(uint value) => InitTensor(value); | |||||
| public Tensor(long value) => InitTensor(value); | |||||
| public Tensor(ulong value) => InitTensor(value); | |||||
| public Tensor(float value) => InitTensor(value); | |||||
| public Tensor(double value) => InitTensor(value); | |||||
| #endregion | #endregion | ||||
| #region 1d array | #region 1d array | ||||
| @@ -142,31 +94,74 @@ namespace Tensorflow | |||||
| public Tensor(Complex[] data, Shape? shape = null) => InitTensor(data, shape); | public Tensor(Complex[] data, Shape? shape = null) => InitTensor(data, shape); | ||||
| #endregion | #endregion | ||||
| /// <summary> | |||||
| /// Create a string Tensor from the given string | |||||
| /// </summary> | |||||
| public Tensor(string str) | |||||
| public Tensor(Operation op, int value_index, TF_DataType dtype) | |||||
| { | |||||
| _op = op; | |||||
| _value_index = value_index; | |||||
| _override_dtype = dtype; | |||||
| _id = ops.uid(); | |||||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | |||||
| internal Tensor(Shape shape, TF_DataType dtype) => InitTensor(shape, dtype); | |||||
| internal Tensor(Array array, Shape? shape = null) => InitTensor(array, shape); | |||||
| internal Tensor(string value) => InitTensor(value); | |||||
| protected unsafe void InitTensor<T>(T data) where T : unmanaged | |||||
| { | { | ||||
| _handle = StringTensor(new string[] { str }, TensorShape.Scalar); | |||||
| #if TRACK_TENSOR_LIFE | |||||
| print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||||
| #endif | |||||
| _handle = TF_NewTensor(data); | |||||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| public Tensor(string[] strings) | |||||
| protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | |||||
| { | { | ||||
| _handle = StringTensor(strings, new TensorShape(strings.Length)); | |||||
| #if TRACK_TENSOR_LIFE | |||||
| print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||||
| #endif | |||||
| _handle = TF_NewTensor(shape, dtype, null); | |||||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| public Tensor(Operation op, int value_index, TF_DataType dtype) | |||||
| protected void InitTensor(string value) | |||||
| { | { | ||||
| _op = op; | |||||
| _value_index = value_index; | |||||
| _override_dtype = dtype; | |||||
| _id = ops.uid(); | |||||
| _handle = StringTensor(new[] { value }, TensorShape.Scalar); | |||||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | |||||
| protected unsafe void InitTensor(Array array, Shape? shape = null) | |||||
| { | |||||
| shape = shape ?? array.GetShape(); | |||||
| var dtype = array.GetType().GetElementType().as_tf_dtype(); | |||||
| switch (array) | |||||
| { | |||||
| case bool[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case bool[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case bool[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case bool[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case byte[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case byte[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case byte[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case byte[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case int[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case int[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case int[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case int[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case long[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case long[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case long[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case long[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case float[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case float[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case float[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case float[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case double[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case double[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case double[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case double[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||||
| case string[] val: _handle = StringTensor(val, shape); break; | |||||
| default: | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,7 +23,7 @@ namespace Tensorflow | |||||
| public IntPtr StringTensor(byte[][] buffer, TensorShape shape) | public IntPtr StringTensor(byte[][] buffer, TensorShape shape) | ||||
| { | { | ||||
| var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, | var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, | ||||
| shape.ndim == 0 ? null : shape.dims.Select(x => (long)x).ToArray(), | |||||
| shape.ndim == 0 ? null : shape.dims, | |||||
| shape.ndim, | shape.ndim, | ||||
| (ulong)shape.size * TF_TSRING_SIZE); | (ulong)shape.size * TF_TSRING_SIZE); | ||||
| @@ -93,9 +93,13 @@ namespace Tensorflow | |||||
| /// TFE_TensorHandle | /// TFE_TensorHandle | ||||
| /// </summary> | /// </summary> | ||||
| public SafeTensorHandleHandle EagerTensorHandle { get; set; } | public SafeTensorHandleHandle EagerTensorHandle { get; set; } | ||||
| protected bool _createdInGraphMode; | |||||
| public bool CreatedInGraphMode => _createdInGraphMode; | |||||
| public bool IsEagerTensor => this is EagerTensor; | |||||
| protected bool isReferencedByNDArray; | |||||
| public bool IsReferencedByNDArray => isReferencedByNDArray; | |||||
| protected bool isCreatedInGraphMode; | |||||
| public bool IsCreatedInGraphMode => isCreatedInGraphMode; | |||||
| public bool IsSparseTensor => this is SparseTensor; | public bool IsSparseTensor => this is SparseTensor; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -207,6 +211,8 @@ namespace Tensorflow | |||||
| return _tf_output.Value; | return _tf_output.Value; | ||||
| } | } | ||||
| public void SetReferencedByNDArray() => isReferencedByNDArray = true; | |||||
| public Tensor MaybeMove() | public Tensor MaybeMove() | ||||
| { | { | ||||
| var tensor = c_api.TF_TensorMaybeMove(_handle); | var tensor = c_api.TF_TensorMaybeMove(_handle); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Tensorflow.NumPy; | |||||
| using System.Linq; | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -13,7 +14,7 @@ namespace Tensorflow | |||||
| public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone()); | public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone()); | ||||
| public static implicit operator Shape(TensorShape shape) => shape == null ? null : new Shape((long[])shape.dims.Clone()); | public static implicit operator Shape(TensorShape shape) => shape == null ? null : new Shape((long[])shape.dims.Clone()); | ||||
| public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes | |||||
| public static implicit operator int[](TensorShape shape) => shape == null ? null : shape.dims.Select(x => (int)x).ToArray(); //we clone to avoid any changes | |||||
| public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); | public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); | ||||
| public static implicit operator long[](TensorShape shape) => shape == null ? null : (long[])shape.dims.Clone(); //we clone to avoid any changes | public static implicit operator long[](TensorShape shape) => shape == null ? null : (long[])shape.dims.Clone(); //we clone to avoid any changes | ||||
| @@ -21,7 +21,7 @@ namespace Tensorflow | |||||
| public TensorShape shape => items.First().TensorShape; | public TensorShape shape => items.First().TensorShape; | ||||
| public int rank => items.First().rank; | public int rank => items.First().rank; | ||||
| public Graph graph => items.First().graph; | public Graph graph => items.First().graph; | ||||
| public bool IsEagerTensor => items.First().IsEagerTensor; | |||||
| public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode; | |||||
| public bool IsList { get; set; } | public bool IsList { get; set; } | ||||
| public int Length => items.Count(); | public int Length => items.Count(); | ||||
| @@ -98,7 +98,6 @@ namespace Tensorflow | |||||
| attrs: attrs, | attrs: attrs, | ||||
| name: name); | name: name); | ||||
| var o = op.outputs; | |||||
| return op.outputs[0]; | return op.outputs[0]; | ||||
| } | } | ||||
| @@ -167,9 +166,9 @@ namespace Tensorflow | |||||
| case TensorShape val: | case TensorShape val: | ||||
| return new EagerTensor(val.dims, ctx.DeviceName); | return new EagerTensor(val.dims, ctx.DeviceName); | ||||
| case string val: | case string val: | ||||
| return new EagerTensor(val); | |||||
| return new EagerTensor(new[] { val }, Shape.Scalar); | |||||
| case string[] val: | case string[] val: | ||||
| return new EagerTensor(val, ctx.DeviceName); | |||||
| return new EagerTensor(val, new Shape(val.Length)); | |||||
| case bool val: | case bool val: | ||||
| return new EagerTensor(new[] { val }, Shape.Scalar); | return new EagerTensor(new[] { val }, Shape.Scalar); | ||||
| case byte val: | case byte val: | ||||
| @@ -75,7 +75,7 @@ namespace Tensorflow | |||||
| case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | ||||
| return typeof(Complex); | return typeof(Complex); | ||||
| default: | default: | ||||
| return null; | |||||
| throw new NotSupportedException($"Unable to convert {type} to a system data type."); | |||||
| } | } | ||||
| } | } | ||||
| @@ -83,24 +83,25 @@ namespace Tensorflow | |||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="type"></param> | /// <param name="type"></param> | ||||
| /// <param name="dtype"></param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> | /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> | ||||
| public static TF_DataType as_tf_dtype(this Type type, TF_DataType? dtype = null) | |||||
| public static TF_DataType as_tf_dtype(this Type type) | |||||
| { | { | ||||
| while (type.IsArray) | while (type.IsArray) | ||||
| type = type.GetElementType(); | type = type.GetElementType(); | ||||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||||
| switch (type.Name) | switch (type.Name) | ||||
| { | { | ||||
| case "Char": | case "Char": | ||||
| dtype = dtype ?? TF_DataType.TF_UINT8; | |||||
| dtype = TF_DataType.TF_UINT8; | |||||
| break; | break; | ||||
| case "SByte": | case "SByte": | ||||
| dtype = TF_DataType.TF_INT8; | dtype = TF_DataType.TF_INT8; | ||||
| break; | break; | ||||
| case "Byte": | case "Byte": | ||||
| dtype = dtype ?? TF_DataType.TF_UINT8; | |||||
| dtype = TF_DataType.TF_UINT8; | |||||
| break; | break; | ||||
| case "Int16": | case "Int16": | ||||
| dtype = TF_DataType.TF_INT16; | dtype = TF_DataType.TF_INT16; | ||||
| @@ -136,60 +137,32 @@ namespace Tensorflow | |||||
| dtype = TF_DataType.TF_BOOL; | dtype = TF_DataType.TF_BOOL; | ||||
| break; | break; | ||||
| default: | default: | ||||
| throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||||
| throw new NotSupportedException($"Unable to convert {type} to a TensorFlow data type."); | |||||
| } | } | ||||
| return dtype.Value; | |||||
| return dtype; | |||||
| } | } | ||||
| public static TF_DataType tf_dtype_from_name(string name) | public static TF_DataType tf_dtype_from_name(string name) | ||||
| { | { | ||||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||||
| switch (name.ToLower()) | |||||
| TF_DataType dtype = name.ToLower() switch | |||||
| { | { | ||||
| case "char": | |||||
| dtype = TF_DataType.TF_UINT8; | |||||
| break; | |||||
| case "boolean": | |||||
| dtype = TF_DataType.TF_BOOL; | |||||
| break; | |||||
| case "sbyte": | |||||
| dtype = TF_DataType.TF_INT8; | |||||
| break; | |||||
| case "byte": | |||||
| dtype = TF_DataType.TF_UINT8; | |||||
| break; | |||||
| case "int16": | |||||
| dtype = TF_DataType.TF_INT16; | |||||
| break; | |||||
| case "uint16": | |||||
| dtype = TF_DataType.TF_UINT16; | |||||
| break; | |||||
| case "int32": | |||||
| dtype = TF_DataType.TF_INT32; | |||||
| break; | |||||
| case "uint32": | |||||
| dtype = TF_DataType.TF_UINT32; | |||||
| break; | |||||
| case "int64": | |||||
| dtype = TF_DataType.TF_INT64; | |||||
| break; | |||||
| case "uint64": | |||||
| dtype = TF_DataType.TF_UINT64; | |||||
| break; | |||||
| case "single": | |||||
| dtype = TF_DataType.TF_FLOAT; | |||||
| break; | |||||
| case "double": | |||||
| dtype = TF_DataType.TF_DOUBLE; | |||||
| break; | |||||
| case "complex": | |||||
| dtype = TF_DataType.TF_COMPLEX128; | |||||
| break; | |||||
| case "string": | |||||
| dtype = TF_DataType.TF_STRING; | |||||
| break; | |||||
| } | |||||
| "char" => TF_DataType.TF_UINT8, | |||||
| "boolean" => TF_DataType.TF_BOOL, | |||||
| "sbyte" => TF_DataType.TF_INT8, | |||||
| "byte" => TF_DataType.TF_UINT8, | |||||
| "int16" => TF_DataType.TF_INT16, | |||||
| "uint16" => TF_DataType.TF_UINT16, | |||||
| "int32" => TF_DataType.TF_INT32, | |||||
| "uint32" => TF_DataType.TF_UINT32, | |||||
| "int64" => TF_DataType.TF_INT64, | |||||
| "uint64" => TF_DataType.TF_UINT64, | |||||
| "single" => TF_DataType.TF_FLOAT, | |||||
| "double" => TF_DataType.TF_DOUBLE, | |||||
| "complex" => TF_DataType.TF_COMPLEX128, | |||||
| "string" => TF_DataType.TF_STRING, | |||||
| _ => TF_DataType.DtInvalid | |||||
| }; | |||||
| return dtype; | return dtype; | ||||
| } | } | ||||
| @@ -108,7 +108,7 @@ namespace Tensorflow | |||||
| if (values is TensorProto tp) | if (values is TensorProto tp) | ||||
| return tp; | return tp; | ||||
| dtype = values.GetType().as_tf_dtype(); | |||||
| dtype = values.GetDataType(); | |||||
| shape = shape ?? values.GetShape(); | shape = shape ?? values.GetShape(); | ||||
| var tensor_proto = new TensorProto | var tensor_proto = new TensorProto | ||||
| { | { | ||||
| @@ -117,7 +117,13 @@ namespace Tensorflow | |||||
| }; | }; | ||||
| // scalar | // scalar | ||||
| if (!values.GetType().IsArray) | |||||
| if (values is NDArray nd) | |||||
| { | |||||
| var len = nd.dtypesize * nd.size; | |||||
| byte[] bytes = nd.ToByteArray(); | |||||
| tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); | |||||
| } | |||||
| else if (!values.GetType().IsArray) | |||||
| { | { | ||||
| switch (values) | switch (values) | ||||
| { | { | ||||
| @@ -154,7 +160,7 @@ namespace Tensorflow | |||||
| else if (values is byte[] byte_values) | else if (values is byte[] byte_values) | ||||
| tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); | tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); | ||||
| } | } | ||||
| else if(values is Array array) | |||||
| else if (values is Array array) | |||||
| { | { | ||||
| // array | // array | ||||
| var len = dtype.get_datatype_size() * (int)shape.size; | var len = dtype.get_datatype_size() * (int)shape.size; | ||||
| @@ -68,7 +68,7 @@ namespace Tensorflow | |||||
| // when this object is garbage collected the deleter will be too. This | // when this object is garbage collected the deleter will be too. This | ||||
| // means ResourceVariables can be part of reference cycles without those | // means ResourceVariables can be part of reference cycles without those | ||||
| // cycles being uncollectable. | // cycles being uncollectable. | ||||
| if (handle.IsEagerTensor) | |||||
| if (!handle.IsCreatedInGraphMode) | |||||
| { | { | ||||
| _handle = handle.EagerTensorHandle.DangerousGetHandle(); | _handle = handle.EagerTensorHandle.DangerousGetHandle(); | ||||
| eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); | eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); | ||||
| @@ -123,7 +123,7 @@ namespace Tensorflow | |||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| dtype = preferred_dtype; | dtype = preferred_dtype; | ||||
| if (value is EagerTensor eager_tensor) | |||||
| if (value is EagerTensor eager_tensor && !eager_tensor.IsCreatedInGraphMode) | |||||
| { | { | ||||
| if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
| { | { | ||||
| @@ -140,7 +140,13 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| else if (value is NDArray nd) | else if (value is NDArray nd) | ||||
| { | |||||
| return nd; | return nd; | ||||
| } | |||||
| else if (value is Tensor tensor && tensor.IsReferencedByNDArray) | |||||
| { | |||||
| return tensor; | |||||
| } | |||||
| // graph mode | // graph mode | ||||
| Tensor ret = value switch | Tensor ret = value switch | ||||
| @@ -115,7 +115,7 @@ namespace Tensorflow.Keras.Engine | |||||
| bool _in_functional_construction_mode(Tensors inputs) | bool _in_functional_construction_mode(Tensors inputs) | ||||
| { | { | ||||
| return tf.Context.executing_eagerly() | return tf.Context.executing_eagerly() | ||||
| && inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | |||||
| && inputs.Count(x => x.IsCreatedInGraphMode) == inputs.Count(); | |||||
| } | } | ||||
| public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) | public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) | ||||
| @@ -177,7 +177,7 @@ namespace Tensorflow.Keras.Engine | |||||
| tf.init_scope(); | tf.init_scope(); | ||||
| bool need_restore_mode = false; | bool need_restore_mode = false; | ||||
| if (inputs.IsEagerTensor || tf.Context.is_build_function()) | |||||
| if (!inputs.IsCreatedInGraphMode || tf.Context.is_build_function()) | |||||
| { | { | ||||
| need_restore_mode = true; | need_restore_mode = true; | ||||
| tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | ||||
| @@ -148,10 +148,10 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| { | { | ||||
| var dataset = tf.data.Dataset.range(10); | var dataset = tf.data.Dataset.range(10); | ||||
| var cardinality = dataset.cardinality(); | var cardinality = dataset.cardinality(); | ||||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||||
| Assert.AreEqual(cardinality.numpy(), 10L); | |||||
| dataset = dataset.map(x => x[0] + 1); | dataset = dataset.map(x => x[0] + 1); | ||||
| cardinality = dataset.cardinality(); | cardinality = dataset.cardinality(); | ||||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||||
| Assert.AreEqual(cardinality.numpy(), 10L); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -160,7 +160,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| var dataset = tf.data.Dataset.range(10); | var dataset = tf.data.Dataset.range(10); | ||||
| dataset = dataset.map(x => x, num_parallel_calls: -1); | dataset = dataset.map(x => x, num_parallel_calls: -1); | ||||
| var cardinality = dataset.cardinality(); | var cardinality = dataset.cardinality(); | ||||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||||
| Assert.AreEqual(cardinality.numpy(), 10L); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -7,7 +7,7 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestClass] | [TestClass] | ||||
| public class MnistModelLoaderTest | public class MnistModelLoaderTest | ||||
| { | { | ||||
| [TestMethod] | |||||
| [TestMethod, Ignore] | |||||
| public async Task TestLoad() | public async Task TestLoad() | ||||
| { | { | ||||
| var loader = new MnistModelLoader(); | var loader = new MnistModelLoader(); | ||||