- Cleanup and docs to all Tensor.cs files - Changed all uses of System.Convert to NumSharp.Utilities.Converts - Added all missing types in dtypes.cs - Renamed tensor.Data<T> to tensor.ToArray<T>, added obsolete message - Renamed tensor.Data() to tensor.BufferToArray(), added obsolete message - Made GraphKeys to use const string instead allocating strings at every use of GraphKeys.tags/v0.12
| @@ -141,7 +141,7 @@ namespace Tensorflow.Operations | |||||
| data, frame_name, is_constant, parallel_iterations, name: name); | data, frame_name, is_constant, parallel_iterations, name: name); | ||||
| if (use_input_shape) | if (use_input_shape) | ||||
| result.SetShape(data.TensorShape); | |||||
| result.set_shape(data.TensorShape); | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -233,7 +233,7 @@ namespace Tensorflow.Operations | |||||
| dims.AddRange(x_static_shape.dims.Skip(2)); | dims.AddRange(x_static_shape.dims.Skip(2)); | ||||
| var shape = new TensorShape(dims.ToArray()); | var shape = new TensorShape(dims.ToArray()); | ||||
| x_t.SetShape(shape); | |||||
| x_t.set_shape(shape); | |||||
| return x_t; | return x_t; | ||||
| } | } | ||||
| @@ -351,7 +351,7 @@ namespace Tensorflow | |||||
| var input_shape = tensor_util.to_shape(input_tensor.shape); | var input_shape = tensor_util.to_shape(input_tensor.shape); | ||||
| if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | ||||
| { | { | ||||
| var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype()); | |||||
| var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); | |||||
| return constant_op.constant(nd, name: name); | return constant_op.constant(nd, name: name); | ||||
| } | } | ||||
| } | } | ||||
| @@ -98,7 +98,7 @@ namespace Tensorflow | |||||
| // float to be selected, hence we use a >= comparison. | // float to be selected, hence we use a >= comparison. | ||||
| var keep_mask = random_tensor >= rate; | var keep_mask = random_tensor >= rate; | ||||
| var ret = x * scale * math_ops.cast(keep_mask, x.dtype); | var ret = x * scale * math_ops.cast(keep_mask, x.dtype); | ||||
| ret.SetShape(x.TensorShape); | |||||
| ret.set_shape(x.TensorShape); | |||||
| return ret; | return ret; | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -273,7 +273,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| var tensor = new Tensor(output); | var tensor = new Tensor(output); | ||||
| NDArray nd = null; | NDArray nd = null; | ||||
| Type type = tensor.dtype.as_numpy_datatype(); | |||||
| Type type = tensor.dtype.as_numpy_dtype(); | |||||
| var ndims = tensor.shape; | var ndims = tensor.shape; | ||||
| var offset = c_api.TF_TensorData(output); | var offset = c_api.TF_TensorData(output); | ||||
| @@ -285,7 +285,7 @@ namespace Tensorflow | |||||
| nd = NDArray.Scalar(*(bool*)offset); | nd = NDArray.Scalar(*(bool*)offset); | ||||
| break; | break; | ||||
| case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
| var bytes = tensor.Data(); | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | // wired, don't know why we have to start from offset 9. | ||||
| // length in the begin | // length in the begin | ||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | ||||
| @@ -324,7 +324,7 @@ namespace Tensorflow | |||||
| nd = np.array(bools).reshape(ndims); | nd = np.array(bools).reshape(ndims); | ||||
| break; | break; | ||||
| case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
| var bytes = tensor.Data(); | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | // wired, don't know why we have to start from offset 9. | ||||
| // length in the begin | // length in the begin | ||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | ||||
| @@ -549,10 +549,11 @@ namespace Tensorflow | |||||
| this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it | this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it | ||||
| var ptr = new IntPtr(arraySlice.Address); | var ptr = new IntPtr(arraySlice.Address); | ||||
| int num_bytes = (nd.size * nd.dtypesize); | int num_bytes = (nd.size * nd.dtypesize); | ||||
| var dtype = given_dtype ?? ToTFDataType(nd.dtype); | |||||
| var dtype = given_dtype ?? nd.dtype.as_dtype(); | |||||
| var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); | var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); | ||||
| IsMemoryOwner = false; | IsMemoryOwner = false; | ||||
| return handle; | return handle; | ||||
| } | } | ||||
| public unsafe Tensor(byte[][] buffer, long[] shape) | public unsafe Tensor(byte[][] buffer, long[] shape) | ||||
| @@ -17,9 +17,16 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Globalization; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading.Tasks; | |||||
| using NumSharp.Backends; | |||||
| using NumSharp.Backends.Unmanaged; | |||||
| using NumSharp.Utilities; | |||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -29,42 +36,68 @@ namespace Tensorflow | |||||
| /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | ||||
| /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | ||||
| /// </summary> | /// </summary> | ||||
| [SuppressMessage("ReSharper", "ConvertToAutoProperty")] | |||||
| public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike | public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike | ||||
| { | { | ||||
| private int _id; | |||||
| private Operation _op; | |||||
| private readonly int _id; | |||||
| private readonly Operation _op; | |||||
| private readonly int _value_index; | |||||
| private TF_Output? _tf_output; | |||||
| private readonly TF_DataType _dtype; | |||||
| public int Id => _id; | public int Id => _id; | ||||
| /// <summary> | |||||
| /// The Graph that contains this tensor. | |||||
| /// </summary> | |||||
| public Graph graph => op?.graph; | public Graph graph => op?.graph; | ||||
| /// <summary> | |||||
| /// The Operation that produces this tensor as an output. | |||||
| /// </summary> | |||||
| public Operation op => _op; | public Operation op => _op; | ||||
| public Tensor[] outputs => op.outputs; | public Tensor[] outputs => op.outputs; | ||||
| /// <summary> | /// <summary> | ||||
| /// The string name of this tensor. | |||||
| /// The string name of this tensor. | |||||
| /// </summary> | /// </summary> | ||||
| public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}"; | public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}"; | ||||
| private int _value_index; | |||||
| /// <summary> | |||||
| /// The index of this tensor in the outputs of its Operation. | |||||
| /// </summary> | |||||
| public int value_index => _value_index; | public int value_index => _value_index; | ||||
| private TF_DataType _dtype = TF_DataType.DtInvalid; | |||||
| /// <summary> | |||||
| /// The DType of elements in this tensor. | |||||
| /// </summary> | |||||
| public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | ||||
| public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | ||||
| public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | ||||
| public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | ||||
| public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | ||||
| public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | ||||
| public int NDims => rank; | |||||
| private TF_Output? _tf_output; | |||||
| /// <summary> | |||||
| /// The name of the device on which this tensor will be produced, or null. | |||||
| /// </summary> | |||||
| public string Device => op.Device; | |||||
| public int[] dims => shape; | |||||
| /// <summary> | /// <summary> | ||||
| /// used for keep other pointer when do implicit operating | |||||
| /// Used for keep other pointer when do implicit operating | |||||
| /// </summary> | /// </summary> | ||||
| public object Tag { get; set; } | public object Tag { get; set; } | ||||
| /// <summary> | |||||
| /// Returns the shape of a tensor. | |||||
| /// </summary> | |||||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/shape</remarks> | |||||
| public int[] shape | public int[] shape | ||||
| { | { | ||||
| get | get | ||||
| @@ -76,14 +109,13 @@ namespace Tensorflow | |||||
| var status = new Status(); | var status = new Status(); | ||||
| c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | ||||
| status.Check(); | status.Check(); | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| for (int i = 0; i < rank; i++) | for (int i = 0; i < rank; i++) | ||||
| dims[i] = c_api.TF_Dim(_handle, i); | dims[i] = c_api.TF_Dim(_handle, i); | ||||
| } | } | ||||
| return dims.Select(x => Convert.ToInt32(x)).ToArray(); | |||||
| return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray(); | |||||
| } | } | ||||
| set | set | ||||
| @@ -93,38 +125,52 @@ namespace Tensorflow | |||||
| if (value == null) | if (value == null) | ||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | ||||
| else | else | ||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status); | |||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
| } | } | ||||
| } | } | ||||
| public int[] _shape_tuple() | public int[] _shape_tuple() | ||||
| { | { | ||||
| if (shape == null) return null; | |||||
| return shape.Select(x => (int)x).ToArray(); | |||||
| return (int[]) shape.Clone(); | |||||
| } | } | ||||
| public TensorShape TensorShape => tensor_util.to_shape(shape); | public TensorShape TensorShape => tensor_util.to_shape(shape); | ||||
| public void SetShape(TensorShape shape) | |||||
| /// <summary> | |||||
| /// Updates the shape of this tensor. | |||||
| /// </summary> | |||||
| public void set_shape(TensorShape shape) | |||||
| { | |||||
| this.shape = (int[]) shape.dims.Clone(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Updates the shape of this tensor. | |||||
| /// </summary> | |||||
| [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||||
| public void SetShape(TensorShape shape) | |||||
| { | { | ||||
| this.shape = shape.dims; | |||||
| this.shape = (int[]) shape.dims.Clone(); | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Updates the shape of this tensor. | |||||
| /// </summary> | |||||
| public void set_shape(Tensor shape) | public void set_shape(Tensor shape) | ||||
| { | { | ||||
| // ReSharper disable once MergeConditionalExpression | |||||
| this.shape = shape is null ? null : shape.shape; | this.shape = shape is null ? null : shape.shape; | ||||
| } | } | ||||
| public int[] dims => shape; | |||||
| /// <summary> | /// <summary> | ||||
| /// number of dimensions | |||||
| /// 0 Scalar (magnitude only) | |||||
| /// 1 Vector (magnitude and direction) | |||||
| /// 2 Matrix (table of numbers) | |||||
| /// 3 3-Tensor (cube of numbers) | |||||
| /// number of dimensions <br></br> | |||||
| /// 0 Scalar (magnitude only) <br></br> | |||||
| /// 1 Vector (magnitude and direction) <br></br> | |||||
| /// 2 Matrix (table of numbers) <br></br> | |||||
| /// 3 3-Tensor (cube of numbers) <br></br> | |||||
| /// n n-Tensor (you get the idea) | /// n n-Tensor (you get the idea) | ||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/rank</remarks> | |||||
| public int rank | public int rank | ||||
| { | { | ||||
| get | get | ||||
| @@ -137,17 +183,15 @@ namespace Tensorflow | |||||
| status.Check(); | status.Check(); | ||||
| return ndim; | return ndim; | ||||
| } | } | ||||
| else | |||||
| { | |||||
| return c_api.TF_NumDims(_handle); | |||||
| } | |||||
| return c_api.TF_NumDims(_handle); | |||||
| } | } | ||||
| } | } | ||||
| public int NDims => rank; | |||||
| public string Device => op.Device; | |||||
| /// <summary> | |||||
| /// Returns a list of Operations that consume this tensor. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public Operation[] consumers() | public Operation[] consumers() | ||||
| { | { | ||||
| var output = _as_tf_output(); | var output = _as_tf_output(); | ||||
| @@ -157,37 +201,136 @@ namespace Tensorflow | |||||
| public TF_Output _as_tf_output() | public TF_Output _as_tf_output() | ||||
| { | { | ||||
| if(!_tf_output.HasValue) | |||||
| if (!_tf_output.HasValue) | |||||
| _tf_output = new TF_Output(op, value_index); | _tf_output = new TF_Output(op, value_index); | ||||
| return _tf_output.Value; | return _tf_output.Value; | ||||
| } | } | ||||
| public T[] Data<T>() | |||||
| [Obsolete("Please use ToArray<T>() instead.", false)] | |||||
| public T[] Data<T>() where T : unmanaged | |||||
| { | |||||
| return ToArray<T>(); | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception> | |||||
| public T[] ToArray<T>() where T : unmanaged | |||||
| { | { | ||||
| // Column major order | |||||
| // https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg | |||||
| // matrix:[[1, 2, 3], [4, 5, 6]] | |||||
| // index: 0 2 4 1 3 5 | |||||
| // result: 1 4 2 5 3 6 | |||||
| var data = new T[size]; | |||||
| for (ulong i = 0; i < size; i++) | |||||
| //when T is string | |||||
| if (typeof(T) == typeof(string)) | |||||
| { | { | ||||
| data[i] = Marshal.PtrToStructure<T>(buffer + (int)(i * itemsize)); | |||||
| if (dtype != TF_DataType.TF_STRING) | |||||
| throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string."); | |||||
| return (T[]) (object) StringData(); | |||||
| } | } | ||||
| return data; | |||||
| //Are the types matching? | |||||
| if (typeof(T).as_dtype() == _dtype) | |||||
| { | |||||
| //types match, no need to perform cast | |||||
| var ret = new T[size]; | |||||
| unsafe | |||||
| { | |||||
| var len = (long) size; | |||||
| fixed (T* dstRet = ret) | |||||
| { | |||||
| T* dst = dstRet; //local stack copy | |||||
| if (typeof(T).IsPrimitive) | |||||
| { | |||||
| var src = (T*) buffer; | |||||
| len *= ((long) itemsize); | |||||
| System.Buffer.MemoryCopy(src, dst, len, len); | |||||
| } else | |||||
| { | |||||
| var itemsize = (long) this.itemsize; | |||||
| var buffer = this.buffer.ToInt64(); | |||||
| Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure<T>(new IntPtr(buffer + i * itemsize))); | |||||
| } | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } else | |||||
| { | |||||
| //types do not match, need to perform cast | |||||
| var ret = new T[size]; | |||||
| unsafe | |||||
| { | |||||
| var len = (long) size; | |||||
| fixed (T* dstRet = ret) | |||||
| { | |||||
| T* dst = dstRet; //local stack copy | |||||
| #if _REGEN | |||||
| #region Compute | |||||
| switch (_dtype.as_numpy_datatype().GetTypeCode()) | |||||
| { | |||||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||||
| case NPTypeCode.#1:new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| % | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #else | |||||
| #region Compute | |||||
| switch (_dtype.as_numpy_dtype().GetTypeCode()) | |||||
| { | |||||
| case NPTypeCode.Boolean:new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Byte:new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int16:new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt16:new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int32:new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt32:new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int64:new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt64:new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Char:new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Double:new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Single:new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Copies the memory of current buffer onto newly allocated array. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||||
| public byte[] Data() | public byte[] Data() | ||||
| { | |||||
| return BufferToArray(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Copies the memory of current buffer onto newly allocated array. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public byte[] BufferToArray() | |||||
| { | { | ||||
| var data = new byte[bytesize]; | var data = new byte[bytesize]; | ||||
| Marshal.Copy(buffer, data, 0, (int)bytesize); | |||||
| Marshal.Copy(buffer, data, 0, (int) bytesize); | |||||
| return data; | return data; | ||||
| } | } | ||||
| public unsafe string[] StringData() | |||||
| /// Used internally in ToArray<T> | |||||
| private unsafe string[] StringData() | |||||
| { | { | ||||
| // | // | ||||
| // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | ||||
| @@ -199,19 +342,19 @@ namespace Tensorflow | |||||
| var buffer = new byte[size][]; | var buffer = new byte[size][]; | ||||
| var src = c_api.TF_TensorData(_handle); | var src = c_api.TF_TensorData(_handle); | ||||
| var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | |||||
| src += (int)(size * 8); | |||||
| var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize); | |||||
| src += (int) (size * 8); | |||||
| for (int i = 0; i < buffer.Length; i++) | for (int i = 0; i < buffer.Length; i++) | ||||
| { | { | ||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| { | { | ||||
| IntPtr dst = IntPtr.Zero; | IntPtr dst = IntPtr.Zero; | ||||
| UIntPtr dstLen = UIntPtr.Zero; | UIntPtr dstLen = UIntPtr.Zero; | ||||
| var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); | |||||
| var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status); | |||||
| status.Check(true); | status.Check(true); | ||||
| buffer[i] = new byte[(int)dstLen]; | |||||
| buffer[i] = new byte[(int) dstLen]; | |||||
| Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | ||||
| src += (int)read; | |||||
| src += (int) read; | |||||
| } | } | ||||
| } | } | ||||
| @@ -229,51 +372,29 @@ namespace Tensorflow | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Evaluates this tensor in a `Session`. | |||||
| /// Evaluates this tensor in a `Session`. | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | ||||
| /// <param name="session">The `Session` to be used to evaluate this tensor.</param> | |||||
| /// <returns></returns> | |||||
| /// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | |||||
| public NDArray eval(params FeedItem[] feed_dict) | public NDArray eval(params FeedItem[] feed_dict) | ||||
| { | { | ||||
| return ops._eval_using_default_session(this, feed_dict, graph); | return ops._eval_using_default_session(this, feed_dict, graph); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Evaluates this tensor in a `Session`. | |||||
| /// </summary> | |||||
| /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | |||||
| /// <param name="session">The `Session` to be used to evaluate this tensor.</param> | |||||
| /// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | |||||
| public NDArray eval(Session session, FeedItem[] feed_dict = null) | public NDArray eval(Session session, FeedItem[] feed_dict = null) | ||||
| { | { | ||||
| return ops._eval_using_default_session(this, feed_dict, graph, session); | return ops._eval_using_default_session(this, feed_dict, graph, session); | ||||
| } | } | ||||
| public TF_DataType ToTFDataType(Type type) | |||||
| { | |||||
| switch (type.Name) | |||||
| { | |||||
| case "Char": | |||||
| return TF_DataType.TF_UINT8; | |||||
| case "Int16": | |||||
| return TF_DataType.TF_INT16; | |||||
| case "Int32": | |||||
| return TF_DataType.TF_INT32; | |||||
| case "Int64": | |||||
| return TF_DataType.TF_INT64; | |||||
| case "Single": | |||||
| return TF_DataType.TF_FLOAT; | |||||
| case "Double": | |||||
| return TF_DataType.TF_DOUBLE; | |||||
| case "Byte": | |||||
| return TF_DataType.TF_UINT8; | |||||
| case "String": | |||||
| return TF_DataType.TF_STRING; | |||||
| case "Boolean": | |||||
| return TF_DataType.TF_BOOL; | |||||
| default: | |||||
| throw new NotImplementedException("ToTFDataType error"); | |||||
| } | |||||
| } | |||||
| public Tensor slice(Slice slice) | public Tensor slice(Slice slice) | ||||
| { | { | ||||
| var slice_spec = new int[] { slice.Start.Value }; | |||||
| var slice_spec = new int[] {slice.Start.Value}; | |||||
| var begin = new List<int>(); | var begin = new List<int>(); | ||||
| var end = new List<int>(); | var end = new List<int>(); | ||||
| var strides = new List<int>(); | var strides = new List<int>(); | ||||
| @@ -289,26 +410,26 @@ namespace Tensorflow | |||||
| if (slice.Stop.HasValue) | if (slice.Stop.HasValue) | ||||
| { | { | ||||
| end.Add(slice.Stop.Value); | end.Add(slice.Stop.Value); | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| end.Add(0); | end.Add(0); | ||||
| end_mask |= (1 << index); | end_mask |= (1 << index); | ||||
| } | } | ||||
| strides.Add(slice.Step); | strides.Add(slice.Step); | ||||
| index += 1; | index += 1; | ||||
| } | } | ||||
| return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||||
| return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => | |||||
| { | { | ||||
| string name = scope; | string name = scope; | ||||
| if (begin != null) | if (begin != null) | ||||
| { | { | ||||
| var (packed_begin, packed_end, packed_strides) = | var (packed_begin, packed_end, packed_strides) = | ||||
| (array_ops.stack(begin.ToArray()), | (array_ops.stack(begin.ToArray()), | ||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| return gen_array_ops.strided_slice( | return gen_array_ops.strided_slice( | ||||
| this, | this, | ||||
| @@ -320,7 +441,6 @@ namespace Tensorflow | |||||
| shrink_axis_mask: shrink_axis_mask, | shrink_axis_mask: shrink_axis_mask, | ||||
| new_axis_mask: new_axis_mask, | new_axis_mask: new_axis_mask, | ||||
| ellipsis_mask: ellipsis_mask, | ellipsis_mask: ellipsis_mask, | ||||
| name: name); | name: name); | ||||
| } | } | ||||
| @@ -330,7 +450,7 @@ namespace Tensorflow | |||||
| public Tensor slice(int start) | public Tensor slice(int start) | ||||
| { | { | ||||
| var slice_spec = new int[] { start }; | |||||
| var slice_spec = new int[] {start}; | |||||
| var begin = new List<int>(); | var begin = new List<int>(); | ||||
| var end = new List<int>(); | var end = new List<int>(); | ||||
| var strides = new List<int>(); | var strides = new List<int>(); | ||||
| @@ -349,15 +469,15 @@ namespace Tensorflow | |||||
| index += 1; | index += 1; | ||||
| } | } | ||||
| return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||||
| return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => | |||||
| { | { | ||||
| string name = scope; | string name = scope; | ||||
| if (begin != null) | if (begin != null) | ||||
| { | { | ||||
| var (packed_begin, packed_end, packed_strides) = | var (packed_begin, packed_end, packed_strides) = | ||||
| (array_ops.stack(begin.ToArray()), | (array_ops.stack(begin.ToArray()), | ||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| return gen_array_ops.strided_slice( | return gen_array_ops.strided_slice( | ||||
| this, | this, | ||||
| @@ -369,7 +489,6 @@ namespace Tensorflow | |||||
| shrink_axis_mask: shrink_axis_mask, | shrink_axis_mask: shrink_axis_mask, | ||||
| new_axis_mask: new_axis_mask, | new_axis_mask: new_axis_mask, | ||||
| ellipsis_mask: ellipsis_mask, | ellipsis_mask: ellipsis_mask, | ||||
| name: name); | name: name); | ||||
| } | } | ||||
| @@ -392,13 +511,9 @@ namespace Tensorflow | |||||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | ||||
| } | } | ||||
| //protected override void DisposeManagedState() | |||||
| //{ | |||||
| //} | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
| { | { | ||||
| if(handle != IntPtr.Zero) | |||||
| if (handle != IntPtr.Zero) | |||||
| { | { | ||||
| c_api.TF_DeleteTensor(handle); | c_api.TF_DeleteTensor(handle); | ||||
| } | } | ||||
| @@ -417,4 +532,4 @@ namespace Tensorflow | |||||
| public int tensor_int_val { get; set; } | public int tensor_int_val { get; set; } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -15,6 +15,8 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Numerics; | |||||
| using NumSharp.Backends; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -23,35 +25,100 @@ namespace Tensorflow | |||||
| public static TF_DataType int8 = TF_DataType.TF_INT8; | public static TF_DataType int8 = TF_DataType.TF_INT8; | ||||
| public static TF_DataType int32 = TF_DataType.TF_INT32; | public static TF_DataType int32 = TF_DataType.TF_INT32; | ||||
| public static TF_DataType int64 = TF_DataType.TF_INT64; | public static TF_DataType int64 = TF_DataType.TF_INT64; | ||||
| public static TF_DataType uint8 = TF_DataType.TF_UINT8; | |||||
| public static TF_DataType uint32 = TF_DataType.TF_UINT32; | |||||
| public static TF_DataType uint64 = TF_DataType.TF_UINT64; | |||||
| public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | ||||
| public static TF_DataType float16 = TF_DataType.TF_HALF; | public static TF_DataType float16 = TF_DataType.TF_HALF; | ||||
| public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | ||||
| public static Type as_numpy_datatype(this TF_DataType type) | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| /// <returns><see cref="System.Type"/> equivalent to <paramref name="type"/>, if none exists, returns null.</returns> | |||||
| public static Type as_numpy_dtype(this TF_DataType type) | |||||
| { | { | ||||
| switch (type) | switch (type) | ||||
| { | { | ||||
| case TF_DataType.TF_BOOL: | case TF_DataType.TF_BOOL: | ||||
| return typeof(bool); | return typeof(bool); | ||||
| case TF_DataType.TF_UINT8: | |||||
| return typeof(byte); | |||||
| case TF_DataType.TF_INT64: | case TF_DataType.TF_INT64: | ||||
| return typeof(long); | return typeof(long); | ||||
| case TF_DataType.TF_UINT64: | |||||
| return typeof(ulong); | |||||
| case TF_DataType.TF_INT32: | case TF_DataType.TF_INT32: | ||||
| return typeof(int); | return typeof(int); | ||||
| case TF_DataType.TF_UINT32: | |||||
| return typeof(uint); | |||||
| case TF_DataType.TF_INT16: | case TF_DataType.TF_INT16: | ||||
| return typeof(short); | return typeof(short); | ||||
| case TF_DataType.TF_UINT16: | |||||
| return typeof(ushort); | |||||
| case TF_DataType.TF_FLOAT: | case TF_DataType.TF_FLOAT: | ||||
| return typeof(float); | return typeof(float); | ||||
| case TF_DataType.TF_DOUBLE: | case TF_DataType.TF_DOUBLE: | ||||
| return typeof(double); | return typeof(double); | ||||
| case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
| return typeof(string); | return typeof(string); | ||||
| case TF_DataType.TF_COMPLEX128: | |||||
| case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||||
| return typeof(Complex); | |||||
| default: | default: | ||||
| return null; | return null; | ||||
| } | } | ||||
| } | } | ||||
| // "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex" | |||||
| public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null) | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="NPTypeCode"/></exception> | |||||
| public static NPTypeCode as_numpy_typecode(this TF_DataType type) | |||||
| { | |||||
| switch (type) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| return NPTypeCode.Boolean; | |||||
| case TF_DataType.TF_UINT8: | |||||
| return NPTypeCode.Byte; | |||||
| case TF_DataType.TF_INT64: | |||||
| return NPTypeCode.Int64; | |||||
| case TF_DataType.TF_INT32: | |||||
| return NPTypeCode.Int32; | |||||
| case TF_DataType.TF_INT16: | |||||
| return NPTypeCode.Int16; | |||||
| case TF_DataType.TF_UINT64: | |||||
| return NPTypeCode.UInt64; | |||||
| case TF_DataType.TF_UINT32: | |||||
| return NPTypeCode.UInt32; | |||||
| case TF_DataType.TF_UINT16: | |||||
| return NPTypeCode.UInt16; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| return NPTypeCode.Single; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| return NPTypeCode.Double; | |||||
| case TF_DataType.TF_STRING: | |||||
| return NPTypeCode.String; | |||||
| case TF_DataType.TF_COMPLEX128: | |||||
| case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||||
| return NPTypeCode.Complex; | |||||
| default: | |||||
| throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> | |||||
| public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null) | |||||
| { | { | ||||
| switch (type.Name) | switch (type.Name) | ||||
| { | { | ||||
| @@ -98,7 +165,7 @@ namespace Tensorflow | |||||
| dtype = TF_DataType.TF_BOOL; | dtype = TF_DataType.TF_BOOL; | ||||
| break; | break; | ||||
| default: | default: | ||||
| throw new Exception("as_dtype Not Implemented"); | |||||
| throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||||
| } | } | ||||
| return dtype.Value; | return dtype.Value; | ||||
| @@ -106,16 +173,7 @@ namespace Tensorflow | |||||
| public static DataType as_datatype_enum(this TF_DataType type) | public static DataType as_datatype_enum(this TF_DataType type) | ||||
| { | { | ||||
| DataType dtype = DataType.DtInvalid; | |||||
| switch (type) | |||||
| { | |||||
| default: | |||||
| Enum.TryParse(((int)type).ToString(), out dtype); | |||||
| break; | |||||
| } | |||||
| return dtype; | |||||
| return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid; | |||||
| } | } | ||||
| public static TF_DataType as_base_dtype(this TF_DataType type) | public static TF_DataType as_base_dtype(this TF_DataType type) | ||||
| @@ -132,7 +190,7 @@ namespace Tensorflow | |||||
| public static Type as_numpy_dtype(this DataType type) | public static Type as_numpy_dtype(this DataType type) | ||||
| { | { | ||||
| return type.as_tf_dtype().as_numpy_datatype(); | |||||
| return type.as_tf_dtype().as_numpy_dtype(); | |||||
| } | } | ||||
| public static DataType as_base_dtype(this DataType type) | public static DataType as_base_dtype(this DataType type) | ||||
| @@ -144,16 +202,7 @@ namespace Tensorflow | |||||
| public static TF_DataType as_tf_dtype(this DataType type) | public static TF_DataType as_tf_dtype(this DataType type) | ||||
| { | { | ||||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||||
| switch (type) | |||||
| { | |||||
| default: | |||||
| Enum.TryParse(((int)type).ToString(), out dtype); | |||||
| break; | |||||
| } | |||||
| return dtype; | |||||
| return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid; | |||||
| } | } | ||||
| public static TF_DataType as_ref(this TF_DataType type) | public static TF_DataType as_ref(this TF_DataType type) | ||||
| @@ -17,6 +17,7 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using NumSharp.Utilities; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -109,7 +110,7 @@ namespace Tensorflow | |||||
| // We first convert value to a numpy array or scalar. | // We first convert value to a numpy array or scalar. | ||||
| NDArray nparray = null; | NDArray nparray = null; | ||||
| var np_dt = dtype.as_numpy_datatype(); | |||||
| var np_dt = dtype.as_numpy_dtype(); | |||||
| if (values is NDArray nd) | if (values is NDArray nd) | ||||
| { | { | ||||
| @@ -188,37 +189,37 @@ namespace Tensorflow | |||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((int[])values, np_dt); | nparray = np.array((int[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToInt32(values); | |||||
| nparray = Converts.ToInt32(values); | |||||
| break; | break; | ||||
| case "Int64": | case "Int64": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((int[])values, np_dt); | nparray = np.array((int[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToInt64(values); | |||||
| nparray = Converts.ToInt64(values); | |||||
| break; | break; | ||||
| case "Single": | case "Single": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((float[])values, np_dt); | nparray = np.array((float[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToSingle(values); | |||||
| nparray = Converts.ToSingle(values); | |||||
| break; | break; | ||||
| case "Double": | case "Double": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((double[])values, np_dt); | nparray = np.array((double[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToDouble(values); | |||||
| nparray = Converts.ToDouble(values); | |||||
| break; | break; | ||||
| case "String": | case "String": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((string[])values, np_dt); | nparray = np.array((string[])values, np_dt); | ||||
| else | else | ||||
| nparray = NDArray.FromString(Convert.ToString(values)); | |||||
| nparray = NDArray.FromString(Converts.ToString(values)); | |||||
| break; | break; | ||||
| case "Boolean": | case "Boolean": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((bool[])values, np_dt); | nparray = np.array((bool[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToBoolean(values); | |||||
| nparray = Converts.ToBoolean(values); | |||||
| break; | break; | ||||
| default: | default: | ||||
| throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); | throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); | ||||
| @@ -29,55 +29,111 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class GraphKeys | public class GraphKeys | ||||
| { | { | ||||
| #region const | |||||
| /// <summary> | |||||
| /// the subset of `Variable` objects that will be trained by an optimizer. | |||||
| /// </summary> | |||||
| public const string TRAINABLE_VARIABLES_ = "trainable_variables"; | |||||
| /// <summary> | |||||
| /// Trainable resource-style variables. | |||||
| /// </summary> | |||||
| public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables"; | |||||
| /// <summary> | |||||
| /// Key for streaming model ports. | |||||
| /// </summary> | |||||
| public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports"; | |||||
| /// <summary> | |||||
| /// Key to collect losses | |||||
| /// </summary> | |||||
| public const string LOSSES_ = "losses"; | |||||
| /// <summary> | |||||
| /// Key to collect Variable objects that are global (shared across machines). | |||||
| /// Default collection for all variables, except local ones. | |||||
| /// </summary> | |||||
| public const string GLOBAL_VARIABLES_ = "variables"; | |||||
| public const string TRAIN_OP_ = "train_op"; | |||||
| public const string GLOBAL_STEP_ = "global_step"; | |||||
| public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" }; | |||||
| /// <summary> | |||||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||||
| /// </summary> | |||||
| public const string SAVEABLE_OBJECTS_ = "saveable_objects"; | |||||
| /// <summary> | |||||
| /// Key to collect update_ops | |||||
| /// </summary> | |||||
| public const string UPDATE_OPS_ = "update_ops"; | |||||
| // Key to collect summaries. | |||||
| public const string SUMMARIES_ = "summaries"; | |||||
| // Used to store v2 summary names. | |||||
| public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2"; | |||||
| // Key for control flow context. | |||||
| public const string COND_CONTEXT_ = "cond_context"; | |||||
| public const string WHILE_CONTEXT_ = "while_context"; | |||||
| #endregion | |||||
| /// <summary> | /// <summary> | ||||
| /// the subset of `Variable` objects that will be trained by an optimizer. | /// the subset of `Variable` objects that will be trained by an optimizer. | ||||
| /// </summary> | /// </summary> | ||||
| public string TRAINABLE_VARIABLES = "trainable_variables"; | |||||
| public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Trainable resource-style variables. | /// Trainable resource-style variables. | ||||
| /// </summary> | /// </summary> | ||||
| public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; | |||||
| public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key for streaming model ports. | /// Key for streaming model ports. | ||||
| /// </summary> | /// </summary> | ||||
| public string _STREAMING_MODEL_PORTS = "streaming_model_ports"; | |||||
| public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect losses | /// Key to collect losses | ||||
| /// </summary> | /// </summary> | ||||
| public string LOSSES = "losses"; | |||||
| public string LOSSES => LOSSES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect Variable objects that are global (shared across machines). | /// Key to collect Variable objects that are global (shared across machines). | ||||
| /// Default collection for all variables, except local ones. | /// Default collection for all variables, except local ones. | ||||
| /// </summary> | /// </summary> | ||||
| public string GLOBAL_VARIABLES = "variables"; | |||||
| public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; | |||||
| public string TRAIN_OP = "train_op"; | |||||
| public string TRAIN_OP => TRAIN_OP_; | |||||
| public string GLOBAL_STEP = "global_step"; | |||||
| public string GLOBAL_STEP => GLOBAL_STEP_; | |||||
| public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; | |||||
| public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | ||||
| /// </summary> | /// </summary> | ||||
| public string SAVEABLE_OBJECTS = "saveable_objects"; | |||||
| public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect update_ops | /// Key to collect update_ops | ||||
| /// </summary> | /// </summary> | ||||
| public string UPDATE_OPS = "update_ops"; | |||||
| public string UPDATE_OPS => UPDATE_OPS_; | |||||
| // Key to collect summaries. | // Key to collect summaries. | ||||
| public string SUMMARIES = "summaries"; | |||||
| public string SUMMARIES => SUMMARIES_; | |||||
| // Used to store v2 summary names. | // Used to store v2 summary names. | ||||
| public string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||||
| public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_; | |||||
| // Key for control flow context. | // Key for control flow context. | ||||
| public string COND_CONTEXT = "cond_context"; | |||||
| public string WHILE_CONTEXT = "while_context"; | |||||
| public string COND_CONTEXT => COND_CONTEXT_; | |||||
| public string WHILE_CONTEXT => WHILE_CONTEXT_; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | ||||
| EXPECT_EQ(0, outTensor.NDims); | EXPECT_EQ(0, outTensor.NDims); | ||||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | ||||
| var output_contents = outTensor.Data<int>(); | |||||
| var output_contents = outTensor.ToArray<int>(); | |||||
| EXPECT_EQ(3 + 2, output_contents[0]); | EXPECT_EQ(3 + 2, output_contents[0]); | ||||
| // Add another operation to the graph. | // Add another operation to the graph. | ||||
| @@ -66,7 +66,7 @@ namespace TensorFlowNET.UnitTest | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | ||||
| EXPECT_EQ(0, outTensor.NDims); // scalar | EXPECT_EQ(0, outTensor.NDims); // scalar | ||||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | ||||
| output_contents = outTensor.Data<int>(); | |||||
| output_contents = outTensor.ToArray<int>(); | |||||
| EXPECT_EQ(-(7 + 2), output_contents[0]); | EXPECT_EQ(-(7 + 2), output_contents[0]); | ||||
| // Clean up | // Clean up | ||||
| @@ -112,7 +112,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | ||||
| var tensor = new Tensor(nd); | var tensor = new Tensor(nd); | ||||
| var array = tensor.Data<float>(); | |||||
| var array = tensor.ToArray<float>(); | |||||
| EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | ||||
| EXPECT_EQ(tensor.rank, nd.ndim); | EXPECT_EQ(tensor.rank, nd.ndim); | ||||
| @@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.nn_test | |||||
| var y_np = this._ZeroFraction(x_np); | var y_np = this._ZeroFraction(x_np); | ||||
| var x_tf = constant_op.constant(x_np); | var x_tf = constant_op.constant(x_np); | ||||
| x_tf.SetShape(x_shape); | |||||
| x_tf.set_shape(x_shape); | |||||
| var y_tf = nn_impl.zero_fraction(x_tf); | var y_tf = nn_impl.zero_fraction(x_tf); | ||||
| var y_tf_np = self.evaluate<NDArray>(y_tf); | var y_tf_np = self.evaluate<NDArray>(y_tf); | ||||