* Refactored DisposableObject * Added different build directory for TensorflowNET.Examples.GPU * _FetchHandler: Switched to NPTypeCode * gfile.cs, Walk(...): Handle case when directory top doesn't exist. * Tensor.Creation: Perf-opted when creating tensor from NDArray of string * Graph.cs: refactor and added docs * Tensor.Creation.cs: perf-ops * Tensor.Explicit.cs: perf-ops * Copied globals.regen from NumSharp - Added supported_numericals_TF_DataType * Tensor perf-ops and cleanup, Revamped dtypes.cs, some renames. - Cleanup and docs to all Tensor.cs files - Changed all uses of System.Convert to NumSharp.Utilities.Converts - Added all missing types in dtypes.cs - Renamed tensor.Data<T> to tensor.ToArray<T>, added obsolete message - Renamed tensor.Data() to tensor.BufferToArray(), added obsolete message - Made GraphKeys to use const string instead allocating strings at every use of GraphKeys. * Tensor: Added guards for explicit casts. * Tensor: Added explicit cast to string * Tensor.ToArray<T>(): Added support for cases when tensor is scalar. * Tensor.BufferToArray(): Fixed to use long instead of int. * TensorShape: Revamped and documented. * BaseSession: Added Session.run(ITensorOrOperation fetche, params FeedItem[] feed_dict) * Tensor: renamed _dtype to _override_dtype - Fixed all locations _dtype is used incorrectly. * Fixed unit tests * Tensor.Operations: Reverted commit * DisposableObject: sorted internal_dispose to properly handle Dispose() calls * Tensor.DisposeUnmanagedResources: Nullify _handle after delete. * TensorShape.this[...]: fixed guard check. * DisposableObject #362tags/v0.12
| @@ -59,6 +59,6 @@ namespace Tensorflow | |||||
| } | } | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_Version(); | |||||
| public static extern IntPtr TF_Version(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -308,15 +308,14 @@ namespace Tensorflow | |||||
| public static IEnumerable TupleToEnumerable(object tuple) | public static IEnumerable TupleToEnumerable(object tuple) | ||||
| { | { | ||||
| Type t = tuple.GetType(); | Type t = tuple.GetType(); | ||||
| if(t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) | |||||
| if (t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) | |||||
| { | { | ||||
| var flds = t.GetFields(); | var flds = t.GetFields(); | ||||
| for(int i = 0; i < flds.Length;i++) | |||||
| for (int i = 0; i < flds.Length; i++) | |||||
| { | { | ||||
| yield return flds[i].GetValue(tuple); | yield return flds[i].GetValue(tuple); | ||||
| } | } | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| throw new System.Exception("Expected Tuple."); | throw new System.Exception("Expected Tuple."); | ||||
| } | } | ||||
| @@ -329,12 +328,9 @@ namespace Tensorflow | |||||
| public static bool isinstance(object Item1, object tuple) | public static bool isinstance(object Item1, object tuple) | ||||
| { | { | ||||
| var tup = TupleToEnumerable(tuple); | |||||
| foreach(var t in tup) | |||||
| { | |||||
| if(isinstance(Item1, (Type)t)) | |||||
| foreach (var t in TupleToEnumerable(tuple)) | |||||
| if (isinstance(Item1, (Type) t)) | |||||
| return true; | return true; | ||||
| } | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -66,7 +66,7 @@ namespace Tensorflow | |||||
| return buffer.Data; | return buffer.Data; | ||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TF_DeleteBuffer(handle); | => c_api.TF_DeleteBuffer(handle); | ||||
| } | } | ||||
| } | } | ||||
| @@ -29,18 +29,10 @@ namespace Tensorflow | |||||
| protected DisposableObject() { } | protected DisposableObject() { } | ||||
| public DisposableObject(IntPtr handle) | |||||
| { | |||||
| _handle = handle; | |||||
| } | |||||
| protected virtual void DisposeManagedState() | |||||
| { | |||||
| } | |||||
| protected DisposableObject(IntPtr handle) | |||||
| => _handle = handle; | |||||
| protected abstract void DisposeUnManagedState(IntPtr handle); | |||||
| protected virtual void Dispose(bool disposing) | |||||
| private void internal_dispose(bool disposing) | |||||
| { | { | ||||
| if (disposing) | if (disposing) | ||||
| { | { | ||||
| @@ -48,30 +40,43 @@ namespace Tensorflow | |||||
| if (_handle != IntPtr.Zero) | if (_handle != IntPtr.Zero) | ||||
| { | { | ||||
| // dispose managed state (managed objects). | // dispose managed state (managed objects). | ||||
| DisposeManagedState(); | |||||
| DisposeManagedResources(); | |||||
| // set large fields to null. | // set large fields to null. | ||||
| DisposeUnManagedState(_handle); | |||||
| DisposeUnmanagedResources(_handle); | |||||
| _handle = IntPtr.Zero; | _handle = IntPtr.Zero; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Dispose any managed resources. | |||||
| /// </summary> | |||||
| /// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks> | |||||
| protected virtual void DisposeManagedResources() | |||||
| { | |||||
| } | |||||
| /// <summary> | |||||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
| /// </summary> | |||||
| protected abstract void DisposeUnmanagedResources(IntPtr handle); | |||||
| // override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. | // override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. | ||||
| ~DisposableObject() | ~DisposableObject() | ||||
| { | { | ||||
| // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | ||||
| Dispose(false); | |||||
| internal_dispose(false); | |||||
| } | } | ||||
| // This code added to correctly implement the disposable pattern. | // This code added to correctly implement the disposable pattern. | ||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | // Do not change this code. Put cleanup code in Dispose(bool disposing) above. | ||||
| Dispose(true); | |||||
| internal_dispose(true); | |||||
| // uncomment the following line if the finalizer is overridden above. | // uncomment the following line if the finalizer is overridden above. | ||||
| GC.SuppressFinalize(this); | GC.SuppressFinalize(this); | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -1,8 +1,9 @@ | |||||
| using System; | using System; | ||||
| using System.IO; | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| public class ContextOptions : IDisposable | |||||
| public class ContextOptions : IDisposable //TODO! Eli: Shouldn't this inherieting DisposableObject? | |||||
| { | { | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| @@ -23,57 +23,58 @@ using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /* | |||||
| A TensorFlow computation, represented as a dataflow graph. | |||||
| A `Graph` contains a set of | |||||
| `tf.Operation` objects, | |||||
| which represent units of computation; and | |||||
| `tf.Tensor` objects, which represent | |||||
| the units of data that flow between operations. | |||||
| A default `Graph` is always registered, and accessible by calling | |||||
| `tf.get_default_graph`. | |||||
| To add an operation to the default graph, simply call one of the functions | |||||
| that defines a new `Operation`: | |||||
| ```python | |||||
| c = tf.constant(4.0) | |||||
| assert c.graph is tf.get_default_graph() | |||||
| ``` | |||||
| Another typical usage involves the | |||||
| `tf.Graph.as_default` | |||||
| context manager, which overrides the current default graph for the | |||||
| lifetime of the context: | |||||
| ```python | |||||
| g = tf.Graph() | |||||
| with g.as_default(): | |||||
| # Define operations and tensors in `g`. | |||||
| c = tf.constant(30.0) | |||||
| assert c.graph is g | |||||
| ``` | |||||
| Important note: This class *is not* thread-safe for graph construction. All | |||||
| operations should be created from a single thread, or external | |||||
| synchronization must be provided. Unless otherwise specified, all methods | |||||
| are not thread-safe. | |||||
| A `Graph` instance supports an arbitrary number of "collections" | |||||
| that are identified by name. For convenience when building a large | |||||
| graph, collections can store groups of related objects: for | |||||
| example, the `tf.Variable` uses a collection (named | |||||
| `tf.GraphKeys.GLOBAL_VARIABLES`) for | |||||
| all variables that are created during the construction of a graph. The caller | |||||
| may define additional collections by specifying a new name. | |||||
| */ | |||||
| /// <summary> | /// <summary> | ||||
| /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. | |||||
| /// This leads to a low-level programming model in which you first define the dataflow graph, | |||||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||||
| /// https://www.tensorflow.org/guide/graphs | |||||
| /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. | |||||
| /// This leads to a low-level programming model in which you first define the dataflow graph, | |||||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||||
| /// </summary> | /// </summary> | ||||
| /* | |||||
| A TensorFlow computation, represented as a dataflow graph. | |||||
| A `Graph` contains a set of | |||||
| `tf.Operation` objects, | |||||
| which represent units of computation; and | |||||
| `tf.Tensor` objects, which represent | |||||
| the units of data that flow between operations. | |||||
| A default `Graph` is always registered, and accessible by calling | |||||
| `tf.get_default_graph`. | |||||
| To add an operation to the default graph, simply call one of the functions | |||||
| that defines a new `Operation`: | |||||
| ```python | |||||
| c = tf.constant(4.0) | |||||
| assert c.graph is tf.get_default_graph() | |||||
| ``` | |||||
| Another typical usage involves the | |||||
| `tf.Graph.as_default` | |||||
| context manager, which overrides the current default graph for the | |||||
| lifetime of the context: | |||||
| ```python | |||||
| g = tf.Graph() | |||||
| with g.as_default(): | |||||
| # Define operations and tensors in `g`. | |||||
| c = tf.constant(30.0) | |||||
| assert c.graph is g | |||||
| ``` | |||||
| Important note: This class *is not* thread-safe for graph construction. All | |||||
| operations should be created from a single thread, or external | |||||
| synchronization must be provided. Unless otherwise specified, all methods | |||||
| are not thread-safe. | |||||
| A `Graph` instance supports an arbitrary number of "collections" | |||||
| that are identified by name. For convenience when building a large | |||||
| graph, collections can store groups of related objects: for | |||||
| example, the `tf.Variable` uses a collection (named | |||||
| `tf.GraphKeys.GLOBAL_VARIABLES`) for | |||||
| all variables that are created during the construction of a graph. The caller | |||||
| may define additional collections by specifying a new name. | |||||
| */ | |||||
| /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | |||||
| public partial class Graph : DisposableObject, IEnumerable<Operation> | public partial class Graph : DisposableObject, IEnumerable<Operation> | ||||
| { | { | ||||
| private Dictionary<int, ITensorOrOperation> _nodes_by_id; | private Dictionary<int, ITensorOrOperation> _nodes_by_id; | ||||
| @@ -439,12 +440,12 @@ namespace Tensorflow | |||||
| _unfetchable_ops.Add(op); | _unfetchable_ops.Add(op); | ||||
| } | } | ||||
| protected override void DisposeManagedState() | |||||
| protected override void DisposeManagedResources() | |||||
| { | { | ||||
| ops.default_graph_stack.remove(this); | ops.default_graph_stack.remove(this); | ||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | { | ||||
| c_api.TF_DeleteGraph(handle); | c_api.TF_DeleteGraph(handle); | ||||
| } | } | ||||
| @@ -37,7 +37,7 @@ namespace Tensorflow | |||||
| c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | ||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TF_DeleteImportGraphDefOptions(handle); | => c_api.TF_DeleteImportGraphDefOptions(handle); | ||||
| public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | ||||
| @@ -16,6 +16,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | |||||
| namespace Tensorflow.IO | namespace Tensorflow.IO | ||||
| { | { | ||||
| @@ -28,6 +29,9 @@ namespace Tensorflow.IO | |||||
| /// <param name="in_order">Traverse in order if True, post order if False.</param> | /// <param name="in_order">Traverse in order if True, post order if False.</param> | ||||
| public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true) | public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true) | ||||
| { | { | ||||
| if (!Directory.Exists(top)) | |||||
| return Enumerable.Empty<(string, string[], string[])>(); | |||||
| return walk_v2(top, in_order); | return walk_v2(top, in_order); | ||||
| } | } | ||||
| @@ -141,7 +141,7 @@ namespace Tensorflow.Operations | |||||
| data, frame_name, is_constant, parallel_iterations, name: name); | data, frame_name, is_constant, parallel_iterations, name: name); | ||||
| if (use_input_shape) | if (use_input_shape) | ||||
| result.SetShape(data.TensorShape); | |||||
| result.set_shape(data.TensorShape); | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -233,7 +233,7 @@ namespace Tensorflow.Operations | |||||
| dims.AddRange(x_static_shape.dims.Skip(2)); | dims.AddRange(x_static_shape.dims.Skip(2)); | ||||
| var shape = new TensorShape(dims.ToArray()); | var shape = new TensorShape(dims.ToArray()); | ||||
| x_t.SetShape(shape); | |||||
| x_t.set_shape(shape); | |||||
| return x_t; | return x_t; | ||||
| } | } | ||||
| @@ -351,7 +351,7 @@ namespace Tensorflow | |||||
| var input_shape = tensor_util.to_shape(input_tensor.shape); | var input_shape = tensor_util.to_shape(input_tensor.shape); | ||||
| if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | ||||
| { | { | ||||
| var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype()); | |||||
| var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); | |||||
| return constant_op.constant(nd, name: name); | return constant_op.constant(nd, name: name); | ||||
| } | } | ||||
| } | } | ||||
| @@ -98,7 +98,7 @@ namespace Tensorflow | |||||
| // float to be selected, hence we use a >= comparison. | // float to be selected, hence we use a >= comparison. | ||||
| var keep_mask = random_tensor >= rate; | var keep_mask = random_tensor >= rate; | ||||
| var ret = x * scale * math_ops.cast(keep_mask, x.dtype); | var ret = x * scale * math_ops.cast(keep_mask, x.dtype); | ||||
| ret.SetShape(x.TensorShape); | |||||
| ret.set_shape(x.TensorShape); | |||||
| return ret; | return ret; | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -49,7 +49,7 @@ namespace Tensorflow | |||||
| // dispose newOpts | // dispose newOpts | ||||
| if (opts == null) | if (opts == null) | ||||
| c_api.TF_DeleteSessionOptions(newOpts); | |||||
| newOpts.Dispose(); | |||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| @@ -64,6 +64,11 @@ namespace Tensorflow | |||||
| return _run(fetche, feed_dict)[0]; | return _run(fetche, feed_dict)[0]; | ||||
| } | } | ||||
| public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||||
| { | |||||
| return _run(fetche, feed_dict)[0]; | |||||
| } | |||||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | ||||
| { | { | ||||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | ||||
| @@ -273,7 +278,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| var tensor = new Tensor(output); | var tensor = new Tensor(output); | ||||
| NDArray nd = null; | NDArray nd = null; | ||||
| Type type = tensor.dtype.as_numpy_datatype(); | |||||
| Type type = tensor.dtype.as_numpy_dtype(); | |||||
| var ndims = tensor.shape; | var ndims = tensor.shape; | ||||
| var offset = c_api.TF_TensorData(output); | var offset = c_api.TF_TensorData(output); | ||||
| @@ -285,7 +290,7 @@ namespace Tensorflow | |||||
| nd = NDArray.Scalar(*(bool*)offset); | nd = NDArray.Scalar(*(bool*)offset); | ||||
| break; | break; | ||||
| case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
| var bytes = tensor.Data(); | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | // wired, don't know why we have to start from offset 9. | ||||
| // length in the begin | // length in the begin | ||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | ||||
| @@ -324,7 +329,7 @@ namespace Tensorflow | |||||
| nd = np.array(bools).reshape(ndims); | nd = np.array(bools).reshape(ndims); | ||||
| break; | break; | ||||
| case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
| var bytes = tensor.Data(); | |||||
| var bytes = tensor.BufferToArray(); | |||||
| // wired, don't know why we have to start from offset 9. | // wired, don't know why we have to start from offset 9. | ||||
| // length in the begin | // length in the begin | ||||
| var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | ||||
| @@ -396,7 +401,7 @@ namespace Tensorflow | |||||
| Dispose(); | Dispose(); | ||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | { | ||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| { | { | ||||
| @@ -32,7 +32,7 @@ namespace Tensorflow | |||||
| _handle = handle; | _handle = handle; | ||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TF_DeleteSessionOptions(handle); | => c_api.TF_DeleteSessionOptions(handle); | ||||
| public void SetConfig(ConfigProto config) | public void SetConfig(ConfigProto config) | ||||
| @@ -17,6 +17,7 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using NumSharp.Backends; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -71,18 +72,18 @@ namespace Tensorflow | |||||
| { | { | ||||
| if(tensor_values.Length > 0) | if(tensor_values.Length > 0) | ||||
| { | { | ||||
| switch (tensor_values[0].dtype.Name) | |||||
| switch (tensor_values[0].typecode) | |||||
| { | { | ||||
| case "Int32": | |||||
| case NPTypeCode.Int32: | |||||
| full_values.Add(float.NaN); | full_values.Add(float.NaN); | ||||
| break; | break; | ||||
| case "Single": | |||||
| case NPTypeCode.Single: | |||||
| full_values.Add(float.NaN); | full_values.Add(float.NaN); | ||||
| break; | break; | ||||
| case "String": | |||||
| case NPTypeCode.String: | |||||
| full_values.Add(float.NaN); | full_values.Add(float.NaN); | ||||
| break; | break; | ||||
| case "Char": | |||||
| case NPTypeCode.Char: | |||||
| full_values.Add(float.NaN); | full_values.Add(float.NaN); | ||||
| break; | break; | ||||
| default: | default: | ||||
| @@ -100,21 +101,21 @@ namespace Tensorflow | |||||
| j += 1; | j += 1; | ||||
| if (value.ndim == 0) | if (value.ndim == 0) | ||||
| { | { | ||||
| switch (value.dtype.Name) | |||||
| switch (value.typecode) | |||||
| { | { | ||||
| case "Int16": | |||||
| case NPTypeCode.Int16: | |||||
| full_values.Add(value.GetValue<short>(0)); | full_values.Add(value.GetValue<short>(0)); | ||||
| break; | break; | ||||
| case "Int32": | |||||
| case NPTypeCode.Int32: | |||||
| full_values.Add(value.GetValue<int>(0)); | full_values.Add(value.GetValue<int>(0)); | ||||
| break; | break; | ||||
| case "Int64": | |||||
| case NPTypeCode.Int64: | |||||
| full_values.Add(value.GetValue<long>(0)); | full_values.Add(value.GetValue<long>(0)); | ||||
| break; | break; | ||||
| case "Single": | |||||
| case NPTypeCode.Single: | |||||
| full_values.Add(value.GetValue<float>(0)); | full_values.Add(value.GetValue<float>(0)); | ||||
| break; | break; | ||||
| case "Double": | |||||
| case NPTypeCode.Double: | |||||
| full_values.Add(value.GetValue<double>(0)); | full_values.Add(value.GetValue<double>(0)); | ||||
| break; | break; | ||||
| /*case "String": | /*case "String": | ||||
| @@ -50,7 +50,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public void Check(bool throwException = false) | public void Check(bool throwException = false) | ||||
| { | { | ||||
| if(Code != TF_Code.TF_OK) | |||||
| if (Code != TF_Code.TF_OK) | |||||
| { | { | ||||
| Console.WriteLine(Message); | Console.WriteLine(Message); | ||||
| if (throwException) | if (throwException) | ||||
| @@ -65,7 +65,7 @@ namespace Tensorflow | |||||
| return status._handle; | return status._handle; | ||||
| } | } | ||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TF_DeleteStatus(handle); | => c_api.TF_DeleteStatus(handle); | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -16,11 +16,13 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Numerics; | using System.Numerics; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using NumSharp.Backends; | |||||
| using NumSharp.Backends.Unmanaged; | using NumSharp.Backends.Unmanaged; | ||||
| using static Tensorflow.c_api; | using static Tensorflow.c_api; | ||||
| @@ -462,7 +464,7 @@ namespace Tensorflow | |||||
| *v = value; | *v = value; | ||||
| _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); | _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); | ||||
| IsMemoryOwner=true; | IsMemoryOwner=true; | ||||
| } | |||||
| } | |||||
| #endif | #endif | ||||
| /// <summary> | /// <summary> | ||||
| @@ -477,7 +479,7 @@ namespace Tensorflow | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | IntPtr tensor = c_api.TF_TensorData(handle); | ||||
| Marshal.WriteInt64(tensor, 0); | Marshal.WriteInt64(tensor, 0); | ||||
| fixed (byte* src = &buffer[0]) | |||||
| fixed (byte* src = buffer) | |||||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | ||||
| _handle = handle; | _handle = handle; | ||||
| status.Check(true); | status.Check(true); | ||||
| @@ -486,35 +488,55 @@ namespace Tensorflow | |||||
| public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | ||||
| { | { | ||||
| // todo: handle nd of type "String" here too | // todo: handle nd of type "String" here too | ||||
| if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") | |||||
| if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) | |||||
| { | { | ||||
| var buffer = nd.ToArray<byte>(); | |||||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | |||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||||
| Marshal.WriteInt64(tensor, 0); | |||||
| if (nd.Unsafe.Storage.Shape.IsContiguous) | |||||
| { | |||||
| var bytesLength = (UIntPtr)nd.size; | |||||
| var size = c_api.TF_StringEncodedSize(bytesLength); | |||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||||
| Marshal.WriteInt64(tensor, 0); | |||||
| var status = new Status(); | |||||
| c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||||
| status.Check(true); | |||||
| _handle = handle; | |||||
| IsMemoryOwner = false; | |||||
| } | |||||
| else | |||||
| { | |||||
| var buffer = nd.ToArray<byte>(); | |||||
| var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); | |||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); | |||||
| IntPtr tensor = c_api.TF_TensorData(handle); | |||||
| Marshal.WriteInt64(tensor, 0); | |||||
| var status = new Status(); | |||||
| fixed (byte* src = buffer) | |||||
| c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||||
| status.Check(true); | |||||
| _handle = handle; | |||||
| IsMemoryOwner = false; | |||||
| } | |||||
| var status = new Status(); | |||||
| fixed (byte* src = &buffer[0]) | |||||
| c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||||
| status.Check(true); | |||||
| _handle=handle; | |||||
| IsMemoryOwner = false; | |||||
| return; | return; | ||||
| } | } | ||||
| _handle = CreateTensorFromNDArray(nd, tensorDType); | _handle = CreateTensorFromNDArray(nd, tensorDType); | ||||
| IsMemoryOwner = true; | IsMemoryOwner = true; | ||||
| } | } | ||||
| private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) | private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) | ||||
| { | { | ||||
| if (nd.dtype.Name == "String") | |||||
| throw new NotImplementedException("Support for NDArray of type string not implemented yet"); | |||||
| if (nd.dtype.Name == "String") | |||||
| throw new NotImplementedException("Support for NDArray of type string not implemented yet"); | |||||
| IArraySlice arraySlice; | IArraySlice arraySlice; | ||||
| var shape = nd.Unsafe.Storage.Shape; | |||||
| if (shape.IsSliced || shape.IsBroadcasted) | |||||
| if (nd.Unsafe.Storage.Shape.IsContiguous == false) | |||||
| { | { | ||||
| // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. | // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. | ||||
| arraySlice = nd.CloneData(); | arraySlice = nd.CloneData(); | ||||
| @@ -527,51 +549,52 @@ namespace Tensorflow | |||||
| this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it | this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it | ||||
| var ptr = new IntPtr(arraySlice.Address); | var ptr = new IntPtr(arraySlice.Address); | ||||
| int num_bytes = (nd.size * nd.dtypesize); | int num_bytes = (nd.size * nd.dtypesize); | ||||
| var dtype = given_dtype ?? ToTFDataType(nd.dtype); | |||||
| var dtype = given_dtype ?? nd.dtype.as_dtype(); | |||||
| var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); | var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); | ||||
| IsMemoryOwner = false; | IsMemoryOwner = false; | ||||
| return handle; | return handle; | ||||
| } | |||||
| public unsafe Tensor(byte[][] buffer, long[] shape) | |||||
| { | |||||
| int size = 0; | |||||
| foreach (var b in buffer) | |||||
| { | |||||
| size += (int)TF_StringEncodedSize((UIntPtr)b.Length); | |||||
| } | |||||
| int totalSize = size + buffer.Length * 8; | |||||
| ulong offset = 0; | |||||
| IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); | |||||
| // Clear offset table | |||||
| IntPtr pOffset = TF_TensorData(handle); | |||||
| IntPtr dst = pOffset + buffer.Length * 8; | |||||
| IntPtr dstLimit = pOffset + totalSize; | |||||
| for (int i = 0; i < buffer.Length; i++) | |||||
| { | |||||
| Marshal.WriteInt64(pOffset, (long)offset); | |||||
| using (var status = new Status()) | |||||
| { | |||||
| fixed (byte* src = &buffer[i][0]) | |||||
| { | |||||
| var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); | |||||
| status.Check(true); | |||||
| pOffset += 8; | |||||
| dst += (int)written; | |||||
| offset += written; | |||||
| } | |||||
| } | |||||
| } | |||||
| _handle = handle; | |||||
| } | |||||
| public unsafe Tensor(byte[][] buffer, long[] shape) | |||||
| { | |||||
| int size = 0; | |||||
| foreach (var b in buffer) | |||||
| { | |||||
| size += (int)TF_StringEncodedSize((UIntPtr)b.Length); | |||||
| } | |||||
| int totalSize = size + buffer.Length * 8; | |||||
| ulong offset = 0; | |||||
| IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); | |||||
| // Clear offset table | |||||
| IntPtr pOffset = TF_TensorData(handle); | |||||
| IntPtr dst = pOffset + buffer.Length * 8; | |||||
| IntPtr dstLimit = pOffset + totalSize; | |||||
| for (int i = 0; i < buffer.Length; i++) | |||||
| { | |||||
| Marshal.WriteInt64(pOffset, (long)offset); | |||||
| using (var status = new Status()) | |||||
| { | |||||
| fixed (byte* src = &buffer[i][0]) | |||||
| { | |||||
| var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); | |||||
| status.Check(true); | |||||
| pOffset += 8; | |||||
| dst += (int)written; | |||||
| offset += written; | |||||
| } | |||||
| } | |||||
| } | |||||
| _handle = handle; | |||||
| } | } | ||||
| public Tensor(Operation op, int value_index, TF_DataType dtype) | public Tensor(Operation op, int value_index, TF_DataType dtype) | ||||
| { | { | ||||
| _op = op; | _op = op; | ||||
| _value_index = value_index; | _value_index = value_index; | ||||
| _dtype = dtype; | |||||
| _override_dtype = dtype; | |||||
| _id = ops.uid(); | _id = ops.uid(); | ||||
| } | } | ||||
| @@ -589,11 +612,11 @@ namespace Tensorflow | |||||
| /// specified dimensions. | /// specified dimensions. | ||||
| /// </remarks> | /// </remarks> | ||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | [MethodImpl(MethodImplOptions.AggressiveInlining)] | ||||
| [SuppressMessage("ReSharper", "LocalVariableHidesMember")] | |||||
| protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) | protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) | ||||
| { | { | ||||
| if (dt == TF_DataType.TF_STRING && data is byte[]) | |||||
| if (dt == TF_DataType.TF_STRING && data is byte[] buffer) | |||||
| { | { | ||||
| var buffer = (byte[])data; | |||||
| var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | ||||
| var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using System.Runtime.CompilerServices; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -6,86 +7,142 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static explicit operator bool(Tensor tensor) | public static explicit operator bool(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<bool>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_BOOL); | |||||
| return *(bool*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator sbyte(Tensor tensor) | public static explicit operator sbyte(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<sbyte>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT8); | |||||
| return *(sbyte*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator byte(Tensor tensor) | public static explicit operator byte(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<byte>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT8); | |||||
| return *(byte*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator ushort(Tensor tensor) | public static explicit operator ushort(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<ushort>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT16); | |||||
| return *(ushort*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator short(Tensor tensor) | public static explicit operator short(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<short>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT16); | |||||
| return *(short*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator int(Tensor tensor) | public static explicit operator int(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<int>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT32); | |||||
| return *(int*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator uint(Tensor tensor) | public static explicit operator uint(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<uint>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT32); | |||||
| return *(uint*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator long(Tensor tensor) | public static explicit operator long(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<long>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_INT64); | |||||
| return *(long*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator ulong(Tensor tensor) | public static explicit operator ulong(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<ulong>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_UINT64); | |||||
| return *(ulong*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator float(Tensor tensor) | public static explicit operator float(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<float>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||||
| return *(float*) tensor.buffer; | |||||
| } | |||||
| } | } | ||||
| public static explicit operator double(Tensor tensor) | public static explicit operator double(Tensor tensor) | ||||
| { | { | ||||
| EnsureScalar(tensor); | |||||
| return tensor.Data<double>()[0]; | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||||
| return *(double*) tensor.buffer; | |||||
| } | |||||
| } | |||||
| public static explicit operator string(Tensor tensor) | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| EnsureScalar(tensor); | |||||
| EnsureDType(tensor, TF_DataType.TF_STRING); | |||||
| return new string((char*) tensor.buffer, 0, (int) tensor.size); | |||||
| } | |||||
| } | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| private static void EnsureDType(Tensor tensor, TF_DataType @is) | |||||
| { | |||||
| if (tensor.dtype != @is) | |||||
| throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}"); | |||||
| } | } | ||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| private static void EnsureScalar(Tensor tensor) | private static void EnsureScalar(Tensor tensor) | ||||
| { | { | ||||
| if (tensor == null) | if (tensor == null) | ||||
| { | |||||
| throw new ArgumentNullException(nameof(tensor)); | throw new ArgumentNullException(nameof(tensor)); | ||||
| } | |||||
| if (tensor.TensorShape.ndim != 0) | if (tensor.TensorShape.ndim != 0) | ||||
| { | |||||
| throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); | throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); | ||||
| } | |||||
| if (tensor.TensorShape.size != 1) | if (tensor.TensorShape.size != 1) | ||||
| { | |||||
| throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); | throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); | ||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -69,11 +69,12 @@ namespace Tensorflow | |||||
| TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, | TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, | ||||
| TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 | TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 | ||||
| }; | }; | ||||
| public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); | public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); | ||||
| public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); | public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); | ||||
| public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); | public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); | ||||
| public static Tensor operator /(Tensor x, Tensor y) => | public static Tensor operator /(Tensor x, Tensor y) => | ||||
| _intTfDataTypes.Contains(x._dtype) | |||||
| _intTfDataTypes.Contains(x.dtype) | |||||
| ? BinaryOpWrapper("floordiv", x, y) | ? BinaryOpWrapper("floordiv", x, y) | ||||
| : BinaryOpWrapper("truediv", x, y); | : BinaryOpWrapper("truediv", x, y); | ||||
| public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); | public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); | ||||
| @@ -122,8 +123,7 @@ namespace Tensorflow | |||||
| if (y is Tensor tr) | if (y is Tensor tr) | ||||
| dtype = tr.dtype.as_base_dtype(); | dtype = tr.dtype.as_base_dtype(); | ||||
| var namescope = ops.name_scope(null, name, new { x, y }); | |||||
| return tf_with(namescope, scope => | |||||
| return tf_with(ops.name_scope(null, name, new { x, y }), scope => | |||||
| { | { | ||||
| Tensor result = null; | Tensor result = null; | ||||
| var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | ||||
| @@ -155,7 +155,6 @@ namespace Tensorflow | |||||
| return result; | return result; | ||||
| }); | }); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -17,9 +17,16 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Globalization; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using System.Threading.Tasks; | |||||
| using NumSharp.Backends; | |||||
| using NumSharp.Backends.Unmanaged; | |||||
| using NumSharp.Utilities; | |||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -29,42 +36,68 @@ namespace Tensorflow | |||||
| /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | ||||
| /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | ||||
| /// </summary> | /// </summary> | ||||
| [SuppressMessage("ReSharper", "ConvertToAutoProperty")] | |||||
| public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike | public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike | ||||
| { | { | ||||
| private int _id; | |||||
| private Operation _op; | |||||
| private readonly int _id; | |||||
| private readonly Operation _op; | |||||
| private readonly int _value_index; | |||||
| private TF_Output? _tf_output; | |||||
| private readonly TF_DataType _override_dtype; | |||||
| public int Id => _id; | public int Id => _id; | ||||
| /// <summary> | |||||
| /// The Graph that contains this tensor. | |||||
| /// </summary> | |||||
| public Graph graph => op?.graph; | public Graph graph => op?.graph; | ||||
| /// <summary> | |||||
| /// The Operation that produces this tensor as an output. | |||||
| /// </summary> | |||||
| public Operation op => _op; | public Operation op => _op; | ||||
| public Tensor[] outputs => op.outputs; | public Tensor[] outputs => op.outputs; | ||||
| /// <summary> | /// <summary> | ||||
| /// The string name of this tensor. | |||||
| /// The string name of this tensor. | |||||
| /// </summary> | /// </summary> | ||||
| public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}"; | public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}"; | ||||
| private int _value_index; | |||||
| /// <summary> | |||||
| /// The index of this tensor in the outputs of its Operation. | |||||
| /// </summary> | |||||
| public int value_index => _value_index; | public int value_index => _value_index; | ||||
| private TF_DataType _dtype = TF_DataType.DtInvalid; | |||||
| public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | |||||
| /// <summary> | |||||
| /// The DType of elements in this tensor. | |||||
| /// </summary> | |||||
| public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); | |||||
| public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | ||||
| public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | ||||
| public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | ||||
| public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | ||||
| public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | ||||
| public int NDims => rank; | |||||
| private TF_Output? _tf_output; | |||||
| /// <summary> | |||||
| /// The name of the device on which this tensor will be produced, or null. | |||||
| /// </summary> | |||||
| public string Device => op.Device; | |||||
| public int[] dims => shape; | |||||
| /// <summary> | /// <summary> | ||||
| /// used for keep other pointer when do implicit operating | |||||
| /// Used for keep other pointer when do implicit operating | |||||
| /// </summary> | /// </summary> | ||||
| public object Tag { get; set; } | public object Tag { get; set; } | ||||
| /// <summary> | |||||
| /// Returns the shape of a tensor. | |||||
| /// </summary> | |||||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/shape</remarks> | |||||
| public int[] shape | public int[] shape | ||||
| { | { | ||||
| get | get | ||||
| @@ -76,14 +109,13 @@ namespace Tensorflow | |||||
| var status = new Status(); | var status = new Status(); | ||||
| c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | ||||
| status.Check(); | status.Check(); | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| for (int i = 0; i < rank; i++) | for (int i = 0; i < rank; i++) | ||||
| dims[i] = c_api.TF_Dim(_handle, i); | dims[i] = c_api.TF_Dim(_handle, i); | ||||
| } | } | ||||
| return dims.Select(x => Convert.ToInt32(x)).ToArray(); | |||||
| return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray(); | |||||
| } | } | ||||
| set | set | ||||
| @@ -93,38 +125,52 @@ namespace Tensorflow | |||||
| if (value == null) | if (value == null) | ||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | ||||
| else | else | ||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status); | |||||
| c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
| } | } | ||||
| } | } | ||||
| public int[] _shape_tuple() | public int[] _shape_tuple() | ||||
| { | { | ||||
| if (shape == null) return null; | |||||
| return shape.Select(x => (int)x).ToArray(); | |||||
| return (int[]) shape.Clone(); | |||||
| } | } | ||||
| public TensorShape TensorShape => tensor_util.to_shape(shape); | public TensorShape TensorShape => tensor_util.to_shape(shape); | ||||
| public void SetShape(TensorShape shape) | |||||
| /// <summary> | |||||
| /// Updates the shape of this tensor. | |||||
| /// </summary> | |||||
| public void set_shape(TensorShape shape) | |||||
| { | { | ||||
| this.shape = shape.dims; | |||||
| this.shape = (int[]) shape.dims.Clone(); | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Updates the shape of this tensor. | |||||
| /// </summary> | |||||
| [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||||
| public void SetShape(TensorShape shape) | |||||
| { | |||||
| this.shape = (int[]) shape.dims.Clone(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Updates the shape of this tensor. | |||||
| /// </summary> | |||||
| public void set_shape(Tensor shape) | public void set_shape(Tensor shape) | ||||
| { | { | ||||
| // ReSharper disable once MergeConditionalExpression | |||||
| this.shape = shape is null ? null : shape.shape; | this.shape = shape is null ? null : shape.shape; | ||||
| } | } | ||||
| public int[] dims => shape; | |||||
| /// <summary> | /// <summary> | ||||
| /// number of dimensions | |||||
| /// 0 Scalar (magnitude only) | |||||
| /// 1 Vector (magnitude and direction) | |||||
| /// 2 Matrix (table of numbers) | |||||
| /// 3 3-Tensor (cube of numbers) | |||||
| /// number of dimensions <br></br> | |||||
| /// 0 Scalar (magnitude only) <br></br> | |||||
| /// 1 Vector (magnitude and direction) <br></br> | |||||
| /// 2 Matrix (table of numbers) <br></br> | |||||
| /// 3 3-Tensor (cube of numbers) <br></br> | |||||
| /// n n-Tensor (you get the idea) | /// n n-Tensor (you get the idea) | ||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/rank</remarks> | |||||
| public int rank | public int rank | ||||
| { | { | ||||
| get | get | ||||
| @@ -137,17 +183,15 @@ namespace Tensorflow | |||||
| status.Check(); | status.Check(); | ||||
| return ndim; | return ndim; | ||||
| } | } | ||||
| else | |||||
| { | |||||
| return c_api.TF_NumDims(_handle); | |||||
| } | |||||
| return c_api.TF_NumDims(_handle); | |||||
| } | } | ||||
| } | } | ||||
| public int NDims => rank; | |||||
| public string Device => op.Device; | |||||
| /// <summary> | |||||
| /// Returns a list of Operations that consume this tensor. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public Operation[] consumers() | public Operation[] consumers() | ||||
| { | { | ||||
| var output = _as_tf_output(); | var output = _as_tf_output(); | ||||
| @@ -157,37 +201,191 @@ namespace Tensorflow | |||||
| public TF_Output _as_tf_output() | public TF_Output _as_tf_output() | ||||
| { | { | ||||
| if(!_tf_output.HasValue) | |||||
| if (!_tf_output.HasValue) | |||||
| _tf_output = new TF_Output(op, value_index); | _tf_output = new TF_Output(op, value_index); | ||||
| return _tf_output.Value; | return _tf_output.Value; | ||||
| } | } | ||||
| public T[] Data<T>() | |||||
| [Obsolete("Please use ToArray<T>() instead.", false)] | |||||
| public T[] Data<T>() where T : unmanaged | |||||
| { | |||||
| return ToArray<T>(); | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception> | |||||
| public T[] ToArray<T>() where T : unmanaged | |||||
| { | { | ||||
| // Column major order | |||||
| // https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg | |||||
| // matrix:[[1, 2, 3], [4, 5, 6]] | |||||
| // index: 0 2 4 1 3 5 | |||||
| // result: 1 4 2 5 3 6 | |||||
| var data = new T[size]; | |||||
| for (ulong i = 0; i < size; i++) | |||||
| //when T is string | |||||
| if (typeof(T) == typeof(string)) | |||||
| { | { | ||||
| data[i] = Marshal.PtrToStructure<T>(buffer + (int)(i * itemsize)); | |||||
| if (dtype != TF_DataType.TF_STRING) | |||||
| throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string."); | |||||
| return (T[]) (object) StringData(); | |||||
| } | } | ||||
| return data; | |||||
| //Are the types matching? | |||||
| if (typeof(T).as_dtype() == dtype) | |||||
| { | |||||
| if (NDims == 0 && size == 1) //is it a scalar? | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| return new T[] {*(T*) buffer}; | |||||
| } | |||||
| } | |||||
| //types match, no need to perform cast | |||||
| var ret = new T[size]; | |||||
| unsafe | |||||
| { | |||||
| var len = (long) size; | |||||
| fixed (T* dstRet = ret) | |||||
| { | |||||
| T* dst = dstRet; //local stack copy | |||||
| if (typeof(T).IsPrimitive) | |||||
| { | |||||
| var src = (T*) buffer; | |||||
| len *= ((long) itemsize); | |||||
| System.Buffer.MemoryCopy(src, dst, len, len); | |||||
| } else | |||||
| { | |||||
| var itemsize = (long) this.itemsize; | |||||
| var buffer = this.buffer.ToInt64(); | |||||
| Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure<T>(new IntPtr(buffer + i * itemsize))); | |||||
| } | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } else | |||||
| { | |||||
| //types do not match, need to perform cast | |||||
| if (NDims == 0 && size == 1) //is it a scalar? | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| #if _REGEN | |||||
| #region Compute | |||||
| switch (dtype.as_numpy_dtype().GetTypeCode()) | |||||
| { | |||||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||||
| case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer, NPTypeCode.#1)}; | |||||
| % | |||||
| case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #else | |||||
| #region Compute | |||||
| switch (dtype.as_numpy_dtype()?.GetTypeCode()) | |||||
| { | |||||
| case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer, NPTypeCode.Boolean)}; | |||||
| case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer, NPTypeCode.Byte)}; | |||||
| case NPTypeCode.Int16: return new T[] {Converts.ChangeType<T>(*(short*) buffer, NPTypeCode.Int16)}; | |||||
| case NPTypeCode.UInt16: return new T[] {Converts.ChangeType<T>(*(ushort*) buffer, NPTypeCode.UInt16)}; | |||||
| case NPTypeCode.Int32: return new T[] {Converts.ChangeType<T>(*(int*) buffer, NPTypeCode.Int32)}; | |||||
| case NPTypeCode.UInt32: return new T[] {Converts.ChangeType<T>(*(uint*) buffer, NPTypeCode.UInt32)}; | |||||
| case NPTypeCode.Int64: return new T[] {Converts.ChangeType<T>(*(long*) buffer, NPTypeCode.Int64)}; | |||||
| case NPTypeCode.UInt64: return new T[] {Converts.ChangeType<T>(*(ulong*) buffer, NPTypeCode.UInt64)}; | |||||
| case NPTypeCode.Char: return new T[] {Converts.ChangeType<T>(*(char*) buffer, NPTypeCode.Char)}; | |||||
| case NPTypeCode.Double: return new T[] {Converts.ChangeType<T>(*(double*) buffer, NPTypeCode.Double)}; | |||||
| case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer, NPTypeCode.Single)}; | |||||
| case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | |||||
| } | |||||
| } | |||||
| var ret = new T[size]; | |||||
| unsafe | |||||
| { | |||||
| var len = (long) size; | |||||
| fixed (T* dstRet = ret) | |||||
| { | |||||
| T* dst = dstRet; //local stack copy | |||||
| #if _REGEN | |||||
| #region Compute | |||||
| switch (dtype.as_numpy_dtype().GetTypeCode()) | |||||
| { | |||||
| %foreach supported_dtypes,supported_dtypes_lowercase% | |||||
| case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| % | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #else | |||||
| #region Compute | |||||
| switch (dtype.as_numpy_dtype().GetTypeCode()) | |||||
| { | |||||
| case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int16: new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt16: new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int32: new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt32: new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Int64: new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.UInt64: new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Char: new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Double: new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.Single: new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||||
| case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To<T> | |||||
| default: | |||||
| throw new NotSupportedException(); | |||||
| } | |||||
| #endregion | |||||
| #endif | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Copies the memory of current buffer onto newly allocated array. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||||
| public byte[] Data() | public byte[] Data() | ||||
| { | { | ||||
| var data = new byte[bytesize]; | |||||
| Marshal.Copy(buffer, data, 0, (int)bytesize); | |||||
| return data; | |||||
| return BufferToArray(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Copies the memory of current buffer onto newly allocated array. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public byte[] BufferToArray() | |||||
| { | |||||
| unsafe | |||||
| { | |||||
| // ReSharper disable once LocalVariableHidesMember | |||||
| var bytesize = (long) this.bytesize; | |||||
| var data = new byte[bytesize]; | |||||
| fixed (byte* dst = data) | |||||
| System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize); | |||||
| return data; | |||||
| } | |||||
| } | } | ||||
| public unsafe string[] StringData() | |||||
| /// Used internally in ToArray<T> | |||||
| private unsafe string[] StringData() | |||||
| { | { | ||||
| // | // | ||||
| // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | ||||
| @@ -199,19 +397,19 @@ namespace Tensorflow | |||||
| var buffer = new byte[size][]; | var buffer = new byte[size][]; | ||||
| var src = c_api.TF_TensorData(_handle); | var src = c_api.TF_TensorData(_handle); | ||||
| var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | |||||
| src += (int)(size * 8); | |||||
| var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize); | |||||
| src += (int) (size * 8); | |||||
| for (int i = 0; i < buffer.Length; i++) | for (int i = 0; i < buffer.Length; i++) | ||||
| { | { | ||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| { | { | ||||
| IntPtr dst = IntPtr.Zero; | IntPtr dst = IntPtr.Zero; | ||||
| UIntPtr dstLen = UIntPtr.Zero; | UIntPtr dstLen = UIntPtr.Zero; | ||||
| var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); | |||||
| var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status); | |||||
| status.Check(true); | status.Check(true); | ||||
| buffer[i] = new byte[(int)dstLen]; | |||||
| buffer[i] = new byte[(int) dstLen]; | |||||
| Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | ||||
| src += (int)read; | |||||
| src += (int) read; | |||||
| } | } | ||||
| } | } | ||||
| @@ -229,51 +427,29 @@ namespace Tensorflow | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Evaluates this tensor in a `Session`. | |||||
| /// Evaluates this tensor in a `Session`. | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | ||||
| /// <param name="session">The `Session` to be used to evaluate this tensor.</param> | |||||
| /// <returns></returns> | |||||
| /// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | |||||
| public NDArray eval(params FeedItem[] feed_dict) | public NDArray eval(params FeedItem[] feed_dict) | ||||
| { | { | ||||
| return ops._eval_using_default_session(this, feed_dict, graph); | return ops._eval_using_default_session(this, feed_dict, graph); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Evaluates this tensor in a `Session`. | |||||
| /// </summary> | |||||
| /// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | |||||
| /// <param name="session">The `Session` to be used to evaluate this tensor.</param> | |||||
| /// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | |||||
| public NDArray eval(Session session, FeedItem[] feed_dict = null) | public NDArray eval(Session session, FeedItem[] feed_dict = null) | ||||
| { | { | ||||
| return ops._eval_using_default_session(this, feed_dict, graph, session); | return ops._eval_using_default_session(this, feed_dict, graph, session); | ||||
| } | } | ||||
| public TF_DataType ToTFDataType(Type type) | |||||
| { | |||||
| switch (type.Name) | |||||
| { | |||||
| case "Char": | |||||
| return TF_DataType.TF_UINT8; | |||||
| case "Int16": | |||||
| return TF_DataType.TF_INT16; | |||||
| case "Int32": | |||||
| return TF_DataType.TF_INT32; | |||||
| case "Int64": | |||||
| return TF_DataType.TF_INT64; | |||||
| case "Single": | |||||
| return TF_DataType.TF_FLOAT; | |||||
| case "Double": | |||||
| return TF_DataType.TF_DOUBLE; | |||||
| case "Byte": | |||||
| return TF_DataType.TF_UINT8; | |||||
| case "String": | |||||
| return TF_DataType.TF_STRING; | |||||
| case "Boolean": | |||||
| return TF_DataType.TF_BOOL; | |||||
| default: | |||||
| throw new NotImplementedException("ToTFDataType error"); | |||||
| } | |||||
| } | |||||
| public Tensor slice(Slice slice) | public Tensor slice(Slice slice) | ||||
| { | { | ||||
| var slice_spec = new int[] { slice.Start.Value }; | |||||
| var slice_spec = new int[] {slice.Start.Value}; | |||||
| var begin = new List<int>(); | var begin = new List<int>(); | ||||
| var end = new List<int>(); | var end = new List<int>(); | ||||
| var strides = new List<int>(); | var strides = new List<int>(); | ||||
| @@ -289,26 +465,26 @@ namespace Tensorflow | |||||
| if (slice.Stop.HasValue) | if (slice.Stop.HasValue) | ||||
| { | { | ||||
| end.Add(slice.Stop.Value); | end.Add(slice.Stop.Value); | ||||
| } | |||||
| else | |||||
| } else | |||||
| { | { | ||||
| end.Add(0); | end.Add(0); | ||||
| end_mask |= (1 << index); | end_mask |= (1 << index); | ||||
| } | } | ||||
| strides.Add(slice.Step); | strides.Add(slice.Step); | ||||
| index += 1; | index += 1; | ||||
| } | } | ||||
| return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||||
| return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => | |||||
| { | { | ||||
| string name = scope; | string name = scope; | ||||
| if (begin != null) | if (begin != null) | ||||
| { | { | ||||
| var (packed_begin, packed_end, packed_strides) = | var (packed_begin, packed_end, packed_strides) = | ||||
| (array_ops.stack(begin.ToArray()), | (array_ops.stack(begin.ToArray()), | ||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| return gen_array_ops.strided_slice( | return gen_array_ops.strided_slice( | ||||
| this, | this, | ||||
| @@ -320,7 +496,6 @@ namespace Tensorflow | |||||
| shrink_axis_mask: shrink_axis_mask, | shrink_axis_mask: shrink_axis_mask, | ||||
| new_axis_mask: new_axis_mask, | new_axis_mask: new_axis_mask, | ||||
| ellipsis_mask: ellipsis_mask, | ellipsis_mask: ellipsis_mask, | ||||
| name: name); | name: name); | ||||
| } | } | ||||
| @@ -330,7 +505,7 @@ namespace Tensorflow | |||||
| public Tensor slice(int start) | public Tensor slice(int start) | ||||
| { | { | ||||
| var slice_spec = new int[] { start }; | |||||
| var slice_spec = new int[] {start}; | |||||
| var begin = new List<int>(); | var begin = new List<int>(); | ||||
| var end = new List<int>(); | var end = new List<int>(); | ||||
| var strides = new List<int>(); | var strides = new List<int>(); | ||||
| @@ -349,15 +524,15 @@ namespace Tensorflow | |||||
| index += 1; | index += 1; | ||||
| } | } | ||||
| return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||||
| return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => | |||||
| { | { | ||||
| string name = scope; | string name = scope; | ||||
| if (begin != null) | if (begin != null) | ||||
| { | { | ||||
| var (packed_begin, packed_end, packed_strides) = | var (packed_begin, packed_end, packed_strides) = | ||||
| (array_ops.stack(begin.ToArray()), | (array_ops.stack(begin.ToArray()), | ||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| array_ops.stack(end.ToArray()), | |||||
| array_ops.stack(strides.ToArray())); | |||||
| return gen_array_ops.strided_slice( | return gen_array_ops.strided_slice( | ||||
| this, | this, | ||||
| @@ -369,7 +544,6 @@ namespace Tensorflow | |||||
| shrink_axis_mask: shrink_axis_mask, | shrink_axis_mask: shrink_axis_mask, | ||||
| new_axis_mask: new_axis_mask, | new_axis_mask: new_axis_mask, | ||||
| ellipsis_mask: ellipsis_mask, | ellipsis_mask: ellipsis_mask, | ||||
| name: name); | name: name); | ||||
| } | } | ||||
| @@ -392,15 +566,12 @@ namespace Tensorflow | |||||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | ||||
| } | } | ||||
| protected override void DisposeManagedState() | |||||
| { | |||||
| } | |||||
| protected override void DisposeUnManagedState(IntPtr handle) | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | { | ||||
| if(handle != IntPtr.Zero) | |||||
| if (handle != IntPtr.Zero) | |||||
| { | { | ||||
| c_api.TF_DeleteTensor(handle); | c_api.TF_DeleteTensor(handle); | ||||
| _handle = IntPtr.Zero; | |||||
| } | } | ||||
| } | } | ||||
| @@ -417,4 +588,4 @@ namespace Tensorflow | |||||
| public int tensor_int_val { get; set; } | public int tensor_int_val { get; set; } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -1,35 +1,84 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Diagnostics.CodeAnalysis; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.CompilerServices; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Represents the shape of a `Tensor`. | |||||
| /// Represents the shape of a `Tensor`. | |||||
| /// </summary> | /// </summary> | ||||
| /// <remarks>https://www.tensorflow.org/api_docs/python/tf/TensorShape</remarks> | |||||
| public class TensorShape | public class TensorShape | ||||
| { | { | ||||
| private Shape shape; | |||||
| private readonly Shape shape; | |||||
| /// <summary> | |||||
| /// Returns a list of Dimensions, or None if the shape is unspecified. | |||||
| /// </summary> | |||||
| public int[] dims => shape.Dimensions; | public int[] dims => shape.Dimensions; | ||||
| /// <summary> | |||||
| /// Returns the rank of this shape. | |||||
| /// </summary> | |||||
| public int ndim => shape.NDim; | public int ndim => shape.NDim; | ||||
| /// <summary> | |||||
| /// Returns the rank of this shape. | |||||
| /// </summary> | |||||
| public int rank => shape.NDim; | |||||
| /// <summary> | |||||
| /// Returns the size this shape represents. | |||||
| /// </summary> | |||||
| public int size => shape.Size; | public int size => shape.Size; | ||||
| public TensorShape(TensorShapeProto proto) | public TensorShape(TensorShapeProto proto) | ||||
| { | { | ||||
| if (proto.UnknownRank) return; | if (proto.UnknownRank) return; | ||||
| switch (proto.Dim.Count) | |||||
| { | |||||
| case 0: shape = new Shape(new int[0]); break; | |||||
| case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break; | |||||
| case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break; | |||||
| default: | |||||
| var protodims = proto.Dim; | |||||
| var len = protodims.Count; | |||||
| var dims = new int[len]; | |||||
| for (int i = 0; i < len; i++) | |||||
| dims[i] = (int) protodims[i].Size; | |||||
| shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray()); | |||||
| shape = new Shape(dims); break; | |||||
| } | |||||
| } | } | ||||
| public TensorShape(params int[] dims) | public TensorShape(params int[] dims) | ||||
| { | { | ||||
| shape = new Shape(dims); | |||||
| switch (dims.Length) | |||||
| { | |||||
| case 0: shape = new Shape(new int[0]); break; | |||||
| case 1: shape = Shape.Vector((int) dims[0]); break; | |||||
| case 2: shape = Shape.Matrix(dims[0], dims[1]); break; | |||||
| default: shape = new Shape(dims); break; | |||||
| } | |||||
| } | } | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="slice"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <see cref="Slice"/> is not an Index.</exception> | |||||
| [SuppressMessage("ReSharper", "PossibleInvalidOperationException")] | |||||
| public TensorShape this[Slice slice] | public TensorShape this[Slice slice] | ||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| if (slice.Start.HasValue == false || slice.Length.HasValue == false) | |||||
| throw new ArgumentException("Slice must has Start and Length."); | |||||
| return new TensorShape(dims.Skip(slice.Start.Value) | return new TensorShape(dims.Skip(slice.Start.Value) | ||||
| .Take(slice.Length.Value) | .Take(slice.Length.Value) | ||||
| .ToArray()); | .ToArray()); | ||||
| @@ -37,7 +86,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Returns True iff `self` is fully defined in every dimension. | |||||
| /// Returns True iff `self` is fully defined in every dimension. | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public bool is_fully_defined() | public bool is_fully_defined() | ||||
| @@ -50,6 +99,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("TensorShape is_compatible_with"); | throw new NotImplementedException("TensorShape is_compatible_with"); | ||||
| } | } | ||||
| [SuppressMessage("ReSharper", "ParameterHidesMember")] | |||||
| public TensorShape with_rank_at_least(int rank) | public TensorShape with_rank_at_least(int rank) | ||||
| { | { | ||||
| if (rank != ndim) | if (rank != ndim) | ||||
| @@ -59,35 +109,63 @@ namespace Tensorflow | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="other"></param> | /// <param name="other"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public TensorShape concatenate(int[] other_) | |||||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | |||||
| public TensorShape concatenate(int[] other) | |||||
| { | { | ||||
| var other = new TensorShape(other_); | |||||
| return concatenate(new TensorShape(other)); | |||||
| } | |||||
| if (ndim < 0 || other.ndim < 0) | |||||
| /// <summary> | |||||
| /// Returns the concatenation of the dimension in `self` and `other`. | |||||
| /// </summary> | |||||
| /// <param name="other"></param> | |||||
| /// <returns></returns> | |||||
| public TensorShape concatenate(TensorShape other) | |||||
| { | |||||
| var otherShape = other; | |||||
| if (ndim < 0 || otherShape.ndim < 0) | |||||
| return new TensorShape(); | return new TensorShape(); | ||||
| else | else | ||||
| { | { | ||||
| var concatenate_dims = new int[ndim + other.ndim]; | |||||
| var concatenate_dims = new int[ndim + otherShape.ndim]; | |||||
| for (int i = 0; i < ndim; i++) | for (int i = 0; i < ndim; i++) | ||||
| concatenate_dims[i] = dims[i]; | concatenate_dims[i] = dims[i]; | ||||
| for (int i = 0; i < other.ndim; i++) | |||||
| concatenate_dims[ndim + i] = other.dims[i]; | |||||
| for (int i = 0; i < otherShape.ndim; i++) | |||||
| concatenate_dims[ndim + i] = otherShape.dims[i]; | |||||
| return new TensorShape(concatenate_dims); | return new TensorShape(concatenate_dims); | ||||
| } | } | ||||
| } | } | ||||
| public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions); | |||||
| public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims); | |||||
| public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); | |||||
| public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); | |||||
| public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes | |||||
| public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | ||||
| public static implicit operator int[](TensorShape shape) => shape.dims; | |||||
| public static explicit operator int(TensorShape shape) => shape.size; | |||||
| public static explicit operator TensorShape(int dim) => new TensorShape(dim); | |||||
| public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | |||||
| public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | ||||
| public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | ||||
| public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | ||||
| public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); | |||||
| public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); | |||||
| public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,6 +15,8 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Numerics; | |||||
| using NumSharp.Backends; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -23,35 +25,100 @@ namespace Tensorflow | |||||
| public static TF_DataType int8 = TF_DataType.TF_INT8; | public static TF_DataType int8 = TF_DataType.TF_INT8; | ||||
| public static TF_DataType int32 = TF_DataType.TF_INT32; | public static TF_DataType int32 = TF_DataType.TF_INT32; | ||||
| public static TF_DataType int64 = TF_DataType.TF_INT64; | public static TF_DataType int64 = TF_DataType.TF_INT64; | ||||
| public static TF_DataType uint8 = TF_DataType.TF_UINT8; | |||||
| public static TF_DataType uint32 = TF_DataType.TF_UINT32; | |||||
| public static TF_DataType uint64 = TF_DataType.TF_UINT64; | |||||
| public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | ||||
| public static TF_DataType float16 = TF_DataType.TF_HALF; | public static TF_DataType float16 = TF_DataType.TF_HALF; | ||||
| public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | ||||
| public static Type as_numpy_datatype(this TF_DataType type) | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| /// <returns><see cref="System.Type"/> equivalent to <paramref name="type"/>, if none exists, returns null.</returns> | |||||
| public static Type as_numpy_dtype(this TF_DataType type) | |||||
| { | { | ||||
| switch (type) | switch (type) | ||||
| { | { | ||||
| case TF_DataType.TF_BOOL: | case TF_DataType.TF_BOOL: | ||||
| return typeof(bool); | return typeof(bool); | ||||
| case TF_DataType.TF_UINT8: | |||||
| return typeof(byte); | |||||
| case TF_DataType.TF_INT64: | case TF_DataType.TF_INT64: | ||||
| return typeof(long); | return typeof(long); | ||||
| case TF_DataType.TF_UINT64: | |||||
| return typeof(ulong); | |||||
| case TF_DataType.TF_INT32: | case TF_DataType.TF_INT32: | ||||
| return typeof(int); | return typeof(int); | ||||
| case TF_DataType.TF_UINT32: | |||||
| return typeof(uint); | |||||
| case TF_DataType.TF_INT16: | case TF_DataType.TF_INT16: | ||||
| return typeof(short); | return typeof(short); | ||||
| case TF_DataType.TF_UINT16: | |||||
| return typeof(ushort); | |||||
| case TF_DataType.TF_FLOAT: | case TF_DataType.TF_FLOAT: | ||||
| return typeof(float); | return typeof(float); | ||||
| case TF_DataType.TF_DOUBLE: | case TF_DataType.TF_DOUBLE: | ||||
| return typeof(double); | return typeof(double); | ||||
| case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
| return typeof(string); | return typeof(string); | ||||
| case TF_DataType.TF_COMPLEX128: | |||||
| case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||||
| return typeof(Complex); | |||||
| default: | default: | ||||
| return null; | return null; | ||||
| } | } | ||||
| } | } | ||||
| // "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex" | |||||
| public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null) | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="NPTypeCode"/></exception> | |||||
| public static NPTypeCode as_numpy_typecode(this TF_DataType type) | |||||
| { | |||||
| switch (type) | |||||
| { | |||||
| case TF_DataType.TF_BOOL: | |||||
| return NPTypeCode.Boolean; | |||||
| case TF_DataType.TF_UINT8: | |||||
| return NPTypeCode.Byte; | |||||
| case TF_DataType.TF_INT64: | |||||
| return NPTypeCode.Int64; | |||||
| case TF_DataType.TF_INT32: | |||||
| return NPTypeCode.Int32; | |||||
| case TF_DataType.TF_INT16: | |||||
| return NPTypeCode.Int16; | |||||
| case TF_DataType.TF_UINT64: | |||||
| return NPTypeCode.UInt64; | |||||
| case TF_DataType.TF_UINT32: | |||||
| return NPTypeCode.UInt32; | |||||
| case TF_DataType.TF_UINT16: | |||||
| return NPTypeCode.UInt16; | |||||
| case TF_DataType.TF_FLOAT: | |||||
| return NPTypeCode.Single; | |||||
| case TF_DataType.TF_DOUBLE: | |||||
| return NPTypeCode.Double; | |||||
| case TF_DataType.TF_STRING: | |||||
| return NPTypeCode.String; | |||||
| case TF_DataType.TF_COMPLEX128: | |||||
| case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||||
| return NPTypeCode.Complex; | |||||
| default: | |||||
| throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="type"></param> | |||||
| /// <param name="dtype"></param> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> | |||||
| public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null) | |||||
| { | { | ||||
| switch (type.Name) | switch (type.Name) | ||||
| { | { | ||||
| @@ -98,7 +165,7 @@ namespace Tensorflow | |||||
| dtype = TF_DataType.TF_BOOL; | dtype = TF_DataType.TF_BOOL; | ||||
| break; | break; | ||||
| default: | default: | ||||
| throw new Exception("as_dtype Not Implemented"); | |||||
| throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||||
| } | } | ||||
| return dtype.Value; | return dtype.Value; | ||||
| @@ -106,16 +173,7 @@ namespace Tensorflow | |||||
| public static DataType as_datatype_enum(this TF_DataType type) | public static DataType as_datatype_enum(this TF_DataType type) | ||||
| { | { | ||||
| DataType dtype = DataType.DtInvalid; | |||||
| switch (type) | |||||
| { | |||||
| default: | |||||
| Enum.TryParse(((int)type).ToString(), out dtype); | |||||
| break; | |||||
| } | |||||
| return dtype; | |||||
| return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid; | |||||
| } | } | ||||
| public static TF_DataType as_base_dtype(this TF_DataType type) | public static TF_DataType as_base_dtype(this TF_DataType type) | ||||
| @@ -132,7 +190,7 @@ namespace Tensorflow | |||||
| public static Type as_numpy_dtype(this DataType type) | public static Type as_numpy_dtype(this DataType type) | ||||
| { | { | ||||
| return type.as_tf_dtype().as_numpy_datatype(); | |||||
| return type.as_tf_dtype().as_numpy_dtype(); | |||||
| } | } | ||||
| public static DataType as_base_dtype(this DataType type) | public static DataType as_base_dtype(this DataType type) | ||||
| @@ -144,16 +202,7 @@ namespace Tensorflow | |||||
| public static TF_DataType as_tf_dtype(this DataType type) | public static TF_DataType as_tf_dtype(this DataType type) | ||||
| { | { | ||||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||||
| switch (type) | |||||
| { | |||||
| default: | |||||
| Enum.TryParse(((int)type).ToString(), out dtype); | |||||
| break; | |||||
| } | |||||
| return dtype; | |||||
| return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid; | |||||
| } | } | ||||
| public static TF_DataType as_ref(this TF_DataType type) | public static TF_DataType as_ref(this TF_DataType type) | ||||
| @@ -17,6 +17,7 @@ | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using NumSharp.Utilities; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -109,7 +110,7 @@ namespace Tensorflow | |||||
| // We first convert value to a numpy array or scalar. | // We first convert value to a numpy array or scalar. | ||||
| NDArray nparray = null; | NDArray nparray = null; | ||||
| var np_dt = dtype.as_numpy_datatype(); | |||||
| var np_dt = dtype.as_numpy_dtype(); | |||||
| if (values is NDArray nd) | if (values is NDArray nd) | ||||
| { | { | ||||
| @@ -188,37 +189,37 @@ namespace Tensorflow | |||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((int[])values, np_dt); | nparray = np.array((int[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToInt32(values); | |||||
| nparray = Converts.ToInt32(values); | |||||
| break; | break; | ||||
| case "Int64": | case "Int64": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((int[])values, np_dt); | nparray = np.array((int[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToInt64(values); | |||||
| nparray = Converts.ToInt64(values); | |||||
| break; | break; | ||||
| case "Single": | case "Single": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((float[])values, np_dt); | nparray = np.array((float[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToSingle(values); | |||||
| nparray = Converts.ToSingle(values); | |||||
| break; | break; | ||||
| case "Double": | case "Double": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((double[])values, np_dt); | nparray = np.array((double[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToDouble(values); | |||||
| nparray = Converts.ToDouble(values); | |||||
| break; | break; | ||||
| case "String": | case "String": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((string[])values, np_dt); | nparray = np.array((string[])values, np_dt); | ||||
| else | else | ||||
| nparray = NDArray.FromString(Convert.ToString(values)); | |||||
| nparray = NDArray.FromString(Converts.ToString(values)); | |||||
| break; | break; | ||||
| case "Boolean": | case "Boolean": | ||||
| if (values.GetType().IsArray) | if (values.GetType().IsArray) | ||||
| nparray = np.array((bool[])values, np_dt); | nparray = np.array((bool[])values, np_dt); | ||||
| else | else | ||||
| nparray = Convert.ToBoolean(values); | |||||
| nparray = Converts.ToBoolean(values); | |||||
| break; | break; | ||||
| default: | default: | ||||
| throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); | throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); | ||||
| @@ -0,0 +1,38 @@ | |||||
| %all_dtypes = ["NDArray","Complex","Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] | |||||
| %all_dtypes_lowercase = ["NDArray","Complex","bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] | |||||
| %supported_primitives = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] | |||||
| %supported_primitives_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] | |||||
| %supported_numericals = ["Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] | |||||
| %supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"] | |||||
| %supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||||
| %supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"] | |||||
| %supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||||
| //this is the type we use in summerizing/reducting: | |||||
| %supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] | |||||
| %supported_numericals_accumulatingType_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||||
| %supported_numericals_signed = ["Int16","Int32","Int64","Double","Single"] | |||||
| %supported_numericals_signed_lowercase = ["short","int","long","double","float"] | |||||
| %supported_numericals_signed_defaultvals = ["0","0","0L","0d","0f"] | |||||
| %supported_numericals_signed_onevales = ["1","1","1L","1d","1f"] | |||||
| %supported_numericals_unsigned = ["Byte","UInt16","UInt32","UInt64","Char"] | |||||
| %supported_numericals_unsigned_lowercase = ["byte","ushort","uint","ulong","char"] | |||||
| %supported_numericals_unsigned_defaultvals = ["0","0","0U","0UL","'\0'"] | |||||
| %supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"] | |||||
| %supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] | |||||
| %supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||||
| %supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"] | |||||
| %supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||||
| %supported_dtypes_onevales = [true,"1","1","1","1","1u","1L","1UL","'\1'","1d","1f"] | |||||
| %supported_dtypes_dtype = ["bool","uint8","int16","uint16","int32","uint32","int64","uint64","uint8","float64","float32"] | |||||
| //this is the type we use in summerizing/reducting: | |||||
| %supported_dtypes_accumulatingType = ["Int32","UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] | |||||
| %supported_dtypes_accumulatingType_defaultvals = [false, "0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||||
| @@ -29,55 +29,111 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class GraphKeys | public class GraphKeys | ||||
| { | { | ||||
| #region const | |||||
| /// <summary> | |||||
| /// the subset of `Variable` objects that will be trained by an optimizer. | |||||
| /// </summary> | |||||
| public const string TRAINABLE_VARIABLES_ = "trainable_variables"; | |||||
| /// <summary> | |||||
| /// Trainable resource-style variables. | |||||
| /// </summary> | |||||
| public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables"; | |||||
| /// <summary> | |||||
| /// Key for streaming model ports. | |||||
| /// </summary> | |||||
| public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports"; | |||||
| /// <summary> | |||||
| /// Key to collect losses | |||||
| /// </summary> | |||||
| public const string LOSSES_ = "losses"; | |||||
| /// <summary> | |||||
| /// Key to collect Variable objects that are global (shared across machines). | |||||
| /// Default collection for all variables, except local ones. | |||||
| /// </summary> | |||||
| public const string GLOBAL_VARIABLES_ = "variables"; | |||||
| public const string TRAIN_OP_ = "train_op"; | |||||
| public const string GLOBAL_STEP_ = "global_step"; | |||||
| public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" }; | |||||
| /// <summary> | |||||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||||
| /// </summary> | |||||
| public const string SAVEABLE_OBJECTS_ = "saveable_objects"; | |||||
| /// <summary> | |||||
| /// Key to collect update_ops | |||||
| /// </summary> | |||||
| public const string UPDATE_OPS_ = "update_ops"; | |||||
| // Key to collect summaries. | |||||
| public const string SUMMARIES_ = "summaries"; | |||||
| // Used to store v2 summary names. | |||||
| public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2"; | |||||
| // Key for control flow context. | |||||
| public const string COND_CONTEXT_ = "cond_context"; | |||||
| public const string WHILE_CONTEXT_ = "while_context"; | |||||
| #endregion | |||||
| /// <summary> | /// <summary> | ||||
| /// the subset of `Variable` objects that will be trained by an optimizer. | /// the subset of `Variable` objects that will be trained by an optimizer. | ||||
| /// </summary> | /// </summary> | ||||
| public string TRAINABLE_VARIABLES = "trainable_variables"; | |||||
| public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Trainable resource-style variables. | /// Trainable resource-style variables. | ||||
| /// </summary> | /// </summary> | ||||
| public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; | |||||
| public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key for streaming model ports. | /// Key for streaming model ports. | ||||
| /// </summary> | /// </summary> | ||||
| public string _STREAMING_MODEL_PORTS = "streaming_model_ports"; | |||||
| public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect losses | /// Key to collect losses | ||||
| /// </summary> | /// </summary> | ||||
| public string LOSSES = "losses"; | |||||
| public string LOSSES => LOSSES_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect Variable objects that are global (shared across machines). | /// Key to collect Variable objects that are global (shared across machines). | ||||
| /// Default collection for all variables, except local ones. | /// Default collection for all variables, except local ones. | ||||
| /// </summary> | /// </summary> | ||||
| public string GLOBAL_VARIABLES = "variables"; | |||||
| public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; | |||||
| public string TRAIN_OP = "train_op"; | |||||
| public string TRAIN_OP => TRAIN_OP_; | |||||
| public string GLOBAL_STEP = "global_step"; | |||||
| public string GLOBAL_STEP => GLOBAL_STEP_; | |||||
| public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; | |||||
| public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | ||||
| /// </summary> | /// </summary> | ||||
| public string SAVEABLE_OBJECTS = "saveable_objects"; | |||||
| public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_; | |||||
| /// <summary> | /// <summary> | ||||
| /// Key to collect update_ops | /// Key to collect update_ops | ||||
| /// </summary> | /// </summary> | ||||
| public string UPDATE_OPS = "update_ops"; | |||||
| public string UPDATE_OPS => UPDATE_OPS_; | |||||
| // Key to collect summaries. | // Key to collect summaries. | ||||
| public string SUMMARIES = "summaries"; | |||||
| public string SUMMARIES => SUMMARIES_; | |||||
| // Used to store v2 summary names. | // Used to store v2 summary names. | ||||
| public string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||||
| public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_; | |||||
| // Key for control flow context. | // Key for control flow context. | ||||
| public string COND_CONTEXT = "cond_context"; | |||||
| public string WHILE_CONTEXT = "while_context"; | |||||
| public string COND_CONTEXT => COND_CONTEXT_; | |||||
| public string WHILE_CONTEXT => WHILE_CONTEXT_; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -6,6 +6,14 @@ | |||||
| <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | <GeneratePackageOnBuild>false</GeneratePackageOnBuild> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||||
| <OutputPath>bin\debug-gpu</OutputPath> | |||||
| </PropertyGroup> | |||||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||||
| <OutputPath>bin\release-gpu</OutputPath> | |||||
| </PropertyGroup> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Colorful.Console" Version="1.2.9" /> | <PackageReference Include="Colorful.Console" Version="1.2.9" /> | ||||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | <PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | ||||
| @@ -98,9 +98,9 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| var result = sess.run(tensor); | var result = sess.run(tensor); | ||||
| Assert.AreEqual(result[0].shape[0], 3); | |||||
| Assert.AreEqual(result[0].shape[1], 2); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result[0].Data<int>())); | |||||
| Assert.AreEqual(result.shape[0], 3); | |||||
| Assert.AreEqual(result.shape[1], 2); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data<int>())); | |||||
| } | } | ||||
| // big size | // big size | ||||
| @@ -109,13 +109,13 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| var result = sess.run(tensor); | var result = sess.run(tensor); | ||||
| Assert.AreEqual(result[0].shape[0], 200); | |||||
| Assert.AreEqual(result[0].shape[1], 100); | |||||
| Assert.AreEqual(result.shape[0], 200); | |||||
| Assert.AreEqual(result.shape[1], 100); | |||||
| var data = result[0].Data<int>(); | |||||
| var data = result.Data<int>(); | |||||
| Assert.AreEqual(0, data[0]); | Assert.AreEqual(0, data[0]); | ||||
| Assert.AreEqual(0, data[500]); | Assert.AreEqual(0, data[500]); | ||||
| Assert.AreEqual(0, data[result[0].size - 1]); | |||||
| Assert.AreEqual(0, data[result.size - 1]); | |||||
| } | } | ||||
| } | } | ||||
| @@ -127,9 +127,9 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| var result = sess.run(ones); | var result = sess.run(ones); | ||||
| Assert.AreEqual(result[0].shape[0], 3); | |||||
| Assert.AreEqual(result[0].shape[1], 2); | |||||
| Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result[0].Data<int>())); | |||||
| Assert.AreEqual(result.shape[0], 3); | |||||
| Assert.AreEqual(result.shape[1], 2); | |||||
| Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result.Data<int>())); | |||||
| } | } | ||||
| } | } | ||||
| @@ -142,9 +142,9 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| var result = sess.run(halfes); | var result = sess.run(halfes); | ||||
| Assert.AreEqual(result[0].shape[0], 3); | |||||
| Assert.AreEqual(result[0].shape[1], 2); | |||||
| Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result[0].Data<double>())); | |||||
| Assert.AreEqual(result.shape[0], 3); | |||||
| Assert.AreEqual(result.shape[1], 2); | |||||
| Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result.Data<double>())); | |||||
| } | } | ||||
| } | } | ||||
| @@ -161,10 +161,10 @@ namespace TensorFlowNET.UnitTest | |||||
| using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
| { | { | ||||
| var result = sess.run(tensor); | var result = sess.run(tensor); | ||||
| var data = result[0].Data<int>(); | |||||
| var data = result.Data<int>(); | |||||
| Assert.AreEqual(result[0].shape[0], 2); | |||||
| Assert.AreEqual(result[0].shape[1], 3); | |||||
| Assert.AreEqual(result.shape[0], 2); | |||||
| Assert.AreEqual(result.shape[1], 3); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); | Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -177,7 +177,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var c = a * b; | var c = a * b; | ||||
| var sess = tf.Session(); | var sess = tf.Session(); | ||||
| double result = sess.run(c)[0]; | |||||
| double result = sess.run(c); | |||||
| sess.close(); | sess.close(); | ||||
| Assert.AreEqual(6.0, result); | Assert.AreEqual(6.0, result); | ||||
| @@ -41,7 +41,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var grad = tf.gradients(y, x); | var grad = tf.gradients(y, x); | ||||
| Assert.AreEqual(grad[0].name, "gradients/AddN:0"); | Assert.AreEqual(grad[0].name, "gradients/AddN:0"); | ||||
| float r = sess.run(grad[0])[0]; | |||||
| float r = sess.run(grad[0]); | |||||
| Assert.AreEqual(r, 1.4f); | Assert.AreEqual(r, 1.4f); | ||||
| } | } | ||||
| } | } | ||||
| @@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var grad = tf.gradients(y, x); | var grad = tf.gradients(y, x); | ||||
| Assert.AreEqual(grad[0].name, "gradients/AddN:0"); | Assert.AreEqual(grad[0].name, "gradients/AddN:0"); | ||||
| float r = sess.run(grad[0])[0]; | |||||
| float r = sess.run(grad[0]); | |||||
| Assert.AreEqual(r, 14.700001f); | Assert.AreEqual(r, 14.700001f); | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -94,7 +94,7 @@ namespace TensorFlowNET.UnitTest | |||||
| using (var sess = tf.Session(graph)) | using (var sess = tf.Session(graph)) | ||||
| { | { | ||||
| var r = sess.run(slice)[0]; | |||||
| var r = sess.run(slice); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 })); | Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 })); | ||||
| Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData<int>(), new[] { 11, 13 })); | Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData<int>(), new[] { 11, 13 })); | ||||
| @@ -17,7 +17,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| var result = sess.run(y, | var result = sess.run(y, | ||||
| new FeedItem(x, 2)); | new FeedItem(x, 2)); | ||||
| Assert.AreEqual((int)result[0], 6); | |||||
| Assert.AreEqual((int)result, 6); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | ||||
| EXPECT_EQ(0, outTensor.NDims); | EXPECT_EQ(0, outTensor.NDims); | ||||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | ||||
| var output_contents = outTensor.Data<int>(); | |||||
| var output_contents = outTensor.ToArray<int>(); | |||||
| EXPECT_EQ(3 + 2, output_contents[0]); | EXPECT_EQ(3 + 2, output_contents[0]); | ||||
| // Add another operation to the graph. | // Add another operation to the graph. | ||||
| @@ -66,7 +66,7 @@ namespace TensorFlowNET.UnitTest | |||||
| EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | ||||
| EXPECT_EQ(0, outTensor.NDims); // scalar | EXPECT_EQ(0, outTensor.NDims); // scalar | ||||
| ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | ||||
| output_contents = outTensor.Data<int>(); | |||||
| output_contents = outTensor.ToArray<int>(); | |||||
| EXPECT_EQ(-(7 + 2), output_contents[0]); | EXPECT_EQ(-(7 + 2), output_contents[0]); | ||||
| // Clean up | // Clean up | ||||
| @@ -112,7 +112,7 @@ namespace TensorFlowNET.UnitTest | |||||
| var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | ||||
| var tensor = new Tensor(nd); | var tensor = new Tensor(nd); | ||||
| var array = tensor.Data<float>(); | |||||
| var array = tensor.ToArray<float>(); | |||||
| EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | ||||
| EXPECT_EQ(tensor.rank, nd.ndim); | EXPECT_EQ(tensor.rank, nd.ndim); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using NumSharp; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -16,7 +17,7 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| session.run(x.initializer); | session.run(x.initializer); | ||||
| var result = session.run(x); | var result = session.run(x); | ||||
| Assert.AreEqual(10, (int)result[0]); | |||||
| Assert.AreEqual(10, (int)result); | |||||
| } | } | ||||
| } | } | ||||
| @@ -81,7 +82,7 @@ namespace TensorFlowNET.UnitTest | |||||
| using (var session = tf.Session()) | using (var session = tf.Session()) | ||||
| { | { | ||||
| session.run(model); | session.run(model); | ||||
| int result = session.run(y)[0]; | |||||
| int result = session.run(y); | |||||
| Assert.AreEqual(result, 4); | Assert.AreEqual(result, 4); | ||||
| } | } | ||||
| } | } | ||||
| @@ -97,12 +98,12 @@ namespace TensorFlowNET.UnitTest | |||||
| var sess = tf.Session(graph); | var sess = tf.Session(graph); | ||||
| sess.run(init); | sess.run(init); | ||||
| var result = sess.run(variable); | |||||
| Assert.IsTrue((int)result[0] == 31); | |||||
| NDArray result = sess.run(variable); | |||||
| Assert.IsTrue((int)result == 31); | |||||
| var assign = variable.assign(12); | var assign = variable.assign(12); | ||||
| result = sess.run(assign); | result = sess.run(assign); | ||||
| Assert.IsTrue((int)result[0] == 12); | |||||
| Assert.IsTrue((int)result == 12); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||
| @@ -139,7 +140,7 @@ namespace TensorFlowNET.UnitTest | |||||
| for(int i = 0; i < 5; i++) | for(int i = 0; i < 5; i++) | ||||
| { | { | ||||
| x = x + 1; | x = x + 1; | ||||
| result = session.run(x)[0]; | |||||
| result = session.run(x); | |||||
| print(result); | print(result); | ||||
| } | } | ||||
| } | } | ||||
| @@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.nn_test | |||||
| var y_np = this._ZeroFraction(x_np); | var y_np = this._ZeroFraction(x_np); | ||||
| var x_tf = constant_op.constant(x_np); | var x_tf = constant_op.constant(x_np); | ||||
| x_tf.SetShape(x_shape); | |||||
| x_tf.set_shape(x_shape); | |||||
| var y_tf = nn_impl.zero_fraction(x_tf); | var y_tf = nn_impl.zero_fraction(x_tf); | ||||
| var y_tf_np = self.evaluate<NDArray>(y_tf); | var y_tf_np = self.evaluate<NDArray>(y_tf); | ||||