diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 2c05e36a..2a76c52c 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -141,7 +141,7 @@ namespace Tensorflow.Operations data, frame_name, is_constant, parallel_iterations, name: name); if (use_input_shape) - result.SetShape(data.TensorShape); + result.set_shape(data.TensorShape); return result; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 3198942b..1b68d1cd 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -233,7 +233,7 @@ namespace Tensorflow.Operations dims.AddRange(x_static_shape.dims.Skip(2)); var shape = new TensorShape(dims.ToArray()); - x_t.SetShape(shape); + x_t.set_shape(shape); return x_t; } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index d3213250..92fe2e3c 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -351,7 +351,7 @@ namespace Tensorflow var input_shape = tensor_util.to_shape(input_tensor.shape); if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) { - var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype()); + var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); return constant_op.constant(nd, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index cbf55861..63e0fca1 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -98,7 +98,7 @@ namespace Tensorflow // float to be selected, hence we use a >= comparison. var keep_mask = random_tensor >= rate; var ret = x * scale * math_ops.cast(keep_mask, x.dtype); - ret.SetShape(x.TensorShape); + ret.set_shape(x.TensorShape); return ret; }); } diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index b6db7a65..42ab1d4b 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -273,7 +273,7 @@ namespace Tensorflow { var tensor = new Tensor(output); NDArray nd = null; - Type type = tensor.dtype.as_numpy_datatype(); + Type type = tensor.dtype.as_numpy_dtype(); var ndims = tensor.shape; var offset = c_api.TF_TensorData(output); @@ -285,7 +285,7 @@ namespace Tensorflow nd = NDArray.Scalar(*(bool*)offset); break; case TF_DataType.TF_STRING: - var bytes = tensor.Data(); + var bytes = tensor.BufferToArray(); // wired, don't know why we have to start from offset 9. // length in the begin var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); @@ -324,7 +324,7 @@ namespace Tensorflow nd = np.array(bools).reshape(ndims); break; case TF_DataType.TF_STRING: - var bytes = tensor.Data(); + var bytes = tensor.BufferToArray(); // wired, don't know why we have to start from offset 9. // length in the begin var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 5fd3dfba..ea58607b 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -549,10 +549,11 @@ namespace Tensorflow this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it var ptr = new IntPtr(arraySlice.Address); int num_bytes = (nd.size * nd.dtypesize); - var dtype = given_dtype ?? ToTFDataType(nd.dtype); + var dtype = given_dtype ?? nd.dtype.as_dtype(); var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); IsMemoryOwner = false; return handle; + } public unsafe Tensor(byte[][] buffer, long[] shape) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 66466b22..798c27b6 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -17,9 +17,16 @@ using NumSharp; using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.Linq; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using System.Threading.Tasks; +using NumSharp.Backends; +using NumSharp.Backends.Unmanaged; +using NumSharp.Utilities; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -29,42 +36,68 @@ namespace Tensorflow /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// + [SuppressMessage("ReSharper", "ConvertToAutoProperty")] public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike { - private int _id; - private Operation _op; + private readonly int _id; + private readonly Operation _op; + private readonly int _value_index; + private TF_Output? _tf_output; + private readonly TF_DataType _dtype; public int Id => _id; + + /// + /// The Graph that contains this tensor. + /// public Graph graph => op?.graph; + + /// + /// The Operation that produces this tensor as an output. + /// public Operation op => _op; + public Tensor[] outputs => op.outputs; /// - /// The string name of this tensor. + /// The string name of this tensor. /// public string name => $"{(op == null ? "" : $"{op.name}:{_value_index}")}"; - private int _value_index; + /// + /// The index of this tensor in the outputs of its Operation. + /// public int value_index => _value_index; - private TF_DataType _dtype = TF_DataType.DtInvalid; + /// + /// The DType of elements in this tensor. + /// public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); - public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; - public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); + public int NDims => rank; - private TF_Output? _tf_output; + /// + /// The name of the device on which this tensor will be produced, or null. + /// + public string Device => op.Device; + + public int[] dims => shape; /// - /// used for keep other pointer when do implicit operating + /// Used for keep other pointer when do implicit operating /// public object Tag { get; set; } + + /// + /// Returns the shape of a tensor. + /// + /// https://www.tensorflow.org/api_docs/python/tf/shape public int[] shape { get @@ -76,14 +109,13 @@ namespace Tensorflow var status = new Status(); c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); status.Check(); - } - else + } else { for (int i = 0; i < rank; i++) dims[i] = c_api.TF_Dim(_handle, i); } - return dims.Select(x => Convert.ToInt32(x)).ToArray(); + return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray(); } set @@ -93,38 +125,52 @@ namespace Tensorflow if (value == null) c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); else - c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status); + c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); } } public int[] _shape_tuple() { - if (shape == null) return null; - return shape.Select(x => (int)x).ToArray(); + return (int[]) shape.Clone(); } public TensorShape TensorShape => tensor_util.to_shape(shape); - public void SetShape(TensorShape shape) + /// + /// Updates the shape of this tensor. + /// + public void set_shape(TensorShape shape) + { + this.shape = (int[]) shape.dims.Clone(); + } + + /// + /// Updates the shape of this tensor. + /// + [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] + public void SetShape(TensorShape shape) { - this.shape = shape.dims; + this.shape = (int[]) shape.dims.Clone(); } + /// + /// Updates the shape of this tensor. + /// public void set_shape(Tensor shape) { + // ReSharper disable once MergeConditionalExpression this.shape = shape is null ? null : shape.shape; } - public int[] dims => shape; - /// - /// number of dimensions - /// 0 Scalar (magnitude only) - /// 1 Vector (magnitude and direction) - /// 2 Matrix (table of numbers) - /// 3 3-Tensor (cube of numbers) + /// number of dimensions

+ /// 0 Scalar (magnitude only)

+ /// 1 Vector (magnitude and direction)

+ /// 2 Matrix (table of numbers)

+ /// 3 3-Tensor (cube of numbers)

/// n n-Tensor (you get the idea) ///
+ /// https://www.tensorflow.org/api_docs/python/tf/rank public int rank { get @@ -137,17 +183,15 @@ namespace Tensorflow status.Check(); return ndim; } - else - { - return c_api.TF_NumDims(_handle); - } + + return c_api.TF_NumDims(_handle); } } - public int NDims => rank; - - public string Device => op.Device; - + /// + /// Returns a list of Operations that consume this tensor. + /// + /// public Operation[] consumers() { var output = _as_tf_output(); @@ -157,37 +201,136 @@ namespace Tensorflow public TF_Output _as_tf_output() { - if(!_tf_output.HasValue) + if (!_tf_output.HasValue) _tf_output = new TF_Output(op, value_index); return _tf_output.Value; } - public T[] Data() + [Obsolete("Please use ToArray() instead.", false)] + public T[] Data() where T : unmanaged + { + return ToArray(); + } + + /// + /// + /// + /// + /// + /// When is string + public T[] ToArray() where T : unmanaged { - // Column major order - // https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg - // matrix:[[1, 2, 3], [4, 5, 6]] - // index: 0 2 4 1 3 5 - // result: 1 4 2 5 3 6 - var data = new T[size]; - - for (ulong i = 0; i < size; i++) + //when T is string + if (typeof(T) == typeof(string)) { - data[i] = Marshal.PtrToStructure(buffer + (int)(i * itemsize)); + if (dtype != TF_DataType.TF_STRING) + throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string."); + + return (T[]) (object) StringData(); } - return data; + //Are the types matching? + if (typeof(T).as_dtype() == _dtype) + { + //types match, no need to perform cast + var ret = new T[size]; + unsafe + { + var len = (long) size; + fixed (T* dstRet = ret) + { + T* dst = dstRet; //local stack copy + if (typeof(T).IsPrimitive) + { + var src = (T*) buffer; + len *= ((long) itemsize); + System.Buffer.MemoryCopy(src, dst, len, len); + } else + { + var itemsize = (long) this.itemsize; + var buffer = this.buffer.ToInt64(); + Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure(new IntPtr(buffer + i * itemsize))); + } + } + } + + return ret; + } else + { + + //types do not match, need to perform cast + var ret = new T[size]; + unsafe + { + var len = (long) size; + fixed (T* dstRet = ret) + { + T* dst = dstRet; //local stack copy + +#if _REGEN + #region Compute + switch (_dtype.as_numpy_datatype().GetTypeCode()) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1:new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + % + default: + throw new NotSupportedException(); + } + #endregion +#else + #region Compute + switch (_dtype.as_numpy_dtype().GetTypeCode()) + { + case NPTypeCode.Boolean:new UnmanagedMemoryBlock((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Byte:new UnmanagedMemoryBlock((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int16:new UnmanagedMemoryBlock((short*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt16:new UnmanagedMemoryBlock((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int32:new UnmanagedMemoryBlock((int*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt32:new UnmanagedMemoryBlock((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int64:new UnmanagedMemoryBlock((long*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt64:new UnmanagedMemoryBlock((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Char:new UnmanagedMemoryBlock((char*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Double:new UnmanagedMemoryBlock((double*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Single:new UnmanagedMemoryBlock((float*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + default: + throw new NotSupportedException(); + } + #endregion +#endif + + } + } + + return ret; + } } + + /// + /// Copies the memory of current buffer onto newly allocated array. + /// + /// + [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] public byte[] Data() + { + return BufferToArray(); + } + + /// + /// Copies the memory of current buffer onto newly allocated array. + /// + /// + public byte[] BufferToArray() { var data = new byte[bytesize]; - Marshal.Copy(buffer, data, 0, (int)bytesize); + Marshal.Copy(buffer, data, 0, (int) bytesize); return data; } - public unsafe string[] StringData() + /// Used internally in ToArray<T> + private unsafe string[] StringData() { // // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. @@ -199,19 +342,19 @@ namespace Tensorflow var buffer = new byte[size][]; var src = c_api.TF_TensorData(_handle); - var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); - src += (int)(size * 8); + var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize); + src += (int) (size * 8); for (int i = 0; i < buffer.Length; i++) { using (var status = new Status()) { IntPtr dst = IntPtr.Zero; UIntPtr dstLen = UIntPtr.Zero; - var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); + var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status); status.Check(true); - buffer[i] = new byte[(int)dstLen]; + buffer[i] = new byte[(int) dstLen]; Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); - src += (int)read; + src += (int) read; } } @@ -229,51 +372,29 @@ namespace Tensorflow } /// - /// Evaluates this tensor in a `Session`. + /// Evaluates this tensor in a `Session`. /// /// A dictionary that maps `Tensor` objects to feed values. - /// The `Session` to be used to evaluate this tensor. - /// + /// A array corresponding to the value of this tensor. public NDArray eval(params FeedItem[] feed_dict) { return ops._eval_using_default_session(this, feed_dict, graph); } + /// + /// Evaluates this tensor in a `Session`. + /// + /// A dictionary that maps `Tensor` objects to feed values. + /// The `Session` to be used to evaluate this tensor. + /// A array corresponding to the value of this tensor. public NDArray eval(Session session, FeedItem[] feed_dict = null) { return ops._eval_using_default_session(this, feed_dict, graph, session); } - public TF_DataType ToTFDataType(Type type) - { - switch (type.Name) - { - case "Char": - return TF_DataType.TF_UINT8; - case "Int16": - return TF_DataType.TF_INT16; - case "Int32": - return TF_DataType.TF_INT32; - case "Int64": - return TF_DataType.TF_INT64; - case "Single": - return TF_DataType.TF_FLOAT; - case "Double": - return TF_DataType.TF_DOUBLE; - case "Byte": - return TF_DataType.TF_UINT8; - case "String": - return TF_DataType.TF_STRING; - case "Boolean": - return TF_DataType.TF_BOOL; - default: - throw new NotImplementedException("ToTFDataType error"); - } - } - public Tensor slice(Slice slice) { - var slice_spec = new int[] { slice.Start.Value }; + var slice_spec = new int[] {slice.Start.Value}; var begin = new List(); var end = new List(); var strides = new List(); @@ -289,26 +410,26 @@ namespace Tensorflow if (slice.Stop.HasValue) { end.Add(slice.Stop.Value); - } - else + } else { end.Add(0); end_mask |= (1 << index); } + strides.Add(slice.Step); index += 1; } - return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => { string name = scope; if (begin != null) { var (packed_begin, packed_end, packed_strides) = (array_ops.stack(begin.ToArray()), - array_ops.stack(end.ToArray()), - array_ops.stack(strides.ToArray())); + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); return gen_array_ops.strided_slice( this, @@ -320,7 +441,6 @@ namespace Tensorflow shrink_axis_mask: shrink_axis_mask, new_axis_mask: new_axis_mask, ellipsis_mask: ellipsis_mask, - name: name); } @@ -330,7 +450,7 @@ namespace Tensorflow public Tensor slice(int start) { - var slice_spec = new int[] { start }; + var slice_spec = new int[] {start}; var begin = new List(); var end = new List(); var strides = new List(); @@ -349,15 +469,15 @@ namespace Tensorflow index += 1; } - return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => { string name = scope; if (begin != null) { var (packed_begin, packed_end, packed_strides) = (array_ops.stack(begin.ToArray()), - array_ops.stack(end.ToArray()), - array_ops.stack(strides.ToArray())); + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); return gen_array_ops.strided_slice( this, @@ -369,7 +489,6 @@ namespace Tensorflow shrink_axis_mask: shrink_axis_mask, new_axis_mask: new_axis_mask, ellipsis_mask: ellipsis_mask, - name: name); } @@ -392,13 +511,9 @@ namespace Tensorflow return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; } - //protected override void DisposeManagedState() - //{ - //} - protected override void DisposeUnmanagedResources(IntPtr handle) { - if(handle != IntPtr.Zero) + if (handle != IntPtr.Zero) { c_api.TF_DeleteTensor(handle); } @@ -417,4 +532,4 @@ namespace Tensorflow public int tensor_int_val { get; set; } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 807dc6f5..37f1ca61 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System; +using System.Numerics; +using NumSharp.Backends; namespace Tensorflow { @@ -23,35 +25,100 @@ namespace Tensorflow public static TF_DataType int8 = TF_DataType.TF_INT8; public static TF_DataType int32 = TF_DataType.TF_INT32; public static TF_DataType int64 = TF_DataType.TF_INT64; + public static TF_DataType uint8 = TF_DataType.TF_UINT8; + public static TF_DataType uint32 = TF_DataType.TF_UINT32; + public static TF_DataType uint64 = TF_DataType.TF_UINT64; public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float64 = TF_DataType.TF_DOUBLE; - public static Type as_numpy_datatype(this TF_DataType type) + /// + /// + /// + /// + /// equivalent to , if none exists, returns null. + public static Type as_numpy_dtype(this TF_DataType type) { switch (type) { case TF_DataType.TF_BOOL: return typeof(bool); + case TF_DataType.TF_UINT8: + return typeof(byte); case TF_DataType.TF_INT64: return typeof(long); + case TF_DataType.TF_UINT64: + return typeof(ulong); case TF_DataType.TF_INT32: return typeof(int); + case TF_DataType.TF_UINT32: + return typeof(uint); case TF_DataType.TF_INT16: return typeof(short); + case TF_DataType.TF_UINT16: + return typeof(ushort); case TF_DataType.TF_FLOAT: return typeof(float); case TF_DataType.TF_DOUBLE: return typeof(double); case TF_DataType.TF_STRING: return typeof(string); + case TF_DataType.TF_COMPLEX128: + case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX + return typeof(Complex); default: return null; } } - // "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex" - public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null) + /// + /// + /// + /// + /// + /// When has no equivalent + public static NPTypeCode as_numpy_typecode(this TF_DataType type) + { + switch (type) + { + case TF_DataType.TF_BOOL: + return NPTypeCode.Boolean; + case TF_DataType.TF_UINT8: + return NPTypeCode.Byte; + case TF_DataType.TF_INT64: + return NPTypeCode.Int64; + case TF_DataType.TF_INT32: + return NPTypeCode.Int32; + case TF_DataType.TF_INT16: + return NPTypeCode.Int16; + case TF_DataType.TF_UINT64: + return NPTypeCode.UInt64; + case TF_DataType.TF_UINT32: + return NPTypeCode.UInt32; + case TF_DataType.TF_UINT16: + return NPTypeCode.UInt16; + case TF_DataType.TF_FLOAT: + return NPTypeCode.Single; + case TF_DataType.TF_DOUBLE: + return NPTypeCode.Double; + case TF_DataType.TF_STRING: + return NPTypeCode.String; + case TF_DataType.TF_COMPLEX128: + case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX + return NPTypeCode.Complex; + default: + throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); + } + } + + /// + /// + /// + /// + /// + /// + /// When has no equivalent + public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null) { switch (type.Name) { @@ -98,7 +165,7 @@ namespace Tensorflow dtype = TF_DataType.TF_BOOL; break; default: - throw new Exception("as_dtype Not Implemented"); + throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); } return dtype.Value; @@ -106,16 +173,7 @@ namespace Tensorflow public static DataType as_datatype_enum(this TF_DataType type) { - DataType dtype = DataType.DtInvalid; - - switch (type) - { - default: - Enum.TryParse(((int)type).ToString(), out dtype); - break; - } - - return dtype; + return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid; } public static TF_DataType as_base_dtype(this TF_DataType type) @@ -132,7 +190,7 @@ namespace Tensorflow public static Type as_numpy_dtype(this DataType type) { - return type.as_tf_dtype().as_numpy_datatype(); + return type.as_tf_dtype().as_numpy_dtype(); } public static DataType as_base_dtype(this DataType type) @@ -144,16 +202,7 @@ namespace Tensorflow public static TF_DataType as_tf_dtype(this DataType type) { - TF_DataType dtype = TF_DataType.DtInvalid; - - switch (type) - { - default: - Enum.TryParse(((int)type).ToString(), out dtype); - break; - } - - return dtype; + return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid; } public static TF_DataType as_ref(this TF_DataType type) diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index ded105c7..43848da6 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -17,6 +17,7 @@ using NumSharp; using System; using System.Linq; +using NumSharp.Utilities; namespace Tensorflow { @@ -109,7 +110,7 @@ namespace Tensorflow // We first convert value to a numpy array or scalar. NDArray nparray = null; - var np_dt = dtype.as_numpy_datatype(); + var np_dt = dtype.as_numpy_dtype(); if (values is NDArray nd) { @@ -188,37 +189,37 @@ namespace Tensorflow if (values.GetType().IsArray) nparray = np.array((int[])values, np_dt); else - nparray = Convert.ToInt32(values); + nparray = Converts.ToInt32(values); break; case "Int64": if (values.GetType().IsArray) nparray = np.array((int[])values, np_dt); else - nparray = Convert.ToInt64(values); + nparray = Converts.ToInt64(values); break; case "Single": if (values.GetType().IsArray) nparray = np.array((float[])values, np_dt); else - nparray = Convert.ToSingle(values); + nparray = Converts.ToSingle(values); break; case "Double": if (values.GetType().IsArray) nparray = np.array((double[])values, np_dt); else - nparray = Convert.ToDouble(values); + nparray = Converts.ToDouble(values); break; case "String": if (values.GetType().IsArray) nparray = np.array((string[])values, np_dt); else - nparray = NDArray.FromString(Convert.ToString(values)); + nparray = NDArray.FromString(Converts.ToString(values)); break; case "Boolean": if (values.GetType().IsArray) nparray = np.array((bool[])values, np_dt); else - nparray = Convert.ToBoolean(values); + nparray = Converts.ToBoolean(values); break; default: throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index 17b095a4..c5a06433 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -29,55 +29,111 @@ namespace Tensorflow /// public class GraphKeys { + #region const + + + /// + /// the subset of `Variable` objects that will be trained by an optimizer. + /// + public const string TRAINABLE_VARIABLES_ = "trainable_variables"; + + /// + /// Trainable resource-style variables. + /// + public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables"; + + /// + /// Key for streaming model ports. + /// + public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports"; + + /// + /// Key to collect losses + /// + public const string LOSSES_ = "losses"; + + /// + /// Key to collect Variable objects that are global (shared across machines). + /// Default collection for all variables, except local ones. + /// + public const string GLOBAL_VARIABLES_ = "variables"; + + public const string TRAIN_OP_ = "train_op"; + + public const string GLOBAL_STEP_ = "global_step"; + + public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" }; + /// + /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. + /// + public const string SAVEABLE_OBJECTS_ = "saveable_objects"; + /// + /// Key to collect update_ops + /// + public const string UPDATE_OPS_ = "update_ops"; + + // Key to collect summaries. + public const string SUMMARIES_ = "summaries"; + + // Used to store v2 summary names. + public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2"; + + // Key for control flow context. + public const string COND_CONTEXT_ = "cond_context"; + public const string WHILE_CONTEXT_ = "while_context"; + + #endregion + + /// /// the subset of `Variable` objects that will be trained by an optimizer. /// - public string TRAINABLE_VARIABLES = "trainable_variables"; + public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_; /// /// Trainable resource-style variables. /// - public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; + public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_; /// /// Key for streaming model ports. /// - public string _STREAMING_MODEL_PORTS = "streaming_model_ports"; + public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_; /// /// Key to collect losses /// - public string LOSSES = "losses"; + public string LOSSES => LOSSES_; /// /// Key to collect Variable objects that are global (shared across machines). /// Default collection for all variables, except local ones. /// - public string GLOBAL_VARIABLES = "variables"; + public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; - public string TRAIN_OP = "train_op"; + public string TRAIN_OP => TRAIN_OP_; - public string GLOBAL_STEP = "global_step"; + public string GLOBAL_STEP => GLOBAL_STEP_; - public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; + public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; /// /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. /// - public string SAVEABLE_OBJECTS = "saveable_objects"; + public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_; /// /// Key to collect update_ops /// - public string UPDATE_OPS = "update_ops"; + public string UPDATE_OPS => UPDATE_OPS_; // Key to collect summaries. - public string SUMMARIES = "summaries"; + public string SUMMARIES => SUMMARIES_; // Used to store v2 summary names. - public string _SUMMARY_COLLECTION = "_SUMMARY_V2"; + public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_; // Key for control flow context. - public string COND_CONTEXT = "cond_context"; - public string WHILE_CONTEXT = "while_context"; + public string COND_CONTEXT => COND_CONTEXT_; + public string WHILE_CONTEXT => WHILE_CONTEXT_; } } } diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index 9c8485ec..8fd4dc8a 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); EXPECT_EQ(0, outTensor.NDims); ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - var output_contents = outTensor.Data(); + var output_contents = outTensor.ToArray(); EXPECT_EQ(3 + 2, output_contents[0]); // Add another operation to the graph. @@ -66,7 +66,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); EXPECT_EQ(0, outTensor.NDims); // scalar ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - output_contents = outTensor.Data(); + output_contents = outTensor.ToArray(); EXPECT_EQ(-(7 + 2), output_contents[0]); // Clean up diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 07da9dca..11557f14 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -112,7 +112,7 @@ namespace TensorFlowNET.UnitTest var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); var tensor = new Tensor(nd); - var array = tensor.Data(); + var array = tensor.ToArray(); EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); EXPECT_EQ(tensor.rank, nd.ndim); diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs index 1fd7d3aa..3a5515d9 100644 --- a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs +++ b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs @@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.nn_test var y_np = this._ZeroFraction(x_np); var x_tf = constant_op.constant(x_np); - x_tf.SetShape(x_shape); + x_tf.set_shape(x_shape); var y_tf = nn_impl.zero_fraction(x_tf); var y_tf_np = self.evaluate(y_tf);