| @@ -526,8 +526,19 @@ namespace Tensorflow | |||
| var type = data.GetType(); | |||
| switch (data) | |||
| { | |||
| case Shape shape: | |||
| case TensorShape: | |||
| case Shape: | |||
| return TF_DataType.TF_INT64; | |||
| case Axis: | |||
| return TF_DataType.TF_INT32; | |||
| case NDArray nd: | |||
| return nd.dtype; | |||
| case Tensor tensor: | |||
| return tensor.dtype; | |||
| case Tensor[] tensor: | |||
| return tensor[0].dtype; | |||
| case ResourceVariable variable: | |||
| return variable.dtype; | |||
| default: | |||
| return type.as_tf_dtype(); | |||
| } | |||
| @@ -142,7 +142,7 @@ namespace Tensorflow.Contexts | |||
| bool has_graph_arg = !tf.Context.executing_eagerly(); | |||
| foreach (var el in flatten_args) | |||
| { | |||
| if (el is Tensor tensor && !tensor.IsEagerTensor) | |||
| if (el is Tensor tensor && tensor.IsCreatedInGraphMode) | |||
| { | |||
| has_graph_arg = true; | |||
| break; | |||
| @@ -50,9 +50,6 @@ namespace Tensorflow.Eager | |||
| public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype) | |||
| => NewEagerTensorHandle(_handle); | |||
| internal unsafe EagerTensor(string value) : base(value) | |||
| => NewEagerTensorHandle(_handle); | |||
| internal unsafe EagerTensor(Array array, Shape shape) : base(array, shape) | |||
| => NewEagerTensorHandle(_handle); | |||
| @@ -141,7 +141,7 @@ namespace Tensorflow.Functions | |||
| src_graph: _func_graph); | |||
| var captures_from_forward = backwards_graph.external_captures | |||
| .Where(x => !x.IsEagerTensor && x.graph == _func_graph) | |||
| .Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph) | |||
| .ToArray(); | |||
| foreach(var capture in captures_from_forward) | |||
| { | |||
| @@ -8,20 +8,47 @@ namespace Tensorflow.NumPy | |||
| { | |||
| public partial class NDArray | |||
| { | |||
| public NDArray(bool value) => _tensor = new EagerTensor(value); | |||
| public NDArray(byte value) => _tensor = new EagerTensor(value); | |||
| public NDArray(short value) => _tensor = new EagerTensor(value); | |||
| public NDArray(int value) => _tensor = new EagerTensor(value); | |||
| public NDArray(long value) => _tensor = new EagerTensor(value); | |||
| public NDArray(float value) => _tensor = new EagerTensor(value); | |||
| public NDArray(double value) => _tensor = new EagerTensor(value); | |||
| public NDArray(bool value) => Init(value); | |||
| public NDArray(byte value) => Init(value); | |||
| public NDArray(short value) => Init(value); | |||
| public NDArray(int value) => Init(value); | |||
| public NDArray(long value) => Init(value); | |||
| public NDArray(float value) => Init(value); | |||
| public NDArray(double value) => Init(value); | |||
| public NDArray(Array value, Shape? shape = null) => Init(value, shape); | |||
| public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype); | |||
| public NDArray(Tensor value, Shape? shape = null) => Init(value, shape); | |||
| public NDArray(Array value, Shape? shape = null) => _tensor = new EagerTensor(value, shape); | |||
| public static NDArray Scalar<T>(T value) where T : unmanaged | |||
| => value switch | |||
| { | |||
| bool val => new NDArray(val), | |||
| byte val => new NDArray(val), | |||
| int val => new NDArray(val), | |||
| float val => new NDArray(val), | |||
| double val => new NDArray(val), | |||
| _ => throw new NotImplementedException("") | |||
| }; | |||
| void Init<T>(T value) where T : unmanaged | |||
| { | |||
| _tensor = new EagerTensor(value); | |||
| _tensor.SetReferencedByNDArray(); | |||
| } | |||
| void Init(Array value, Shape? shape = null) | |||
| { | |||
| _tensor = new EagerTensor(value, shape ?? value.GetShape()); | |||
| _tensor.SetReferencedByNDArray(); | |||
| } | |||
| public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | |||
| => _tensor = new EagerTensor(shape, dtype: dtype); | |||
| void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | |||
| { | |||
| _tensor = new EagerTensor(shape, dtype: dtype); | |||
| _tensor.SetReferencedByNDArray(); | |||
| } | |||
| public NDArray(Tensor value, Shape? shape = null) | |||
| void Init(Tensor value, Shape? shape = null) | |||
| { | |||
| if (shape is not null) | |||
| _tensor = tf.reshape(value, shape); | |||
| @@ -30,18 +57,8 @@ namespace Tensorflow.NumPy | |||
| if (_tensor.TensorDataPointer == IntPtr.Zero) | |||
| _tensor = tf.get_default_session().eval(_tensor); | |||
| } | |||
| public static NDArray Scalar<T>(T value) where T : unmanaged | |||
| { | |||
| return value switch | |||
| { | |||
| bool val => new NDArray(val), | |||
| int val => new NDArray(val), | |||
| float val => new NDArray(val), | |||
| double val => new NDArray(val), | |||
| _ => throw new NotImplementedException("") | |||
| }; | |||
| _tensor.SetReferencedByNDArray(); | |||
| } | |||
| } | |||
| } | |||
| @@ -21,6 +21,7 @@ using System.Linq; | |||
| using System.Numerics; | |||
| using System.Text; | |||
| using static Tensorflow.c_api; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -31,7 +32,7 @@ namespace Tensorflow | |||
| public Tensor() | |||
| { | |||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||
| } | |||
| /// <summary> | |||
| @@ -41,60 +42,7 @@ namespace Tensorflow | |||
| public Tensor(IntPtr handle) | |||
| { | |||
| _handle = handle; | |||
| //no need to set AllocationType = AllocationType.None; | |||
| #if TRACK_TENSOR_LIFE | |||
| print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||
| #endif | |||
| } | |||
| unsafe internal Tensor(Shape shape, TF_DataType dtype) | |||
| => _handle = TF_NewTensor(shape, dtype, null); | |||
| internal Tensor(Array array, Shape? shape = null) | |||
| => InitTensor(array, shape); | |||
| unsafe void InitTensor(Array array, Shape? shape = null) | |||
| { | |||
| shape = shape ?? array.GetShape(); | |||
| var dtype = array.GetType().GetElementType().as_tf_dtype(); | |||
| switch (array) | |||
| { | |||
| case bool[] val: | |||
| fixed (void* addr = &val[0]) | |||
| _handle = TF_NewTensor(shape, dtype, addr); | |||
| break; | |||
| case int[] val: | |||
| fixed (void* addr = &val[0]) | |||
| _handle = TF_NewTensor(shape, dtype, addr); | |||
| break; | |||
| case int[,] val: | |||
| fixed (void* addr = &val[0, 0]) | |||
| _handle = TF_NewTensor(shape, dtype, addr); | |||
| break; | |||
| case long[] val: | |||
| fixed (void* addr = &val[0]) | |||
| _handle = TF_NewTensor(shape, dtype, addr); | |||
| break; | |||
| case float[] val: | |||
| fixed (void* addr = &val[0]) | |||
| _handle = TF_NewTensor(shape, dtype, addr); | |||
| break; | |||
| case float[,] val: | |||
| fixed (void* addr = &val[0, 0]) | |||
| _handle = TF_NewTensor(shape, dtype, addr); | |||
| break; | |||
| case double[] val: | |||
| fixed (void* addr = &val[0]) | |||
| _handle = TF_NewTensor(shape, dtype, addr); | |||
| break; | |||
| case double[,] val: | |||
| fixed (void* addr = &val[0, 0]) | |||
| _handle = TF_NewTensor(shape, dtype, addr); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException(""); | |||
| } | |||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||
| } | |||
| /// <summary> | |||
| @@ -109,22 +57,26 @@ namespace Tensorflow | |||
| public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes) | |||
| { | |||
| _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (ulong)num_bytes); | |||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||
| } | |||
| public unsafe Tensor(NDArray nd) | |||
| => _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | |||
| { | |||
| _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | |||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||
| } | |||
| #region scala | |||
| public Tensor(bool value) => _handle = TF_NewTensor(value); | |||
| public Tensor(byte value) => _handle = TF_NewTensor(value); | |||
| public Tensor(sbyte value) => _handle = TF_NewTensor(value); | |||
| public Tensor(short value) => _handle = TF_NewTensor(value); | |||
| public Tensor(int value) => _handle = TF_NewTensor(value); | |||
| public Tensor(uint value) => _handle = TF_NewTensor(value); | |||
| public Tensor(long value) => _handle = TF_NewTensor(value); | |||
| public Tensor(ulong value) => _handle = TF_NewTensor(value); | |||
| public Tensor(float value) => _handle = TF_NewTensor(value); | |||
| public Tensor(double value) => _handle = TF_NewTensor(value); | |||
| public Tensor(bool value) => InitTensor(value); | |||
| public Tensor(byte value) => InitTensor(value); | |||
| public Tensor(sbyte value) => InitTensor(value); | |||
| public Tensor(short value) => InitTensor(value); | |||
| public Tensor(int value) => InitTensor(value); | |||
| public Tensor(uint value) => InitTensor(value); | |||
| public Tensor(long value) => InitTensor(value); | |||
| public Tensor(ulong value) => InitTensor(value); | |||
| public Tensor(float value) => InitTensor(value); | |||
| public Tensor(double value) => InitTensor(value); | |||
| #endregion | |||
| #region 1d array | |||
| @@ -142,31 +94,74 @@ namespace Tensorflow | |||
| public Tensor(Complex[] data, Shape? shape = null) => InitTensor(data, shape); | |||
| #endregion | |||
| /// <summary> | |||
| /// Create a string Tensor from the given string | |||
| /// </summary> | |||
| public Tensor(string str) | |||
| public Tensor(Operation op, int value_index, TF_DataType dtype) | |||
| { | |||
| _op = op; | |||
| _value_index = value_index; | |||
| _override_dtype = dtype; | |||
| _id = ops.uid(); | |||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||
| } | |||
| internal Tensor(Shape shape, TF_DataType dtype) => InitTensor(shape, dtype); | |||
| internal Tensor(Array array, Shape? shape = null) => InitTensor(array, shape); | |||
| internal Tensor(string value) => InitTensor(value); | |||
| protected unsafe void InitTensor<T>(T data) where T : unmanaged | |||
| { | |||
| _handle = StringTensor(new string[] { str }, TensorShape.Scalar); | |||
| #if TRACK_TENSOR_LIFE | |||
| print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||
| #endif | |||
| _handle = TF_NewTensor(data); | |||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||
| } | |||
| public Tensor(string[] strings) | |||
| protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | |||
| { | |||
| _handle = StringTensor(strings, new TensorShape(strings.Length)); | |||
| #if TRACK_TENSOR_LIFE | |||
| print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); | |||
| #endif | |||
| _handle = TF_NewTensor(shape, dtype, null); | |||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||
| } | |||
| public Tensor(Operation op, int value_index, TF_DataType dtype) | |||
| protected void InitTensor(string value) | |||
| { | |||
| _op = op; | |||
| _value_index = value_index; | |||
| _override_dtype = dtype; | |||
| _id = ops.uid(); | |||
| _handle = StringTensor(new[] { value }, TensorShape.Scalar); | |||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||
| } | |||
| protected unsafe void InitTensor(Array array, Shape? shape = null) | |||
| { | |||
| shape = shape ?? array.GetShape(); | |||
| var dtype = array.GetType().GetElementType().as_tf_dtype(); | |||
| switch (array) | |||
| { | |||
| case bool[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case bool[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case bool[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case bool[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case byte[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case byte[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case byte[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case byte[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case int[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case int[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case int[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case int[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case long[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case long[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case long[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case long[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case float[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case float[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case float[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case float[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case double[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case double[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case double[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case double[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break; | |||
| case string[] val: _handle = StringTensor(val, shape); break; | |||
| default: | |||
| throw new NotImplementedException(""); | |||
| } | |||
| isCreatedInGraphMode = !tf.executing_eagerly(); | |||
| } | |||
| } | |||
| } | |||
| @@ -23,7 +23,7 @@ namespace Tensorflow | |||
| public IntPtr StringTensor(byte[][] buffer, TensorShape shape) | |||
| { | |||
| var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, | |||
| shape.ndim == 0 ? null : shape.dims.Select(x => (long)x).ToArray(), | |||
| shape.ndim == 0 ? null : shape.dims, | |||
| shape.ndim, | |||
| (ulong)shape.size * TF_TSRING_SIZE); | |||
| @@ -93,9 +93,13 @@ namespace Tensorflow | |||
| /// TFE_TensorHandle | |||
| /// </summary> | |||
| public SafeTensorHandleHandle EagerTensorHandle { get; set; } | |||
| protected bool _createdInGraphMode; | |||
| public bool CreatedInGraphMode => _createdInGraphMode; | |||
| public bool IsEagerTensor => this is EagerTensor; | |||
| protected bool isReferencedByNDArray; | |||
| public bool IsReferencedByNDArray => isReferencedByNDArray; | |||
| protected bool isCreatedInGraphMode; | |||
| public bool IsCreatedInGraphMode => isCreatedInGraphMode; | |||
| public bool IsSparseTensor => this is SparseTensor; | |||
| /// <summary> | |||
| @@ -207,6 +211,8 @@ namespace Tensorflow | |||
| return _tf_output.Value; | |||
| } | |||
| public void SetReferencedByNDArray() => isReferencedByNDArray = true; | |||
| public Tensor MaybeMove() | |||
| { | |||
| var tensor = c_api.TF_TensorMaybeMove(_handle); | |||
| @@ -1,4 +1,5 @@ | |||
| using Tensorflow.NumPy; | |||
| using System.Linq; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -13,7 +14,7 @@ namespace Tensorflow | |||
| public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone()); | |||
| public static implicit operator Shape(TensorShape shape) => shape == null ? null : new Shape((long[])shape.dims.Clone()); | |||
| public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes | |||
| public static implicit operator int[](TensorShape shape) => shape == null ? null : shape.dims.Select(x => (int)x).ToArray(); //we clone to avoid any changes | |||
| public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims); | |||
| public static implicit operator long[](TensorShape shape) => shape == null ? null : (long[])shape.dims.Clone(); //we clone to avoid any changes | |||
| @@ -21,7 +21,7 @@ namespace Tensorflow | |||
| public TensorShape shape => items.First().TensorShape; | |||
| public int rank => items.First().rank; | |||
| public Graph graph => items.First().graph; | |||
| public bool IsEagerTensor => items.First().IsEagerTensor; | |||
| public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode; | |||
| public bool IsList { get; set; } | |||
| public int Length => items.Count(); | |||
| @@ -98,7 +98,6 @@ namespace Tensorflow | |||
| attrs: attrs, | |||
| name: name); | |||
| var o = op.outputs; | |||
| return op.outputs[0]; | |||
| } | |||
| @@ -167,9 +166,9 @@ namespace Tensorflow | |||
| case TensorShape val: | |||
| return new EagerTensor(val.dims, ctx.DeviceName); | |||
| case string val: | |||
| return new EagerTensor(val); | |||
| return new EagerTensor(new[] { val }, Shape.Scalar); | |||
| case string[] val: | |||
| return new EagerTensor(val, ctx.DeviceName); | |||
| return new EagerTensor(val, new Shape(val.Length)); | |||
| case bool val: | |||
| return new EagerTensor(new[] { val }, Shape.Scalar); | |||
| case byte val: | |||
| @@ -75,7 +75,7 @@ namespace Tensorflow | |||
| case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||
| return typeof(Complex); | |||
| default: | |||
| return null; | |||
| throw new NotSupportedException($"Unable to convert {type} to a system data type."); | |||
| } | |||
| } | |||
| @@ -83,24 +83,25 @@ namespace Tensorflow | |||
| /// | |||
| /// </summary> | |||
| /// <param name="type"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> | |||
| public static TF_DataType as_tf_dtype(this Type type, TF_DataType? dtype = null) | |||
| public static TF_DataType as_tf_dtype(this Type type) | |||
| { | |||
| while (type.IsArray) | |||
| type = type.GetElementType(); | |||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||
| switch (type.Name) | |||
| { | |||
| case "Char": | |||
| dtype = dtype ?? TF_DataType.TF_UINT8; | |||
| dtype = TF_DataType.TF_UINT8; | |||
| break; | |||
| case "SByte": | |||
| dtype = TF_DataType.TF_INT8; | |||
| break; | |||
| case "Byte": | |||
| dtype = dtype ?? TF_DataType.TF_UINT8; | |||
| dtype = TF_DataType.TF_UINT8; | |||
| break; | |||
| case "Int16": | |||
| dtype = TF_DataType.TF_INT16; | |||
| @@ -136,60 +137,32 @@ namespace Tensorflow | |||
| dtype = TF_DataType.TF_BOOL; | |||
| break; | |||
| default: | |||
| throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||
| throw new NotSupportedException($"Unable to convert {type} to a TensorFlow data type."); | |||
| } | |||
| return dtype.Value; | |||
| return dtype; | |||
| } | |||
| public static TF_DataType tf_dtype_from_name(string name) | |||
| { | |||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||
| switch (name.ToLower()) | |||
| TF_DataType dtype = name.ToLower() switch | |||
| { | |||
| case "char": | |||
| dtype = TF_DataType.TF_UINT8; | |||
| break; | |||
| case "boolean": | |||
| dtype = TF_DataType.TF_BOOL; | |||
| break; | |||
| case "sbyte": | |||
| dtype = TF_DataType.TF_INT8; | |||
| break; | |||
| case "byte": | |||
| dtype = TF_DataType.TF_UINT8; | |||
| break; | |||
| case "int16": | |||
| dtype = TF_DataType.TF_INT16; | |||
| break; | |||
| case "uint16": | |||
| dtype = TF_DataType.TF_UINT16; | |||
| break; | |||
| case "int32": | |||
| dtype = TF_DataType.TF_INT32; | |||
| break; | |||
| case "uint32": | |||
| dtype = TF_DataType.TF_UINT32; | |||
| break; | |||
| case "int64": | |||
| dtype = TF_DataType.TF_INT64; | |||
| break; | |||
| case "uint64": | |||
| dtype = TF_DataType.TF_UINT64; | |||
| break; | |||
| case "single": | |||
| dtype = TF_DataType.TF_FLOAT; | |||
| break; | |||
| case "double": | |||
| dtype = TF_DataType.TF_DOUBLE; | |||
| break; | |||
| case "complex": | |||
| dtype = TF_DataType.TF_COMPLEX128; | |||
| break; | |||
| case "string": | |||
| dtype = TF_DataType.TF_STRING; | |||
| break; | |||
| } | |||
| "char" => TF_DataType.TF_UINT8, | |||
| "boolean" => TF_DataType.TF_BOOL, | |||
| "sbyte" => TF_DataType.TF_INT8, | |||
| "byte" => TF_DataType.TF_UINT8, | |||
| "int16" => TF_DataType.TF_INT16, | |||
| "uint16" => TF_DataType.TF_UINT16, | |||
| "int32" => TF_DataType.TF_INT32, | |||
| "uint32" => TF_DataType.TF_UINT32, | |||
| "int64" => TF_DataType.TF_INT64, | |||
| "uint64" => TF_DataType.TF_UINT64, | |||
| "single" => TF_DataType.TF_FLOAT, | |||
| "double" => TF_DataType.TF_DOUBLE, | |||
| "complex" => TF_DataType.TF_COMPLEX128, | |||
| "string" => TF_DataType.TF_STRING, | |||
| _ => TF_DataType.DtInvalid | |||
| }; | |||
| return dtype; | |||
| } | |||
| @@ -108,7 +108,7 @@ namespace Tensorflow | |||
| if (values is TensorProto tp) | |||
| return tp; | |||
| dtype = values.GetType().as_tf_dtype(); | |||
| dtype = values.GetDataType(); | |||
| shape = shape ?? values.GetShape(); | |||
| var tensor_proto = new TensorProto | |||
| { | |||
| @@ -117,7 +117,13 @@ namespace Tensorflow | |||
| }; | |||
| // scalar | |||
| if (!values.GetType().IsArray) | |||
| if (values is NDArray nd) | |||
| { | |||
| var len = nd.dtypesize * nd.size; | |||
| byte[] bytes = nd.ToByteArray(); | |||
| tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); | |||
| } | |||
| else if (!values.GetType().IsArray) | |||
| { | |||
| switch (values) | |||
| { | |||
| @@ -154,7 +160,7 @@ namespace Tensorflow | |||
| else if (values is byte[] byte_values) | |||
| tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); | |||
| } | |||
| else if(values is Array array) | |||
| else if (values is Array array) | |||
| { | |||
| // array | |||
| var len = dtype.get_datatype_size() * (int)shape.size; | |||
| @@ -68,7 +68,7 @@ namespace Tensorflow | |||
| // when this object is garbage collected the deleter will be too. This | |||
| // means ResourceVariables can be part of reference cycles without those | |||
| // cycles being uncollectable. | |||
| if (handle.IsEagerTensor) | |||
| if (!handle.IsCreatedInGraphMode) | |||
| { | |||
| _handle = handle.EagerTensorHandle.DangerousGetHandle(); | |||
| eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); | |||
| @@ -123,7 +123,7 @@ namespace Tensorflow | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = preferred_dtype; | |||
| if (value is EagerTensor eager_tensor) | |||
| if (value is EagerTensor eager_tensor && !eager_tensor.IsCreatedInGraphMode) | |||
| { | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| @@ -140,7 +140,13 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| else if (value is NDArray nd) | |||
| { | |||
| return nd; | |||
| } | |||
| else if (value is Tensor tensor && tensor.IsReferencedByNDArray) | |||
| { | |||
| return tensor; | |||
| } | |||
| // graph mode | |||
| Tensor ret = value switch | |||
| @@ -115,7 +115,7 @@ namespace Tensorflow.Keras.Engine | |||
| bool _in_functional_construction_mode(Tensors inputs) | |||
| { | |||
| return tf.Context.executing_eagerly() | |||
| && inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | |||
| && inputs.Count(x => x.IsCreatedInGraphMode) == inputs.Count(); | |||
| } | |||
| public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) | |||
| @@ -177,7 +177,7 @@ namespace Tensorflow.Keras.Engine | |||
| tf.init_scope(); | |||
| bool need_restore_mode = false; | |||
| if (inputs.IsEagerTensor || tf.Context.is_build_function()) | |||
| if (!inputs.IsCreatedInGraphMode || tf.Context.is_build_function()) | |||
| { | |||
| need_restore_mode = true; | |||
| tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | |||
| @@ -148,10 +148,10 @@ namespace TensorFlowNET.UnitTest.Dataset | |||
| { | |||
| var dataset = tf.data.Dataset.range(10); | |||
| var cardinality = dataset.cardinality(); | |||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||
| Assert.AreEqual(cardinality.numpy(), 10L); | |||
| dataset = dataset.map(x => x[0] + 1); | |||
| cardinality = dataset.cardinality(); | |||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||
| Assert.AreEqual(cardinality.numpy(), 10L); | |||
| } | |||
| [TestMethod] | |||
| @@ -160,7 +160,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||
| var dataset = tf.data.Dataset.range(10); | |||
| dataset = dataset.map(x => x, num_parallel_calls: -1); | |||
| var cardinality = dataset.cardinality(); | |||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||
| Assert.AreEqual(cardinality.numpy(), 10L); | |||
| } | |||
| [TestMethod] | |||
| @@ -7,7 +7,7 @@ namespace TensorFlowNET.UnitTest | |||
| [TestClass] | |||
| public class MnistModelLoaderTest | |||
| { | |||
| [TestMethod] | |||
| [TestMethod, Ignore] | |||
| public async Task TestLoad() | |||
| { | |||
| var loader = new MnistModelLoader(); | |||