@@ -17,9 +17,16 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using NumSharp.Utilities;
using Tensorflow.Framework;
using static Tensorflow.Binding;
@@ -29,42 +36,68 @@ namespace Tensorflow
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions.
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
/// </summary>
[SuppressMessage("ReSharper", "ConvertToAutoProperty")]
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike
{
private int _id;
private Operation _op;
private readonly int _id;
private readonly Operation _op;
private readonly int _value_index;
private TF_Output? _tf_output;
private readonly TF_DataType _dtype;
public int Id => _id;
/// <summary>
/// The Graph that contains this tensor.
/// </summary>
public Graph graph => op?.graph;
/// <summary>
/// The Operation that produces this tensor as an output.
/// </summary>
public Operation op => _op;
public Tensor[] outputs => op.outputs;
/// <summary>
/// The string name of this tensor.
/// The string name of this tensor.
/// </summary>
public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}";
private int _value_index;
/// <summary>
/// The index of this tensor in the outputs of its Operation.
/// </summary>
public int value_index => _value_index;
private TF_DataType _dtype = TF_DataType.DtInvalid;
/// <summary>
/// The DType of elements in this tensor.
/// </summary>
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
public int NDims => rank;
private TF_Output? _tf_output;
/// <summary>
/// The name of the device on which this tensor will be produced, or null.
/// </summary>
public string Device => op.Device;
public int[] dims => shape;
/// <summary>
/// used for keep other pointer when do implicit operating
/// U sed for keep other pointer when do implicit operating
/// </summary>
public object Tag { get; set; }
/// <summary>
/// Returns the shape of a tensor.
/// </summary>
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/shape</remarks>
public int[] shape
{
get
@@ -76,14 +109,13 @@ namespace Tensorflow
var status = new Status();
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
}
else
} else
{
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);
}
return dims.Select(x => Convert.ToInt32(x )).ToArray();
return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture )).ToArray();
}
set
@@ -93,38 +125,52 @@ namespace Tensorflow
if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
else
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x) ).ToArray(), value.Length, status);
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);
}
}
public int[] _shape_tuple()
{
if (shape == null) return null;
return shape.Select(x => (int)x).ToArray();
return (int[]) shape.Clone();
}
public TensorShape TensorShape => tensor_util.to_shape(shape);
public void SetShape(TensorShape shape)
/// <summary>
/// Updates the shape of this tensor.
/// </summary>
public void set_shape(TensorShape shape)
{
this.shape = (int[]) shape.dims.Clone();
}
/// <summary>
/// Updates the shape of this tensor.
/// </summary>
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
public void SetShape(TensorShape shape)
{
this.shape = shape.dims;
this.shape = (int[]) shape.dims.Clone() ;
}
/// <summary>
/// Updates the shape of this tensor.
/// </summary>
public void set_shape(Tensor shape)
{
// ReSharper disable once MergeConditionalExpression
this.shape = shape is null ? null : shape.shape;
}
public int[] dims => shape;
/// <summary>
/// number of dimensions
/// 0 Scalar (magnitude only)
/// 1 Vector (magnitude and direction)
/// 2 Matrix (table of numbers)
/// 3 3-Tensor (cube of numbers)
/// number of dimensions <br></br>
/// 0 Scalar (magnitude only) <br></br>
/// 1 Vector (magnitude and direction) <br></br>
/// 2 Matrix (table of numbers) <br></br>
/// 3 3-Tensor (cube of numbers) <br></br>
/// n n-Tensor (you get the idea)
/// </summary>
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/rank</remarks>
public int rank
{
get
@@ -137,17 +183,15 @@ namespace Tensorflow
status.Check();
return ndim;
}
else
{
return c_api.TF_NumDims(_handle);
}
return c_api.TF_NumDims(_handle);
}
}
public int NDims => rank;
public string Device => op.Device;
/// <summary>
/// Returns a list of Operations that consume this tensor.
/// </summary>
/// <returns></returns>
public Operation[] consumers()
{
var output = _as_tf_output();
@@ -157,37 +201,136 @@ namespace Tensorflow
public TF_Output _as_tf_output()
{
if(!_tf_output.HasValue)
if (!_tf_output.HasValue)
_tf_output = new TF_Output(op, value_index);
return _tf_output.Value;
}
public T[] Data<T>()
[Obsolete("Please use ToArray<T>() instead.", false)]
public T[] Data<T>() where T : unmanaged
{
return ToArray<T>();
}
/// <summary>
///
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
/// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception>
public T[] ToArray<T>() where T : unmanaged
{
// Column major order
// https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg
// matrix:[[1, 2, 3], [4, 5, 6]]
// index: 0 2 4 1 3 5
// result: 1 4 2 5 3 6
var data = new T[size];
for (ulong i = 0; i < size; i++)
//when T is string
if (typeof(T) == typeof(string))
{
data[i] = Marshal.PtrToStructure<T>(buffer + (int)(i * itemsize));
if (dtype != TF_DataType.TF_STRING)
throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string.");
return (T[]) (object) StringData();
}
return data;
//Are the types matching?
if (typeof(T).as_dtype() == _dtype)
{
//types match, no need to perform cast
var ret = new T[size];
unsafe
{
var len = (long) size;
fixed (T* dstRet = ret)
{
T* dst = dstRet; //local stack copy
if (typeof(T).IsPrimitive)
{
var src = (T*) buffer;
len *= ((long) itemsize);
System.Buffer.MemoryCopy(src, dst, len, len);
} else
{
var itemsize = (long) this.itemsize;
var buffer = this.buffer.ToInt64();
Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure<T>(new IntPtr(buffer + i * itemsize)));
}
}
}
return ret;
} else
{
//types do not match, need to perform cast
var ret = new T[size];
unsafe
{
var len = (long) size;
fixed (T* dstRet = ret)
{
T* dst = dstRet; //local stack copy
#if _REGEN
#region Compute
switch (_dtype.as_numpy_datatype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1:new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
%
default:
throw new NotSupportedException();
}
#endregion
#else
#region Compute
switch (_dtype.as_numpy_dtype().GetTypeCode())
{
case NPTypeCode.Boolean:new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Byte:new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int16:new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt16:new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int32:new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt32:new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int64:new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt64:new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Char:new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Double:new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Single:new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
default:
throw new NotSupportedException();
}
#endregion
#endif
}
}
return ret;
}
}
/// <summary>
/// Copies the memory of current buffer onto newly allocated array.
/// </summary>
/// <returns></returns>
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
public byte[] Data()
{
return BufferToArray();
}
/// <summary>
/// Copies the memory of current buffer onto newly allocated array.
/// </summary>
/// <returns></returns>
public byte[] BufferToArray()
{
var data = new byte[bytesize];
Marshal.Copy(buffer, data, 0, (int)bytesize);
Marshal.Copy(buffer, data, 0, (int) bytesize);
return data;
}
public unsafe string[] StringData()
/// Used internally in ToArray<T>
private unsafe string[] StringData()
{
//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
@@ -199,19 +342,19 @@ namespace Tensorflow
var buffer = new byte[size][];
var src = c_api.TF_TensorData(_handle);
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize);
src += (int)(size * 8);
var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize);
src += (int) (size * 8);
for (int i = 0; i < buffer.Length; i++)
{
using (var status = new Status())
{
IntPtr dst = IntPtr.Zero;
UIntPtr dstLen = UIntPtr.Zero;
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status);
var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status);
status.Check(true);
buffer[i] = new byte[(int)dstLen];
buffer[i] = new byte[(int) dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
src += (int)read;
src += (int) read;
}
}
@@ -229,51 +372,29 @@ namespace Tensorflow
}
/// <summary>
/// Evaluates this tensor in a `Session`.
/// Evaluates this tensor in a `Session`.
/// </summary>
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param>
/// <param name="session">The `Session` to be used to evaluate this tensor.</param>
/// <returns></returns>
/// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns>
public NDArray eval(params FeedItem[] feed_dict)
{
return ops._eval_using_default_session(this, feed_dict, graph);
}
/// <summary>
/// Evaluates this tensor in a `Session`.
/// </summary>
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param>
/// <param name="session">The `Session` to be used to evaluate this tensor.</param>
/// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns>
public NDArray eval(Session session, FeedItem[] feed_dict = null)
{
return ops._eval_using_default_session(this, feed_dict, graph, session);
}
public TF_DataType ToTFDataType(Type type)
{
switch (type.Name)
{
case "Char":
return TF_DataType.TF_UINT8;
case "Int16":
return TF_DataType.TF_INT16;
case "Int32":
return TF_DataType.TF_INT32;
case "Int64":
return TF_DataType.TF_INT64;
case "Single":
return TF_DataType.TF_FLOAT;
case "Double":
return TF_DataType.TF_DOUBLE;
case "Byte":
return TF_DataType.TF_UINT8;
case "String":
return TF_DataType.TF_STRING;
case "Boolean":
return TF_DataType.TF_BOOL;
default:
throw new NotImplementedException("ToTFDataType error");
}
}
public Tensor slice(Slice slice)
{
var slice_spec = new int[] { slice.Start.Value };
var slice_spec = new int[] {slice.Start.Value};
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();
@@ -289,26 +410,26 @@ namespace Tensorflow
if (slice.Stop.HasValue)
{
end.Add(slice.Stop.Value);
}
else
} else
{
end.Add(0);
end_mask |= (1 << index);
}
strides.Add(slice.Step);
index += 1;
}
return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));
return gen_array_ops.strided_slice(
this,
@@ -320,7 +441,6 @@ namespace Tensorflow
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
name: name);
}
@@ -330,7 +450,7 @@ namespace Tensorflow
public Tensor slice(int start)
{
var slice_spec = new int[] { start };
var slice_spec = new int[] {start};
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();
@@ -349,15 +469,15 @@ namespace Tensorflow
index += 1;
}
return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));
return gen_array_ops.strided_slice(
this,
@@ -369,7 +489,6 @@ namespace Tensorflow
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
name: name);
}
@@ -392,13 +511,9 @@ namespace Tensorflow
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
}
//protected override void DisposeManagedState()
//{
//}
protected override void DisposeUnmanagedResources(IntPtr handle)
{
if(handle != IntPtr.Zero)
if (handle != IntPtr.Zero)
{
c_api.TF_DeleteTensor(handle);
}
@@ -417,4 +532,4 @@ namespace Tensorflow
public int tensor_int_val { get; set; }
}
}
}