diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
index 2c05e36a..2a76c52c 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
@@ -141,7 +141,7 @@ namespace Tensorflow.Operations
data, frame_name, is_constant, parallel_iterations, name: name);
if (use_input_shape)
- result.SetShape(data.TensorShape);
+ result.set_shape(data.TensorShape);
return result;
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
index 3198942b..1b68d1cd 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
@@ -233,7 +233,7 @@ namespace Tensorflow.Operations
dims.AddRange(x_static_shape.dims.Skip(2));
var shape = new TensorShape(dims.ToArray());
- x_t.SetShape(shape);
+ x_t.set_shape(shape);
return x_t;
}
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index d3213250..92fe2e3c 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -351,7 +351,7 @@ namespace Tensorflow
var input_shape = tensor_util.to_shape(input_tensor.shape);
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined())
{
- var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype());
+ var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype());
return constant_op.constant(nd, name: name);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs
index cbf55861..63e0fca1 100644
--- a/src/TensorFlowNET.Core/Operations/nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs
@@ -98,7 +98,7 @@ namespace Tensorflow
// float to be selected, hence we use a >= comparison.
var keep_mask = random_tensor >= rate;
var ret = x * scale * math_ops.cast(keep_mask, x.dtype);
- ret.SetShape(x.TensorShape);
+ ret.set_shape(x.TensorShape);
return ret;
});
}
diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
index b6db7a65..42ab1d4b 100644
--- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs
+++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
@@ -273,7 +273,7 @@ namespace Tensorflow
{
var tensor = new Tensor(output);
NDArray nd = null;
- Type type = tensor.dtype.as_numpy_datatype();
+ Type type = tensor.dtype.as_numpy_dtype();
var ndims = tensor.shape;
var offset = c_api.TF_TensorData(output);
@@ -285,7 +285,7 @@ namespace Tensorflow
nd = NDArray.Scalar(*(bool*)offset);
break;
case TF_DataType.TF_STRING:
- var bytes = tensor.Data();
+ var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
@@ -324,7 +324,7 @@ namespace Tensorflow
nd = np.array(bools).reshape(ndims);
break;
case TF_DataType.TF_STRING:
- var bytes = tensor.Data();
+ var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
index 5fd3dfba..ea58607b 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
@@ -549,10 +549,11 @@ namespace Tensorflow
this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it
var ptr = new IntPtr(arraySlice.Address);
int num_bytes = (nd.size * nd.dtypesize);
- var dtype = given_dtype ?? ToTFDataType(nd.dtype);
+ var dtype = given_dtype ?? nd.dtype.as_dtype();
var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs);
IsMemoryOwner = false;
return handle;
+
}
public unsafe Tensor(byte[][] buffer, long[] shape)
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 66466b22..798c27b6 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -17,9 +17,16 @@
using NumSharp;
using System;
using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.Globalization;
using System.Linq;
+using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
+using System.Threading.Tasks;
+using NumSharp.Backends;
+using NumSharp.Backends.Unmanaged;
+using NumSharp.Utilities;
using Tensorflow.Framework;
using static Tensorflow.Binding;
@@ -29,42 +36,68 @@ namespace Tensorflow
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions.
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
///
+ [SuppressMessage("ReSharper", "ConvertToAutoProperty")]
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike
{
- private int _id;
- private Operation _op;
+ private readonly int _id;
+ private readonly Operation _op;
+ private readonly int _value_index;
+ private TF_Output? _tf_output;
+ private readonly TF_DataType _dtype;
public int Id => _id;
+
+ ///
+ /// The Graph that contains this tensor.
+ ///
public Graph graph => op?.graph;
+
+ ///
+ /// The Operation that produces this tensor as an output.
+ ///
public Operation op => _op;
+
public Tensor[] outputs => op.outputs;
///
- /// The string name of this tensor.
+ /// The string name of this tensor.
///
public string name => $"{(op == null ? "" : $"{op.name}:{_value_index}")}";
- private int _value_index;
+ ///
+ /// The index of this tensor in the outputs of its Operation.
+ ///
public int value_index => _value_index;
- private TF_DataType _dtype = TF_DataType.DtInvalid;
+ ///
+ /// The DType of elements in this tensor.
+ ///
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
-
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
-
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
+ public int NDims => rank;
- private TF_Output? _tf_output;
+ ///
+ /// The name of the device on which this tensor will be produced, or null.
+ ///
+ public string Device => op.Device;
+
+ public int[] dims => shape;
///
- /// used for keep other pointer when do implicit operating
+ /// Used for keep other pointer when do implicit operating
///
public object Tag { get; set; }
+
+ ///
+ /// Returns the shape of a tensor.
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/shape
public int[] shape
{
get
@@ -76,14 +109,13 @@ namespace Tensorflow
var status = new Status();
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
- }
- else
+ } else
{
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);
}
- return dims.Select(x => Convert.ToInt32(x)).ToArray();
+ return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray();
}
set
@@ -93,38 +125,52 @@ namespace Tensorflow
if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
else
- c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status);
+ c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);
}
}
public int[] _shape_tuple()
{
- if (shape == null) return null;
- return shape.Select(x => (int)x).ToArray();
+ return (int[]) shape.Clone();
}
public TensorShape TensorShape => tensor_util.to_shape(shape);
- public void SetShape(TensorShape shape)
+ ///
+ /// Updates the shape of this tensor.
+ ///
+ public void set_shape(TensorShape shape)
+ {
+ this.shape = (int[]) shape.dims.Clone();
+ }
+
+ ///
+ /// Updates the shape of this tensor.
+ ///
+ [Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
+ public void SetShape(TensorShape shape)
{
- this.shape = shape.dims;
+ this.shape = (int[]) shape.dims.Clone();
}
+ ///
+ /// Updates the shape of this tensor.
+ ///
public void set_shape(Tensor shape)
{
+ // ReSharper disable once MergeConditionalExpression
this.shape = shape is null ? null : shape.shape;
}
- public int[] dims => shape;
-
///
- /// number of dimensions
- /// 0 Scalar (magnitude only)
- /// 1 Vector (magnitude and direction)
- /// 2 Matrix (table of numbers)
- /// 3 3-Tensor (cube of numbers)
+ /// number of dimensions
+ /// 0 Scalar (magnitude only)
+ /// 1 Vector (magnitude and direction)
+ /// 2 Matrix (table of numbers)
+ /// 3 3-Tensor (cube of numbers)
/// n n-Tensor (you get the idea)
///
+ /// https://www.tensorflow.org/api_docs/python/tf/rank
public int rank
{
get
@@ -137,17 +183,15 @@ namespace Tensorflow
status.Check();
return ndim;
}
- else
- {
- return c_api.TF_NumDims(_handle);
- }
+
+ return c_api.TF_NumDims(_handle);
}
}
- public int NDims => rank;
-
- public string Device => op.Device;
-
+ ///
+ /// Returns a list of Operations that consume this tensor.
+ ///
+ ///
public Operation[] consumers()
{
var output = _as_tf_output();
@@ -157,37 +201,136 @@ namespace Tensorflow
public TF_Output _as_tf_output()
{
- if(!_tf_output.HasValue)
+ if (!_tf_output.HasValue)
_tf_output = new TF_Output(op, value_index);
return _tf_output.Value;
}
- public T[] Data()
+ [Obsolete("Please use ToArray() instead.", false)]
+ public T[] Data() where T : unmanaged
+ {
+ return ToArray();
+ }
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// When is string
+ public T[] ToArray() where T : unmanaged
{
- // Column major order
- // https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg
- // matrix:[[1, 2, 3], [4, 5, 6]]
- // index: 0 2 4 1 3 5
- // result: 1 4 2 5 3 6
- var data = new T[size];
-
- for (ulong i = 0; i < size; i++)
+ //when T is string
+ if (typeof(T) == typeof(string))
{
- data[i] = Marshal.PtrToStructure(buffer + (int)(i * itemsize));
+ if (dtype != TF_DataType.TF_STRING)
+ throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string.");
+
+ return (T[]) (object) StringData();
}
- return data;
+ //Are the types matching?
+ if (typeof(T).as_dtype() == _dtype)
+ {
+ //types match, no need to perform cast
+ var ret = new T[size];
+ unsafe
+ {
+ var len = (long) size;
+ fixed (T* dstRet = ret)
+ {
+ T* dst = dstRet; //local stack copy
+ if (typeof(T).IsPrimitive)
+ {
+ var src = (T*) buffer;
+ len *= ((long) itemsize);
+ System.Buffer.MemoryCopy(src, dst, len, len);
+ } else
+ {
+ var itemsize = (long) this.itemsize;
+ var buffer = this.buffer.ToInt64();
+ Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure(new IntPtr(buffer + i * itemsize)));
+ }
+ }
+ }
+
+ return ret;
+ } else
+ {
+
+ //types do not match, need to perform cast
+ var ret = new T[size];
+ unsafe
+ {
+ var len = (long) size;
+ fixed (T* dstRet = ret)
+ {
+ T* dst = dstRet; //local stack copy
+
+#if _REGEN
+ #region Compute
+ switch (_dtype.as_numpy_datatype().GetTypeCode())
+ {
+ %foreach supported_dtypes,supported_dtypes_lowercase%
+ case NPTypeCode.#1:new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ %
+ default:
+ throw new NotSupportedException();
+ }
+ #endregion
+#else
+ #region Compute
+ switch (_dtype.as_numpy_dtype().GetTypeCode())
+ {
+ case NPTypeCode.Boolean:new UnmanagedMemoryBlock((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Byte:new UnmanagedMemoryBlock((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Int16:new UnmanagedMemoryBlock((short*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.UInt16:new UnmanagedMemoryBlock((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Int32:new UnmanagedMemoryBlock((int*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.UInt32:new UnmanagedMemoryBlock((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Int64:new UnmanagedMemoryBlock((long*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.UInt64:new UnmanagedMemoryBlock((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Char:new UnmanagedMemoryBlock((char*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Double:new UnmanagedMemoryBlock((double*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Single:new UnmanagedMemoryBlock((float*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ default:
+ throw new NotSupportedException();
+ }
+ #endregion
+#endif
+
+ }
+ }
+
+ return ret;
+ }
}
+
+ ///
+ /// Copies the memory of current buffer onto newly allocated array.
+ ///
+ ///
+ [Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
public byte[] Data()
+ {
+ return BufferToArray();
+ }
+
+ ///
+ /// Copies the memory of current buffer onto newly allocated array.
+ ///
+ ///
+ public byte[] BufferToArray()
{
var data = new byte[bytesize];
- Marshal.Copy(buffer, data, 0, (int)bytesize);
+ Marshal.Copy(buffer, data, 0, (int) bytesize);
return data;
}
- public unsafe string[] StringData()
+ /// Used internally in ToArray<T>
+ private unsafe string[] StringData()
{
//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
@@ -199,19 +342,19 @@ namespace Tensorflow
var buffer = new byte[size][];
var src = c_api.TF_TensorData(_handle);
- var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize);
- src += (int)(size * 8);
+ var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize);
+ src += (int) (size * 8);
for (int i = 0; i < buffer.Length; i++)
{
using (var status = new Status())
{
IntPtr dst = IntPtr.Zero;
UIntPtr dstLen = UIntPtr.Zero;
- var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status);
+ var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status);
status.Check(true);
- buffer[i] = new byte[(int)dstLen];
+ buffer[i] = new byte[(int) dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
- src += (int)read;
+ src += (int) read;
}
}
@@ -229,51 +372,29 @@ namespace Tensorflow
}
///
- /// Evaluates this tensor in a `Session`.
+ /// Evaluates this tensor in a `Session`.
///
/// A dictionary that maps `Tensor` objects to feed values.
- /// The `Session` to be used to evaluate this tensor.
- ///
+ /// A array corresponding to the value of this tensor.
public NDArray eval(params FeedItem[] feed_dict)
{
return ops._eval_using_default_session(this, feed_dict, graph);
}
+ ///
+ /// Evaluates this tensor in a `Session`.
+ ///
+ /// A dictionary that maps `Tensor` objects to feed values.
+ /// The `Session` to be used to evaluate this tensor.
+ /// A array corresponding to the value of this tensor.
public NDArray eval(Session session, FeedItem[] feed_dict = null)
{
return ops._eval_using_default_session(this, feed_dict, graph, session);
}
- public TF_DataType ToTFDataType(Type type)
- {
- switch (type.Name)
- {
- case "Char":
- return TF_DataType.TF_UINT8;
- case "Int16":
- return TF_DataType.TF_INT16;
- case "Int32":
- return TF_DataType.TF_INT32;
- case "Int64":
- return TF_DataType.TF_INT64;
- case "Single":
- return TF_DataType.TF_FLOAT;
- case "Double":
- return TF_DataType.TF_DOUBLE;
- case "Byte":
- return TF_DataType.TF_UINT8;
- case "String":
- return TF_DataType.TF_STRING;
- case "Boolean":
- return TF_DataType.TF_BOOL;
- default:
- throw new NotImplementedException("ToTFDataType error");
- }
- }
-
public Tensor slice(Slice slice)
{
- var slice_spec = new int[] { slice.Start.Value };
+ var slice_spec = new int[] {slice.Start.Value};
var begin = new List();
var end = new List();
var strides = new List();
@@ -289,26 +410,26 @@ namespace Tensorflow
if (slice.Stop.HasValue)
{
end.Add(slice.Stop.Value);
- }
- else
+ } else
{
end.Add(0);
end_mask |= (1 << index);
}
+
strides.Add(slice.Step);
index += 1;
}
- return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
+ return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
- array_ops.stack(end.ToArray()),
- array_ops.stack(strides.ToArray()));
+ array_ops.stack(end.ToArray()),
+ array_ops.stack(strides.ToArray()));
return gen_array_ops.strided_slice(
this,
@@ -320,7 +441,6 @@ namespace Tensorflow
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
-
name: name);
}
@@ -330,7 +450,7 @@ namespace Tensorflow
public Tensor slice(int start)
{
- var slice_spec = new int[] { start };
+ var slice_spec = new int[] {start};
var begin = new List();
var end = new List();
var strides = new List();
@@ -349,15 +469,15 @@ namespace Tensorflow
index += 1;
}
- return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
+ return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
- array_ops.stack(end.ToArray()),
- array_ops.stack(strides.ToArray()));
+ array_ops.stack(end.ToArray()),
+ array_ops.stack(strides.ToArray()));
return gen_array_ops.strided_slice(
this,
@@ -369,7 +489,6 @@ namespace Tensorflow
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
-
name: name);
}
@@ -392,13 +511,9 @@ namespace Tensorflow
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
}
- //protected override void DisposeManagedState()
- //{
- //}
-
protected override void DisposeUnmanagedResources(IntPtr handle)
{
- if(handle != IntPtr.Zero)
+ if (handle != IntPtr.Zero)
{
c_api.TF_DeleteTensor(handle);
}
@@ -417,4 +532,4 @@ namespace Tensorflow
public int tensor_int_val { get; set; }
}
-}
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs
index 807dc6f5..37f1ca61 100644
--- a/src/TensorFlowNET.Core/Tensors/dtypes.cs
+++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs
@@ -15,6 +15,8 @@
******************************************************************************/
using System;
+using System.Numerics;
+using NumSharp.Backends;
namespace Tensorflow
{
@@ -23,35 +25,100 @@ namespace Tensorflow
public static TF_DataType int8 = TF_DataType.TF_INT8;
public static TF_DataType int32 = TF_DataType.TF_INT32;
public static TF_DataType int64 = TF_DataType.TF_INT64;
+ public static TF_DataType uint8 = TF_DataType.TF_UINT8;
+ public static TF_DataType uint32 = TF_DataType.TF_UINT32;
+ public static TF_DataType uint64 = TF_DataType.TF_UINT64;
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
- public static Type as_numpy_datatype(this TF_DataType type)
+ ///
+ ///
+ ///
+ ///
+ /// equivalent to , if none exists, returns null.
+ public static Type as_numpy_dtype(this TF_DataType type)
{
switch (type)
{
case TF_DataType.TF_BOOL:
return typeof(bool);
+ case TF_DataType.TF_UINT8:
+ return typeof(byte);
case TF_DataType.TF_INT64:
return typeof(long);
+ case TF_DataType.TF_UINT64:
+ return typeof(ulong);
case TF_DataType.TF_INT32:
return typeof(int);
+ case TF_DataType.TF_UINT32:
+ return typeof(uint);
case TF_DataType.TF_INT16:
return typeof(short);
+ case TF_DataType.TF_UINT16:
+ return typeof(ushort);
case TF_DataType.TF_FLOAT:
return typeof(float);
case TF_DataType.TF_DOUBLE:
return typeof(double);
case TF_DataType.TF_STRING:
return typeof(string);
+ case TF_DataType.TF_COMPLEX128:
+ case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX
+ return typeof(Complex);
default:
return null;
}
}
- // "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"
- public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null)
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// When has no equivalent
+ public static NPTypeCode as_numpy_typecode(this TF_DataType type)
+ {
+ switch (type)
+ {
+ case TF_DataType.TF_BOOL:
+ return NPTypeCode.Boolean;
+ case TF_DataType.TF_UINT8:
+ return NPTypeCode.Byte;
+ case TF_DataType.TF_INT64:
+ return NPTypeCode.Int64;
+ case TF_DataType.TF_INT32:
+ return NPTypeCode.Int32;
+ case TF_DataType.TF_INT16:
+ return NPTypeCode.Int16;
+ case TF_DataType.TF_UINT64:
+ return NPTypeCode.UInt64;
+ case TF_DataType.TF_UINT32:
+ return NPTypeCode.UInt32;
+ case TF_DataType.TF_UINT16:
+ return NPTypeCode.UInt16;
+ case TF_DataType.TF_FLOAT:
+ return NPTypeCode.Single;
+ case TF_DataType.TF_DOUBLE:
+ return NPTypeCode.Double;
+ case TF_DataType.TF_STRING:
+ return NPTypeCode.String;
+ case TF_DataType.TF_COMPLEX128:
+ case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX
+ return NPTypeCode.Complex;
+ default:
+ throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode.");
+ }
+ }
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// When has no equivalent
+ public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null)
{
switch (type.Name)
{
@@ -98,7 +165,7 @@ namespace Tensorflow
dtype = TF_DataType.TF_BOOL;
break;
default:
- throw new Exception("as_dtype Not Implemented");
+ throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode.");
}
return dtype.Value;
@@ -106,16 +173,7 @@ namespace Tensorflow
public static DataType as_datatype_enum(this TF_DataType type)
{
- DataType dtype = DataType.DtInvalid;
-
- switch (type)
- {
- default:
- Enum.TryParse(((int)type).ToString(), out dtype);
- break;
- }
-
- return dtype;
+ return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid;
}
public static TF_DataType as_base_dtype(this TF_DataType type)
@@ -132,7 +190,7 @@ namespace Tensorflow
public static Type as_numpy_dtype(this DataType type)
{
- return type.as_tf_dtype().as_numpy_datatype();
+ return type.as_tf_dtype().as_numpy_dtype();
}
public static DataType as_base_dtype(this DataType type)
@@ -144,16 +202,7 @@ namespace Tensorflow
public static TF_DataType as_tf_dtype(this DataType type)
{
- TF_DataType dtype = TF_DataType.DtInvalid;
-
- switch (type)
- {
- default:
- Enum.TryParse(((int)type).ToString(), out dtype);
- break;
- }
-
- return dtype;
+ return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid;
}
public static TF_DataType as_ref(this TF_DataType type)
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index ded105c7..43848da6 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -17,6 +17,7 @@
using NumSharp;
using System;
using System.Linq;
+using NumSharp.Utilities;
namespace Tensorflow
{
@@ -109,7 +110,7 @@ namespace Tensorflow
// We first convert value to a numpy array or scalar.
NDArray nparray = null;
- var np_dt = dtype.as_numpy_datatype();
+ var np_dt = dtype.as_numpy_dtype();
if (values is NDArray nd)
{
@@ -188,37 +189,37 @@ namespace Tensorflow
if (values.GetType().IsArray)
nparray = np.array((int[])values, np_dt);
else
- nparray = Convert.ToInt32(values);
+ nparray = Converts.ToInt32(values);
break;
case "Int64":
if (values.GetType().IsArray)
nparray = np.array((int[])values, np_dt);
else
- nparray = Convert.ToInt64(values);
+ nparray = Converts.ToInt64(values);
break;
case "Single":
if (values.GetType().IsArray)
nparray = np.array((float[])values, np_dt);
else
- nparray = Convert.ToSingle(values);
+ nparray = Converts.ToSingle(values);
break;
case "Double":
if (values.GetType().IsArray)
nparray = np.array((double[])values, np_dt);
else
- nparray = Convert.ToDouble(values);
+ nparray = Converts.ToDouble(values);
break;
case "String":
if (values.GetType().IsArray)
nparray = np.array((string[])values, np_dt);
else
- nparray = NDArray.FromString(Convert.ToString(values));
+ nparray = NDArray.FromString(Converts.ToString(values));
break;
case "Boolean":
if (values.GetType().IsArray)
nparray = np.array((bool[])values, np_dt);
else
- nparray = Convert.ToBoolean(values);
+ nparray = Converts.ToBoolean(values);
break;
default:
throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented");
diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs
index 17b095a4..c5a06433 100644
--- a/src/TensorFlowNET.Core/ops.GraphKeys.cs
+++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs
@@ -29,55 +29,111 @@ namespace Tensorflow
///
public class GraphKeys
{
+ #region const
+
+
+ ///
+ /// the subset of `Variable` objects that will be trained by an optimizer.
+ ///
+ public const string TRAINABLE_VARIABLES_ = "trainable_variables";
+
+ ///
+ /// Trainable resource-style variables.
+ ///
+ public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables";
+
+ ///
+ /// Key for streaming model ports.
+ ///
+ public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports";
+
+ ///
+ /// Key to collect losses
+ ///
+ public const string LOSSES_ = "losses";
+
+ ///
+ /// Key to collect Variable objects that are global (shared across machines).
+ /// Default collection for all variables, except local ones.
+ ///
+ public const string GLOBAL_VARIABLES_ = "variables";
+
+ public const string TRAIN_OP_ = "train_op";
+
+ public const string GLOBAL_STEP_ = "global_step";
+
+ public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" };
+ ///
+ /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
+ ///
+ public const string SAVEABLE_OBJECTS_ = "saveable_objects";
+ ///
+ /// Key to collect update_ops
+ ///
+ public const string UPDATE_OPS_ = "update_ops";
+
+ // Key to collect summaries.
+ public const string SUMMARIES_ = "summaries";
+
+ // Used to store v2 summary names.
+ public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2";
+
+ // Key for control flow context.
+ public const string COND_CONTEXT_ = "cond_context";
+ public const string WHILE_CONTEXT_ = "while_context";
+
+ #endregion
+
+
///
/// the subset of `Variable` objects that will be trained by an optimizer.
///
- public string TRAINABLE_VARIABLES = "trainable_variables";
+ public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_;
///
/// Trainable resource-style variables.
///
- public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables";
+ public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_;
///
/// Key for streaming model ports.
///
- public string _STREAMING_MODEL_PORTS = "streaming_model_ports";
+ public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_;
///
/// Key to collect losses
///
- public string LOSSES = "losses";
+ public string LOSSES => LOSSES_;
///
/// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones.
///
- public string GLOBAL_VARIABLES = "variables";
+ public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_;
- public string TRAIN_OP = "train_op";
+ public string TRAIN_OP => TRAIN_OP_;
- public string GLOBAL_STEP = "global_step";
+ public string GLOBAL_STEP => GLOBAL_STEP_;
- public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" };
+ public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_;
///
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
///
- public string SAVEABLE_OBJECTS = "saveable_objects";
+ public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_;
///
/// Key to collect update_ops
///
- public string UPDATE_OPS = "update_ops";
+ public string UPDATE_OPS => UPDATE_OPS_;
// Key to collect summaries.
- public string SUMMARIES = "summaries";
+ public string SUMMARIES => SUMMARIES_;
// Used to store v2 summary names.
- public string _SUMMARY_COLLECTION = "_SUMMARY_V2";
+ public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_;
// Key for control flow context.
- public string COND_CONTEXT = "cond_context";
- public string WHILE_CONTEXT = "while_context";
+ public string COND_CONTEXT => COND_CONTEXT_;
+ public string WHILE_CONTEXT => WHILE_CONTEXT_;
}
}
}
diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs
index 9c8485ec..8fd4dc8a 100644
--- a/test/TensorFlowNET.UnitTest/SessionTest.cs
+++ b/test/TensorFlowNET.UnitTest/SessionTest.cs
@@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims);
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
- var output_contents = outTensor.Data();
+ var output_contents = outTensor.ToArray();
EXPECT_EQ(3 + 2, output_contents[0]);
// Add another operation to the graph.
@@ -66,7 +66,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); // scalar
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
- output_contents = outTensor.Data();
+ output_contents = outTensor.ToArray();
EXPECT_EQ(-(7 + 2), output_contents[0]);
// Clean up
diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs
index 07da9dca..11557f14 100644
--- a/test/TensorFlowNET.UnitTest/TensorTest.cs
+++ b/test/TensorFlowNET.UnitTest/TensorTest.cs
@@ -112,7 +112,7 @@ namespace TensorFlowNET.UnitTest
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);
var tensor = new Tensor(nd);
- var array = tensor.Data();
+ var array = tensor.ToArray();
EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
EXPECT_EQ(tensor.rank, nd.ndim);
diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
index 1fd7d3aa..3a5515d9 100644
--- a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
+++ b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
@@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.nn_test
var y_np = this._ZeroFraction(x_np);
var x_tf = constant_op.constant(x_np);
- x_tf.SetShape(x_shape);
+ x_tf.set_shape(x_shape);
var y_tf = nn_impl.zero_fraction(x_tf);
var y_tf_np = self.evaluate(y_tf);